Computing Group Percentages in PySpark: Four Approaches

Python
PySpark
Data Engineering
Four ways to compute the share of a total within each group in PySpark — from the naive collect-to-driver pattern to a dangerously wrong window function and a broadcast cross-join.
Author

Marina Varfolomeeva

Published

June 5, 2026

Computing the percentage of a total that each group is a common task. In pandas it is a one-liner: df["count"] / df["count"].sum(). In PySpark it is harder, because the total is not available locally — it lives across executor nodes — and different approaches have very different performance and safety characteristics.

We use the Olist orders table as a running example: how many orders are in each status, and what percentage of all orders does each status represent? For dataset setup, see Customer Analytics with Olist, Part 1: Data Setup.

Setup

import os
import time
from contextlib import contextmanager
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from dotenv import load_dotenv

load_dotenv()
OLIST_DIR = os.path.join(os.environ["DATA_DIR"], "olist")

spark = (
    SparkSession.builder
    .appName("pyspark-group-percentages")
    .master("local[*]")
    .config("spark.sql.session.timeZone", "UTC")
    .getOrCreate()
)
spark.sparkContext.setLogLevel("WARN")

orders = spark.read.csv(
    os.path.join(OLIST_DIR, "olist_orders_dataset.csv"),
    header=True,
    inferSchema=True,
)

@contextmanager
def timer(label):
    t = time.perf_counter()
    yield
    print(f"{label}: {time.perf_counter() - t:.2f}s")

Approach 1: Collect a scalar to the driver

The simplest approach: trigger one action to get the total as a Python integer, then use it as a literal in a subsequent transformation.

n_orders = orders.count()

s1 = (
    orders
    .groupBy("order_status")
    .count()
    .withColumn("pct", F.round((F.col("count") / n_orders) * 100, 3))
)

s1.show()
+------------+-----+-----+
|order_status|count|  pct|
+------------+-----+-----+
|     shipped| 1107|1.113|
|    canceled|  625|0.629|
|    invoiced|  314|0.316|
|     created|    5|0.005|
|   delivered|96478|97.02|
| unavailable|  609|0.612|
|  processing|  301|0.303|
|    approved|    2|0.002|
+------------+-----+-----+

The data remain distributed across the cluster throughout.

  1. orders.count() is an action — it computes the total across all nodes and sends one integer to the driver.
  2. The driver broadcasts that integer back to all worker nodes as a literal constant.
  3. .withColumn() is a transformation — it adds the division step to the execution plan without running anything yet.
  4. .show() is an action — it executes the plan in parallel across all partitions.

This approach makes two full passes over the data. For Olist-sized data that is fine. For a multi-terabyte table, the first count() alone can be expensive.

s1.explain(mode="formatted")
== Physical Plan ==
AdaptiveSparkPlan (6)
+- Project (5)
   +- HashAggregate (4)
      +- Exchange (3)
         +- HashAggregate (2)
            +- Scan csv  (1)


