Last time we spent a bit of effort building an e-graph data structure in Julia. The e-graph is a data structure that compactly stores and maintains equivalence classes of terms.

We can use the data structure for equational rewriting by iteratively pattern matching terms inside of it and adding new equalities found by instantiating those patterns. For example, we could use the pattern ~x + 0 to instantiate the rule ~x + 0 == ~x, find that ~x binds to foo23(a,b,c) in the e-graph and then insert into the e-graph the equality foo23(a,b,c) + 0 == foo23(a,b,c) and do this over and over. The advantage of the e-graph vs just rewriting terms is that we never lose information while doing this or rewrite ourselves into a corner. The disadvantage is that we maintain this huge data structure.

The rewrite rule problem can be viewed something like a graph. The vertices of the graph are every possible term. A rewrite rule that is applicable to that node describes a directed edge in the graph.

This graph is very large, infinite often, so it can only be described implicitly.

This graph perspective fails to capture some useful properties though. Treating each term as an indivisible vertex fails the capture that there can be rewriting happening in only a small part of the term, the vast majority of it left unchanged. There is a lot of opportunity for shared computations.

The EGraph from this perspective is a data structure for holding the already seen set of vertices efficiently and sharing these computations.

You can use this procedure for theorem proving. Insert the two terms you want to prove equal into the e-graph. Iteratively rewrite using your equalities. At each iteration, check whether two nodes in the e-graph have become equivalent. If so, congrats, problem proved! This is the analog of finding a path between two nodes in a graph by doing a bidirectional search from the start and endpoint. This is roughly the main method by which Z3 reasons about syntactic terms under equality in the presence of quantified equations. There are also methods by which to reconstruct the proof found if you want it.

Another application of basically the same algorithm is finding optimized rewrites. Apply rewrites until you’re sick of it, then extract from the equivalence class of your query term your favorite equivalent term. This is a very intriguing application to me as it produces more than just a yes/no answer.

Pattern Matching

Pattern matching takes a pattern with named holes in it and tries to find a way to fill the holes to match a term. The package SymbolicUtils.jl is useful for pattern matching against single terms, which is what I’ll describe in this section, but not quite set up to pattern match in an e-graph.

Pattern matching against a term is a very straightforward process once you sit down to do it. Where it gets complicated is trying to be efficient. But one should walk before jogging. I sometimes find myself so overwhelmed by performance considerations that I never get started.

First, one needs a data type for describing patterns (well, one doesn’t need it, but it’s nice. You could directly use pattern matching combinators).

struct PatVar
    id::Symbol
end

struct PatTerm
    head::Symbol
    args::Array{Union{PatTerm, PatVar}}
end

Pattern = Union{PatTerm,PatVar}

There is an analogous data structure in SymbolicUtils which I probably should’ve just used. For example Slot is basically PatVar.

The pattern matching function takes a Term and Pattern and returns a dictionary Dict{PatVar,Term} of bindings of pattern variables to terms or nothing if it fails

Matching a pattern variable builds a dictionary of bindings from that variable to a term.

match(t::Term,  p::PatVar ) = Dict(p => t)

Matching a PatTerm requires we check if anything is obviously wrong (wrong head, or wrong number of args) and then recurse on the args. The resulting binding dictionaries are checked for consistency of their bindings and merged.


function merge_consistent(ds)
    newd = Dict()
    for d in ds
        for (k,v) in d
            if haskey(newd,k)
                if newd[k] != v
                    return nothing
                end
            else
                newd[k] = v
            end
        end
    end
    return newd
end

function match(t::Term, p::PatTerm) 
    if t.head != p.head || length(t.args) != length(p.args)
        return nothing
    else
        merge_consistent( [ match(t1,p1) for (t1,p1) in zip(t.args, p.args) ])
    end
end

There are other ways of arranging this computation, such as passing a dictionary parameter along and modifying it, by I find the above purely functional definition the most clear.

Pattern Matching in E-Graphs

The twist that E-Graphs throw into the pattern matching process is that instead of checking whether a child of a term matches the pattern, we need to check for any possible child matching in the equivalence class of children. This means our pattern matching procedure contains a backtracking search.

