Unification is the logic-y word for equation solving. It is finding terms to fill in the holes of an implicitly existential quantified goal equation formulas like 3*X + 4*Y = 7 or foo(X) = foo(bar(Y)).

In first order syntactic unification, we can kind of follow our nose going down the term. I discussed this more previously https://www.philipzucker.com/unify/

We can also bolt in various extra features into our equation solver, for example intrinsic understanding of linear algebra, booleans, or algebraic properties like associativity or commutativity.

One pile of features you can toss on unification are related to having other notions of variable besides unification holes in the terms. Examples of where you might want this include

  • Integral expressions $\int dx e^{-x^2}$ have dummy variables
  • Sums $\sum_i a_i$ have dummy indices
  • lambdas $\lambda x. (x x)$, have bound variables
  • $min_x f(x)$
  • indexful tensor expressions $T_{ijk}\epsilon_{ij}$
  • confusingly modelling logic itself, with quantifiers $\forall x. P(x)$ and $\exists x. P(x)$

Vaguely this set of features is called higher order unification because the people who work on it mostly have in mind a lambda calculus.

There is some useful distinction to be made though. The point of the lambda calculus is substitution $(\lambda x. foo(x,x))10 \rightarrow foo(10,10)$, aka beta reduction. This is a useful feature, but it possibly is too much to ask for. It turns out that full lambda unification is undecidable, meaning you can encode very difficult problems into it.

A separate thing to ask for is to deal with dummy variables properly aka well-scoped alpha equivalence. For example, syntactically unifying $\int dx e^{-x^2} = \int dy e^{-y^2}$ shouldn’t fail.

There are two related approaches here

  • Nominal unification
  • Higher order pattern unification aka Miller pattern unification

Pattern unification I think is actually fairly simple. It is almost the obvious thing to do (in hindsight of course) if you want to extend first order unification but keep things well scoped.

Try it out on colab: https://colab.research.google.com/github/philzook58/philzook58.github.io/blob/master/pynb/2024-11-11-ho_unify.ipynb

Pattern Matching

Before discussing unification, you should always try pattern matching first. It’s way simpler, probably faster and more useful.

Pattern matching is like one-way unification, or unification is like 2-way pattern matching.

First off, let us note that the following problem cannot be solved

$\exists T. (\lambda x. f(x,x) = \lambda x. T)$

What the heck am I talking about? “OBVIOUSLY”, $T = f(x,x)$, right? No, because T is bound outside of the lambda that binds $x$. This solution has $x$ escape it’s scope. It has bad hygiene. This statement gets more confusing when you leave the $\exists$ implicit and write $\lambda x. f(x,x) ?= \lambda x. T$.

Okay, well how about

$\exists T. (\lambda x. f(x,x) = \lambda x. T(x))$

Yes, this is perfectly solvable as $T = \lambda y. f(y,y)$.

What about this? $\exists T. (\lambda x. f(42,42) = \lambda x. T(42))$

Well, actually this has a couple solutions. $T = \lambda y. f(42,42)$ or $T = \lambda y. f(y,y)$ or $T = \lambda y. f(42,y)$ or $T = \lambda y. f(y,42)$.

Higher order pattern unification is a subset of higher order unification that forbids combos that have multiple possible solutions. In higher order pattern unification, variables may only be applied to unique rigid bound variables, not to arbitrary terms. $\lambda x y. F[x,y]$ is ok. $\lambda x y. 1 + F[y]$ is ok. $\lambda x y. F[x,x]$ is not. $\lambda x,y, F[g(x),y]$ is not.

The reason we do this restriction is it is quite obvious how to solve F[x,y] = t when this is the case. We abstract out any x y appearing inside of t. Basically F = Lambda([x,y], t). If there are still scoped variables in there, we fail, since this would have variables escape their scope. This is the minimal simple thing we can toss into first order unification to maintain hygiene.

Really, the combination F[x,y] should be thought of as a single entity rather than a unification variable “beta applied” to a bound variable. I tend to not even think of pattern unification as lesser higher order unification at all. It’s a different thing that deals with well scoped alpha equivalence, not full beta equivalence (which is too much to ask for for some purposes).

Sometimes types can tell a whole story. Ivan once said something like, “show me the types, and I don’t need to see your implementation”. The ocaml datatype of a regular first order term with variables would be

type uvar = string
type term_fo =
    | UVar of uvar
    | Const of string * term_fo list

The ocaml datatype for higher order pattern terms would be

type bvar = int
type term_pat = 
    | UVar of uvar * bvar list (* When we hit the unification var, we know exactly what to do. *)
    | Binder of term_pat (* binder represented as de Bruijn *)
    | BVar of int
    | Const of string * term_pat list

