kdrag.rewrite

Utilities for rewriting and simplification including pattern matching and unification.

Functions

apply(goal, vs, head, body)

decl_index(rules)

Build a dictionary of rules indexed by their lhs head function declaration.

def_eq(e1, e2[, trace])

A notion of computational equality.

horn_of_theorem(thm)

Unpack theorem of form forall vs, body => head into a Rule tuple

kbo(vs, t1, t2)

Knuth Bendix Ordering, naive implementation.

lpo(vs, t1, t2)

Lexicographic path ordering.

rewrite(t, rules[, trace])

Sweep through term once performing rewrites.

rewrite1(t, vs, lhs, rhs)

Rewrite at root a single time.

rewrite1_rule(t, rule[, trace])

Rewrite at root a single time.

rewrite_star(t, rules[, trace])

Repeat rewrite until no more rewrites are possible.

rule_of_theorem(thm)

Unpack theorem of form forall vs, lhs = rhs into a Rule tuple

simp(e[, trace, max_iter])

Simplify using definitions and built in z3 simplifier until no progress is made.

simp1(t)

simplify a term using z3 built in simplifier

simp2(t)

simplify a term using z3 built in simplifier

unfold(e[, decls, trace])

Do a single unfold sweep, unfolding definitions defined by kd.define.

Classes

Order(value[, names, module, qualname, ...])

RewriteRule(vs, lhs, rhs)

A rewrite rule tuple

Rule(vs, hyp, conc, pf)

Exceptions

RewriteRuleException

class kdrag.rewrite.Order(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

Bases: Enum

EQ = 0
GR = 1
NGE = 2
classmethod __contains__(value)

Return True if value is in cls.

value is in cls if: 1) value is a member of cls, or 2) value is the value of one of the cls’s members.

classmethod __getitem__(name)

Return the member matching name.

classmethod __iter__()

Return members in definition order.

classmethod __len__()

Return the number of members (no aliases)

class kdrag.rewrite.RewriteRule(vs: list[ExprRef], lhs: ExprRef, rhs: ExprRef)

Bases: NamedTuple

A rewrite rule tuple

Parameters:
  • vs (list[ExprRef])

  • lhs (ExprRef)

  • rhs (ExprRef)

count(value, /)

Return number of occurrences of value.

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

lhs: ExprRef

Alias for field number 1

rhs: ExprRef

Alias for field number 2

vs: list[ExprRef]

Alias for field number 0

exception kdrag.rewrite.RewriteRuleException

Bases: Exception

add_note()

Exception.add_note(note) – add a note to the exception

args
with_traceback()

Exception.with_traceback(tb) – set self.__traceback__ to tb and return self.

class kdrag.rewrite.Rule(vs, hyp, conc, pf)

Bases: NamedTuple

Parameters:
  • vs (list[ExprRef])

  • hyp (BoolRef)

  • conc (BoolRef)

  • pf (Proof | None)

conc: BoolRef

Alias for field number 2

count(value, /)

Return number of occurrences of value.

hyp: BoolRef

Alias for field number 1

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

pf: Proof | None

Alias for field number 3

vs: list[ExprRef]

Alias for field number 0

kdrag.rewrite.apply(goal: BoolRef, vs: list[ExprRef], head: BoolRef, body: BoolRef) BoolRef | None
Parameters:
  • goal (BoolRef)

  • vs (list[ExprRef])

  • head (BoolRef)

  • body (BoolRef)

Return type:

BoolRef | None

kdrag.rewrite.decl_index(rules: list[RewriteRule]) dict[FuncDeclRef, RewriteRule]

Build a dictionary of rules indexed by their lhs head function declaration.

Parameters:

rules (list[RewriteRule])

Return type:

dict[FuncDeclRef, RewriteRule]

kdrag.rewrite.def_eq(e1: ExprRef, e2: ExprRef, trace=None) bool

A notion of computational equality. Unfold and simp.

>>> import kdrag.theories.nat as nat
>>> def_eq(nat.one + nat.one, nat.S(nat.S(nat.Z)))
True
Parameters:
  • e1 (ExprRef)

  • e2 (ExprRef)

Return type:

bool

kdrag.rewrite.horn_of_theorem(thm: ExprRef | Proof) Rule

Unpack theorem of form forall vs, body => head into a Rule tuple

>>> x = smt.Real("x")
>>> horn_of_theorem(smt.ForAll([x], smt.Implies(x**2 == x*x, x > 0)))
Rule(vs=[X...], hyp=X...**2 == X...*X..., conc=X... > 0, pf=None)
Parameters:

