Introduction
In an ideal world, deploying machine learning models within SQL queries would be as simple as calling a built-in function. Unfortunately, many ML predictions live inside User-Defined Functions (UDFs) that traditional SQL planners can’t modify, preventing optimizations like predicate pushdowns.
This blog post will showcase how you can prune decision tree models based on query filters by dynamically rewriting your expression using Ibis and quickgrove, an experimental GBDT inference library built in Rust. We’ll also show how LetSQL can simplify this pattern further and integrate seamlessly into your ML workflows.
ML models meet SQL
When you deploy machine learning models (like a gradient-boosted trees model from XGBoost) in a data warehouse, you typically wrap them in a UDF. Something like:
SELECT
depth, color, clarity, ...)
my_udf_predict(carat, FROM diamonds
WHERE color_i < 1 AND clarity_vvs2 < 1
The challenge is that SQL planners don’t know what’s happening inside the UDF. Even if you filter color_i < 1
, the full model, including skippable tree paths, are evaluated for every row. With tree-based models, entire branches might never be evaluated at all — so the ideal scenario is to prune those unnecessary branches before evaluating them.
Smart UDFs with Ibis
Ibis is known for letting you write engine-agnostic deferred expressions in Python without losing the power of underlying engines like Spark, DuckDB, or BigQuery. Meanwhile, quickgrove provides a mechanism to prune Gradient Boosted Decision Tree (GBDT) models based on known filter conditions.
Key Ideas:
- Prune decision trees by removing branches that can never be reached, given the known filters
- Rewrite expressions with the pruned model into the query plan to skip unnecessary computations
Understanding tree pruning
Take a simple example: a decision tree that splits on color_i < 1
. If your query also has a predicate color < 1
, any branches with feature color_i >= 1
will never be evaluated. By removing that branch, the tree becomes smaller and faster to evaluate—especially when you have hundreds of trees (as in many gradient-boosted models).
Reference: Check out the Raven optimizer paper. It demonstrates how you can prune nodes in query plans for tree-based inference, so we’re taking a similar approach here for forests (GBDTs) using Ibis.
Quickgrove: prunable GBDT models
Quickgrove is an experimental package that can load GBDT JSON models and provides a .prune(...)
API to remove unreachable branches. For example:
#pip install quickgrove
import quickgrove
= quickgrove.json_load("diamonds_model.json") # Load an XGBoost model
model "color_i") < 0.2]) # Prune based on known predicate model.prune([quickgrove.Feature(
Once pruned, the model is leaner to evaluate. Note: The results heavily depend on model splits and interactions with predicate pushdowns.
Scalar PyArrow UDFs in Ibis
Please note that we are using our own modified DataFusion backend. The DataFusion backend and DuckDB backend behave differently: DuckDB expects a ChunkedArray
while DataFusion UDFs expect ArrayRef
. We are working on extending quickgrove to work with the DuckDB backend.
We’ll define a simple Ibis UDF that calls our model.predict_arrays
under the hood:
import ibis
import ibis.expr.datatypes as dt
"datafusion")
ibis.set_backend(@ibis.udf.scalar.pyarrow
def predict_gbdt(
carat: dt.float64,
depth: dt.float64,# ... other features ...
-> dt.float32:
) = [carat, depth, ...]
array_list return model.predict_arrays(array_list)
Currently, UDFs are opaque to Ibis. We need Ibis to teach Ibis how to rewrite a udf based on predicates it knows about.
Making Ibis UDFs predicate-aware
Here’s the general process:
- Collect predicates from the user’s filter (e.g.
x < 0.3
). - Prune the model based on those predicates (removing unreachable tree branches).
- Rewrite a new UDF that references the pruned model, preserving the rest of the query plan.
1. Collecting predicates
from ibis.expr.operations import Filter, Less, Field, Literal
from typing import List, Dict
def collect_predicates(filter_op: Filter) -> List[dict]:
"""Extract 'column < value' predicates from a Filter operation."""
= []
predicates for pred in filter_op.predicates:
if isinstance(pred, Less) and isinstance(pred.left, Field):
if isinstance(pred.right, Literal):
predicates.append({"column": pred.left.name,
"op": "Less",
"value": pred.right.value
})return predicates
2. Pruning model and creating a new UDF
import functools
from ibis.expr.operations import ScalarUDF
from ibis.common.collections import FrozenDict
def create_pruned_udf(original_udf, model, predicates):
"""Create a new UDF using the pruned model based on the collected predicates."""
from quickgrove import Feature
# Prune the model
= model.prune([
pruned_model "column"]) < pred["value"]
Feature(pred[for pred in predicates
if pred["op"] == "Less" and pred["value"] is not None
])# For simplicity, let’s assume we know the relevant features or keep them the same.
def fn_from_arrays(*arrays):
return pruned_model.predict_arrays(list(arrays))
# Construct a dynamic UDF class
= {
meta "dtype": dt.float32,
"__input_type__": "pyarrow",
"__func__": property(lambda self: fn_from_arrays),
"__config__": FrozenDict(volatility="immutable"),
"__udf_namespace__": original_udf.__module_
"__module__": original_udf.__module__,
"__func_name__": original_udf.__name__ + "_pruned"
}
# Create a new ScalarUDF node type on the fly
= type(original_udf.__name__ + "_pruned", (ScalarUDF,), {**fields, **meta})
node
@functools.wraps(fn_from_arrays)
def construct(*args, **kwargs):
return node(*args, **kwargs).to_expr()
= fn_from_arrays
construct.fn return construct
3. Rewriting the plan
Now we use an Ibis rewrite rule (or a custom function) to detect filters on the expression, prune the model, and produce a new project/filter node.
from ibis.expr.operations import Project
@replace(p.Filter)
def prune_gbdt_model(filter_op, original_udf, model):
"""Rewrite rule to prune GBDT model based on filter predicates."""
= collect_predicates(filter_op)
predicates if not predicates:
# Nothing to prune if no relevant predicates
return filter_op
# in a real implementation you'd want to match on a ScalarUDF and ensure that the instance of the model type is
# the one implemented with quickgrove
= create_pruned_udf(original_udf, model, predicates)
pruned_udf, required_features
= filter_op.parent
parent_op # Build a new projection with the pruned UDF
= {}
new_values for name, value in parent_op.values.items():
# If it’s the column that calls the UDF, swap with pruned version
if name == "prediction":
# For brevity, assume we pass the same columns to the pruned UDF
= pruned_udf(value.op().args[0], value.op().args[1])
new_values[name] else:
= value
new_values[name]
= Project(parent_op.parent, new_values)
new_project
# Re-add the filter conditions on top
= []
new_predicates for pred in filter_op.predicates:
if isinstance(pred, Less) and isinstance(pred.left, Field):
new_predicates.append(
Less(Field(new_project, pred.left.name), pred.right)
)else:
new_predicates.append(pred)
return Filter(parent=new_project, predicates=new_predicates)
Diff
For a query like the following:
= (
expr =predict_gbdt(t.carat, t.depth, ...))
t.mutate(predictionfilter(
."clarity_vvs2"] < 1),
(t["color_i"] < 1),
(t["color_j"] < 1)
(t[
)"prediction")
.select( )
See the diff below:
Notice that with pruning we can drop some of the projections in the UDF e.g., color_i
, color_j
and clarity_vvs2
. The underlying engine (e.g., DataFusion) may optimize this further when pulling data for UDFs. We cannot completely drop these from the query expression.
- predict_gbdt_3(
+ predict_gbdt_pruned(
carat, depth, table, x, y, z,
cut_good, cut_ideal, cut_premium, cut_very_good,
- color_e, color_f, color_g, color_h, color_i, color_j,
+ color_e, color_f, color_g, color_h,
clarity_if, clarity_si1, clarity_si2, clarity_vs1,
- clarity_vs2, clarity_vvs1, clarity_vvs2
+ clarity_vs2, clarity_vvs1
)
Putting it all together
The complete example can be found here.
# 1. Load your dataset into Ibis
= ibis.read_csv("diamonds_data.csv")
t
= (
expr =predict_gbdt(t.carat, t.depth, ...))
t.mutate(predictionfilter(
."clarity_vvs2"] < 1),
(t["color_i"] < 1),
(t["color_j"] < 1)
(t[
)"prediction")
.select(
)
# 3. Apply your custom optimization
= prune_gbdt_model(expr.op(), predict_gbdt, model)
optimized_expr
# 4. Execute the optimized query
= optimized_expr.to_expr().execute() result
When this is done, the model inside predict_gbdt
will be pruned based on the expression’s filter conditions. This can yield significant speedups on large datasets (see Table 1).
Performance impact
Here is the benchmark results ran on Apple M2 Mac Mini, 8 cores / 8GB Memory run with a model trained with 100 trees and depth 6 with following filter conditions:
_.carat < 1,
_.clarity_vvs2 < 1,
_.color_i < 1,
_.color_j < 1,
Benchmark results:
File Size | Regular (s) | Optimized (s) | Improvement |
---|---|---|---|
5M | 0.82 ±0.02 | 0.67 ±0.02 | 18.0% |
25M | 4.16 ±0.01 | 3.46 ±0.05 | 16.7% |
100M | 16.80 ±0.17 | 14.07 ±0.11 | 16.3% |
Key takeaway: As data volume grows, skipping unneeded tree branches can translate to real compute savings, albeit heavily dependent on how pertinent the filter conditions might be.
LetSQL: simplifying UDF rewriting
LetSQL makes advanced UDF rewriting and multi-engine pipelines much simpler. It builds on the same ideas we explored here but wraps them in a higher-level API.
Here’s a quick glimpse of how LetSQL might simplify the pattern:
# pip install letsql
import letsql as ls
from letsql.expr.ml import make_quickgrove_udf, rewrite_quickgrove_expression
= "xgboost_model.json"
model_path = make_quickgrove_udf(model_path)
predict_udf
= ls.memtable(df).mutate(pred=predict_udf.on_expr).filter(ls._.carat < 1)
t = rewrite_quickgrove_expression(t)
optimized_t
= ls.execute(optimized_t) result
The complete example can be found here. With LetSQL, you get a shorter, more declarative approach to the same optimization logic we manually coded with Ibis. It abstracts away the gritty parts of rewriting your query plan.
Best practices & considerations
- Predicate Types: Currently, we demonstrated
column < value
logic. You can extend it to handle<=
,>
,BETWEEN
, or even categorical splits. - Quickgrove only supports a handful of objective functions and most notably does not have categorical support yet. In theory, categorical variables make better candidates for pruning based on filter conditions. It only supports XGBoost format.
- Model Format: XGBoost JSON is straightforward to parse. Other formats (e.g. LightGBM, scikit-learn trees) require similar logic or conversion steps.
- Edge Cases: If the filter references columns not in the model features, or if multiple filters combine in more complex ways, your rewriting logic may need more robust parsing.
- When to Use: This approach is beneficial when queries often filter on the same columns your trees split on. For purely adhoc queries or rarely used filters, the overhead of rewriting might outweigh the benefit.
Conclusion
Combining Ibis with a prune-friendly framework like quickgrove lets you optimize large-scale ML inference inside ML workflows. By pushing filter predicates down into your decision trees, you speed up queries significantly.
With LetSQL, you can streamline this entire process—especially if you’re looking for an out-of-the-box solution that integrates with multiple engines along with batteries included features like caching and aggregate/window UDFs. For the next steps, consider experimenting with more complex models, exploring different tree pruning strategies, or even extending this pattern to other ML models beyond GBDTs.
- Try it out: Explore the Ibis documentation to learn how to build custom UDFs.
- Dive deeper: Check out quickgrove or read the Raven optimizer paper.
- Experiment with LetSQL: If you need a polished solution for dynamic ML UDF rewriting, LetSQL may be just the ticket.