kdrag.rewrite

Utilities for rewriting and simplification including pattern matching and unification.

Functions

apply(goal, vs, head, body)

backward_rule(r, tgt)

Apply a rule to a target term.

beta(e)

Do one pass of beta normalization.

decl_index(rules)

Build a dictionary of rules indexed by their lhs head function declaration.

def_eq(e1, e2[, trace])

A notion of computational equality.

forward_rule(r, tgt)

Apply a rule to a target term.

full_simp(e[, trace])

Fully simplify using definitions and built in z3 simplifier until no progress is made.

kbo(vs, t1, t2)

Knuth Bendix Ordering, naive implementation.

lpo(vs, t1, t2)

Lexicographic path ordering.

rewrite(t, rules[, trace])

Sweep through term once performing rewrites.

rewrite1(t, vs, lhs, rhs)

Rewrite at root a single time.

rewrite1_rule(t, rule[, trace])

Rewrite at root a single time.

rewrite_star(t, rules[, trace])

Repeat rewrite until no more rewrites are possible.

rule_of_expr(pf_or_thm)

Unpack theorem of form forall vs, body => head into a Rule tuple

rule_of_theorem(thm)

Unpack theorem of form forall vs, lhs = rhs into a Rule tuple

simp(e[, trace, max_iter])

Simplify using definitions and built in z3 simplifier until no progress is made.

simp1(t)

simplify a term using z3 built in simplifier

simp2(t)

simplify a term using z3 built in simplifier

unfold(e[, decls, trace])

Do a single unfold sweep, unfolding definitions defined by kd.define.

Classes

Order(*values)

RewriteRule(vs, lhs, rhs)

A rewrite rule tuple

Rule(vs, hyp, conc, pf)

Exceptions

RewriteRuleException

class kdrag.rewrite.Order(*values)

Bases: Enum

EQ = 0
GR = 1
NGE = 2
classmethod __contains__(value)

Return True if value is in cls.

value is in cls if: 1) value is a member of cls, or 2) value is the value of one of the cls’s members.

classmethod __getitem__(name)

Return the member matching name.

classmethod __iter__()

Return members in definition order.

classmethod __len__()

Return the number of members (no aliases)

class kdrag.rewrite.RewriteRule(vs: list[ExprRef], lhs: ExprRef, rhs: ExprRef)

Bases: NamedTuple

A rewrite rule tuple

Parameters:
  • vs (list[ExprRef])

  • lhs (ExprRef)

  • rhs (ExprRef)

count(value, /)

Return number of occurrences of value.

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

lhs: ExprRef

Alias for field number 1

rhs: ExprRef

Alias for field number 2

vs: list[ExprRef]

Alias for field number 0

exception kdrag.rewrite.RewriteRuleException

Bases: Exception

add_note(object, /)

Exception.add_note(note) – add a note to the exception

args
with_traceback(object, /)

Exception.with_traceback(tb) – set self.__traceback__ to tb and return self.

class kdrag.rewrite.Rule(vs, hyp, conc, pf)

Bases: NamedTuple

Parameters:
  • vs (list[ExprRef])

  • hyp (BoolRef)

  • conc (BoolRef)

  • pf (Proof | None)

conc: BoolRef

Alias for field number 2

count(value, /)

Return number of occurrences of value.

hyp: BoolRef

Alias for field number 1

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

pf: Proof | None

Alias for field number 3

vs: list[ExprRef]

Alias for field number 0

kdrag.rewrite.apply(goal: BoolRef, vs: list[ExprRef], head: BoolRef, body: BoolRef) BoolRef | None
Parameters:
  • goal (BoolRef)

  • vs (list[ExprRef])

  • head (BoolRef)

  • body (BoolRef)

Return type:

BoolRef | None

kdrag.rewrite.backward_rule(r: Rule, tgt: BoolRef) tuple[dict[ExprRef, ExprRef], BoolRef] | None

Apply a rule to a target term.

Parameters:
  • r (Rule)

  • tgt (BoolRef)

Return type:

tuple[dict[ExprRef, ExprRef], BoolRef] | None

kdrag.rewrite.beta(e)