The de Moura and Bjorner e-matching paper describes how to do this efficiently via an interpreted virtual machine. An explicit virtual machine is important for them because the patterns change during the Z3 search process, but it seems less relevant for the use case here or in egg. It may be better to use partial evaluation techniques to remove the existence of the virtual machine at runtime like in A Typed, Algebraic Approach to Parsing but I haven’t figured out how to do it yet.

The Simplify paper describes a much simpler inefficient pattern matching algorithm in terms of generators. This can be directly translated to Julia using Channels. This is a nice blog post describing how to build generators in Julia that work like python generators. Basically, as soon as you define your generator function, wrap the whole thing in a Channel() do c and when you want to yield myvalue call put!(c, myvalue).

# https://www.hpl.hp.com/techreports/2003/HPL-2003-148.pdf
function ematchlist(e::EGraph, t::Array{Union{PatTerm, PatVar}} , v::Array{Id}, sub)
    Channel() do c
        if length(t) == 0
            put!(c, sub)
        else
            for sub1 in ematch(e, t[1], v[1], sub)
                for sub2 in ematchlist(e, t[2:end], v[2:end], sub1)
                    put!(c, sub2)
                end
            end
        end
    end
end
# sub should be a map from pattern variables to Id
function ematch(e::EGraph, t::PatVar, v::Id, sub)
    Channel() do c
        if haskey(sub, t)
            if find_root!(e, sub[t]) == find_root!(e, v)
                put!(c, sub)
            end
        else
            put!(c,  Base.ImmutableDict(sub, t => find_root!(e, v)))
        end
    end
end

    
function ematch(e::EGraph, t::PatTerm, v::Id, sub)
    Channel() do c
        for n in e.classes[find_root!(e,v)].nodes
            if n.head == t.head
                for sub1 in ematchlist(e, t.args , n.args , sub)
                    put!(c,sub1)
                end
            end
        end
    end
end

You can then instantiate patterns with the returned dictionaries via


function instantiate(e::EGraph, p::PatVar , sub)
    sub[p]
end

function instantiate(e::EGraph, p::PatTerm , sub)
    push!( e, Term(p.head, [ instantiate(e,a,sub) for a in p.args ] ))
end

And build rewrite rules

struct Rule
    lhs::Pattern
    rhs::Pattern
end

function rewrite!(e::EGraph, r::Rule)
    matches = []
    EMPTY_DICT2 = Base.ImmutableDict{PatVar, Id}(PatVar(:____),  Id(-1))
    for (n, cls) in e.classes
        for sub in ematch(e, r.lhs, n, EMPTY_DICT2)
            push!( matches, ( instantiate(e, r.lhs ,sub)  , instantiate(e, r.rhs ,sub)))
        end
    end
    for (l,r) in matches
        union!(e,l,r)
    end
    rebuild!(e)
end

Here’s a very simple equation proving function that takes in a pile of rules

function prove_eq(e::EGraph, t1::Id, t2::Id , rules)
    for i in 1:3
        if in_same_set(e,t1,t2)
            return true
        end
        for r in rules
            rewrite!(e,r) # I should split this stuff up. We only need to rebuild at the end
        end
    end
    return nothing
end

As sometimes happens, I’m losing steam on this and would like to work on something different for a bit. But I loaded up the WIP at https://github.com/philzook58/EGraphs.jl.

Bits and Bobbles

Catlab

You can find piles of equational axioms in Catlab like here for example

A tricky thing is that some equational axioms of categories seem to produce variables out of thin air. Consider the rewrite rule ~f => id(~A) ⋅ ~f. Where does the A come from? It’s because we’ve highly suppressed all the typing information. The A should come from the type of f.

I think the simplest way to fix is the “type tag” approach I mentioned here https://www.philipzucker.com/theorem-proving-for-catlab-2-lets-try-z3-this-time-nope/ and can be read about here https://people.mpi-inf.mpg.de/~jblanche/mono-trans.pdf. You wrap every term in a tagging function so that f becomes tag(f, Hom(a,b)). This approach makes sense because it turns GATs purely equational without the implicit typing context, I think. It is a bummer that it will probably choke the e-graph with junk though.