thm (ExprRef | Proof)

Return type:

Rule

kdrag.rewrite.kbo(vs: list[ExprRef], t1: ExprRef, t2: ExprRef) Order

Knuth Bendix Ordering, naive implementation. All weights are 1. Source: Term Rewriting and All That section 5.4.4

Parameters:
  • vs (list[ExprRef])

  • t1 (ExprRef)

  • t2 (ExprRef)

Return type:

Order

kdrag.rewrite.lpo(vs: list[ExprRef], t1: ExprRef, t2: ExprRef) Order

Lexicographic path ordering. Based on https://www21.in.tum.de/~nipkow/TRaAT/programs/termorders.ML TODO add ordering parameter.

Parameters:
  • vs (list[ExprRef])

  • t1 (ExprRef)

  • t2 (ExprRef)

Return type:

Order

kdrag.rewrite.rewrite(t: ExprRef, rules: list[RewriteRule], trace=None) ExprRef

Sweep through term once performing rewrites.

>>> x = smt.Real("x")
>>> rule = RewriteRule([x], x**2, x*x)
>>> rewrite((x**2)**2, [rule])
x*x*x*x
Parameters:
Return type:

ExprRef

kdrag.rewrite.rewrite1(t: ExprRef, vs: list[ExprRef], lhs: ExprRef, rhs: ExprRef) ExprRef | None

Rewrite at root a single time.

Parameters:
  • t (ExprRef)

  • vs (list[ExprRef])

  • lhs (ExprRef)

  • rhs (ExprRef)

Return type:

ExprRef | None

kdrag.rewrite.rewrite1_rule(t: ExprRef, rule: RewriteRule, trace: list[tuple[RewriteRule, dict[ExprRef, ExprRef]]] | None = None) ExprRef | None

Rewrite at root a single time.

Parameters:
Return type:

ExprRef | None

kdrag.rewrite.rewrite_star(t: ExprRef, rules: list[RewriteRule], trace=None) ExprRef

Repeat rewrite until no more rewrites are possible.

Parameters:
Return type:

ExprRef

kdrag.rewrite.rule_of_theorem(thm: BoolRef | QuantifierRef) RewriteRule

Unpack theorem of form forall vs, lhs = rhs into a Rule tuple

>>> x = smt.Real("x")
>>> rule_of_theorem(smt.ForAll([x], x**2 == x*x))
RewriteRule(vs=[X...], lhs=X...**2, rhs=X...*X...)
Parameters:

thm (BoolRef | QuantifierRef)

Return type:

RewriteRule

kdrag.rewrite.simp(e: ExprRef, trace=None, max_iter=None) ExprRef

Simplify using definitions and built in z3 simplifier until no progress is made.

>>> import kdrag.theories.nat as nat
>>> simp(nat.one + nat.one + nat.S(nat.one))
S(S(S(S(Z))))
>>> p = smt.Bool("p")
>>> simp(smt.If(p, 42, 3))
If(p, 42, 3)
Parameters:

e (ExprRef)

Return type:

ExprRef

kdrag.rewrite.simp1(t: ExprRef) ExprRef

simplify a term using z3 built in simplifier

Parameters:

t (ExprRef)

Return type:

ExprRef

kdrag.rewrite.simp2(t: ExprRef) ExprRef

simplify a term using z3 built in simplifier

Parameters:

t (ExprRef)

Return type:

ExprRef

kdrag.rewrite.unfold(e: ExprRef, decls=None, trace=None) ExprRef

Do a single unfold sweep, unfolding definitions defined by kd.define. The optional trace parameter will record proof along the way. decls is an optional list of declarations to unfold. If None, all definitions are unfolded.

>>> x = smt.Int("x")
>>> f = kd.define("f", [x], x + 42)
>>> trace = []
>>> unfold(f(1), trace=trace)
1 + 42
>>> trace
[|- f(1) == 1 + 42]
Parameters:

e (ExprRef)

Return type:

ExprRef

"""
Utilities for rewriting and simplification including pattern matching and unification.
"""

import kdrag.smt as smt
import kdrag as kd
from enum import Enum
from typing import NamedTuple, Optional
import kdrag.utils as utils