Do one pass of beta normalization.

>>> x = smt.Int("x")
>>> y = smt.String("y")
>>> f = smt.Function("f", smt.IntSort(), smt.IntSort())
>>> beta(f(x))
f(x)
>>> beta(f(smt.Lambda([x], f(x))[1]))
f(f(1))
>>> beta(f(smt.Select(smt.Lambda([x,y], x), 1, smt.StringVal("fred"))))
f(1)
kdrag.rewrite.decl_index(rules: list[RewriteRule]) dict[FuncDeclRef, RewriteRule]

Build a dictionary of rules indexed by their lhs head function declaration.

Parameters:

rules (list[RewriteRule])

Return type:

dict[FuncDeclRef, RewriteRule]

kdrag.rewrite.def_eq(e1: ExprRef, e2: ExprRef, trace=None) bool

A notion of computational equality. Unfold and simp.

>>> import kdrag.theories.nat as nat
>>> def_eq(nat.one + nat.one, nat.S(nat.S(nat.Z)))
True
Parameters:
  • e1 (ExprRef)

  • e2 (ExprRef)

Return type:

bool

kdrag.rewrite.forward_rule(r: Rule, tgt: BoolRef)

Apply a rule to a target term.

Parameters:
  • r (Rule)

  • tgt (BoolRef)

kdrag.rewrite.full_simp(e: ExprRef, trace=None) ExprRef

Fully simplify using definitions and built in z3 simplifier until no progress is made.

>>> import kdrag.theories.nat as nat
>>> full_simp(nat.one + nat.one + nat.S(nat.one))
S(S(S(S(Z))))
>>> p = smt.Bool("p")
>>> full_simp(smt.If(p, 42, 3))
If(p, 42, 3)
Parameters:

e (ExprRef)

Return type:

ExprRef

kdrag.rewrite.kbo(vs: list[ExprRef], t1: ExprRef, t2: ExprRef) Order

Knuth Bendix Ordering, naive implementation. All weights are 1. Source: Term Rewriting and All That section 5.4.4

Parameters:
  • vs (list[ExprRef])

  • t1 (ExprRef)

  • t2 (ExprRef)

Return type:

Order

kdrag.rewrite.lpo(vs: list[ExprRef], t1: ExprRef, t2: ExprRef) Order

Lexicographic path ordering. Based on https://www21.in.tum.de/~nipkow/TRaAT/programs/termorders.ML TODO add ordering parameter.

Parameters:
  • vs (list[ExprRef])

  • t1 (ExprRef)

  • t2 (ExprRef)

Return type:

Order

kdrag.rewrite.rewrite(t: ExprRef, rules: list[RewriteRule], trace=None) ExprRef

Sweep through term once performing rewrites.

>>> x = smt.Real("x")
>>> rule = RewriteRule([x], x**2, x*x)
>>> rewrite((x**2)**2, [rule])
x*x*x*x
Parameters:
Return type:

ExprRef

kdrag.rewrite.rewrite1(t: ExprRef, vs: list[ExprRef], lhs: ExprRef, rhs: ExprRef) ExprRef | None

Rewrite at root a single time.

Parameters:
  • t (ExprRef)

  • vs (list[ExprRef])

  • lhs (ExprRef)

  • rhs (ExprRef)

Return type:

ExprRef | None

kdrag.rewrite.rewrite1_rule(t: ExprRef, rule: RewriteRule, trace: list[tuple[RewriteRule, dict[ExprRef, ExprRef]]] | None = None) ExprRef | None

Rewrite at root a single time.

Parameters:
Return type:

ExprRef | None

kdrag.rewrite.rewrite_star(t: ExprRef, rules: list[RewriteRule], trace=None) ExprRef

Repeat rewrite until no more rewrites are possible.

Parameters:
Return type:

ExprRef

kdrag.rewrite.rule_of_expr(pf_or_thm: ExprRef | Proof) Rule

Unpack theorem of form forall vs, body => head into a Rule tuple