And NOT like this (which would be a reasonable term type for full higher order unification)

type uvar = string
type bvar = int
type term_ho = 
    | UVar of uvar
    | Lam of term_ho
    | BVar of int
    | App of term_ho * term_ho (* app is playing two purposes here. *)
    | Const of string

We can see that kind of we’ve inlined the app node and that the HO pattern term_pat is a closer relative of the first order term than the full higher order term_ho.

We make a distinction between the binders of our term_pat type and different kind of binders or application of the domain we are modelling inside our terms. A modelled lambda calculus using term_pat could be represent it’s notation of application as a constant Const("app", [f; x]). Our built in syntax manipulation do not intrinsically understand this, just as it may not intrinsically understand properties of arithmetic or any other domain specific theory.

Ok Let’s do it on Z3’s AST

I’ve complained before that python doesn’t have a nice lambda manipulation library. That is not true. z3 is that library.

We still need some preliminaries dealing with lambda term.

Z3 AST’s are sadly not alpha equivalent. This is not that surprising, making alpha equiv fast is hard

! pip install z3-solver
from z3 import *
x,y = Reals("x y")
Lambda([x], x).eq(Lambda([y], y))
False

In the locally nameless approach to binders, you represent bound variables using de bruijn indices and free variables as fresh constants (or alternatively de bruijn levels). The lesson is that whenever you need to traverse a term under a binder, you should open the binder by replacing it’s variables with fresh constants. If you maintain this discipline, everything just kind of works. I don’t think I really know a deep reason why this is the right thing to do, but let’s take it on faith.

Z3 thank god already deals with the de bruijn indices. That would be a blog post on it’s own. It also offers convenient FreshConst functions.

A simple open_binder function gives us the body of a quantifier and a list of the new free variables.

VList = list[z3.ExprRef]
def open_binder(lam: z3.QuantifierRef) -> tuple[VList, z3.ExprRef]:
    vs = [
        z3.FreshConst(lam.var_sort(i), prefix=lam.var_name(i))
        for i in range(lam.num_vars())
    ]
    return vs, z3.substitute_vars(lam.body(), *reversed(vs))

Alpha equivalence is kind of the obvious traversal zipping together the two terms. The interesting bit maybe is when you hit a quantifier. You open one and then instantiate the other with the same free variables. Alternatively, you could open both and maintain a map saying which free vars correspond to which.

def quant_kind_eq(t1, t2):
    """Check both quantifiers are of the same kind"""
    return t1.is_forall() == t2.is_forall() and t1.is_exists() == t2.is_exists() and t1.is_lambda() == t2.is_lambda()

def alpha_eq(t1, t2):
    if t1.eq(t2): # fast path
        return True
    elif is_quantifier(t1) and is_quantifier(t2):
        if quant_kind_eq(t1,t2) and t1.num_vars() == t2.num_vars() and [t1.var_sort(i) == t2.var_sort(i) for i in range(t1.num_vars())]: 
            vs, body1 = open_binder(t1)
            body2 = substitute_vars(t2.body(), *reversed(vs))
            return alpha_eq(body1, body2)
        else:
            return False
    elif is_app(t1) and is_app(t2):
        if t1.decl() == t2.decl():
            return all(alpha_eq(t1.arg(i), t2.arg(i)) for i in range(t1.num_args()))
        else:
            return False
    else:
        raise Exception(f"Unexpected terms in alpha_eq", t1, t2)
    # could instead maybe use a solver check or simplify tactic on Goal(t1 == t2)
    
assert alpha_eq(Lambda([x], x), Lambda([y], y))
assert not alpha_eq(ForAll([x], x == x), Exists([y], y == y))
t = Lambda([x,y], x + y)
vs, body = open_binder(t)
assert alpha_eq(t, Lambda(vs, body))

Ok now onto pattern matching. Previously, I showed how to write first order pattern matching.

The pattern matching variables are the constants in the list vs.

You can write it in different style. Here I use a todo queue and loop. If you hit a variable in the, look it up in the subst if its there, or add it to the subst. Otherwise, make sure the heads match and then match all the children of the pattern and term zipped together.

from z3 import *


def pmatch_fo(vs : VList, pat, t):
    if pat.sort() != t.sort():
        raise Exception("Sort mismatch", pat, t)
    subst = {}
    todo = [(pat, t)]
    def is_var(x):
        return any(x.eq(v) for v in vs)
    while todo:
        pat, t = todo.pop()
        if is_var(pat): # regular pattern
            if pat in subst:
                if not subst[pat].eq(t):
                    return None
            else: 
                subst[pat] = t
        elif is_app(pat):
            if not is_app(t) or pat.decl() != t.decl():
                return None
            todo.extend(zip(pat.children(), t.children())) 
        else:
            raise Exception("Unexpected pattern", t, pat)
    return subst

