Egg-style Equality Saturation using only Z3
An egraph is a data structure for storing many ground equalities at once in a highly shared way. Egg style equality saturation searches over this egraph for patterns and asserts nondestructive rewrites into the egraph. Once you’re sick of doing that, you can use extract a best term from the egraph. This may enable you to avoid valleys that greedy rewriting may stick you in and makes the system more robust to the order of rewrites.
Theorem provers, compilers and optimizers contain a lot of the same tech but are seen through the lens of distinct communities. Equational theorem provers are in the mode term -> term -> bool
whereas a simplifier is used in the mode term -> term
It is a design choice of these systems about whether to present a declarative or operational interface. Theorem provers have a tendency to present a declarative interface.
From the declarative logical interface of an SMT solver it is hard to predict exactly what model the thing will return even though a naive operational picture of the solving process might make you think this or that might happen.
Things like prolog or datalog or rewrite engines or egg have a semi-operational interface. It’s kind of a really fun sweet spot where expertise in the system can enable you to push it in the direction you want.
There are a number of cool pieces in SMT solvers and ATPs that as I’ve learned more about them, I wish I had direct operational access to. This is however a maintenance burden and a new API promise that the maintainers are right to be skeptical of offering. It isn’t what they are in the business of, and I’m kind of a niche weirdo who just likes fiddling.
Something that has been tantalizing is the egraph and ematching engine of Z3. Egg’s implementation of ematching is based off the paper describing z3’s. Z3 offers a ton of extra goodies, and really importantly has great bindings in many languages.
I’ve fiddled with sort of half way doing this https://www.philipzucker.com/ext_z3_egraph/ by reusing the AST of z3, but maintaining the egraph and performing the ematching on my own (via bottom up ematching, easy mode).
A couple of APIs have opened up fairly recently though that makes it possible to reuse Z3’s egraph and ematcher in this way. I realized this while I was using the OnClause
callback functionality to extract instantiations for universally quantifier axioms in knuckledragger, my low barrier LCF-style proof assistant seamlessly built around Z3.
While you might think that asserting ground equalities into a Z3 solver object would be guaranteed to give you the expected minimal, it isn’t always the case. If you only assert equalities, a valid model is one where everything maps to a single semantic object, while from the equality saturation perspective this is deriving more equalities.
Observing Quantifier Instantiations
There is an api I learned of from here https://microsoft.github.io/z3guide/programming/Proof%20Logs/#callbacks-for-clause-inferences where you can get a callback on inferences. This includes quantifier instantiations.
This is enough to infer a pile of ground equalities discovered by the e-matching engine. We can build a little function that records the equalities.
Turning model based quantifier instantiation (mbqi) cleans up the results and I think makes it only use the e-matching engine.
I figured out which parts contain the substitution and rule by just playing around. I’m not sure this is documented anywhere.
A quantified axioms where the explicitly given pattern ought to behave similar to a rewrite rule in an equality saturation system.
import z3
from collections import namedtuple
# little namedtuple to store the pieces of a match. The rule matched, the substitution
# and the ground equation derived by instantiating the rule with the subst
Instan = namedtuple("Instan", ["rule", "subst", "eq"])
def find_instans(t : z3.ExprRef, rules : list[z3.ExprRef], timeout=1000) -> list[Instan]:
s = z3.Solver()
s.set(auto_config=False, mbqi=False)
instans = []
def log_instance(pr, clause, myst):
if z3.is_app(pr):
if pr.decl().name() == "inst":
instans.append(pr)
onc = z3.OnClause(s, log_instance)
s.set(timeout=timeout)
myterm = z3.Const("myterm", t.sort()) # a marker to keep track of original term
s.add(myterm == t) # seed the egraph
for r in rules:
s.add(r)
res = s.check()
instans1 = []
for pr in instans:
quant_thm = pr.arg(0) # rule used
subst = pr.arg(2) # substitution into quantifier
eq = z3.substitute_vars(quant_thm.body(), *reversed(subst.children())) # instantiate the rule
instans1.append(Instan(quant_thm, subst, eq))
return instans1
We can try it on a little example with a rule f(f(x)) -> x
and a start term f(f(f(f(a))))
T = z3.DeclareSort("T")
a,x = z3.Consts("a x", T)
f = z3.Function("f", T, T)
ffffa = f(f(f(f(a))))
r = z3.ForAll([x], f(f(x)) == x, patterns=[f(f(x))])
instans = find_instans(ffffa, [r])
instans
[Instan(rule=ForAll(x, f(f(x)) == x), subst=bind(a), eq=f(f(a)) == a),
Instan(rule=ForAll(x, f(f(x)) == x), subst=bind(f(a)), eq=f(f(f(a))) == f(a)),
Instan(rule=ForAll(x, f(f(x)) == x), subst=bind(f(f(a))), eq=f(f(f(f(a)))) == f(f(a)))]
You could package up the rewrite rules into a little helper function to auto insert the left hand side as the pattern. Saves a little typing
def rw(vs,lhs,rhs):
"""A different helper to construct the rule"""
return z3.ForAll(vs, lhs == rhs, patterns=[lhs])
Building an Egraph with SimpleSolver
The other thing that happened is that the egraph was exposed via the root
and next
API functions. Nikolaj mentioned this in a tweet at some point.
In order to use these, I think you need to use s = SimpleSolver()
, assert your ground equalities and then run s.check()
. This should alway be sat.
The style of the API is a bit different from egg’s presentation partially because the actual implementation I suspect is more pointer-y. The eclass is stored as a linked list of equivalent terms, each of which has a union find pointer to the root of the equivalence class. More about this here https://z3prover.github.io/papers/z3internals.html#sec-equality-and-uninterpreted-functions .
s.root(e)
will give back a canonical term representing the equivalence class of e
. s.next(e)
will give back the next term in the equivalence class of e
.
So after finding the equalities using the previous solver, we can assert them into a new solver and then extract the best term. The amount we’re getting out this is marginal maybe compared to implementing our own egraph in this style https://www.philipzucker.com/ext_z3_egraph/ but it is pleasing to reuse z3 parts.
type EGraph = z3.Solver
def build_egraph(t : z3.ExprRef, eqs : list[z3.BoolRef]) -> EGraph:
s = z3.SimpleSolver()
myterm = z3.Const("myterm", t.sort())
s.add(t == myterm) # seed the egraph
for eq in eqs:
assert z3.is_eq(eq)
s.add(eq)
s.check()
return s
s = build_egraph(ffffa, [i.eq for i in instans])
s
[f(f(a)) = a, f(f(a)) = a, f(f(f(a))) = f(a), myterm = a]
When we need to look through the egraph, we need to discover the graph by depth first search from a starting node. There isn’t a global way to look at the egraph so far as I know.
def egraph_dict(t :z3.ExprRef, s : EGraph) -> dict[z3.ExprRef, list[z3.ExprRef]]:
"""
Returns a dictionary mapping each root of the egraph to the "enodes" in the egraph
The children of these enode terms may be uncanonized.
"""
enode_table = {}
def dfs(t):
root = s.root(t)
if root in enode_table:
return
enodes = []
enode_table[root] = enodes
q = root
while True:
enodes.append(q)
for c in t.children():
dfs(c)
q = s.next(q)
if q.eq(root):
break
dfs(t)
return enode_table
egraph_dict(ffffa, s)
{f(f(f(f(a)))): [f(f(f(f(a))))],
f(a): [f(a), f(f(f(a)))],
a: [a, myterm, f(f(a))]}
Here we see something rearing it’s head. Somehow f(f(f(f(a))))
is not really being shown being in the same eclass as a
, but it does know it. Clearly we are on shaky ground.
In any case, with this it is now possible to run a brute force extraction process. You can do better.
BestTerm = namedtuple("BestTerm", ["term", "cost"])
def init_eclass_table(t : z3.ExprRef, s : EGraph) -> dict[z3.ExprRef, BestTerm]:
"""
DFS the egraph and initilize a cost table to infinity
"""
eclasses = {}
def fill_eclass(t):
root = s.root(t)
if root not in eclasses:
q = root
while True:
eclasses[root] = {"cost": float("inf"), "term":root}
for c in q.children():
fill_eclass(c)
q = s.next(q)
if q.eq(root):
break
fill_eclass(z3.Const("myterm", t.sort()))
return eclasses
def find_best_term(t : z3.ExprRef, s : EGraph) -> BestTerm:
eclasses = init_eclass_table(t, s)
myterm = z3.Const("myterm", t.sort())
for i in range(20): # Actually you should run to fixed point.
for eclass, cost_term in eclasses.items():
cost = cost_term["cost"]
q = eclass
while True:
if q.eq(myterm): # don't extract the marker. Give it infinite cost
newcost = float("inf")
else:
newcost = 1 + sum([eclasses[s.root(c)]["cost"] for c in q.children()])
if newcost < cost:
best = q.decl()([eclasses[s.root(c)]["term"] for c in q.children()])
eclasses[eclass] = {"cost": newcost, "term": best}
q = s.next(q)
if q.eq(eclass):
break
return eclasses[s.root(myterm)]
find_best_term(ffffa, s)
{'cost': 1, 'term': a}
You can chain it together into an equality saturation routine
def eqsat(t : z3.ExprRef, eqs : list[z3.BoolRef], timeout=1000) -> z3.ExprRef:
inst = find_instans(t, eqs, timeout=timeout)
s = build_egraph(t, [i.eq for i in inst])
best = find_best_term(t, s)
return best
eqsat(ffffa, [r])
{'cost': 1, 'term': a}
Here is the canonical egg problem, as problematic as it may be. You don’t want a bitshift optimization to interfere with a mul div optimization.
T = z3.DeclareSort("T")
a,x,y = z3.Consts("a x y", T)
mul = z3.Function("mul", T, T, T)
div = z3.Function("div", T, T, T)
lsh = z3.Function("lsh", T, T, T)
#z3.ArithRef.__mul__ = mul
#z3.ArithRef.__div__ = div
one,two = z3.Consts("one two", T)
r1 = z3.ForAll([x], mul(x, two) == lsh(x,one) , patterns=[mul(x,two)])
r2 = z3.ForAll([x,y], div(mul(x,y),y) == x, patterns=[div(mul(x,y),y)])
find_instans(div(mul(a,two),two), [r1,r2])
[Instan(rule=ForAll(x, mul(x, two) == lsh(x, one)), subst=bind(a), eq=mul(a, two) == lsh(a, one)),
Instan(rule=ForAll([x, y], div(mul(x, y), y) == x), subst=bind(a, two), eq=div(mul(a, two), two) == a)]
eqsat(div(mul(a,two),two), [r1,r2])
{'cost': 1, 'term': a}
Bits and Bobbles
While I did this in python, python accesses the C API so you could do it in C or C++. Then maybe the OnClause callback wouldn’t be that bad performance wise? Regardless of that, z3 is pretty fast.
Using the SimplerSolver egraph is dicey and possibly not worth it. Getting the ematcher through OnClause is the bigger win.
The whole thing is a hack job. YMMV.
If one made a horn clause rule, I’m not sure what it might do. Adding in boolean connectives like that may make the SAT bit activated.
Does this lead to a road of nice theory integration?
Zach I think has mentioned that hacking z3 to get an eq sat engine is something people have done, but it wasn’t sustainable/maintainable. Going into the guts of z3 is a daunting proposition. It’s nice that this uses surface features.
https://x.com/BjornerNikolaj/status/1764793235246076313
This was the pruner I was fiddling with for Knuckledragger. YOu can use the second stage solver to get unsat cores also to know which instantiations were actually useful. Giving this back to the user let’s them instantiate them themselves in their proofs, leading to faster solver calls the next time around.
from kdrag.all import *
def instans(thm, by=[], timeout=1000):
"""
Gather instantiations of quantifiers from Z3 trace
"""
s = smt.Solver()
s.set("timeout", timeout)
for n, p in enumerate(by):
s.assert_and_track(p, "KNUCKLEBY_{}".format(n))
s.assert_and_track(smt.Not(thm), "KNUCKLEGOAL")
instans = []
def log_instance(pr, clause, myst):
if smt.is_app(pr):
if pr.decl().name() == "inst":
quant_thm = pr.arg(0)
target = pr.arg(1)
subst = pr.arg(2)
instans.append((quant_thm, subst.children()))
onc = smt.OnClause(s, log_instance)
s.check()
return instans
R = smt.Function("R", smt.IntSort(), smt.IntSort(), smt.BoolSort())
x,y,z,w = smt.Ints("x y z w")
def prune_quant(thm, by=[]):
insts = instans(thm, by=by)
lemmas = [x for x in by if not isinstance(x, smt.QuantifierRef)] # could go deeper.
insts2 = { thm(*args) : (thm,args) for thm, args in insts}
lemmas.extend(insts2.keys())
needed = kd.utils.prune(thm, by=lemmas)
return { thm : inst for thm, inst in insts2.items() if thm in needed}
instans(R(y,w), [smt.ForAll([x], R(x, x)), smt.ForAll([x,z], R(x, z))])
prune_quant(R(y,w), [smt.ForAll([x], R(x, x)), smt.ForAll([x,z], R(x, z))])
{R(y, w): (ForAll([x, z], R(x, z)), [y, w])}
import z3
def eqsat(t : z3.ExprRef, rules : list[z3.ExprRef], timeout=1000):
# Run a quantifier instantiation run, recording all instantiations.
s = z3.Solver()
s.set(timeout=timeout)
s.set(auto_config=False, mbqi=False)
instans = []
def log_instance(pr, clause, myst):
if z3.is_app(pr):
if pr.decl().name() == "inst":
instans.append(pr)
onc = z3.OnClause(s, log_instance)
myterm = z3.FreshConst(t.sort(), prefix="myterm")
s.add(myterm == t) # seed the egraph
for r in rules:
s.add(r)
res = s.check()
#print(instans)
# Build a new solver that won't see quantifiers. This is our egraph
s = z3.SimpleSolver()
s.add(t == myterm) # seed the egraph
for pr in instans:
quant_thm = pr.arg(0) # rule used
subst = pr.arg(2) # substitution into quantifier
print("assert unit", z3.substitute_vars(quant_thm.body(), *subst.children()))
s.add(z3.substitute_vars(quant_thm.body(), *subst.children()))
s.check()
print(s)
print(s.model())
eclasses = {}
def fill_eclass(t):
root = s.root(t)
if root not in eclasses:
q = root
while True:
eclasses[root] = {"cost": float("inf"), "term":root}
for c in q.children():
fill_eclass(c)
q = s.next(q)
if q.eq(root):
break
fill_eclass(myterm)
# Iterate through dumbly to find lowest cost term
for i in range(10):
for eclass, cost_term in eclasses.items():
cost = cost_term["cost"]
q = eclass
while True:
if myterm.eq(q):
newcost = float("inf")
else:
newcost = 1 + sum([eclasses[s.root(c)]["cost"] for c in q.children()])
if newcost < cost:
best = q.decl()([eclasses[s.root(c)]["term"] for c in q.children()])
eclasses[eclass] = {"cost": newcost, "term": best}
q = s.next(q)
if q.eq(eclass):
break
print(eclasses)
return eclasses[s.root(myterm)]
# return model
T = z3.DeclareSort("T")
a,x = z3.Consts("a x", T)
f = z3.Function("f", T, T)
r = z3.ForAll([x], f(f(x)) == x, patterns=[f(f(x))])
eqsat(f(f(a)), [r])
assert unit f(f(a)) == a
[f(f(a)) == a, myterm!1818 == a]
[a = T!val!0,
myterm!1818 = T!val!0,
f = [T!val!1 -> T!val!0, else -> T!val!1]]
{a: {'cost': 1, 'term': a}, f(a): {'cost': 2, 'term': f(a)}}
{'cost': 1, 'term': a}
eqsat(f(f(f(f(a)))), [r])
{'cost': 1, 'term': a}
eqsat(f(f(f(f(f(a))))), [r])
{'cost': 2, 'term': f(a)}
T = z3.DeclareSort("T")
a,x,y = z3.Consts("a x y", T)
mul = z3.Function("mul1", T, T, T)
div = z3.Function("div1", T, T, T)
lsh = z3.Function("lsh1", T, T, T)
#z3.ArithRef.__mul__ = mul
#z3.ArithRef.__div__ = div
one,two = z3.Consts("one two", T)
r1 = z3.ForAll([x], mul(x, two) == lsh(x,one) , patterns=[mul(x,two)])
r2 = z3.ForAll([x,y], div(mul(x,y),y) == x, patterns=[div(mul(x,y),y)])
eqsat(div(mul(a,two),two), [r1,r2], timeout=4)
#instans(div(mul(a,two),two), [r1,r2])
assert unit mul1(a, two) == lsh1(a, one)
assert unit div1(mul1(two, a), a) == two
[div1(mul1(a, two), two) == myterm!1819,
mul1(a, two) == lsh1(a, one),
div1(mul1(two, a), a) == two]
[myterm!1819 = T!val!2,
one = T!val!3,
a = T!val!0,
two = T!val!1,
mul1 = [(T!val!1, T!val!0) -> T!val!5, else -> T!val!4],
div1 = [(T!val!5, T!val!0) -> T!val!1, else -> T!val!2],
lsh1 = [else -> T!val!4]]
{myterm!1819: {'cost': 5, 'term': div1(mul1(a, two), two)}, lsh1(a, one): {'cost': 3, 'term': mul1(a, two)}, a: {'cost': 1, 'term': a}, one: {'cost': 1, 'term': one}, two: {'cost': 1, 'term': two}, mul1(two, a): {'cost': 3, 'term': mul1(two, a)}}
{'cost': 5, 'term': div1(mul1(a, two), two)}
import z3
def instans(thm, by=[], timeout=1000):
"""
Gather instantiations of quantifiers from Z3 trace
"""
s = smt.Solver()
s.set("timeout", timeout)
for n, p in enumerate(by):
s.assert_and_track(p, "KNUCKLEBY_{}".format(n))
s.assert_and_track(smt.Not(thm), "KNUCKLEGOAL")
instans = []
def log_instance(pr, clause, myst):
if smt.is_app(pr):
if pr.decl().name() == "inst":
quant_thm = pr.arg(0)
target = pr.arg(1)
subst = pr.arg(2)
instans.append((quant_thm, subst.children()))
onc = smt.OnClause(s, log_instance)
s.check()
return instans
R = smt.Function("R", smt.IntSort(), smt.IntSort(), smt.BoolSort())
x,y,z,w = smt.Ints("x y z w")
def prune_quant(thm, by=[]):
insts = instans(thm, by=by)
lemmas = [x for x in by if not isinstance(x, smt.QuantifierRef)] # could go deeper.
insts2 = { thm(*args) : (thm,args) for thm, args in insts}
lemmas.extend(insts2.keys())
needed = prune(thm, by=lemmas)
return { thm : inst for thm, inst in insts2.items() if thm in needed}
instans(R(y,w), [smt.ForAll([x], R(x, x)), smt.ForAll([x,z], R(x, z))])
prune_quant(R(y,w), [smt.ForAll([x], R(x, x)), smt.ForAll([x,z], R(x, z))])
Prune Prune instantiation.
import z3
egraph = SimpleSolver()
s = Solver()
for r in rules:
s.add(r)
instans = []
def on_clause():
if "inst":
instans.append("inst")
z3.OnClause()