>>> x = smt.Real("x")
>>> rule_of_expr(smt.ForAll([x], smt.Implies(x**2 == x*x, x > 0)))
Rule(vs=[X...], hyp=X...**2 == X...*X..., conc=X... > 0, pf=None)
>>> rule_of_expr(x > 0)
Rule(vs=[], hyp=True, conc=x > 0, pf=None)
Parameters:

pf_or_thm (ExprRef | Proof)

Return type:

Rule

kdrag.rewrite.rule_of_theorem(thm: BoolRef | QuantifierRef) RewriteRule

Unpack theorem of form forall vs, lhs = rhs into a Rule tuple

>>> x = smt.Real("x")
>>> rule_of_theorem(smt.ForAll([x], x**2 == x*x))
RewriteRule(vs=[X...], lhs=X...**2, rhs=X...*X...)
Parameters:

thm (BoolRef | QuantifierRef)

Return type:

RewriteRule

kdrag.rewrite.simp(e: ExprRef, trace=None, max_iter=3) ExprRef

Simplify using definitions and built in z3 simplifier until no progress is made.

>>> import kdrag.theories.nat as nat
>>> simp(nat.one + nat.one + nat.S(nat.one))
S(S(S(S(Z))))
>>> p = smt.Bool("p")
>>> simp(smt.If(p, 42, 3))
If(p, 42, 3)
Parameters:

e (ExprRef)

Return type:

ExprRef

kdrag.rewrite.simp1(t: ExprRef) ExprRef

simplify a term using z3 built in simplifier

Parameters:

t (ExprRef)

Return type:

ExprRef

kdrag.rewrite.simp2(t: ExprRef) ExprRef

simplify a term using z3 built in simplifier

Parameters:

t (ExprRef)

Return type:

ExprRef

kdrag.rewrite.unfold(e: ExprRef, decls=None, trace=None) ExprRef

Do a single unfold sweep, unfolding definitions defined by kd.define. The optional trace parameter will record proof along the way. decls is an optional list of declarations to unfold. If None, all definitions are unfolded.

>>> x = smt.Int("x")
>>> f = kd.define("f", [x], x + 42)
>>> trace = []
>>> unfold(f(1), trace=trace)
1 + 42
>>> trace
[|- f(1) == 1 + 42, |- ForAll(x, f(x) == x + 42)]
>>> unfold(smt.Lambda([x], f(x)))
Lambda(x, x + 42)
Parameters:

e (ExprRef)

Return type:

ExprRef

"""
Utilities for rewriting and simplification including pattern matching and unification.
"""

import kdrag.smt as smt
import kdrag as kd
from enum import Enum
from typing import NamedTuple, Optional
import kdrag.utils as utils


def simp1(t: smt.ExprRef) -> smt.ExprRef:
    """simplify a term using z3 built in simplifier"""
    expr = smt.FreshConst(t.sort(), prefix="knuckle_goal")
    G = smt.Goal()
    for v in kd.kernel.defns.values():
        G.add(v.ax.thm)
    G.add(expr == t)
    G2 = smt.Then(smt.Tactic("demodulator"), smt.Tactic("simplify")).apply(G)[0]
    # TODO make this extraction more robust
    return G2[len(G2) - 1].children()[1]


def simp2(t: smt.ExprRef) -> smt.ExprRef:
    """simplify a term using z3 built in simplifier"""
    expr = smt.FreshConst(t.sort(), prefix="knuckle_goal")
    G = smt.Goal()
    for v in kd.kernel.defns.values():
        G.add(v.ax.thm)
    G.add(expr == t)
    G2 = smt.Tactic("elim-predicates").apply(G)[0]
    return G2[len(G2) - 1].children()[1]


# TODO: Doesn't seem to do anything?
# def factor(t: smt.ExprRef) -> smt.ExprRef:
#    """factor a term using z3 built in tactic"""
#    expr = smt.FreshConst(t.sort(), prefix="knuckle_goal")
#    G = smt.Goal()
#    for v in kd.kernel.defns.values():
#        G.add(v.ax.thm)
#    G.add(expr == t)
#    G2 = smt.Tactic("factor").apply(G)[0]
#    return G2[len(G2) - 1].children()[1]


