Bottom Up Egraph Ematching Plays Nicer with Theories (AC, etc)
Something that is an open problem in the egraph community is how to cleanly integrate domain specific information / normalizing rewrite systems / algorithms cleanly with egraph rewriting. AC, lambda, graphs, cheap arithmetical reasoning are all quite desirable.
I claim that “bottom up” ematching makes it easier to think about these things. It is by far the simplest way to do ematching and has some good properties with regards to theories.
Pattern Matching: top down, bottom up, side to side
Top Down Ematching
In much of the literature, a top down ematcher is described. You start with a pattern foo(bar(X),Y) -> biz(X)
and an eclass e1
. One looks for an enode with symbol foo
in eclass e1
, for which there may be many e1 = foo(e2,e3), e1 = foo(e4,e5), ...
.
For each of these possibilities, you recursively ematch with the subpatterns bar(X)
and Y
.
The deeper the pattern, the more nondeterministic search it will expose.
For the case of seeking a query pattern over a database of terms/eclasses, other possibilities become reasonable.
Note that in top down ematching for all eclasses, you do not really have a starting e1
to work with. You also first do a top level scan over the whole egraph to find all the possible e1
.
Relational Ematching
A general perspective is that of relational ematching which allows many possibilities of pattern expansion direction.
Relational ematching’s non opinion about join ordering is both a strength and weakness. Perhaps you can build a good query planner, but where should you start?
Bottom Up Ematching
An underdiscussed method is that of bottom up ematching. It is incredibly simple, even simpler than top down.
First you guess the leaves (the variables). In other words, for each variable, you have a scan over all possible eclasses. Then just construct the eclass by hash table lookup and see if that term is in the egraph already.
The search over the variables is quite reminiscent of the worst case optimal join algorithm GJ (see figure 2b). It is also reminsicent of how people write send-more-money puzzles using the non determinism monad.
# (rewrite (foo (bar X) Y) -> (biz X))
for X in eclasses:
for Y in eclasses:
try:
lhs = foo[bar[X], Y] # possibly failing lookup in hashcons
rhs = biz(X) # construct rhs
egraph.union(lhs,rhs) # set equal
except e:
pass
On first glance, this feels extremely wasteful. Scans are “expensive” and every nesting of them costs asymptotically more.
However, there are big payoffs.
- Very simple.
- Less indices are needed. You only need the most minimal hash consing lookup behavior. You do not need to expand eclasses to enodes.
- No non determinism while traversing the pattern. Depth does not really matter.
- You already pay at least 1 scan in top down. If you only have 1 variable, bottom up is definitely better and YMMV for more variables
- Possibly more parallelizable / gpuable because it’s so simple (?).
- Easily written as an idiom. You can hand tweak you loops. Change the ordering of variable expansion, maybe prune by Sort, lift common subexpressions out to fail early. Your host language understands loops quite well and might do some of this for you. Reminds me of datafrog A demonstration of loop invariant code motion / hoisting for patterns:
for X in eclasses:
try:
t1 = bar[X] # This can be lifted into higher loop
for Y in eclasses:
try:
lhs = foo[t1, Y] # possibly failing lookup in hashcons
rhs = biz(X) # construct rhs
egraph.union(lhs,rhs) # set equal
except:
pass
except:
pass
- During pattern matching you have ground terms to work with. This property I claim gives better theory integration.
Theories
The reason bottom up ematching is better for theories is because you immediately ground everything. Ground things are much easier to normalize than things with variables still in them. During top-down or relational ematching, the pattern gets partially filled in and you a still left with chunks of patterns that may have variables in them. A set or multiset for example are easily normalized by sorting its list representation. Or you can use a trie or some other data structure. See python’s frozenset
and the discussion by Ricard O’Keefe How to Hash a Set. More on this some other day.
You don’t exactly have a ground term because the leaf pattern variables have eclasses filled in. But you do have enough structure exposed that you can typically easily normalize. Sort of you have done a partial extraction, so this is perhaps related to the “extract, normalize, reinsert” method.
A Small Implementation in Python
An egraph ~ union find + hash cons, so first we need a union find. Path compression and balancing are not important for correctness, only for performance so I’ll ignore them. I’ll make an aside that path compression is actually the source of many conceptual problems and keeping an uncompressed copy is the start of having a proof producing union find.
class UnionFind():
def __init__(self):
self.uf = []
def makeset(self):
uf = self.uf
uf.append(len(uf))
return len(uf) - 1
def find(self,x):
while self.uf[x] != x:
x = self.uf[x]
return x
def union(self,x,y):
x = self.find(x)
y = self.find(y)
self.uf[x] = y
return y
def __len__(self):
return len(self.uf)
def __repr__(self):
return "UnionFind({})".format(self.uf)
EClass = int
Next we define keys that can go into the egraph. Regular Enodes, literals, AC enodes and Set nodes are all possible and pleasingly treated on the same footing. By defining each class to have a notion of rebuilding, we make the datatypes supported by the egraph extensible. During rebuilding, we will call these functions to canonize the egraph’s keys.
“AC” is normalized by just sorting its arguments. AC is essentially a multiset.
All of this is somewhat similar to how collections work in egglog.
from dataclasses import dataclass
from typing import Any, Dict
@dataclass(frozen=True)
class Lit():
data: Any
def rebuild(self):
return self
@dataclass(frozen=True)
class ENode():
head: Any
args : tuple[EClass]
def rebuild(self):
return ENode(self.head, tuple(map(egraph.uf.find, self.args)))
@dataclass(frozen=True)
class ESet():
data: frozenset[EClass]
def rebuild(self):
return ESet(frozenset(map(egraph.uf.find, self.data)))
@dataclass(frozen=True)
class AC_ENode(): # AC enode
head: Any
args: tuple[EClass] # always build sorted
# I can get this smart constructor but frozen=True makes it hard.
#def __init__(self, head, args):
# self.head = head
# self.args = sorted(args)
def rebuild(self):
return make_ac(self.head, self.args)
def make_ac(head,args):
return AC_ENode(head,tuple(sorted(args)))
The egraph itself is basically a bundle of a hashcons and the union find. It’s convenient to forward some operations so a lot of this is mostly boiler plate.
Rebuild is a bit interesting. I have a very naive rebuild loop. Run it to propagate congruences. Maybe parent pointers are important. Maybe. Again, not important for correctness.
class EGraph:
def __init__(self):
self.uf = UnionFind()
self.hashcons : Dict[Any, EClass] = {}
def __getitem__(self, key):
return self.hashcons[key]
def union(self, x : EClass,y : EClass) -> EClass:
return self.uf.union(x,y)
def make(self, t : Any) -> EClass:
t1 = self.hashcons.get(t)
if t1 == None:
v = self.uf.makeset()
self.hashcons[t] = v
return v
else:
return t1
def __getitem__(self, key):
return self.hashcons[key]
def __iter__(self):
return iter(range(len(self.uf.uf)))
def __repr__(self):
return f"EGraph(uf={self.uf},hashcons={self.hashcons})"
def rebuild(self):
# simple naive dumb rebuild step
for k,v in self.hashcons.items():
v = self.uf.find(v)
self.hashcons[k] = v
v1 = self.make(k.rebuild())
self.union(v,v1)
Finally we can show using it a bit. The Function
class makes an enode builder object with some nice notation to distinguish between “making” lookup and “non makeing” lookup.
The rewrite rule is directly expressed in python. This is easiest really, and I think it reads pretty simply. I used exception throwing so that any lookup that fails deep in the pattern will just be caught. I could instead have defined Function to propagate the None
of a get
up, but this seemed expedient.
# A cute little helper to make functions easier to use
# f(x,y) wil construct he term if it doesn't exist. It's for right hand sides of rewrite rules
# f[x,y] will not construct and throw an error if it doesn't exist. It's for left hand sides
@dataclass
class Function():
def __init__(self, egraph, func):
self.egraph = egraph
self.func = func
def __call__(self, *args) -> EClass:
return self.egraph.make(ENode(self.func, args))
def __getitem__(self, args) -> EClass:
if not isinstance(args,tuple):
args = (args,)
return self.egraph[ENode(self.func, args)]
egraph = EGraph()
foo = Function(egraph, "foo")
bar = Function(egraph, "bar")
biz = Function(egraph, "biz")
one = egraph.make(Lit(1))
one1 = egraph.make(Lit(1))
# We are hashing identitifiers
assert one == one1
two = egraph.make(Lit(2))
fred = egraph.make(Lit("fred"))
t1 = foo(two,two)
t2 = bar(two)
# just testing some things
print(egraph)
egraph.union(t2, two)
for i in range(10):
egraph.rebuild()
print(egraph)
# ematching foo(bar(X),Y) -> biz(X)
for X in egraph:
for Y in egraph:
try:
lhs = foo[bar[X], Y]
rhs = biz(X)
egraph.union(lhs, rhs)
except KeyError:
pass
egraph.rebuild()
assert foo[bar[two], two] == biz(two)
EGraph(uf=UnionFind([0, 1, 2, 3, 4]),hashcons={Lit(data=1): 0, Lit(data=2): 1, Lit(data='fred'): 2, ENode(head='foo', args=(1, 1)): 3, ENode(head='bar', args=(1,)): 4})
EGraph(uf=UnionFind([0, 1, 2, 3, 1]),hashcons={Lit(data=1): 0, Lit(data=2): 1, Lit(data='fred'): 2, ENode(head='foo', args=(1, 1)): 3, ENode(head='bar', args=(1,)): 1})
Here’s a small example showing that “obviously” it finds a multiset node when you look for i, regardless of ordering. Note that this may not support every possible notion of what you might mean by AC pattern. See more here https://www.philipzucker.com/relational-ac-matching/. Everything that is cheap is quite obvious though.
t3 = egraph.make(make_ac("foo", (t1,one,one,two)))
for X in egraph:
try:
lhs = egraph[make_ac("foo",(one,two,X,t1))]
print("AC match: ", (one,two,X,t1))
except KeyError:
pass
AC match: (0, 1, 0, 3)
Bits and Bobbles
One could view bottom up or top down ematching as merely particular choices of relational ematching. I think the lack of opinion on join ordering of relational ematching is both a feature and weakness. Also I suspect the fully flattened table representation perspective foo(bar(1),2) ~ {foo(e1,e2,e3), bar(e4, e1), const(1,e4), const(2,e3)}
makes it more difficult to recover notions of context. Even hash conses really mess up context by design.
An AC term is roughly the same thing as a multiset. (a + b + c + a + b + a) ~ {a : 3, b: 2, c:1}
. This multiset is canonical regardless of the ordering of the pieces.
AC pattern matching requires a search over the partitions of the multiset argument. If I want to match X + Y
over the multiset, that will generate all partitions.
def match_X_Y(t):
term = {"a":3, "b":3, "c":1 }
This is the fairly natural extension to typical pattern matching over single normal terms in languages like ocaml, haskell, etc. In that case, there can’t be multiple solutions to the pattern match over a single term and there is no search. In a typical implementation, you just chase pointers down the term.
Analyses are one example that treats some of these domain specific problems. Another approach I have seen is the “extract term, do domain specific reasoning on term, reinsert into egraph” loop for example, extract a lambda term, normalize it, reinsert.
AC matching in alt-ergo. Hard to understand this paper. AC matching in theorem provers. I believe vampire and E have special treatment of AC. This is a long running concern beyond that of the egg based egraph solvers. My Ac matching blog post https://www.philipzucker.com/relational-ac-matching/ I realized that bottom up ematching was useful when i stumbled on it here https://www.philipzucker.com/ground-rewrite-2/ . This was phrased in terms of the ground rewrite system idea, which is perhaps ocmplicated in its own right.
When you’re building an egraph saturater, you need to build an egraph, an ematcher, and an extracter.
Bottom up ematching also feels somewhat similar to the aegraph technique, which also has some bottom up flavor. Maybe I’m firing on nothing though.
One feels that there are good strategems for dealing with many cases. Many rewrite systems are decent when applied greedily.
Many many formulations have some kind of rule that has associativty and/or commutativty show up. Of course mathemtical expressions, but also in many cases you have sequences represented by lists. You need to discover some magic association in order to apply the apprropriate rewrite rule. Sometimes you can get away with associating to the right and putting all your rules in that form explicitly carying the tail parameter. foo(1, foo(2,X)) -> bar(7, X)
There is a paradox here. Commutativity is the canonical unorientable rule. It’s too symmetrical. The promise of egraph rewriting is that you don’t care about which direction you orient your equations that much and also you never commit to anything by applying a rewrite.
There is a combinatorial explosion associated with associativity and associativity that feels quite painful. Such simple and common rules! One feels that surely there is some way to special case them.
There is also something to be said for doing this style of pattern matching as an idiom rather than a library. I’ve gone done a rabbithole of trying to libraryize this idea and I think it was not good.
Idioms are kind of a bummer because the library author is doing less work for you. They are awesome because they use ordinary constructs of your language and are more tweakable as the need arises. They are also not as declarative
Reminscent of F-algebra talk
Another idea I rather like is that of the “mergedict”. The python defaultdict
is quite useful. It gives a deefault value to key lookup. “mergedict” is a related idea that the dict could apply a “merge” function to combine values when you write inconsistent data to the dict. Probably merge should often be some kind of lattice operation like min, max, etc. Counter could perhaps be a merge dict.
from collections import UserDict
class MergeDict(UserDict):
def __init__(self, mergefunc):
self.mergefunc = mergefunc
def __getitem__(self, key):
def __setitem__(self, key, value):
self.dict[key] = value
self.uf.makeset(key)
# if I do it in order, I can store a seen index
# I can keep a separate run legnth enocding array.
for x in uf:
for y in uf:
try:
e = f(g(x))
except e:
pass
egraph = {}
def rebuild(egraph):
for k,v in egraph:
egraph[norm(k)] = uf.find(v)
#with egraph as f,g,h:
f,g,h = egraph.functions("f g h")
Simple examples.
+
with constant prop?
grammar = """
"""
class EGraph:
uf: List[int]
funcs: Dict[str,Dict[Any, int]]
def find(self,x):
while x != self.uf[x]:
x = self.uf[x]
return x
def union(self,x,y)
x = self.find(x)
y = self.find(y)
if x != y:
self.uf[x] = y
return y
def rebuild(self):
for f,tbl in self.funcs.items():
for x,y in tbl.items()
x1 = map(self.find, x)
y1 = find(y)
tbl[x1] = y1
def matcher(self, name, *args):
def res(*args):
#if len(args) == 1 and args[0] == None:
# return None
#return self.funcs[name].get[map(self.find,args)]
return self.funcs[name][map(self.find,args)]
return res
def add(self, name, *args):
if name not in self.funcs:
self.funcs[name] = {} # MergeDict
if args not in self.funcs[name]:
self.funcs[name][args] = len(self.uf)
self.uf.append(len(self.uf))
return self.funcs[name][args]
def bind(self, varnum, f):
for x in range(len(self.uf)):
pass
def __iter__(self):
return iter(self.uf)
egraph = Egraph()
f = egraph.matcher("f")
a = egraph.matcher("a")
# This is nice. Shallow embedding of the environment.
for x in range(len(self.uf)): # self.funcs.values()
for y in range(len(self.uf)):
try:
t1 = egraph.matcher("f")
t2 = f(x)
egraph.union(t1,t2)
except KeyError:
pass
egraph.bind(2, lambda (x,y): egraph.union(f(x,y), f_(y,x))) # f[(x,y)] = f[(y,x)]
# [f[(x,y)] := g[z,w] for z in egraph for w in egraph]
egraph.rebuild()
class MergeDict(dict): # defaultdict
def __init__(self,merge,default):
self.merge = merge
self.default = default
def __getitem__(self, key):
if key in self:
return super().__getitem__(key)
else:
super().__setitem__(key, self.default) # ?
return self.default
def __setitem__(self, key, value):
if key in self:
super().__setitem__(key, self.merge(super().__getitem__(key), value))
else:
super().__setitem__(key, value)
egraph = {
("a", ()) : 1,
("b", ()) : 2,
("+", (1,2)) : 3
}
# eh, why bother.