x,y,z = Ints("x y z")

assert pmatch_fo([x], x, IntVal(4)) == {x: IntVal(4)}
assert pmatch_fo([x], IntVal(3), IntVal(3)) == {}
assert pmatch_fo([x], IntVal(3), IntVal(4)) == None
assert pmatch_fo([x], x, IntVal(3)) == {x : IntVal(3)}
assert pmatch_fo([x], x + x, IntVal(3) + IntVal(4)) == None
assert pmatch_fo([x], x + x, IntVal(3) + IntVal(3)) == {x : IntVal(3)}
assert pmatch_fo([y], x + x, IntVal(3) + IntVal(3)) == None
assert pmatch_fo([x,y], x + y, IntVal(3) + IntVal(4)) == {x : IntVal(3), y : IntVal(4)}

The higher order pattern match adds a couple cases.

When we open binders, we note that these should not escape into pattern variables by storing them in noescape.

When we finally resolve a variable, we check that the thing it matched is free of the noescape variables. This is similar to an occurs check in some respects.

When we hit a pattern of F[x1,x2,x3] = t, we process it into F = Lambda([x1,x2,x3], t) and continue. That is basically the magic of the miller pattern and the way you can smuggle terms containing bound vars out. You build a lambda abstraction to ferry them.

Z3 has an AST closer to the 3rd term_ho I mentioned above. This is why it’s slightly ugly that I need to peak into select to see if there is a pattern variable in there.

from z3 import *


def pmatch(vs, pat, t):
    if pat.sort() != t.sort():
        raise Exception("Sort mismatch", pat, t)
    subst = {}
    todo = [(pat, t)]
    no_escape = []
    def is_var(x):
        return any(x.eq(v) for v in vs)
    def check_escape(x):
        if any(x.eq(v) for v in no_escape):
            return False
        else:
            return all(check_escape(c) for c in x.children())
    while todo:
        pat, t = todo.pop()
        if is_var(pat): # regular pattern
            if pat in subst:
                if not alpha_eq(subst[pat], t):
                    return None
            else: 
                if check_escape(t): # check_escape is relative of occurs_check
                    subst[pat] = t
                else:
                    return None
        elif is_select(pat) and is_var(pat.arg(0)): #  higher order pattern. "select" is smt speak for apply.
                F = pat.arg(0)
                allowedvars = pat.children()[1:]
                if any(v not in no_escape for v in allowedvars):
                    raise Exception("Improper higher order pattern", pat) # we could relax this to do syntactic unification here.
                t1 = Lambda(allowedvars, t) # abstract out the allowed vars
                todo.append((F, t1))
        #elif pat.eq(t): early stopping. Not allowed if t contains a pattern variable.
        #if part.sort() != t.sort(): return None. another fast break 
        elif is_quantifier(pat):
            if not is_quantifier(t) or not quant_kind_eq(t,pat) or t.num_vars() != pat.num_vars():
                return None
            vs1, patbody = open_binder(pat)
            no_escape.extend(vs1)
            tbody = substitute_vars(t.body(), *reversed(vs1))
            todo.append((patbody, tbody))
        elif is_app(pat):
            if not is_app(t) or pat.decl() != t.decl():
                return None
            todo.extend(zip(pat.children(), t.children())) 
        else:
            raise Exception("Unexpected pattern", t, pat)
    return subst

x,y,z = Ints("x y z")
F,G = Consts("F G", ArraySort(IntSort(), IntSort()))
assert pmatch([x], x, IntVal(4)) == {x: IntVal(4)}
assert pmatch([x], IntVal(3), IntVal(3)) == {}
assert pmatch([x], IntVal(3), IntVal(4)) == None
assert pmatch([x], x, IntVal(3)) == {x : IntVal(3)}
assert pmatch([x], x + x, IntVal(3) + IntVal(4)) == None
assert pmatch([x], x + x, IntVal(3) + IntVal(3)) == {x : IntVal(3)}
assert pmatch([y], x + x, IntVal(3) + IntVal(3)) == None
assert pmatch([x,y], x + y, IntVal(3) + IntVal(4)) == {x : IntVal(3), y : IntVal(4)}

# alpha equiv terms should pmatch
assert pmatch([], Lambda([x], x == x), Lambda([y], y == y)) == {}
t = Lambda([x,y], x + y)
vs, body = open_binder(t)
assert pmatch([], t, Lambda(vs, body)) == {}

assert alpha_eq(pmatch([F], Lambda([x], F[x]), Lambda([x], x + 3))[F],
                Lambda([x], x + 3))