def unfold(e: smt.ExprRef, decls=None, trace=None) -> smt.ExprRef:
    """
    Do a single unfold sweep, unfolding definitions defined by `kd.define`.
    The optional trace parameter will record proof along the way.
    `decls` is an optional list of declarations to unfold. If None, all definitions are unfolded.

    >>> x = smt.Int("x")
    >>> f = kd.define("f", [x], x + 42)
    >>> trace = []
    >>> unfold(f(1), trace=trace)
    1 + 42
    >>> trace
    [|- f(1) == 1 + 42, |- ForAll(x, f(x) == x + 42)]

    >>> unfold(smt.Lambda([x], f(x)))
    Lambda(x, x + 42)
    """
    if smt.is_app(e):
        decl = e.decl()
        children = [unfold(c, decls=decls, trace=trace) for c in e.children()]
        defn = kd.kernel.defns.get(decl)
        if defn is not None and (decls is None or decl in decls):
            e1 = smt.substitute(defn.body, *zip(defn.args, children))
            e = e1
            if trace is not None:
                if isinstance(defn.ax.thm, smt.QuantifierRef):
                    trace.append((defn.ax(*children)))
                # else:
                # TODO: It would be better to only emit the forall axiom when necessary.
                # When we have traversed under a binder, things get complicated.
                # instantiating the axiom on the fresh variable will not be useful.
                # so just emitting the needed axiom is a hack for now.
                trace.append(defn.ax)
            return e1
        else:
            return decl(*children)
    elif isinstance(e, smt.QuantifierRef):
        vs, e1 = kd.utils.open_binder_unhygienic(e)
        if e.is_forall():
            # TODO: When we go under a quantifier, any trace breadcrumb should really be re-quantified in order to be useful.
            # maybe herbrandize?
            return smt.ForAll(vs, unfold(e1, decls=decls, trace=trace))
        elif e.is_exists():
            return smt.Exists(vs, unfold(e1, decls=decls, trace=trace))
        elif e.is_lambda():
            return smt.Lambda(vs, unfold(e1, decls=decls, trace=trace))
        else:
            raise Exception("Unexpected quantifier", e)
    else:
        return e


def beta(e):
    """
    Do one pass of beta normalization.

    >>> x = smt.Int("x")
    >>> y = smt.String("y")
    >>> f = smt.Function("f", smt.IntSort(), smt.IntSort())
    >>> beta(f(x))
    f(x)
    >>> beta(f(smt.Lambda([x], f(x))[1]))
    f(f(1))
    >>> beta(f(smt.Select(smt.Lambda([x,y], x), 1, smt.StringVal("fred"))))
    f(1)
    """
    if (
        smt.is_select(e)
        and isinstance(e.arg(0), smt.QuantifierRef)
        and e.arg(0).is_lambda()
    ):
        args = [beta(c) for c in e.children()[1:]]
        f = e.arg(0)
        return smt.substitute_vars(f.body(), *reversed(args))
    elif smt.is_app(e):
        decl = e.decl()
        children = [beta(c) for c in e.children()]
        return decl(*children)
    elif isinstance(e, smt.QuantifierRef):
        vs, e1 = kd.utils.open_binder_unhygienic(e)
        if e.is_forall():
            return smt.ForAll(vs, beta(e1))
        elif e.is_exists():
            return smt.Exists(vs, beta(e1))
        elif e.is_lambda():
            return smt.Lambda(vs, beta(e1))
        else:
            raise Exception("Unexpected quantifier", e)
    else:
        raise Exception("Unexpected term", e)


def full_simp(e: smt.ExprRef, trace=None) -> smt.ExprRef:
    """
    Fully simplify using definitions and built in z3 simplifier until no progress is made.

    >>> import kdrag.theories.nat as nat
    >>> full_simp(nat.one + nat.one + nat.S(nat.one))
    S(S(S(S(Z))))

    >>> p = smt.Bool("p")
    >>> full_simp(smt.If(p, 42, 3))
    If(p, 42, 3)
    """
    while True:
        e = unfold(e, trace=trace)
        e1 = smt.simplify(e)
        if e1.eq(e):
            return e
        else:
            e = e1


