import org.apache.spark.sql.catalyst.optimizer.DecimalAggregates
val da = DecimalAggregates(spark.sessionState.conf)
// Build analyzed logical plan
// with sum aggregate function and Decimal field
import org.apache.spark.sql.types.DecimalType
val query = spark.range(5).select(sum($"id" cast DecimalType(1,0)) as "sum")
scala> val plan = query.queryExecution.analyzed
plan: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan =
Aggregate [sum(cast(id#91L as decimal(1,0))) AS sum#95]
+- Range (0, 5, step=1, splits=Some(8))
// Apply DecimalAggregates rule
// Note MakeDecimal and UnscaledValue operators
scala> da.apply(plan)
res27: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan =
Aggregate [MakeDecimal(sum(UnscaledValue(cast(id#91L as decimal(1,0)))),11,0) AS sum#95]
+- Range (0, 5, step=1, splits=Some(8))
DecimalAggregates Logical Plan Optimization
DecimalAggregates
is a logical optimization rule in Optimizer that transforms Sum
and Average
aggregate functions on fixed-precision DecimalType
values to use UnscaledValue
(unscaled Long) values in WindowExpression and AggregateExpression expressions.
DecimalAggregates
is the only optimization in Decimal Optimizations fixed-point batch of rules in Optimizer
.
Tip
|
Import |
Example: sum Aggregate Function on Decimal with Precision Smaller Than 9
// sum aggregate with Decimal field with precision <= 8
val q = "SELECT sum(cast(id AS DECIMAL(5,0))) FROM range(1)"
scala> sql(q).explain(true)
== Parsed Logical Plan ==
'Project [unresolvedalias('sum(cast('id as decimal(5,0))), None)]
+- 'UnresolvedTableValuedFunction range, [1]
== Analyzed Logical Plan ==
sum(CAST(id AS DECIMAL(5,0))): decimal(15,0)
Aggregate [sum(cast(id#104L as decimal(5,0))) AS sum(CAST(id AS DECIMAL(5,0)))#106]
+- Range (0, 1, step=1, splits=None)
== Optimized Logical Plan ==
Aggregate [MakeDecimal(sum(UnscaledValue(cast(id#104L as decimal(5,0)))),15,0) AS sum(CAST(id AS DECIMAL(5,0)))#106]
+- Range (0, 1, step=1, splits=None)
== Physical Plan ==
*HashAggregate(keys=[], functions=[sum(UnscaledValue(cast(id#104L as decimal(5,0))))], output=[sum(CAST(id AS DECIMAL(5,0)))#106])
+- Exchange SinglePartition
+- *HashAggregate(keys=[], functions=[partial_sum(UnscaledValue(cast(id#104L as decimal(5,0))))], output=[sum#108L])
+- *Range (0, 1, step=1, splits=None)
Example: avg Aggregate Function on Decimal with Precision Smaller Than 12
// avg aggregate with Decimal field with precision <= 11
val q = "SELECT avg(cast(id AS DECIMAL(10,0))) FROM range(1)"
scala> val q = "SELECT avg(cast(id AS DECIMAL(10,0))) FROM range(1)"
q: String = SELECT avg(cast(id AS DECIMAL(10,0))) FROM range(1)
scala> sql(q).explain(true)
== Parsed Logical Plan ==
'Project [unresolvedalias('avg(cast('id as decimal(10,0))), None)]
+- 'UnresolvedTableValuedFunction range, [1]
== Analyzed Logical Plan ==
avg(CAST(id AS DECIMAL(10,0))): decimal(14,4)
Aggregate [avg(cast(id#115L as decimal(10,0))) AS avg(CAST(id AS DECIMAL(10,0)))#117]
+- Range (0, 1, step=1, splits=None)
== Optimized Logical Plan ==
Aggregate [cast((avg(UnscaledValue(cast(id#115L as decimal(10,0)))) / 1.0) as decimal(14,4)) AS avg(CAST(id AS DECIMAL(10,0)))#117]
+- Range (0, 1, step=1, splits=None)
== Physical Plan ==
*HashAggregate(keys=[], functions=[avg(UnscaledValue(cast(id#115L as decimal(10,0))))], output=[avg(CAST(id AS DECIMAL(10,0)))#117])
+- Exchange SinglePartition
+- *HashAggregate(keys=[], functions=[partial_avg(UnscaledValue(cast(id#115L as decimal(10,0))))], output=[sum#120, count#121L])
+- *Range (0, 1, step=1, splits=None)