assert alpha_eq(pmatch([F], Lambda([x], F[x]), Lambda([y], y + 3))[F], 
                Lambda([z], z + 3))
assert alpha_eq(pmatch([F], Lambda([x], F[x]), Lambda([x], G[x]))[F], 
                Lambda([x], G[x]))

# Failing examples
# should we allow this? 
# pmatch([F], F[3], G[3]). Seems obvious what the answer should be {F:G}, but we're opening up a can of worms
assert pmatch([F], Lambda([x,y], F[x]), Lambda([x,y], G[y])) == None
assert pmatch([F], Lambda([x,y], F), Lambda([x,y], Lambda([z], x + 3))) == None

# This is the sort of thing you have to do if you want to apply an induction principle about (forall P) to a goal.
P = Const("P", ArraySort(IntSort(), BoolSort()))
assert alpha_eq(pmatch([P], ForAll([x], P[x]), ForAll([y], Or(y == 0, y > 0)))[P], 
                Lambda([z], Or(z == 0, z > 0)))

assert pmatch([F,G], Lambda([x,y],F[y] + F[y]), Lambda([x,y], x + y)) == None
assert pmatch([F,G], Lambda([x,y],F[y] + F[x]), Lambda([x,y], x + y)) == None
assert alpha_eq(pmatch([F,G], Lambda([x,y],F[x] + F[y]), Lambda([x,y], x + y))[F],
                Lambda([x], x))

Unification

Most of the difficulties that occur with unification are basically already dealt with in first order unification.

I mentioned in my previous post that the eager substitution loopy form of unification is the most clear to my eye. Eager substitution can only be easily written if you write unification as a todo queue, where you have access to your “stack” to perform the appropriate substitutions immediately everywhere.

With higher order unification, everything is even more confusing, so I’m going to stick to that style.

This is the basic first order unify routine. UIn this form, it isn’t that different from the pattern match code. The main difference is substituting away the unification variables in the todo and subst when we find them. Even in the first order form, we have a check to see if our solution is ok.

def unify(vs, p1, p2):
    subst = {}
    todo = [(p1,p2)]
    def is_var(x):
        return any(x.eq(v) for v in vs)
    def occurs(x, t):
        if is_var(t):
            return x.eq(t)
        if is_app(t):
            return any(occurs(x, t.arg(i)) for i in range(t.num_args()))
        return False
    while todo:
        p1,p2 = todo.pop()
        if p1.eq(p2): # delete
            continue
        elif is_var(p1): # elim
            if occurs(p1, p2):
                return None
            todo = [(substitute(t1,(p1,p2)), substitute(t2,(p1,p2))) for (t1,t2) in todo]
            subst = {k : substitute(v, (p1,p2)) for k,v in subst.items()}
            subst[p1] = p2
        elif is_var(p2): # orient
            todo.append((p2,p1))
        elif is_app(p1): # decompose
            if not is_app(p2) or p1.decl() != p2.decl():
                return None
            todo.extend(zip(p1.children(), p2.children())) 
        else:
            raise Exception("unexpected case", p1, p2)
    return subst

vs = Ints("x y z")
x,y,z = vs
assert unify(vs,IntVal(3), IntVal(3)) == {}
assert unify(vs,IntVal(3), IntVal(4)) == None
assert unify(vs,x, IntVal(3)) == {x : IntVal(3)}
assert unify(vs, x, y) == {x : y}
assert unify(vs, x + x, y + y) == {x : y}
assert unify(vs, x + x, y + z) == {x : y, z : y}
assert unify(vs,x + y, y + z) == {x : z, y : z}
assert unify(vs,y + z, x + y) == {y : x, z : x}
assert unify(vs, (x + x) + x, x + (x + x)) == None
assert unify(vs, 1 + x, x) == None

Another piece of infrastructure we need is $\beta_0$ normalization. $beta_0$ normalization only performs substitution on rigid bound variables. It is the part of $\beta$ normalization we can deal with elegantly, but it’s enough to get us a lot of distance.

Basically, traverse the term and use z3 to do the substitutions.

When we substitute in a pattern variable during unification, we may create $\beta_0$ redexes. That’s why we might want this. Maybe it’s overkill.

from z3 import *
def beta0_norm(ctx,t):
    # reduce any lambdas applied to rigid variables
    if is_select(t) and is_quantifier(t.arg(0)): # (\x y z. body)[p1,p2,p3]
        params = t.children()[1:]
        if all(any(v.eq(v1) for v1 in ctx) for v in params):
            return beta0_norm(ctx, substitute_vars(t.arg(0).body(), *reversed(params)))
        else:
            return t.decl()(*[beta0_norm(ctx, x) for x in t.children()])
    elif is_quantifier(t):
        vs, body = open_binder(t)
        binder = Lambda if t.is_lambda() else Exists if t.is_exists() else ForAll
        return binder(*vs, beta0_norm(ctx + vs, body))
    elif is_app(t):
        return t.decl()(*[beta0_norm(ctx, x) for x in t.children()])
    else:
        raise Exception("Unexpected term in beta0_norm", t)