def simp(e: smt.ExprRef, trace=None, max_iter=3) -> smt.ExprRef:
    """
    Simplify using definitions and built in z3 simplifier until no progress is made.

    >>> import kdrag.theories.nat as nat
    >>> simp(nat.one + nat.one + nat.S(nat.one))
    S(S(S(S(Z))))

    >>> p = smt.Bool("p")
    >>> simp(smt.If(p, 42, 3))
    If(p, 42, 3)
    """
    i = 0
    ebest = e
    bestsize = len(e.sexpr())
    while True:
        i += 1
        if max_iter is not None and i > max_iter:
            return ebest
        e = unfold(e, trace=trace)
        if (newsize := len(e.sexpr())) < bestsize:
            ebest = e
            bestsize = newsize
        # TODO: Interesting options: som, sort_store, elim_ite, flat, split_concat_eq, sort_sums, sort_disjunctions
        e1 = smt.simplify(e)
        if (newsize := len(e1.sexpr())) < bestsize:
            ebest = e1
            bestsize = newsize
        if e1.eq(e):
            return ebest
        else:
            if trace is not None:
                trace.append(kd.kernel.prove(smt.Eq(e, e1)))
            e = e1


def def_eq(e1: smt.ExprRef, e2: smt.ExprRef, trace=None) -> bool:
    """
    A notion of computational equality. Unfold and simp.

    >>> import kdrag.theories.nat as nat
    >>> def_eq(nat.one + nat.one, nat.S(nat.S(nat.Z)))
    True
    """
    e1 = simp(e1, trace=trace)
    e2 = simp(e2, trace=trace)
    return kd.utils.alpha_eq(e1, e2)
    """
    TODO: But we can have early stopping if we do these processes interleaved.
    while not e1.eq(e2):
        e1 = unfold(e1, trace=trace)
        e2 = unfold(e2, trace=trace)
    """


def rewrite1(
    t: smt.ExprRef, vs: list[smt.ExprRef], lhs: smt.ExprRef, rhs: smt.ExprRef
) -> Optional[smt.ExprRef]:
    """
    Rewrite at root a single time.
    """
    subst = utils.pmatch(vs, lhs, t)
    if subst is not None:
        return smt.substitute(rhs, *subst.items())
    return None


def apply(
    goal: smt.BoolRef, vs: list[smt.ExprRef], head: smt.BoolRef, body: smt.BoolRef
) -> Optional[smt.BoolRef]:
    res = rewrite1(goal, vs, head, body)
    assert res is None or isinstance(res, smt.BoolRef)
    return res


class RewriteRule(NamedTuple):
    """A rewrite rule tuple"""

    vs: list[smt.ExprRef]
    lhs: smt.ExprRef
    rhs: smt.ExprRef


def rewrite1_rule(
    t: smt.ExprRef,
    rule: RewriteRule,
    trace: Optional[list[tuple[RewriteRule, dict[smt.ExprRef, smt.ExprRef]]]] = None,
) -> Optional[smt.ExprRef]:
    """
    Rewrite at root a single time.
    """
    subst = utils.pmatch(rule.vs, rule.lhs, t)
    if subst is not None:
        return smt.substitute(rule.rhs, *subst.items())
        if trace is not None:
            trace.append((rule, subst))
    return None


def rewrite(t: smt.ExprRef, rules: list[RewriteRule], trace=None) -> smt.ExprRef:
    """
    Sweep through term once performing rewrites.

    >>> x = smt.Real("x")
    >>> rule = RewriteRule([x], x**2, x*x)
    >>> rewrite((x**2)**2, [rule])
    x*x*x*x
    """
    if smt.is_app(t):
        t = t.decl()(*[rewrite(arg, rules) for arg in t.children()])  # rewrite children
        for rule in rules:
            res = rewrite1_rule(t, rule, trace=trace)
            if res is not None:
                t = res
    return t


class RewriteRuleException(Exception): ...