def simp1(t: smt.ExprRef) -> smt.ExprRef:
    """simplify a term using z3 built in simplifier"""
    expr = smt.FreshConst(t.sort(), prefix="knuckle_goal")
    G = smt.Goal()
    for v in kd.kernel.defns.values():
        G.add(v.ax.thm)
    G.add(expr == t)
    G2 = smt.Then(smt.Tactic("demodulator"), smt.Tactic("simplify")).apply(G)[0]
    # TODO make this extraction more robust
    return G2[len(G2) - 1].children()[1]


def simp2(t: smt.ExprRef) -> smt.ExprRef:
    """simplify a term using z3 built in simplifier"""
    expr = smt.FreshConst(t.sort(), prefix="knuckle_goal")
    G = smt.Goal()
    for v in kd.kernel.defns.values():
        G.add(v.ax.thm)
    G.add(expr == t)
    G2 = smt.Tactic("elim-predicates").apply(G)[0]
    return G2[len(G2) - 1].children()[1]


# TODO: Doesn't seem to do anything?
# def factor(t: smt.ExprRef) -> smt.ExprRef:
#    """factor a term using z3 built in tactic"""
#    expr = smt.FreshConst(t.sort(), prefix="knuckle_goal")
#    G = smt.Goal()
#    for v in kd.kernel.defns.values():
#        G.add(v.ax.thm)
#    G.add(expr == t)
#    G2 = smt.Tactic("factor").apply(G)[0]
#    return G2[len(G2) - 1].children()[1]


def unfold(e: smt.ExprRef, decls=None, trace=None) -> smt.ExprRef:
    """
    Do a single unfold sweep, unfolding definitions defined by `kd.define`.
    The optional trace parameter will record proof along the way.
    `decls` is an optional list of declarations to unfold. If None, all definitions are unfolded.

    >>> x = smt.Int("x")
    >>> f = kd.define("f", [x], x + 42)
    >>> trace = []
    >>> unfold(f(1), trace=trace)
    1 + 42
    >>> trace
    [|- f(1) == 1 + 42]
    """
    if smt.is_app(e):
        decl = e.decl()
        children = [unfold(c, decls=decls, trace=trace) for c in e.children()]
        defn = kd.kernel.defns.get(decl)
        if defn is not None and (decls is None or decl in decls):
            e1 = smt.substitute(defn.body, *zip(defn.args, children))
            e = e1
            if trace is not None:
                if isinstance(defn.ax.thm, smt.QuantifierRef):
                    trace.append((defn.ax(*children)))
                else:
                    trace.append(defn.ax)
            return e1
        else:
            return decl(*children)
    else:
        return e


def simp(e: smt.ExprRef, trace=None, max_iter=None) -> smt.ExprRef:
    """
    Simplify using definitions and built in z3 simplifier until no progress is made.

    >>> import kdrag.theories.nat as nat
    >>> simp(nat.one + nat.one + nat.S(nat.one))
    S(S(S(S(Z))))

    >>> p = smt.Bool("p")
    >>> simp(smt.If(p, 42, 3))
    If(p, 42, 3)
    """
    i = 0
    while True:
        i += 1
        if max_iter is not None and i > max_iter:
            return e
        e = unfold(e, trace=trace)
        # TODO: Interesting options: som, sort_store, elim_ite, flat, split_concat_eq, sort_sums, sort_disjunctions
        e1 = smt.simplify(e)
        if e1.eq(e):
            return e1
        else:
            if trace is not None:
                trace.append(kd.kernel.prove(smt.Eq(e, e1)))
            e = e1


def def_eq(e1: smt.ExprRef, e2: smt.ExprRef, trace=None) -> bool:
    """
    A notion of computational equality. Unfold and simp.

    >>> import kdrag.theories.nat as nat
    >>> def_eq(nat.one + nat.one, nat.S(nat.S(nat.Z)))
    True
    """
    e1 = simp(e1, trace=trace)
    e2 = simp(e2, trace=trace)
    return kd.utils.alpha_eq(e1, e2)
    """
    TODO: But we can have early stopping if we do these processes interleaved.
    while not e1.eq(e2):
        e1 = unfold(e1, trace=trace)
        e2 = unfold(e2, trace=trace)
    """


def rewrite1(
    t: smt.ExprRef, vs: list[smt.ExprRef], lhs: smt.ExprRef, rhs: smt.ExprRef
) -> Optional[smt.ExprRef]:
    """
    Rewrite at root a single time.
    """
    subst = utils.pmatch(vs, lhs, t)
    if subst is not None:
        return smt.substitute(rhs, *subst.items())
    return None