It may be best to build a TypedTerm to use rather than Term that has an explicit field for the type. This brings it close to the representation that Catlab already uses for syntax, so maybe I should just directly use that. I like having control over my own type definitions though and Catlab plays some monkey business with the Julia level types. :(

struct TypedTerm
    type
    head
    args
end

As I have written it so far, I don’t allow the patterns to match on the heads of terms, although there isn’t any technical reason to not allow it. The natural way of using a TypedTerm would probably want to do so though, since you’d want to match on the type sometimes while keeping the head a pattern. Another possible way to do this that avoids all these issues is to actually make terms a bit like how regular Julia Expr are made, which usually has :call in the head position. Then by convention the first arg could be the actual head, the second arg the type, and the rest of the arguments the regular arguments.

A confusing example sketch:

TypedTerm(:mcopy, typedterm(Hom(a,otimes(a,a))) , [TypedTerm(a,Ob,[])]) would become Term(:call, [:mcopy, term!(Hom(a,otimes(a,a)) , term!(a) ])

Catlab already does some simplifications on it’s syntax, like collecting up associative operations into lists. It is confusing how to integrate the with the e-graph structure so I think the first step is to just not. That stuff isn’t on by default, so you can get rid of it by defining your own versions using @syntax

using Catlab
using Catlab.Theories
import Catlab.Theories: compose, Ob, Hom
@syntax MyFreeCartesianCategory{ObExpr,HomExpr} CartesianCategory begin
  #otimes(A::Ob, B::Ob) = associate_unit(new(A,B), munit)
  #otimes(f::Hom, g::Hom) = associate(new(f,g))
  #compose(f::Hom, g::Hom) = associate(new(f,g; strict=true))

  #pair(f::Hom, g::Hom) = compose(mcopy(dom(f)), otimes(f,g))
  #proj1(A::Ob, B::Ob) = otimes(id(A), delete(B))
  #proj2(A::Ob, B::Ob) = otimes(delete(A), id(B))
end

# we can translate the axiom sets programmatically.
names(Main.MyFreeCartesianCategory)
#A = Ob(MyFreeCartesianCategory, :A)

A = Ob(MyFreeCartesianCategory, :A)
B = Ob(MyFreeCartesianCategory, :B)
f = Hom(:f, A,B)
g = Hom(:g, B, A)
h = compose(compose(f,g),f)
# dump(h)
dump(f)

#=
Main.MyFreeCartesianCategory.Hom{:generator}
  args: Array{Any}((3,))
    1: Symbol f
    2: Main.MyFreeCartesianCategory.Ob{:generator}
      args: Array{Symbol}((1,))
        1: Symbol A
      type_args: Array{GATExpr}((0,))
    3: Main.MyFreeCartesianCategory.Ob{:generator}
      args: Array{Symbol}((1,))
        1: Symbol B
      type_args: Array{GATExpr}((0,))
  type_args: Array{GATExpr}((2,))
    1: Main.MyFreeCartesianCategory.Ob{:generator}
      args: Array{Symbol}((1,))
        1: Symbol A
      type_args: Array{GATExpr}((0,))
    2: Main.MyFreeCartesianCategory.Ob{:generator}
      args: Array{Symbol}((1,))
        1: Symbol B
      type_args: Array{GATExpr}((0,))
=#

Rando thought: Can an EGraph data structure itself be expressed in Catlab? I note that IntDisjointSet does show up in a Catlab definition of pushouts. The difficulty and scariness of the EGraph is maintaining it’s invariants, which may be expressible as some kind of composition condition on the various maps that make up the EGraph.

The rewrite rule problem can be viewed something like a graph. The vertices of the graph is every possible term. A rewrite rule that is applicable to that node describes a directed edge in the graph.

This graph is very large, infinite often, so it can only be described implicitly.

Proving the equivalence of two terms is like finding a path in this graph. You could attempt a greedy approach or a breadth first search approach, or any variety of your favorite search algorithm like A*.

This graph perspective fails to capture some useful properties though. Treating each term as an indivisible vertex fails the capture that there can be rewriting happening in only a small part of the term, the vast majority of it left unchanged. There is a lot of opportunity for shared computations.

The EGraph from this perspective is a data structure for holding the already seen vertices.

I suspect some heuristics to be helpful like applying the “simplifying” direction of the equations more often than the “complicating” direction. In a 5:1 ratio let’s say.

A natural algorithm to consider for optimizing terms is to take the best expression found so far, destroy the e-graph and place just that expression in it. Try rewrite rules. If the best is still old query, don’t destroy e-graph and apply a new round of rewrites the widen the radius of the e-graph. Gives you kind of a combo greedy + complete.

Partial Evaluation

Partial evaluation seems like a good technique for optimizing pattern matches. The ideal pattern match code expands out to a nicely nested set of if statements. One alternative to passing a dictionary might be to assign each pattern variable to an integer at compile time and instead pass an array, which would be a bit better. However, by using metaprogramming we can insert local variables into generated code and avoid the need for a dictionary being based around at runtime. Then the julia compiler can register allocate and what not like it usually does (and quite efficiently I’d expect).

See this post (in particular to reference links https://www.philipzucker.com/metaocaml-style-partial-evaluation-in-coq/ for more on partial evaluation.

A first thing to consider is that we’re building up and destroying dictionaries at a shocking rate.

Secondly dictionaries themselves are (relatively) expensive things.

I experimented a bit with using a curried form for match to see if maybe the Julia compiler was smart enough to sort of evaluate away my use of dictionaries, but that did not seem to be the case.

I found the examining the @code_llvm and @code_native of the following simple experiments illuminating as to what Julia can and can no get rid of when the dictionary is obviously known at compile time to human eyes. Mostly it needs to be very obvious. I suspect Julia does not have any built special compiler magic reasoning for dictionaries and trying to automatically infer what’s safely optimizable by actually expanding the definition of a dictionary op sounds impossible.

function foo() #bad
    x = Dict(:a => 4)
    return x[:a]
end
function foo2()
    return 4
end


function foo3() # bad
    x = Dict(:a => 4)
    () -> x[:a]
end

function foo4() # much better. But we do pay a closure cost.
    x = Dict(:a => 4)
    r = x[:a]
    () -> r
end

function foo5()
    x = Dict(:a => 4)
    :( () -> $(x[:a]) )
end

function foo7()
    return foo()
end


function foo6() # julia does however do constant propagation and deletion of unnecessary bullcrap
    x = 4
    y = 7
    y += x
    return x
end

function foo7() # this compiles well
    x = Base.ImmutableDict( :a => 4)
    return x[:a]
end

function foo8()
    x = Base.ImmutableDict( :a => 4)
    r = x[:a]
    z -> z == r # still has a closure indirection. It's pretty good though
end

function foo9()
    x = Base.ImmutableDict( :a => 4)
    r = x[:a]
    z -> z == 4 
end


#@code_native foo7()

z = foo4()
#@code_llvm z()

z = eval(foo5())
@code_native z()
@code_native foo2()
@code_llvm foo6()
z = foo8()
@code_native z(3)
z = foo9()
@code_native z(3)
@code_native foo7()

# so it's possible that using an ImmutableDict is sufficient for julia to inline almost evertything itself.
# You want the indexing to happen outside the closure.

A possibly useful technique is partial evaluation. We have in a sense built an interpreter for the pattern language. Specializing this interpeter gives a fast pattern matcher. We explicitly can build up the code Expr that will do the pattern matching by interpreting the pattern.

So here’s a version that takes in a pattern and builds code the perform the pattern in terms of if then else statements and runtime bindings.

Here’s a sketch. This version only returns true or false rather than returning the bindings. I probably need to cps it to get actual bindings out. I think that’s what i was starting to do with the c parameter.

function prematch(p::PatVar, env, t, c )
    if haskey(env, p)
       :(  $(env[p]) != $t  && return false )
    else
       env[p] = gensym(p.id) #make fresh variable
       :(  $(env[p]) = $t  ) #bind it at runtime to the current t
    end
end


function sequence(code)
    foldl(
        (s1,s2) -> quote
                        $s1
                        $s2
                    end
        , code)
end

function prematch(p::PatTerm, env, t, c)
    if length(p.args) == 0
        :( $(t).head == $( QuoteNode(p.head)  ) && length($t) == 0 && return false   )
    else
        quote
            if $(t).head != $( QuoteNode(p.head) ) || $( length(p.args) ) != length($(t).args)
                return false
            else
                $( sequence([prematch(a, env, :( $(t).args[$n] ), c)   for (n,a) in enumerate(p.args) ]))
            end
        end 
    end
end


println( prematch( PatTerm(:f, []) , Dict(),  :w , nothing) )
println( prematch( PatTerm(:f, [PatVar(:a)]) , Dict(),  :t , nothing) )
println( prematch( PatTerm(:f, [PatVar(:a), PatVar(:a)]) , Dict(),  :t , nothing) )