def rule_of_theorem(thm: smt.BoolRef | smt.QuantifierRef) -> RewriteRule:
    """
    Unpack theorem of form `forall vs, lhs = rhs` into a Rule tuple

    >>> x = smt.Real("x")
    >>> rule_of_theorem(smt.ForAll([x], x**2 == x*x))
    RewriteRule(vs=[X...], lhs=X...**2, rhs=X...*X...)
    """
    vs = []
    thm1 = thm  # to help out pyright
    while isinstance(thm1, smt.QuantifierRef):
        if thm1.is_forall():
            vs1, thm1 = utils.open_binder(thm1)
            vs.extend(vs1)
        else:
            raise RewriteRuleException("Not a universal quantifier", thm1)
    if not smt.is_eq(thm1):
        raise RewriteRuleException("Not an equation", thm)
    lhs, rhs = thm1.children()
    return RewriteRule(vs, lhs, rhs)


def decl_index(rules: list[RewriteRule]) -> dict[smt.FuncDeclRef, RewriteRule]:
    """Build a dictionary of rules indexed by their lhs head function declaration."""
    return {rule.lhs.decl(): rule for rule in rules}


def rewrite_star(t: smt.ExprRef, rules: list[RewriteRule], trace=None) -> smt.ExprRef:
    """
    Repeat rewrite until no more rewrites are possible.
    """
    while True:
        t1 = rewrite(t, rules, trace=trace)
        if t1.eq(t):
            return t1
        t = t1


class Rule(NamedTuple):
    vs: list[smt.ExprRef]
    hyp: smt.BoolRef
    conc: smt.BoolRef
    pf: Optional[kd.kernel.Proof] = None


def rule_of_expr(pf_or_thm: smt.ExprRef | kd.kernel.Proof) -> Rule:
    """Unpack theorem of form `forall vs, body => head` into a Rule tuple

    >>> x = smt.Real("x")
    >>> rule_of_expr(smt.ForAll([x], smt.Implies(x**2 == x*x, x > 0)))
    Rule(vs=[X...], hyp=X...**2 == X...*X..., conc=X... > 0, pf=None)
    >>> rule_of_expr(x > 0)
    Rule(vs=[], hyp=True, conc=x > 0, pf=None)
    """
    if isinstance(pf_or_thm, smt.ExprRef):
        thm = pf_or_thm
        pf = None
    elif kd.kernel.is_proof(pf_or_thm):
        pf = pf_or_thm
        thm = pf.thm
    else:
        raise ValueError("Expected proof or theorem")
    if isinstance(thm, smt.QuantifierRef) and thm.is_forall():
        vs, thm = utils.open_binder(thm)
    else:
        vs = []
    if smt.is_implies(thm):
        return Rule(vs, hyp=thm.arg(0), conc=thm.arg(1), pf=pf)
    else:
        assert isinstance(thm, smt.BoolRef)
        return Rule(vs, hyp=smt.BoolVal(True), conc=thm, pf=pf)


def backward_rule(
    r: Rule, tgt: smt.BoolRef
) -> Optional[tuple[dict[smt.ExprRef, smt.ExprRef], smt.BoolRef]]:
    """
    Apply a rule to a target term.
    """
    subst = kd.utils.pmatch(r.vs, r.conc, tgt)
    if subst is not None:
        return subst, smt.substitute(r.hyp, *subst.items())
    else:
        return None


def forward_rule(r: Rule, tgt: smt.BoolRef):
    """
    Apply a rule to a target term.
    """
    subst = kd.utils.pmatch(r.vs, r.hyp, tgt)
    if subst is not None:
        return smt.substitute(r.conc, *subst.items())
    else:
        return None


"""
def apply_horn(thm: smt.BoolRef, horn: smt.BoolRef) -> smt.BoolRef:
    pat = horn
    obl = []
    if smt.is_quantifier(pat) and pat.is_forall():
        pat = pat.body()
    while True:
        if smt.is_implies(pat):
            obl.append(pat.arg(0))
            pat = pat.arg(1)
        else:
            break
    return kd.utils.z3_match(thm, pat)


def horn_split(horn: smt.BoolRef) -> smt.BoolRef:
    body = []
    vs = []
    while True:
        if smt.is_quantifier(horn) and horn.is_forall():
            vs1, horn = open_binder(horn)
            vs.extend(vs1)
        if smt.is_implies(horn):
            body.append(horn.arg(0))
            horn = horn.arg(1)
        else:
            break
    head = horn
    return vs, body, head
"""