def substitute_beta0(ctx,t,*subs):
    t = substitute(t, *subs)
    return beta0_norm(ctx, t)

x,y,z = Ints("x y z")
F,G = Consts("F G", ArraySort(IntSort(), IntSort()))
beta0_norm([], Lambda([x], Lambda([y], 3 + y)[x]))
beta0_norm([], Lambda([x], Lambda([y], 3 + y)[4]))
substitute_beta0([], Lambda([x], F[x]), (F, Lambda([y], 3 + y)))

λx!34 : 3 + x!34

Unification looks very much like first order unification with the new quantifier and application dealing code spliced in.

def unify(vs, p1, p2):
    subst = {}
    todo = [(p1,p2)]
    no_escape = []
    def is_var(x):
        return any(x.eq(v) for v in vs)
    def occurs(x, t):
        if x.eq(t):
            return True
        elif any(x.eq(v) for v in no_escape):
            return True
        elif is_app(t):
            return any(occurs(x, t.arg(i)) for i in range(t.num_args()))
        return False
    while todo:
        # it can be instructive to print(todo) here
        p1,p2 = todo.pop()
        if p1.eq(p2): # delete
            continue
        elif is_var(p1): # elim
            if occurs(p1, p2):
                return None
            # use substitute_beta_0 instead of substitute
            todo = [(substitute_beta0(no_escape, t1,(p1,p2)), substitute_beta0(no_escape, t2,(p1,p2))) for (t1,t2) in todo]
            subst = {k : substitute_beta0(no_escape, v, (p1,p2)) for k,v in subst.items()}
            subst[p1] = p2
        elif is_select(p1) and is_var(p1.arg(0)): # higher order abstract. Unmodified from pattern matching basically
            allowedvars = p1.children()[1:]
            if any(v not in no_escape for v in allowedvars):
                raise Exception("Improper higher order pattern", p1, p2) # we could relax this to do syntactic unification here.
            p2 = Lambda(allowedvars, p2) # abstract out the allowed vars
            todo.append((p1.arg(0), p2))
        elif is_quantifier(p1): # also unchanged from pattern match
            if not is_quantifier(p2) or not quant_kind_eq(p1,p2) or p1.num_vars() != p2.num_vars():
                return None
            vs1, body1 = open_binder(p1)
            no_escape.extend(vs1)
            body2 = beta0_norm(no_escape,substitute_vars(p2.body(), *reversed(vs1)))
            todo.append((body1, body2))
        elif is_var(p2) or is_select(p2) and is_var(p2.arg(0)): # orient
            todo.append((p2,p1))
        elif is_app(p1): # decompose
            if not is_app(p2) or p1.decl() != p2.decl():
                return None
            todo.extend(zip(p1.children(), p2.children())) 
        else:
            raise Exception("unexpected case", p1, p2)
    return subst

vs = Ints("x y z")
x,y,z = vs

assert unify(vs,IntVal(3), IntVal(3)) == {}
assert unify(vs,IntVal(3), IntVal(4)) == None
assert unify(vs,x, IntVal(3)) == {x : IntVal(3)}
assert unify(vs, x, y) == {x : y}
assert unify(vs, x + x, y + y) == {x : y}
assert unify(vs, x + x, y + z) == {x : y, z : y}
assert unify(vs,x + y, y + z) == {x : z, y : z}
assert unify(vs,y + z, x + y) == {y : x, z : x}
assert unify(vs, (x + x) + x, x + (x + x)) == None
assert unify(vs, 1 + x, x) == None

assert alpha_eq(unify([F,G], Lambda([x], F[x]), Lambda([y], G[y]))[F],
                Lambda([x], G[x]))
assert unify([F,G], Lambda([x,y],F[y] + F[y]), Lambda([x,y], x + y)) == None
assert unify([F,G], Lambda([x,y],F[y] + F[x]), Lambda([x,y], x + y)) == None



assert alpha_eq(pmatch([F,G], Lambda([x,y],F[x] + F[y]), Lambda([x,y], x + y))[F],
                Lambda([x], x))
assert alpha_eq(unify([F,G], Lambda([x,y],F[x] + F[y]), Lambda([x,y], G[x] + G[y]))[F],
                Lambda([x], G[x]))

