scala> spark.range(4).map(n => n * 2).filter(n => n < 3).explain(extended = true)
== Parsed Logical Plan ==
'TypedFilter <function1>, long, [StructField(value,LongType,false)], unresolveddeserializer(upcast(getcolumnbyordinal(0, LongType), LongType, - root class: "scala.Long"))
+- SerializeFromObject [input[0, bigint, true] AS value#185L]
+- MapElements <function1>, class java.lang.Long, [StructField(value,LongType,true)], obj#184: bigint
+- DeserializeToObject newInstance(class java.lang.Long), obj#183: java.lang.Long
+- Range (0, 4, step=1, splits=Some(8))
== Analyzed Logical Plan ==
value: bigint
TypedFilter <function1>, long, [StructField(value,LongType,false)], cast(value#185L as bigint)
+- SerializeFromObject [input[0, bigint, true] AS value#185L]
+- MapElements <function1>, class java.lang.Long, [StructField(value,LongType,true)], obj#184: bigint
+- DeserializeToObject newInstance(class java.lang.Long), obj#183: java.lang.Long
+- Range (0, 4, step=1, splits=Some(8))
== Optimized Logical Plan ==
SerializeFromObject [input[0, bigint, true] AS value#185L]
+- Filter <function1>.apply
+- MapElements <function1>, class java.lang.Long, [StructField(value,LongType,true)], obj#184: bigint
+- DeserializeToObject newInstance(class java.lang.Long), obj#183: java.lang.Long
+- Range (0, 4, step=1, splits=Some(8))
== Physical Plan ==
*SerializeFromObject [input[0, bigint, true] AS value#185L]
+- *Filter <function1>.apply
+- *MapElements <function1>, obj#184: bigint
+- *DeserializeToObject newInstance(class java.lang.Long), obj#183: java.lang.Long
+- *Range (0, 4, step=1, splits=Some(8))
EliminateSerialization Logical Plan Optimization
EliminateSerialization
is a optimization rule for a logical plan in SparkOptimizer.
EliminateSerialization
optimizes logical plans with DeserializeToObject (after SerializeFromObject
or TypedFilter
), AppendColumns
(after SerializeFromObject
), TypedFilter
(after SerializeFromObject
) logical operators.
Examples include:
Example — map
followed by filter
Logical Plan
Example — map
followed by another map
Logical Plan
// Notice unnecessary mapping between String and Int types
val query = spark.range(3).map(_.toString).map(_.toInt)
scala> query.explain(extended = true)
...
TRACE SparkOptimizer:
=== Applying Rule org.apache.spark.sql.catalyst.optimizer.EliminateSerialization ===
SerializeFromObject [input[0, int, true] AS value#91] SerializeFromObject [input[0, int, true] AS value#91]
+- MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#90: int +- MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#90: int
! +- DeserializeToObject value#86.toString, obj#89: java.lang.String +- Project [obj#85 AS obj#89]
! +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true) AS value#86] +- MapElements <function1>, class java.lang.Long, [StructField(value,LongType,true)], obj#85: java.lang.String
! +- MapElements <function1>, class java.lang.Long, [StructField(value,LongType,true)], obj#85: java.lang.String +- DeserializeToObject newInstance(class java.lang.Long), obj#84: java.lang.Long
! +- DeserializeToObject newInstance(class java.lang.Long), obj#84: java.lang.Long +- Range (0, 3, step=1, splits=Some(8))
! +- Range (0, 3, step=1, splits=Some(8))
...
== Parsed Logical Plan ==
'SerializeFromObject [input[0, int, true] AS value#91]
+- 'MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#90: int
+- 'DeserializeToObject unresolveddeserializer(upcast(getcolumnbyordinal(0, StringType), StringType, - root class: "java.lang.String").toString), obj#89: java.lang.String
+- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true) AS value#86]
+- MapElements <function1>, class java.lang.Long, [StructField(value,LongType,true)], obj#85: java.lang.String
+- DeserializeToObject newInstance(class java.lang.Long), obj#84: java.lang.Long
+- Range (0, 3, step=1, splits=Some(8))
== Analyzed Logical Plan ==
value: int
SerializeFromObject [input[0, int, true] AS value#91]
+- MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#90: int
+- DeserializeToObject cast(value#86 as string).toString, obj#89: java.lang.String
+- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true) AS value#86]
+- MapElements <function1>, class java.lang.Long, [StructField(value,LongType,true)], obj#85: java.lang.String
+- DeserializeToObject newInstance(class java.lang.Long), obj#84: java.lang.Long
+- Range (0, 3, step=1, splits=Some(8))
== Optimized Logical Plan ==
SerializeFromObject [input[0, int, true] AS value#91]
+- MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#90: int
+- MapElements <function1>, class java.lang.Long, [StructField(value,LongType,true)], obj#85: java.lang.String
+- DeserializeToObject newInstance(class java.lang.Long), obj#84: java.lang.Long
+- Range (0, 3, step=1, splits=Some(8))
== Physical Plan ==
*SerializeFromObject [input[0, int, true] AS value#91]
+- *MapElements <function1>, obj#90: int
+- *MapElements <function1>, obj#85: java.lang.String
+- *DeserializeToObject newInstance(class java.lang.Long), obj#84: java.lang.Long
+- *Range (0, 3, step=1, splits=Some(8))
Example — groupByKey
followed by agg
Logical Plan
scala> spark.range(4).map(n => (n, n % 2)).groupByKey(_._2).agg(typed.sum(_._2)).explain(true)
== Parsed Logical Plan ==
'Aggregate [value#454L], [value#454L, unresolvedalias(typedsumdouble(org.apache.spark.sql.execution.aggregate.TypedSumDouble@4fcb0de4, Some(unresolveddeserializer(newInstance(class scala.Tuple2), _1#450L, _2#451L)), Some(class scala.Tuple2), Some(StructType(StructField(_1,LongType,true), StructField(_2,LongType,false))), input[0, double, true] AS value#457, unresolveddeserializer(upcast(getcolumnbyordinal(0, DoubleType), DoubleType, - root class: "scala.Double"), value#457), input[0, double, true] AS value#456, DoubleType, DoubleType, false), Some(<function1>))]
+- AppendColumns <function1>, class scala.Tuple2, [StructField(_1,LongType,true), StructField(_2,LongType,false)], newInstance(class scala.Tuple2), [input[0, bigint, true] AS value#454L]
+- SerializeFromObject [assertnotnull(input[0, scala.Tuple2, true], top level non-flat input object)._1.longValue AS _1#450L, assertnotnull(input[0, scala.Tuple2, true], top level non-flat input object)._2 AS _2#451L]
+- MapElements <function1>, class java.lang.Long, [StructField(value,LongType,true)], obj#449: scala.Tuple2
+- DeserializeToObject newInstance(class java.lang.Long), obj#448: java.lang.Long
+- Range (0, 4, step=1, splits=Some(8))
== Analyzed Logical Plan ==
value: bigint, TypedSumDouble(scala.Tuple2): double
Aggregate [value#454L], [value#454L, typedsumdouble(org.apache.spark.sql.execution.aggregate.TypedSumDouble@4fcb0de4, Some(newInstance(class scala.Tuple2)), Some(class scala.Tuple2), Some(StructType(StructField(_1,LongType,true), StructField(_2,LongType,false))), input[0, double, true] AS value#457, cast(value#457 as double), input[0, double, true] AS value#456, DoubleType, DoubleType, false) AS TypedSumDouble(scala.Tuple2)#462]
+- AppendColumns <function1>, class scala.Tuple2, [StructField(_1,LongType,true), StructField(_2,LongType,false)], newInstance(class scala.Tuple2), [input[0, bigint, true] AS value#454L]
+- SerializeFromObject [assertnotnull(input[0, scala.Tuple2, true], top level non-flat input object)._1.longValue AS _1#450L, assertnotnull(input[0, scala.Tuple2, true], top level non-flat input object)._2 AS _2#451L]
+- MapElements <function1>, class java.lang.Long, [StructField(value,LongType,true)], obj#449: scala.Tuple2
+- DeserializeToObject newInstance(class java.lang.Long), obj#448: java.lang.Long
+- Range (0, 4, step=1, splits=Some(8))
== Optimized Logical Plan ==
Aggregate [value#454L], [value#454L, typedsumdouble(org.apache.spark.sql.execution.aggregate.TypedSumDouble@4fcb0de4, Some(newInstance(class scala.Tuple2)), Some(class scala.Tuple2), Some(StructType(StructField(_1,LongType,true), StructField(_2,LongType,false))), input[0, double, true] AS value#457, value#457, input[0, double, true] AS value#456, DoubleType, DoubleType, false) AS TypedSumDouble(scala.Tuple2)#462]
+- AppendColumnsWithObject <function1>, [assertnotnull(input[0, scala.Tuple2, true], top level non-flat input object)._1.longValue AS _1#450L, assertnotnull(input[0, scala.Tuple2, true], top level non-flat input object)._2 AS _2#451L], [input[0, bigint, true] AS value#454L]
+- MapElements <function1>, class java.lang.Long, [StructField(value,LongType,true)], obj#449: scala.Tuple2
+- DeserializeToObject newInstance(class java.lang.Long), obj#448: java.lang.Long
+- Range (0, 4, step=1, splits=Some(8))
== Physical Plan ==
*HashAggregate(keys=[value#454L], functions=[typedsumdouble(org.apache.spark.sql.execution.aggregate.TypedSumDouble@4fcb0de4, Some(newInstance(class scala.Tuple2)), Some(class scala.Tuple2), Some(StructType(StructField(_1,LongType,true), StructField(_2,LongType,false))), input[0, double, true] AS value#457, value#457, input[0, double, true] AS value#456, DoubleType, DoubleType, false)], output=[value#454L, TypedSumDouble(scala.Tuple2)#462])
+- Exchange hashpartitioning(value#454L, 200)
+- *HashAggregate(keys=[value#454L], functions=[partial_typedsumdouble(org.apache.spark.sql.execution.aggregate.TypedSumDouble@4fcb0de4, Some(newInstance(class scala.Tuple2)), Some(class scala.Tuple2), Some(StructType(StructField(_1,LongType,true), StructField(_2,LongType,false))), input[0, double, true] AS value#457, value#457, input[0, double, true] AS value#456, DoubleType, DoubleType, false)], output=[value#454L, value#463])
+- AppendColumnsWithObject <function1>, [assertnotnull(input[0, scala.Tuple2, true], top level non-flat input object)._1.longValue AS _1#450L, assertnotnull(input[0, scala.Tuple2, true], top level non-flat input object)._2 AS _2#451L], [input[0, bigint, true] AS value#454L]
+- MapElements <function1>, obj#449: scala.Tuple2
+- DeserializeToObject newInstance(class java.lang.Long), obj#448: java.lang.Long
+- *Range (0, 4, step=1, splits=Some(8))