class Order(Enum):
    EQ = 0  # Equal
    GR = 1  # Greater
    NGE = 2  # Not Greater or Equal


def lpo(vs: list[smt.ExprRef], t1: smt.ExprRef, t2: smt.ExprRef) -> Order:
    """
    Lexicographic path ordering.
    Based on https://www21.in.tum.de/~nipkow/TRaAT/programs/termorders.ML
    TODO add ordering parameter.
    """

    def is_var(x):
        return any(x.eq(v) for v in vs)

    if is_var(t2):
        if t1.eq(t2):
            return Order.EQ
        elif utils.is_subterm(t2, t1):
            return Order.GR
        else:
            return Order.NGE
    elif is_var(t1):
        return Order.NGE
    elif smt.is_app(t1) and smt.is_app(t2):
        decl1, decl2 = t1.decl(), t2.decl()
        args1, args2 = t1.children(), t2.children()
        if all(lpo(vs, a, t2) == Order.NGE for a in args1):
            if decl1 == decl2:
                if all(lpo(vs, t1, a) == Order.GR for a in args2):
                    for a1, a2 in zip(args1, args2):
                        ord = lpo(vs, a1, a2)
                        if ord == Order.GR:
                            return Order.GR
                        elif ord == Order.NGE:
                            return Order.NGE
                    return Order.EQ
                else:
                    return Order.NGE
            elif (decl1.name(), decl1.get_id()) > (decl2.name(), decl2.get_id()):
                if all(lpo(vs, t1, a) == Order.GR for a in args2):
                    return Order.GR
                else:
                    return Order.NGE

            else:
                return Order.NGE
        else:
            return Order.GR
    else:
        raise Exception("Unexpected terms in lpo", t1, t2)


def kbo(vs: list[smt.ExprRef], t1: smt.ExprRef, t2: smt.ExprRef) -> Order:
    """
    Knuth Bendix Ordering, naive implementation.
    All weights are 1.
    Source: Term Rewriting and All That section 5.4.4
    """
    if t1.eq(t2):
        return Order.EQ

    def is_var(x):
        return any(x.eq(v) for v in vs)

    def vcount(t):
        todo = [t]
        vcount1 = {v: 0 for v in vs}
        while todo:
            t = todo.pop()
            if is_var(t):
                vcount1[t] += 1
            elif smt.is_app(t):
                todo.extend(t.children())
        return vcount1

    vcount1, vcount2 = vcount(t1), vcount(t2)
    if not all(vcount1[v] >= vcount2[v] for v in vs):
        return Order.NGE

    def weight(t):
        todo = [t]
        w = 0
        while todo:
            t = todo.pop()
            w += 1
            if smt.is_app(t):
                todo.extend(t.children())
        return w

    w1, w2 = weight(t1), weight(t2)
    if w1 > w2:
        return Order.GR
    elif w1 < w2:
        return Order.NGE
    else:
        if is_var(t2):  # KBO2a
            decl = t1.decl()
            if decl.arity() != 1:
                return Order.NGE
            while not t1.eq(t2):
                if t1.decl() != decl:
                    return Order.NGE
                else:
                    t1 = t1.arg(0)
            return Order.GR
        elif is_var(t1):
            return Order.NGE
        elif smt.is_app(t1) and smt.is_app(t2):
            decl1, decl2 = t1.decl(), t2.decl()
            if decl1 == decl2:  # KBO2c
                args1, args2 = t1.children(), t2.children()
                for a1, a2 in zip(args1, args2):
                    ord = kbo(vs, a1, a2)
                    if ord == Order.GR:
                        return Order.GR
                    elif ord == Order.NGE:
                        return Order.NGE
                raise Exception("Unexpected equality reached in kbo")
            elif (decl1.name(), decl1.get_id()) > (
                decl2.name(),
                decl2.get_id(),
            ):  # KBO2b
                return Order.GR
            else:
                return Order.NGE
        else:
            raise Exception("Unexpected terms in kbo", t1, t2)