Bits and Bobbles

Thanks to Cody and Graham for useful discussions.

Useful for induction principles. I want an apply tactic in knuckledragger.

Implementing a lambda prolog. Next time!

Slotted egraphs are quite similar in flavor to miller patterns. Sort of turn your term into eta maximal form, which ferries all variables down through explicit lambdas .

I’ll note that supporting multi-arity lambdas is a useful feature that is perhaps under represented in toy lambda calculi. Implementations definitely want to treat fully applied functions in a manner different from repeated curried application (push-enter https://www.cs.tufts.edu/comp/150FP/archive/simon-peyton-jones/eval-apply-jfp.pdf , Leroy’s ZINC machine ). But also it lets you avoid the need for full beta reduction for some purposes. True lambda is the ulimate unneccessarily powerful and expensive thing. Ask for less. Use less. let also rules independent of lambda. Implementing let via lambda is insane.

Abstract binding trees. If our programming languages offered nice support for binders, it’d be good. PFPL by Bob Harper. https://semantic-domain.blogspot.com/2015/03/abstract-binding-trees.html

HOAS. Sometimes hoas is just using your host language lambdas for your DSL. In most languages, you can only apply these lambdas, not introspect them. You can introspect by reifying into a first order syntax, but now you’re doing all the work yourself.

Some languages (lambda prolog) have a built in notion of lambdas that you can break down. It’s nice.

Most proof assitants basically have higher order unification in them. It’s how their type checkers work / infer. It’s how dependent pattern matching infers. The core of isabelle (Pure) is basically bolting seuqents together with higher order unification.

Python and Lambdas

Python does let us check for equality on lambdas. It is object equality. We can also pattern match on them. The only pattern that’ll match is a default catch all.

Does this chill your blood? Why?

match (lambda x: x + 3):
    case f:
        print(f(4))
7
(lambda x: x) == (lambda y: y)
False
(lambda x: x) == (lambda x: x)
False
f = lambda x: x
f == f
True

We are used to lambda being a pretty opaque thing, but it doesn’t have to be.

lambda
  Cell In[13], line 1
    lambda
          ^
SyntaxError: invalid syntax
# z3 zipper list[FuncDeclRef, largs, rargs]

Manual beta0


def substitute_beta0(t,t1,t2):
    if t.eq(t1):
        return t2
    elif is_select(t) and t.arg(0).eq(t1):
        return substitute_vars(t2.body(), *reversed(t.children()[1:]))
    elif is_quantifier(t):
        vs, body = open_binder(t)
        binder = Lambda if t.is_lambda() else Exists if t.is_exists() else ForAll
        return binder(vs, substitute_beta_0(body, t1, t2))
    elif is_app(t):
        return t.decl()(*[substitute_beta_0(t.arg(i), t1, t2) for i in range(t.num_args())])
    else:
        raise Exception("Unexpected term in subst_beta_0", t)
#substitute_beta0(Lambda([x], F[x]), F, Lambda([y], 3 + y))
        """

We can store the avars e can hold Or we could have the things it is forbidden to hold? Or we could skolmize and make E a higher order variable

This is not going to be the most elegant because I wish I had a Var constructor to my term type. And it should take arguments.

type term =
   | App of decl * term list
   | Var of decl * term list (* var list *)
   | BVar of int
   | FVar of ident
   

from z3 import *

def prefix(goal):
    """
    Strip off quantifier prefix.
    """
    ectx = {}
    avars = []
    while is_quantifier(goal):
        body,vs = open_binder(goal)
        if goal.is_forall():
            avars.extend(vs)
        elif goal.is_exists():
            for v in vs:
                avs = avars.copy()
                ectx[v] = avs
        goal = body
    return goal,ectx,avars
            
def gen_unify(goal):
    goal, ectx, avars = prefix(goal)
    assert goal.decl().name() == "="
    lhs,rhs = goal.children()
    return unify(lhs,rhs,ectx,avars)


def abstract(t, ts):
    []



def occurs(naughties, t):
    if t in naughties:
        return True
    if is_app(t):
        return any(occurs(naughty, c) for c in t.children())
    return False

def unify(p1,p2,ectx,avars):
    subst = {}
    todo = [(p1,p2)]
    while todo:
        p1,p2 = todo.pop() # we could pop _any_ of the todos, not just the top.
        if p1.eq(p2): # delete
            continue
        if p1 in ectx: # regular is_var
        elif is_app(p1) and p1.decl().name() == "Select" and p1.get_arg(0) in ectx: # higher order match
            if occurs(avars - ectx[p1] + [p1], p2):
                return None
            todo = [(substitute(t1,(p1,p2)), substitute(t2,(p1,p2))) for (t1,t2) in todo]
            subst = {k : substitute(v, (p1,p2)) for k,v in subst.items()}
            subst[p1] = p2
        elif is_app(p2) and p2.decl() in vs: # orient
            todo.append((p2,p1))
        if is_quantifier(p1):
            if p1.is_lambda() and is_quantifier(p2) and p1.is_quantifier():
                b1,vs = open_binder(p1)
                b2 = substitute_vars(p2.body(), vs) # instantiate
                avars.extend(vs) # These variable shall not escape
                todo.append((b1,b2))
        elif is_app(p1): # decompose
            if not is_app(p2) or p1.decl() != p2.decl():
                return None
            todo.extend(zip(p1.children(), p2.children())) 
        else:
            raise Exception("unexpected case", p1, p2)
    return subst

from z3 import *
f = Const("f", ArraySort(IntSort(), IntSort(), IntSort()))
x = Int("x")
f[x, x].children()


[f, x, x]
from z3 import *
def occurs(x, t):
    if x.eq(t):
        return True
    if is_app(t):
        return any(occurs(x, t.arg(i)) for i in range(t.num_args()))
    return False

def unify(p1,p2, vs):
    subst = {}
    todo = [(p1,p2)]
    while todo:
        p1,p2 = todo.pop() # we could pop _any_ of the todos, not just the top.
        if p1.eq(p2): # delete
            continue
        elif is_app(p1) and p1.decl() in vs: # elim
            if occurs(p1, p2):
                return None
            todo = [(substitute(t1,(p1,p2)), substitute(t2,(p1,p2))) for (t1,t2) in todo]
            subst = {k : substitute(v, (p1,p2)) for k,v in subst.items()}
            subst[p1] = p2
        elif is_app(p2) and p2.decl() in vs: # orient
            todo.append((p2,p1))
        elif is_app(p1): # decompose
            if not is_app(p2) or p1.decl() != p2.decl():
                return None
            todo.extend(zip(p1.children(), p2.children())) 
        else:
            raise Exception("unexpected case", p1, p2)
    return subst

vs = Ints("x y z")
x,y,z = vs
assert unify(IntVal(3), IntVal(3), vs) == {}
assert unify(IntVal(3), IntVal(4), vs) == None
assert unify(x, IntVal(3), vs) == {x : IntVal(3)}
assert unify(x, y, vs) == {x : y}
assert unify(x + x, y + y, vs) == {x : y}
assert unify(x + x, y + z, vs) == {x : y, z : y}
assert unify(x + y, y + z, vs) == {x : z, y : z}
assert unify(y + z, x + y, vs) == {y : x, z : x}
# non terminating if no occurs check
assert unify((x + x) + x, x + (x + x), vs) == None
assert unify(1 + x, x, vs) == None
---------------------------------------------------------------------------

Z3Exception                               Traceback (most recent call last)

Cell In[1], line 35
     33 x,y,z = vs
     34 assert unify(IntVal(3), IntVal(3), vs) == {}
---> 35 assert unify(IntVal(3), IntVal(4), vs) == None
     36 assert unify(x, IntVal(3), vs) == {x : IntVal(3)}
     37 assert unify(x, y, vs) == {x : y}


Cell In[1], line 16, in unify(p1, p2, vs)
     14 if p1.eq(p2): # delete
     15     continue
---> 16 elif is_app(p1) and p1.decl() in vs: # elim
     17     if occurs(p1, p2):
     18         return None


File ~/.local/lib/python3.10/site-packages/z3/z3.py:1033, in ExprRef.__eq__(self, other)
   1031 if other is None:
   1032     return False
-> 1033 a, b = _coerce_exprs(self, other)
   1034 return BoolRef(Z3_mk_eq(self.ctx_ref(), a.as_ast(), b.as_ast()), self.ctx)


File ~/.local/lib/python3.10/site-packages/z3/z3.py:1237, in _coerce_exprs(a, b, ctx)
   1235 s = _coerce_expr_merge(s, b)
   1236 a = s.cast(a)
-> 1237 b = s.cast(b)
   1238 return (a, b)


File ~/.local/lib/python3.10/site-packages/z3/z3.py:2406, in ArithSortRef.cast(self, val)
   2404 else:
   2405     if self.is_int():
-> 2406         return IntVal(val, self.ctx)
   2407     if self.is_real():
   2408         return RealVal(val, self.ctx)


File ~/.local/lib/python3.10/site-packages/z3/z3.py:3243, in IntVal(val, ctx)
   3235 """Return a Z3 integer value. If `ctx=None`, then the global context is used.
   3236 
   3237 >>> IntVal(1)
   (...)
   3240 100
   3241 """
   3242 ctx = _get_ctx(ctx)
-> 3243 return IntNumRef(Z3_mk_numeral(ctx.ref(), _to_int_str(val), IntSort(ctx).ast), ctx)


File ~/.local/lib/python3.10/site-packages/z3/z3core.py:2378, in Z3_mk_numeral(a0, a1, a2, _elems)
   2376 def Z3_mk_numeral(a0, a1, a2, _elems=Elementaries(_lib.Z3_mk_numeral)):
   2377   r = _elems.f(a0, _str_to_bytes(a1), a2)
-> 2378   _elems.Check(a0)
   2379   return r


File ~/.local/lib/python3.10/site-packages/z3/z3core.py:1566, in Elementaries.Check(self, ctx)
   1564 err = self.get_error_code(ctx)
   1565 if err != self.OK:
-> 1566     raise self.Exception(self.get_error_message(ctx, err))


Z3Exception: b'parser error'

from dataclass import dataclass
@dataclass
class Var():
    x : ExprRef
    ctx : Context
    def collect_a(self):
        res = []
        x = self 
        while x != None:
            if isinstance(x, AVar):
                res.append(x)
            x = x.ctx
        return res
        
            


class EVar(Var):
    x : ExprRef
    ctx : Context
class AVar(Var):


def unify(avars, edecls, p1,p2):
    edecls = []

    def freshe():
        edecl = Function("e_", [x.sort() for x in avars])
        te = edecl(*avars)
        substitute_vars(t, te)


def unify(p1,p2 vctx):
    subst = {}
    ectx = {}
    todo = [([], p1,p2)]
    while todo:
        ctx,p1,p2 = todo.pop()
        if p1.eq(p2):
            continue
        if is_quantifier(p1):
            body, vs = open_binder(p1)
            if p1.is_exists():
                for v in vs:
                    ectx[v] = ctx
                todo.append((ctx, body, p2))
            elif p1.is_forall():
                todo.append((ctx + vs, body, p2))
            elif p1.is_lambda():
                todo.append((ctx + vs, body, p2))
        elif is_var(p1):
            todo = [(substitute(t1,(p1,p2)), substitute(t2,(p1,p2))) for (t1,t2) in todo]
            subst = {k : substitute(v, (p1,p2)) for k,v in subst.items()}
            subst[p1] = p2
        elif is_var(p2):
            todo.append((ctx,p2,p1))
        elif is_app(p1):
            if not is_app(p2) or p1.decl() != p2.decl():
                return None
            todo.extend(zip(p1.children(), p2.children())) 
        else:
            raise Exception("unexpected case", p1, p2)
    return subst

def unify(p1,p2, vctx):
    subst = {}
    todo = [(p1,p2)]
    while todo:
        p1,p2 = todo.pop()
        if p1.eq(p2):
            continue
        elif is_var(p1):
            if # do "occurs"
                in vctx[p1] # narrow the contx  of any var in rhs.
            todo = [(substitute(t1,(p1,p2)), substitute(t
                                                        
def unify_miller(p1,p2, vs):
    if p1.decl() in vs: # is_var
        
        #substitute(   p1.children(),  
        # abstract out children
        subst[p1.decl()] = Lambda(freshes, substitute(p1, zip(p1.children(), freshes)))
        t = Lambda(freshes, substitute(p1, zip(p1.children(), freshes)))
        check_no_escape(t)
        todo = [  , substitute(p1)]
        subst[p1.decl()] = t
    todo = []
    noescape = []
    if is_quantifier(p1)
        if is_quantifier(p2) and p1.num_vars() == p2.num_vars() and # sorts:
            body, vs = open_binder(p1)
            noescape += vs
            body2 = instantiate(p2, vs)
            todo.append((body,body2))


Exisst x, p(x) = exists y, p(y) # I guess we could allow this? Kind of odd
ex(lam x, p(x))
all(lam x, p(x)) # all is basically treated as a lambda.

def pmatch(t : z3.ExprRef, vs : list[z3.ExprRef], pat : z3.ExprRef) -> z3.ExprRef:
    if len(vs) == 0:
        return t
    else:
        return z3.simplify(z3.substitute(t, list(zip(vs, [z3.FreshReal('v') for _ in vs]))))

def pmatch(vs, pat, t):
    subst = {}
    def worker(pat, t):
        if any(pat.eq(v) for v in vs):
            if pat in subst:
                return subst[pat].eq(t)
            else:
                subst[pat] = t
                return True
        if is_app(pat):
            if is_app(t) and pat.decl() == t.decl():
                return all(worker(pat.arg(i), t.arg(i)) for i in range(pat.num_args()))
            return False
    if worker(pat, t):
        return subst