def apply(
    goal: smt.BoolRef, vs: list[smt.ExprRef], head: smt.BoolRef, body: smt.BoolRef
) -> Optional[smt.BoolRef]:
    res = rewrite1(goal, vs, head, body)
    assert res is None or isinstance(res, smt.BoolRef)
    return res


class RewriteRule(NamedTuple):
    """A rewrite rule tuple"""

    vs: list[smt.ExprRef]
    lhs: smt.ExprRef
    rhs: smt.ExprRef


def rewrite1_rule(
    t: smt.ExprRef,
    rule: RewriteRule,
    trace: Optional[list[tuple[RewriteRule, dict[smt.ExprRef, smt.ExprRef]]]] = None,
) -> Optional[smt.ExprRef]:
    """
    Rewrite at root a single time.
    """
    subst = utils.pmatch(rule.vs, rule.lhs, t)
    if subst is not None:
        return smt.substitute(rule.rhs, *subst.items())
        if trace is not None:
            trace.append((rule, subst))
    return None


def rewrite(t: smt.ExprRef, rules: list[RewriteRule], trace=None) -> smt.ExprRef:
    """
    Sweep through term once performing rewrites.

    >>> x = smt.Real("x")
    >>> rule = RewriteRule([x], x**2, x*x)
    >>> rewrite((x**2)**2, [rule])
    x*x*x*x
    """
    if smt.is_app(t):
        t = t.decl()(*[rewrite(arg, rules) for arg in t.children()])  # rewrite children
        for rule in rules:
            res = rewrite1_rule(t, rule, trace=trace)
            if res is not None:
                t = res
    return t


class RewriteRuleException(Exception): ...


def rule_of_theorem(thm: smt.BoolRef | smt.QuantifierRef) -> RewriteRule:
    """
    Unpack theorem of form `forall vs, lhs = rhs` into a Rule tuple

    >>> x = smt.Real("x")
    >>> rule_of_theorem(smt.ForAll([x], x**2 == x*x))
    RewriteRule(vs=[X...], lhs=X...**2, rhs=X...*X...)
    """
    vs = []
    thm1 = thm  # to help out pyright
    while isinstance(thm1, smt.QuantifierRef):
        if thm1.is_forall():
            vs1, thm1 = utils.open_binder(thm1)
            vs.extend(vs1)
        else:
            raise RewriteRuleException("Not a universal quantifier", thm1)
    if not smt.is_eq(thm1):
        raise RewriteRuleException("Not an equation", thm)
    lhs, rhs = thm1.children()
    return RewriteRule(vs, lhs, rhs)


def decl_index(rules: list[RewriteRule]) -> dict[smt.FuncDeclRef, RewriteRule]:
    """Build a dictionary of rules indexed by their lhs head function declaration."""
    return {rule.lhs.decl(): rule for rule in rules}


def rewrite_star(t: smt.ExprRef, rules: list[RewriteRule], trace=None) -> smt.ExprRef:
    """
    Repeat rewrite until no more rewrites are possible.
    """
    while True:
        t1 = rewrite(t, rules, trace=trace)
        if t1.eq(t):
            return t1
        t = t1


class Rule(NamedTuple):
    vs: list[smt.ExprRef]
    hyp: smt.BoolRef
    conc: smt.BoolRef
    pf: Optional[kd.kernel.Proof] = None


def horn_of_theorem(thm: smt.ExprRef | kd.kernel.Proof) -> Rule:
    """Unpack theorem of form `forall vs, body => head` into a Rule tuple

    >>> x = smt.Real("x")
    >>> horn_of_theorem(smt.ForAll([x], smt.Implies(x**2 == x*x, x > 0)))
    Rule(vs=[X...], hyp=X...**2 == X...*X..., conc=X... > 0, pf=None)
    """
    pf = None
    if isinstance(thm, smt.ExprRef):
        pass
    elif kd.kernel.is_proof(thm):
        pf = thm
        thm = pf.thm
    if not isinstance(thm, smt.QuantifierRef) or not thm.is_forall():
        raise Exception("Not a universal quantifier", thm)
    vs, thm = utils.open_binder(thm)
    assert smt.is_implies(thm)
    return Rule(vs, hyp=thm.arg(0), conc=thm.arg(1), pf=pf)


