kdrag.rewrite
Utilities for rewriting and simplification including pattern matching and unification.
Functions
|
|
|
Build a dictionary of rules indexed by their lhs head function declaration. |
|
A notion of computational equality. |
|
Unpack theorem of form forall vs, body => head into a Rule tuple |
|
Knuth Bendix Ordering, naive implementation. |
|
Lexicographic path ordering. |
|
Sweep through term once performing rewrites. |
|
Rewrite at root a single time. |
|
Rewrite at root a single time. |
|
Repeat rewrite until no more rewrites are possible. |
|
Unpack theorem of form forall vs, lhs = rhs into a Rule tuple |
|
Simplify using definitions and built in z3 simplifier until no progress is made. |
|
simplify a term using z3 built in simplifier |
|
simplify a term using z3 built in simplifier |
|
Do a single unfold sweep, unfolding definitions defined by kd.define. |
Classes
|
|
|
A rewrite rule tuple |
|
Exceptions
- class kdrag.rewrite.Order(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
Bases:
Enum
- EQ = 0
- GR = 1
- NGE = 2
- classmethod __contains__(value)
Return True if value is in cls.
value is in cls if: 1) value is a member of cls, or 2) value is the value of one of the cls’s members.
- classmethod __getitem__(name)
Return the member matching name.
- classmethod __iter__()
Return members in definition order.
- classmethod __len__()
Return the number of members (no aliases)
- class kdrag.rewrite.RewriteRule(vs: list[ExprRef], lhs: ExprRef, rhs: ExprRef)
Bases:
NamedTuple
A rewrite rule tuple
- Parameters:
vs (list[ExprRef])
lhs (ExprRef)
rhs (ExprRef)
- count(value, /)
Return number of occurrences of value.
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- lhs: ExprRef
Alias for field number 1
- rhs: ExprRef
Alias for field number 2
- vs: list[ExprRef]
Alias for field number 0
- exception kdrag.rewrite.RewriteRuleException
Bases:
Exception
- add_note()
Exception.add_note(note) – add a note to the exception
- args
- with_traceback()
Exception.with_traceback(tb) – set self.__traceback__ to tb and return self.
- class kdrag.rewrite.Rule(vs, hyp, conc, pf)
Bases:
NamedTuple
- Parameters:
vs (list[ExprRef])
hyp (BoolRef)
conc (BoolRef)
pf (Proof | None)
- conc: BoolRef
Alias for field number 2
- count(value, /)
Return number of occurrences of value.
- hyp: BoolRef
Alias for field number 1
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- vs: list[ExprRef]
Alias for field number 0
- kdrag.rewrite.apply(goal: BoolRef, vs: list[ExprRef], head: BoolRef, body: BoolRef) BoolRef | None
- Parameters:
goal (BoolRef)
vs (list[ExprRef])
head (BoolRef)
body (BoolRef)
- Return type:
BoolRef | None
- kdrag.rewrite.decl_index(rules: list[RewriteRule]) dict[FuncDeclRef, RewriteRule]
Build a dictionary of rules indexed by their lhs head function declaration.
- Parameters:
rules (list[RewriteRule])
- Return type:
dict[FuncDeclRef, RewriteRule]
- kdrag.rewrite.def_eq(e1: ExprRef, e2: ExprRef, trace=None) bool
A notion of computational equality. Unfold and simp.
>>> import kdrag.theories.nat as nat >>> def_eq(nat.one + nat.one, nat.S(nat.S(nat.Z))) True
- Parameters:
e1 (ExprRef)
e2 (ExprRef)
- Return type:
bool
- kdrag.rewrite.horn_of_theorem(thm: ExprRef | Proof) Rule
Unpack theorem of form forall vs, body => head into a Rule tuple
>>> x = smt.Real("x") >>> horn_of_theorem(smt.ForAll([x], smt.Implies(x**2 == x*x, x > 0))) Rule(vs=[X...], hyp=X...**2 == X...*X..., conc=X... > 0, pf=None)
- kdrag.rewrite.kbo(vs: list[ExprRef], t1: ExprRef, t2: ExprRef) Order
Knuth Bendix Ordering, naive implementation. All weights are 1. Source: Term Rewriting and All That section 5.4.4
- Parameters:
vs (list[ExprRef])
t1 (ExprRef)
t2 (ExprRef)
- Return type:
- kdrag.rewrite.lpo(vs: list[ExprRef], t1: ExprRef, t2: ExprRef) Order
Lexicographic path ordering. Based on https://www21.in.tum.de/~nipkow/TRaAT/programs/termorders.ML TODO add ordering parameter.
- Parameters:
vs (list[ExprRef])
t1 (ExprRef)
t2 (ExprRef)
- Return type:
- kdrag.rewrite.rewrite(t: ExprRef, rules: list[RewriteRule], trace=None) ExprRef
Sweep through term once performing rewrites.
>>> x = smt.Real("x") >>> rule = RewriteRule([x], x**2, x*x) >>> rewrite((x**2)**2, [rule]) x*x*x*x
- Parameters:
t (ExprRef)
rules (list[RewriteRule])
- Return type:
ExprRef
- kdrag.rewrite.rewrite1(t: ExprRef, vs: list[ExprRef], lhs: ExprRef, rhs: ExprRef) ExprRef | None
Rewrite at root a single time.
- Parameters:
t (ExprRef)
vs (list[ExprRef])
lhs (ExprRef)
rhs (ExprRef)
- Return type:
ExprRef | None
- kdrag.rewrite.rewrite1_rule(t: ExprRef, rule: RewriteRule, trace: list[tuple[RewriteRule, dict[ExprRef, ExprRef]]] | None = None) ExprRef | None
Rewrite at root a single time.
- Parameters:
t (ExprRef)
rule (RewriteRule)
trace (list[tuple[RewriteRule, dict[ExprRef, ExprRef]]] | None)
- Return type:
ExprRef | None
- kdrag.rewrite.rewrite_star(t: ExprRef, rules: list[RewriteRule], trace=None) ExprRef
Repeat rewrite until no more rewrites are possible.
- Parameters:
t (ExprRef)
rules (list[RewriteRule])
- Return type:
ExprRef
- kdrag.rewrite.rule_of_theorem(thm: BoolRef | QuantifierRef) RewriteRule
Unpack theorem of form forall vs, lhs = rhs into a Rule tuple
>>> x = smt.Real("x") >>> rule_of_theorem(smt.ForAll([x], x**2 == x*x)) RewriteRule(vs=[X...], lhs=X...**2, rhs=X...*X...)
- Parameters:
thm (BoolRef | QuantifierRef)
- Return type:
- kdrag.rewrite.simp(e: ExprRef, trace=None, max_iter=None) ExprRef
Simplify using definitions and built in z3 simplifier until no progress is made.
>>> import kdrag.theories.nat as nat >>> simp(nat.one + nat.one + nat.S(nat.one)) S(S(S(S(Z))))
>>> p = smt.Bool("p") >>> simp(smt.If(p, 42, 3)) If(p, 42, 3)
- Parameters:
e (ExprRef)
- Return type:
ExprRef
- kdrag.rewrite.simp1(t: ExprRef) ExprRef
simplify a term using z3 built in simplifier
- Parameters:
t (ExprRef)
- Return type:
ExprRef
- kdrag.rewrite.simp2(t: ExprRef) ExprRef
simplify a term using z3 built in simplifier
- Parameters:
t (ExprRef)
- Return type:
ExprRef
- kdrag.rewrite.unfold(e: ExprRef, decls=None, trace=None) ExprRef
Do a single unfold sweep, unfolding definitions defined by kd.define. The optional trace parameter will record proof along the way. decls is an optional list of declarations to unfold. If None, all definitions are unfolded.
>>> x = smt.Int("x") >>> f = kd.define("f", [x], x + 42) >>> trace = [] >>> unfold(f(1), trace=trace) 1 + 42 >>> trace [|- f(1) == 1 + 42]
- Parameters:
e (ExprRef)
- Return type:
ExprRef
"""
Utilities for rewriting and simplification including pattern matching and unification.
"""
import kdrag.smt as smt
import kdrag as kd
from enum import Enum
from typing import NamedTuple, Optional
import kdrag.utils as utils
def simp1(t: smt.ExprRef) -> smt.ExprRef:
"""simplify a term using z3 built in simplifier"""
expr = smt.FreshConst(t.sort(), prefix="knuckle_goal")
G = smt.Goal()
for v in kd.kernel.defns.values():
G.add(v.ax.thm)
G.add(expr == t)
G2 = smt.Then(smt.Tactic("demodulator"), smt.Tactic("simplify")).apply(G)[0]
# TODO make this extraction more robust
return G2[len(G2) - 1].children()[1]
def simp2(t: smt.ExprRef) -> smt.ExprRef:
"""simplify a term using z3 built in simplifier"""
expr = smt.FreshConst(t.sort(), prefix="knuckle_goal")
G = smt.Goal()
for v in kd.kernel.defns.values():
G.add(v.ax.thm)
G.add(expr == t)
G2 = smt.Tactic("elim-predicates").apply(G)[0]
return G2[len(G2) - 1].children()[1]
# TODO: Doesn't seem to do anything?
# def factor(t: smt.ExprRef) -> smt.ExprRef:
# """factor a term using z3 built in tactic"""
# expr = smt.FreshConst(t.sort(), prefix="knuckle_goal")
# G = smt.Goal()
# for v in kd.kernel.defns.values():
# G.add(v.ax.thm)
# G.add(expr == t)
# G2 = smt.Tactic("factor").apply(G)[0]
# return G2[len(G2) - 1].children()[1]
def unfold(e: smt.ExprRef, decls=None, trace=None) -> smt.ExprRef:
"""
Do a single unfold sweep, unfolding definitions defined by `kd.define`.
The optional trace parameter will record proof along the way.
`decls` is an optional list of declarations to unfold. If None, all definitions are unfolded.
>>> x = smt.Int("x")
>>> f = kd.define("f", [x], x + 42)
>>> trace = []
>>> unfold(f(1), trace=trace)
1 + 42
>>> trace
[|- f(1) == 1 + 42]
"""
if smt.is_app(e):
decl = e.decl()
children = [unfold(c, decls=decls, trace=trace) for c in e.children()]
defn = kd.kernel.defns.get(decl)
if defn is not None and (decls is None or decl in decls):
e1 = smt.substitute(defn.body, *zip(defn.args, children))
e = e1
if trace is not None:
if isinstance(defn.ax.thm, smt.QuantifierRef):
trace.append((defn.ax(*children)))
else:
trace.append(defn.ax)
return e1
else:
return decl(*children)
else:
return e
def simp(e: smt.ExprRef, trace=None, max_iter=None) -> smt.ExprRef:
"""
Simplify using definitions and built in z3 simplifier until no progress is made.
>>> import kdrag.theories.nat as nat
>>> simp(nat.one + nat.one + nat.S(nat.one))
S(S(S(S(Z))))
>>> p = smt.Bool("p")
>>> simp(smt.If(p, 42, 3))
If(p, 42, 3)
"""
i = 0
while True:
i += 1
if max_iter is not None and i > max_iter:
return e
e = unfold(e, trace=trace)
# TODO: Interesting options: som, sort_store, elim_ite, flat, split_concat_eq, sort_sums, sort_disjunctions
e1 = smt.simplify(e)
if e1.eq(e):
return e1
else:
if trace is not None:
trace.append(kd.kernel.prove(smt.Eq(e, e1)))
e = e1
def def_eq(e1: smt.ExprRef, e2: smt.ExprRef, trace=None) -> bool:
"""
A notion of computational equality. Unfold and simp.
>>> import kdrag.theories.nat as nat
>>> def_eq(nat.one + nat.one, nat.S(nat.S(nat.Z)))
True
"""
e1 = simp(e1, trace=trace)
e2 = simp(e2, trace=trace)
return kd.utils.alpha_eq(e1, e2)
"""
TODO: But we can have early stopping if we do these processes interleaved.
while not e1.eq(e2):
e1 = unfold(e1, trace=trace)
e2 = unfold(e2, trace=trace)
"""
def rewrite1(
t: smt.ExprRef, vs: list[smt.ExprRef], lhs: smt.ExprRef, rhs: smt.ExprRef
) -> Optional[smt.ExprRef]:
"""
Rewrite at root a single time.
"""
subst = utils.pmatch(vs, lhs, t)
if subst is not None:
return smt.substitute(rhs, *subst.items())
return None
def apply(
goal: smt.BoolRef, vs: list[smt.ExprRef], head: smt.BoolRef, body: smt.BoolRef
) -> Optional[smt.BoolRef]:
res = rewrite1(goal, vs, head, body)
assert res is None or isinstance(res, smt.BoolRef)
return res
class RewriteRule(NamedTuple):
"""A rewrite rule tuple"""
vs: list[smt.ExprRef]
lhs: smt.ExprRef
rhs: smt.ExprRef
def rewrite1_rule(
t: smt.ExprRef,
rule: RewriteRule,
trace: Optional[list[tuple[RewriteRule, dict[smt.ExprRef, smt.ExprRef]]]] = None,
) -> Optional[smt.ExprRef]:
"""
Rewrite at root a single time.
"""
subst = utils.pmatch(rule.vs, rule.lhs, t)
if subst is not None:
return smt.substitute(rule.rhs, *subst.items())
if trace is not None:
trace.append((rule, subst))
return None
def rewrite(t: smt.ExprRef, rules: list[RewriteRule], trace=None) -> smt.ExprRef:
"""
Sweep through term once performing rewrites.
>>> x = smt.Real("x")
>>> rule = RewriteRule([x], x**2, x*x)
>>> rewrite((x**2)**2, [rule])
x*x*x*x
"""
if smt.is_app(t):
t = t.decl()(*[rewrite(arg, rules) for arg in t.children()]) # rewrite children
for rule in rules:
res = rewrite1_rule(t, rule, trace=trace)
if res is not None:
t = res
return t
class RewriteRuleException(Exception): ...
def rule_of_theorem(thm: smt.BoolRef | smt.QuantifierRef) -> RewriteRule:
"""
Unpack theorem of form `forall vs, lhs = rhs` into a Rule tuple
>>> x = smt.Real("x")
>>> rule_of_theorem(smt.ForAll([x], x**2 == x*x))
RewriteRule(vs=[X...], lhs=X...**2, rhs=X...*X...)
"""
vs = []
thm1 = thm # to help out pyright
while isinstance(thm1, smt.QuantifierRef):
if thm1.is_forall():
vs1, thm1 = utils.open_binder(thm1)
vs.extend(vs1)
else:
raise RewriteRuleException("Not a universal quantifier", thm1)
if not smt.is_eq(thm1):
raise RewriteRuleException("Not an equation", thm)
lhs, rhs = thm1.children()
return RewriteRule(vs, lhs, rhs)
def decl_index(rules: list[RewriteRule]) -> dict[smt.FuncDeclRef, RewriteRule]:
"""Build a dictionary of rules indexed by their lhs head function declaration."""
return {rule.lhs.decl(): rule for rule in rules}
def rewrite_star(t: smt.ExprRef, rules: list[RewriteRule], trace=None) -> smt.ExprRef:
"""
Repeat rewrite until no more rewrites are possible.
"""
while True:
t1 = rewrite(t, rules, trace=trace)
if t1.eq(t):
return t1
t = t1
class Rule(NamedTuple):
vs: list[smt.ExprRef]
hyp: smt.BoolRef
conc: smt.BoolRef
pf: Optional[kd.kernel.Proof] = None
def horn_of_theorem(thm: smt.ExprRef | kd.kernel.Proof) -> Rule:
"""Unpack theorem of form `forall vs, body => head` into a Rule tuple
>>> x = smt.Real("x")
>>> horn_of_theorem(smt.ForAll([x], smt.Implies(x**2 == x*x, x > 0)))
Rule(vs=[X...], hyp=X...**2 == X...*X..., conc=X... > 0, pf=None)
"""
pf = None
if isinstance(thm, smt.ExprRef):
pass
elif kd.kernel.is_proof(thm):
pf = thm
thm = pf.thm
if not isinstance(thm, smt.QuantifierRef) or not thm.is_forall():
raise Exception("Not a universal quantifier", thm)
vs, thm = utils.open_binder(thm)
assert smt.is_implies(thm)
return Rule(vs, hyp=thm.arg(0), conc=thm.arg(1), pf=pf)
"""
def apply_horn(thm: smt.BoolRef, horn: smt.BoolRef) -> smt.BoolRef:
pat = horn
obl = []
if smt.is_quantifier(pat) and pat.is_forall():
pat = pat.body()
while True:
if smt.is_implies(pat):
obl.append(pat.arg(0))
pat = pat.arg(1)
else:
break
return kd.utils.z3_match(thm, pat)
def horn_split(horn: smt.BoolRef) -> smt.BoolRef:
body = []
vs = []
while True:
if smt.is_quantifier(horn) and horn.is_forall():
vs1, horn = open_binder(horn)
vs.extend(vs1)
if smt.is_implies(horn):
body.append(horn.arg(0))
horn = horn.arg(1)
else:
break
head = horn
return vs, body, head
"""
class Order(Enum):
EQ = 0 # Equal
GR = 1 # Greater
NGE = 2 # Not Greater or Equal
def lpo(vs: list[smt.ExprRef], t1: smt.ExprRef, t2: smt.ExprRef) -> Order:
"""
Lexicographic path ordering.
Based on https://www21.in.tum.de/~nipkow/TRaAT/programs/termorders.ML
TODO add ordering parameter.
"""
def is_var(x):
return any(x.eq(v) for v in vs)
if is_var(t2):
if t1.eq(t2):
return Order.EQ
elif utils.is_subterm(t2, t1):
return Order.GR
else:
return Order.NGE
elif is_var(t1):
return Order.NGE
elif smt.is_app(t1) and smt.is_app(t2):
decl1, decl2 = t1.decl(), t2.decl()
args1, args2 = t1.children(), t2.children()
if all(lpo(vs, a, t2) == Order.NGE for a in args1):
if decl1 == decl2:
if all(lpo(vs, t1, a) == Order.GR for a in args2):
for a1, a2 in zip(args1, args2):
ord = lpo(vs, a1, a2)
if ord == Order.GR:
return Order.GR
elif ord == Order.NGE:
return Order.NGE
return Order.EQ
else:
return Order.NGE
elif (decl1.name(), decl1.get_id()) > (decl2.name(), decl2.get_id()):
if all(lpo(vs, t1, a) == Order.GR for a in args2):
return Order.GR
else:
return Order.NGE
else:
return Order.NGE
else:
return Order.GR
else:
raise Exception("Unexpected terms in lpo", t1, t2)
def kbo(vs: list[smt.ExprRef], t1: smt.ExprRef, t2: smt.ExprRef) -> Order:
"""
Knuth Bendix Ordering, naive implementation.
All weights are 1.
Source: Term Rewriting and All That section 5.4.4
"""
if t1.eq(t2):
return Order.EQ
def is_var(x):
return any(x.eq(v) for v in vs)
def vcount(t):
todo = [t]
vcount1 = {v: 0 for v in vs}
while todo:
t = todo.pop()
if is_var(t):
vcount1[t] += 1
elif smt.is_app(t):
todo.extend(t.children())
return vcount1
vcount1, vcount2 = vcount(t1), vcount(t2)
if not all(vcount1[v] >= vcount2[v] for v in vs):
return Order.NGE
def weight(t):
todo = [t]
w = 0
while todo:
t = todo.pop()
w += 1
if smt.is_app(t):
todo.extend(t.children())
return w
w1, w2 = weight(t1), weight(t2)
if w1 > w2:
return Order.GR
elif w1 < w2:
return Order.NGE
else:
if is_var(t2): # KBO2a
decl = t1.decl()
if decl.arity() != 1:
return Order.NGE
while not t1.eq(t2):
if t1.decl() != decl:
return Order.NGE
else:
t1 = t1.arg(0)
return Order.GR
elif is_var(t1):
return Order.NGE
elif smt.is_app(t1) and smt.is_app(t2):
decl1, decl2 = t1.decl(), t2.decl()
if decl1 == decl2: # KBO2c
args1, args2 = t1.children(), t2.children()
for a1, a2 in zip(args1, args2):
ord = kbo(vs, a1, a2)
if ord == Order.GR:
return Order.GR
elif ord == Order.NGE:
return Order.NGE
raise Exception("Unexpected equality reached in kbo")
elif (decl1.name(), decl1.get_id()) > (
decl2.name(),
decl2.get_id(),
): # KBO2b
return Order.GR
else:
return Order.NGE
else:
raise Exception("Unexpected terms in kbo", t1, t2)