(1) Scan csv 
Output [1]: [order_status#19]
Batched: false
Location: InMemoryFileIndex [file:/home/varmara/80_websites/varmara-data/olist/olist_orders_dataset.csv]
ReadSchema: struct<order_status:string>

(2) HashAggregate
Input [1]: [order_status#19]
Keys [1]: [order_status#19]
Functions [1]: [partial_count(1)]
Aggregate Attributes [1]: [count#52L]
Results [2]: [order_status#19, count#53L]

(3) Exchange
Input [2]: [order_status#19, count#53L]
Arguments: hashpartitioning(order_status#19, 200), ENSURE_REQUIREMENTS, [plan_id=122]

(4) HashAggregate
Input [2]: [order_status#19, count#53L]
Keys [1]: [order_status#19]
Functions [1]: [count(1)]
Aggregate Attributes [1]: [count(1)#46L]
Results [2]: [order_status#19, count(1)#46L AS count#37L]

(5) Project
Output [3]: [order_status#19, count#37L, round(((cast(count#37L as double) / 99441.0) * 100.0), 3) AS pct#48]
Input [2]: [order_status#19, count#37L]

(6) AdaptiveSparkPlan
Output [3]: [order_status#19, count#37L, pct#48]
Arguments: isFinalPlan=false

The plan shows a two-stage HashAggregate with an intervening Exchange (shuffle) to compute the groupBy categories. However, because the divisor is evaluated upfront on the driver, it is baked into the final Project node as a literal constant. This means the percentage division itself adds zero additional shuffles and executes entirely locally on the post-shuffle partitions.

Approach 2: Window over the whole dataset

A window function can compute the total without collecting to the driver. The intuition is appealing: define a window over the entire table, sum the counts within it, and divide. The entire chain up to .withColumn() is a sequence of transformations that build one execution plan. .show() is the action that triggers it.

Never use this pattern.

s2 = (
    orders
    .groupBy("order_status")
    .count()
    .withColumn(
        "pct",
        F.round((F.col("count") / F.sum("count").over(Window.partitionBy())) * 100, 3)
    )
)

s2.show()
26/06/05 15:54:54 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/06/05 15:54:54 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/06/05 15:54:54 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/06/05 15:54:54 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/06/05 15:54:54 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/06/05 15:54:54 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/06/05 15:54:54 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
+------------+-----+-----+
|order_status|count|  pct|
+------------+-----+-----+
|     shipped| 1107|1.113|
|    canceled|  625|0.629|
|    invoiced|  314|0.316|
|     created|    5|0.005|
|   delivered|96478|97.02|
| unavailable|  609|0.612|
|  processing|  301|0.303|
|    approved|    2|0.002|
+------------+-----+-----+
s2.explain(mode="formatted")
== Physical Plan ==
AdaptiveSparkPlan (8)
+- Project (7)
   +- Window (6)
      +- Exchange (5)
         +- HashAggregate (4)
            +- Exchange (3)
               +- HashAggregate (2)
                  +- Scan csv  (1)


(1) Scan csv 
Output [1]: [order_status#19]
Batched: false
Location: InMemoryFileIndex [file:/home/varmara/80_websites/varmara-data/olist/olist_orders_dataset.csv]
ReadSchema: struct<order_status:string>

(2) HashAggregate
Input [1]: [order_status#19]
Keys [1]: [order_status#19]
Functions [1]: [partial_count(1)]
Aggregate Attributes [1]: [count#78L]
Results [2]: [order_status#19, count#79L]

(3) Exchange
Input [2]: [order_status#19, count#79L]
Arguments: hashpartitioning(order_status#19, 200), ENSURE_REQUIREMENTS, [plan_id=241]

(4) HashAggregate
Input [2]: [order_status#19, count#79L]
Keys [1]: [order_status#19]
Functions [1]: [count(1)]
Aggregate Attributes [1]: [count(1)#71L]
Results [2]: [order_status#19, count(1)#71L AS count#62L]

(5) Exchange
Input [2]: [order_status#19, count#62L]
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [plan_id=244]

(6) Window
Input [2]: [order_status#19, count#62L]
Arguments: [sum(count#62L) windowspecdefinition(specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#74L]

(7) Project
Output [3]: [order_status#19, count#62L, round(((cast(count#62L as double) / cast(_we0#74L as double)) * 100.0), 3) AS pct#72]
Input [3]: [order_status#19, count#62L, _we0#74L]

(8) AdaptiveSparkPlan
Output [3]: [order_status#19, count#62L, pct#72]
Arguments: isFinalPlan=false

26/06/05 15:54:54 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/06/05 15:54:54 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/06/05 15:54:54 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.

The groupBy Shuffle: Spark reads the data, computes local partial counts, and shuffles the rows across the network (Exchange) to group identical statuses together on different nodes, calculating the final count for each status.

The Single-Partition Shuffle: Because the window function is empty (Window.partitionBy()), Spark triggers a second network shuffle to pull all those status counts off their distributed nodes and force them onto a single worker node (SinglePartition).

Finally, the local sum & division happen: That single node calculates the grand total sequentially, appends it as a new column, and runs the final percentage division.

Forcing all data onto one node destroys parallelism and frequently causes Out-Of-Memory (OOM) errors on large datasets. On a small table this produces a correct result, which is what makes it a trap.

Approach 3: Cross-join with a single-row DataFrame

This is the Aggregate Cross-Join Pattern, also known as the Scalar Join. We keep everything as DataFrames and avoid collecting to the driver.

Three steps:

  1. Compute a one-row DataFrame holding the total. orders.count() returns a Python scalar — we use orders.select(F.count("*")) to stay in the DataFrame world.
  2. Compute the per-group counts with groupBy.
  3. Cross-join the two DataFrames together.
n_orders_sdf = orders.select(F.count("*").alias("total_count"))

counts_by_cat = (
    orders
    .groupBy("order_status")
    .count()
)

s3 = (
    counts_by_cat
    .crossJoin(n_orders_sdf)
    .withColumn("pct", F.round((F.col("count") / F.col("total_count")) * 100, 3))
)

s3.show()
+------------+-----+-----------+-----+
|order_status|count|total_count|  pct|
+------------+-----+-----------+-----+
|     shipped| 1107|      99441|1.113|
|    canceled|  625|      99441|0.629|
|    invoiced|  314|      99441|0.316|
|     created|    5|      99441|0.005|
|   delivered|96478|      99441|97.02|
| unavailable|  609|      99441|0.612|
|  processing|  301|      99441|0.303|
|    approved|    2|      99441|0.002|
+------------+-----+-----------+-----+
s3.explain(mode="formatted")
== Physical Plan ==
AdaptiveSparkPlan (12)
+- Project (11)
   +- BroadcastNestedLoopJoin Cross BuildRight (10)
      :- HashAggregate (4)
      :  +- Exchange (3)
      :     +- HashAggregate (2)
      :        +- Scan csv  (1)
      +- BroadcastExchange (9)
         +- HashAggregate (8)
            +- Exchange (7)
               +- HashAggregate (6)
                  +- Scan csv  (5)


(1) Scan csv 
Output [1]: [order_status#19]
Batched: false
Location: InMemoryFileIndex [file:/home/varmara/80_websites/varmara-data/olist/olist_orders_dataset.csv]
ReadSchema: struct<order_status:string>

(2) HashAggregate
Input [1]: [order_status#19]
Keys [1]: [order_status#19]
Functions [1]: [partial_count(1)]
Aggregate Attributes [1]: [count#114L]
Results [2]: [order_status#19, count#115L]

(3) Exchange
Input [2]: [order_status#19, count#115L]
Arguments: hashpartitioning(order_status#19, 200), ENSURE_REQUIREMENTS, [plan_id=439]

(4) HashAggregate
Input [2]: [order_status#19, count#115L]
Keys [1]: [order_status#19]
Functions [1]: [count(1)]
Aggregate Attributes [1]: [count(1)#100L]
Results [2]: [order_status#19, count(1)#100L AS count#91L]

(5) Scan csv 
Output: []
Batched: false
Location: InMemoryFileIndex [file:/home/varmara/80_websites/varmara-data/olist/olist_orders_dataset.csv]
ReadSchema: struct<>

(6) HashAggregate
Input: []
Keys: []
Functions [1]: [partial_count(1)]
Aggregate Attributes [1]: [count#116L]
Results [1]: [count#117L]

(7) Exchange
Input [1]: [count#117L]
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [plan_id=441]

(8) HashAggregate
Input [1]: [count#117L]
Keys: []
Functions [1]: [count(1)]
Aggregate Attributes [1]: [count(1)#90L]
Results [1]: [count(1)#90L AS total_count#89L]

(9) BroadcastExchange
Input [1]: [total_count#89L]
Arguments: IdentityBroadcastMode, [plan_id=444]

(10) BroadcastNestedLoopJoin
Join type: Cross
Join condition: None

(11) Project
Output [4]: [order_status#19, count#91L, total_count#89L, round(((cast(count#91L as double) / cast(total_count#89L as double)) * 100.0), 3) AS pct#109]
Input [3]: [order_status#19, count#91L, total_count#89L]

(12) AdaptiveSparkPlan
Output [4]: [order_status#19, count#91L, total_count#89L, pct#109]
Arguments: isFinalPlan=false

The plan shows BroadcastNestedLoopJoin. Spark recognises that n_orders_sdf produces exactly one row and automatically broadcasts it to every executor. Each node receives the total and completes the division locally — no data moves between executors.

This is the recommended pattern for large data: the computation stays fully distributed, and Spark’s optimiser handles the broadcast without explicit hints.

Approach 4: Conditional aggregation

This is the Conditional Aggregation Pattern, also known as Pivoted Aggregation. We avoid the groupBy entirely by scanning the dataset once and using conditional logic to increment a counter for each status type in parallel.

s4_static = (
    orders
    .agg(
        F.count(F.when(F.col("order_status") == "shipped", 1)).alias("shipped_count"),
        F.count(F.when(F.col("order_status") == "canceled", 1)).alias("canceled_count"),
        F.count(F.when(F.col("order_status") == "invoiced", 1)).alias("invoiced_count"),
        F.count(F.when(F.col("order_status") == "created", 1)).alias("created_count"),
        F.count(F.when(F.col("order_status") == "delivered", 1)).alias("delivered_count"),
        F.count(F.when(F.col("order_status") == "unavailable", 1)).alias("unavailable_count"),
        F.count(F.when(F.col("order_status") == "processing", 1)).alias("processing_count"),
        F.count(F.when(F.col("order_status") == "approved", 1)).alias("approved_count"),
        F.count("*").alias("total_count"),
    )
    .withColumn("shipped_pct", F.round((F.col("shipped_count") / F.col("total_count")) * 100, 3))
    .withColumn("canceled_pct", F.round((F.col("canceled_count") / F.col("total_count")) * 100, 3))
    .withColumn("invoiced_pct", F.round((F.col("invoiced_count") / F.col("total_count")) * 100, 3))
    .withColumn("created_pct", F.round((F.col("created_count") / F.col("total_count")) * 100, 3))
    .withColumn("delivered_pct", F.round((F.col("delivered_count") / F.col("total_count")) * 100, 3))
    .withColumn("unavailable_pct", F.round((F.col("unavailable_count") / F.col("total_count")) * 100, 3))
    .withColumn("processing_pct", F.round((F.col("processing_count") / F.col("total_count")) * 100, 3))
    .withColumn("approved_pct", F.round((F.col("approved_count") / F.col("total_count")) * 100, 3))

)

s4_static.show()
+-------------+--------------+--------------+-------------+---------------+-----------------+----------------+--------------+-----------+-----------+------------+------------+-----------+-------------+---------------+--------------+------------+
|shipped_count|canceled_count|invoiced_count|created_count|delivered_count|unavailable_count|processing_count|approved_count|total_count|shipped_pct|canceled_pct|invoiced_pct|created_pct|delivered_pct|unavailable_pct|processing_pct|approved_pct|
+-------------+--------------+--------------+-------------+---------------+-----------------+----------------+--------------+-----------+-----------+------------+------------+-----------+-------------+---------------+--------------+------------+
|         1107|           625|           314|            5|          96478|              609|             301|             2|      99441|      1.113|       0.629|       0.316|      0.005|        97.02|          0.612|         0.303|       0.002|
+-------------+--------------+--------------+-------------+---------------+-----------------+----------------+--------------+-----------+-----------+------------+------------+-----------+-------------+---------------+--------------+------------+

One scan, no groupBy, no shuffle. However it is ugly and wordy. It produces one wide dataframe. What if new statuses are added? It simply won’t scale.

We can make the status list dynamic by collecting the distinct values first, but the naïve version doesn’t handle nulls gracefully:

statuses = [row[0] for row in orders.select("order_status").distinct().collect()]

agg_exprs = [
    F.count(F.when(F.col("order_status") == status, 1)).alias(f"{status}_count")
    for status in statuses
]
agg_exprs.append(F.count("*").alias("total_count"))

s4_dynamic_buggy = orders.agg(*agg_exprs)

for status in statuses:
    s4_dynamic_buggy = s4_dynamic_buggy.withColumn(
        f"{status}_pct",
        F.round((F.col(f"{status}_count") / F.col("total_count")) * 100, 3),
    )

s4_dynamic_buggy.show()

If any order_status value is null, row[0] returns None. The list comprehension builds the alias f"None_count" and the aggregation expression F.col("order_status") == None. In Spark’s three-valued logic, == None evaluates to null rather than True, so that group is silently undercounted. The withColumn loop then tries to reference f"None_count" as a column name, which fails at runtime.

We fix this with an explicit null check:

statuses = [row[0] for row in orders.select("order_status").distinct().collect()]

agg_exprs = [
    F.count(F.when(F.col("order_status") == s, 1)).alias(f"{s}_count")
    for s in statuses
    if s is not None
]
if None in statuses:
    agg_exprs.append(F.count(F.when(F.col("order_status").isNull(), 1)).alias("null_count"))
agg_exprs.append(F.count("*").alias("total_count"))

s4_dynamic = orders.agg(*agg_exprs)

for s in statuses:
    col_name = "null" if s is None else s
    s4_dynamic = s4_dynamic.withColumn(
        f"{col_name}_pct",
        F.round((F.col(f"{col_name}_count") / F.col("total_count")) * 100, 3),
    )

s4_dynamic.show()
+-------------+--------------+--------------+-------------+---------------+-----------------+----------------+--------------+-----------+-----------+------------+------------+-----------+-------------+---------------+--------------+------------+
|shipped_count|canceled_count|invoiced_count|created_count|delivered_count|unavailable_count|processing_count|approved_count|total_count|shipped_pct|canceled_pct|invoiced_pct|created_pct|delivered_pct|unavailable_pct|processing_pct|approved_pct|
+-------------+--------------+--------------+-------------+---------------+-----------------+----------------+--------------+-----------+-----------+------------+------------+-----------+-------------+---------------+--------------+------------+
|         1107|           625|           314|            5|          96478|              609|             301|             2|      99441|      1.113|       0.629|       0.316|      0.005|        97.02|          0.612|         0.303|       0.002|
+-------------+--------------+--------------+-------------+---------------+-----------------+----------------+--------------+-----------+-----------+------------+------------+-----------+-------------+---------------+--------------+------------+

The dynamic version still requires two passes — one distinct().collect() action to build the status list, then the aggregation — so the single-scan advantage of the static version is partially lost.

Benchmarking

We cache orders first so I/O does not dominate the measurement.

orders.cache()
orders.count()

with timer("approach 1 — collect scalar"):
    s1.show()

with timer("approach 2 — empty window"):
    s2.show()

with timer("approach 3 — cross-join"):
    s3.show()

with timer("approach 4a — conditional aggregation"):
    s4_static.show()

with timer("approach 4b — conditional aggregation"):
    s4_dynamic.show()
[Stage 29:===========>                                              (1 + 4) / 5]                                                                                
+------------+-----+-----+
|order_status|count|  pct|
+------------+-----+-----+
|     shipped| 1107|1.113|
|    canceled|  625|0.629|
|    invoiced|  314|0.316|
|     created|    5|0.005|
|   delivered|96478|97.02|
| unavailable|  609|0.612|
|  processing|  301|0.303|
|    approved|    2|0.002|
+------------+-----+-----+

approach 1 — collect scalar: 0.21s
26/06/05 15:54:58 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/06/05 15:54:58 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/06/05 15:54:58 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/06/05 15:54:58 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/06/05 15:54:58 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/06/05 15:54:59 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/06/05 15:54:59 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/06/05 15:54:59 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/06/05 15:54:59 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
+------------+-----+-----+
|order_status|count|  pct|
+------------+-----+-----+
|     shipped| 1107|1.113|
|    canceled|  625|0.629|
|    invoiced|  314|0.316|
|     created|    5|0.005|
|   delivered|96478|97.02|
| unavailable|  609|0.612|
|  processing|  301|0.303|
|    approved|    2|0.002|
+------------+-----+-----+

approach 2 — empty window: 0.27s
+------------+-----+-----------+-----+
|order_status|count|total_count|  pct|
+------------+-----+-----------+-----+
|     shipped| 1107|      99441|1.113|
|    canceled|  625|      99441|0.629|
|    invoiced|  314|      99441|0.316|
|     created|    5|      99441|0.005|
|   delivered|96478|      99441|97.02|
| unavailable|  609|      99441|0.612|
|  processing|  301|      99441|0.303|
|    approved|    2|      99441|0.002|
+------------+-----+-----------+-----+

approach 3 — cross-join: 0.28s
+-------------+--------------+--------------+-------------+---------------+-----------------+----------------+--------------+-----------+-----------+------------+------------+-----------+-------------+---------------+--------------+------------+
|shipped_count|canceled_count|invoiced_count|created_count|delivered_count|unavailable_count|processing_count|approved_count|total_count|shipped_pct|canceled_pct|invoiced_pct|created_pct|delivered_pct|unavailable_pct|processing_pct|approved_pct|
+-------------+--------------+--------------+-------------+---------------+-----------------+----------------+--------------+-----------+-----------+------------+------------+-----------+-------------+---------------+--------------+------------+
|         1107|           625|           314|            5|          96478|              609|             301|             2|      99441|      1.113|       0.629|       0.316|      0.005|        97.02|          0.612|         0.303|       0.002|
+-------------+--------------+--------------+-------------+---------------+-----------------+----------------+--------------+-----------+-----------+------------+------------+-----------+-------------+---------------+--------------+------------+

approach 4a — conditional aggregation: 0.21s
+-------------+--------------+--------------+-------------+---------------+-----------------+----------------+--------------+-----------+-----------+------------+------------+-----------+-------------+---------------+--------------+------------+
|shipped_count|canceled_count|invoiced_count|created_count|delivered_count|unavailable_count|processing_count|approved_count|total_count|shipped_pct|canceled_pct|invoiced_pct|created_pct|delivered_pct|unavailable_pct|processing_pct|approved_pct|
+-------------+--------------+--------------+-------------+---------------+-----------------+----------------+--------------+-----------+-----------+------------+------------+-----------+-------------+---------------+--------------+------------+
|         1107|           625|           314|            5|          96478|              609|             301|             2|      99441|      1.113|       0.629|       0.316|      0.005|        97.02|          0.612|         0.303|       0.002|
+-------------+--------------+--------------+-------------+---------------+-----------------+----------------+--------------+-----------+-----------+------------+------------+-----------+-------------+---------------+--------------+------------+

approach 4b — conditional aggregation: 0.15s

Summary

For most workloads, approach 3 is the right choice: it stays fully distributed, Spark handles the broadcast automatically, and the code reads clearly. Approach 1 is simpler and works well when the table is small enough that the initial count() is cheap. Approach 2 is a trap — destroys performance at scale.