"""
def apply_horn(thm: smt.BoolRef, horn: smt.BoolRef) -> smt.BoolRef:
    pat = horn
    obl = []
    if smt.is_quantifier(pat) and pat.is_forall():
        pat = pat.body()
    while True:
        if smt.is_implies(pat):
            obl.append(pat.arg(0))
            pat = pat.arg(1)
        else:
            break
    return kd.utils.z3_match(thm, pat)


def horn_split(horn: smt.BoolRef) -> smt.BoolRef:
    body = []
    vs = []
    while True:
        if smt.is_quantifier(horn) and horn.is_forall():
            vs1, horn = open_binder(horn)
            vs.extend(vs1)
        if smt.is_implies(horn):
            body.append(horn.arg(0))
            horn = horn.arg(1)
        else:
            break
    head = horn
    return vs, body, head
"""


class Order(Enum):
    EQ = 0  # Equal
    GR = 1  # Greater
    NGE = 2  # Not Greater or Equal


def lpo(vs: list[smt.ExprRef], t1: smt.ExprRef, t2: smt.ExprRef) -> Order:
    """
    Lexicographic path ordering.
    Based on https://www21.in.tum.de/~nipkow/TRaAT/programs/termorders.ML
    TODO add ordering parameter.
    """

    def is_var(x):
        return any(x.eq(v) for v in vs)

    if is_var(t2):
        if t1.eq(t2):
            return Order.EQ
        elif utils.is_subterm(t2, t1):
            return Order.GR
        else:
            return Order.NGE
    elif is_var(t1):
        return Order.NGE
    elif smt.is_app(t1) and smt.is_app(t2):
        decl1, decl2 = t1.decl(), t2.decl()
        args1, args2 = t1.children(), t2.children()
        if all(lpo(vs, a, t2) == Order.NGE for a in args1):
            if decl1 == decl2:
                if all(lpo(vs, t1, a) == Order.GR for a in args2):
                    for a1, a2 in zip(args1, args2):
                        ord = lpo(vs, a1, a2)
                        if ord == Order.GR:
                            return Order.GR
                        elif ord == Order.NGE:
                            return Order.NGE
                    return Order.EQ
                else:
                    return Order.NGE
            elif (decl1.name(), decl1.get_id()) > (decl2.name(), decl2.get_id()):
                if all(lpo(vs, t1, a) == Order.GR for a in args2):
                    return Order.GR
                else:
                    return Order.NGE

            else:
                return Order.NGE
        else:
            return Order.GR
    else:
        raise Exception("Unexpected terms in lpo", t1, t2)


def kbo(vs: list[smt.ExprRef], t1: smt.ExprRef, t2: smt.ExprRef) -> Order:
    """
    Knuth Bendix Ordering, naive implementation.
    All weights are 1.
    Source: Term Rewriting and All That section 5.4.4
    """
    if t1.eq(t2):
        return Order.EQ

    def is_var(x):
        return any(x.eq(v) for v in vs)

    def vcount(t):
        todo = [t]
        vcount1 = {v: 0 for v in vs}
        while todo:
            t = todo.pop()
            if is_var(t):
                vcount1[t] += 1
            elif smt.is_app(t):
                todo.extend(t.children())
        return vcount1

    vcount1, vcount2 = vcount(t1), vcount(t2)
    if not all(vcount1[v] >= vcount2[v] for v in vs):
        return Order.NGE

    def weight(t):
        todo = [t]
        w = 0
        while todo:
            t = todo.pop()
            w += 1
            if smt.is_app(t):
                todo.extend(t.children())
        return w

    w1, w2 = weight(t1), weight(t2)
    if w1 > w2:
        return Order.GR
    elif w1 < w2:
        return Order.NGE
    else:
        if is_var(t2):  # KBO2a
            decl = t1.decl()
            if decl.arity() != 1:
                return Order.NGE
            while not t1.eq(t2):
                if t1.decl() != decl:
                    return Order.NGE
                else:
                    t1 = t1.arg(0)
            return Order.GR
        elif is_var(t1):
            return Order.NGE
        elif smt.is_app(t1) and smt.is_app(t2):
            decl1, decl2 = t1.decl(), t2.decl()
            if decl1 == decl2:  # KBO2c
                args1, args2 = t1.children(), t2.children()
                for a1, a2 in zip(args1, args2):
                    ord = kbo(vs, a1, a2)
                    if ord == Order.GR:
                        return Order.GR
                    elif ord == Order.NGE:
                        return Order.NGE
                raise Exception("Unexpected equality reached in kbo")
            elif (decl1.name(), decl1.get_id()) > (
                decl2.name(),
                decl2.get_id(),
            ):  # KBO2b
                return Order.GR
            else:
                return Order.NGE
        else:
            raise Exception("Unexpected terms in kbo", t1, t2)