
Module Attributes


defn holds definitional axioms for function symbols.


axiom(thm[, by])

Assert an axiom.

beta_conv(lam, *args)

Beta conversion for lambda calculus.


The purpose of this is to seed the solver with interesting terms.

define(name, args, body)

Define a non recursive definition.

define_fix(name, args, retsort, fix_lam)

Define a recursive definition.

forget(ts, pf)

"Forget" a term using existentials.

forget2(ts, thm)

"Forget" a term using existentials.


Generate fresh constants of same sort as quantifier.


Herbrandize a theorem.

instan(ts, pf)

Instantiate a universally quantified formula.

instan2(ts, thm)

Instantiate a universally quantified formula.


lemma(thm[, by, admit, timeout, dump, solver])

Prove a theorem using a list of previously proved lemmas.


Skolemize an existential quantifier.


Defn(name, args, body, ax)

A record storing definition.



import kdrag.smt as smt
from dataclasses import dataclass
from typing import Any
import logging
from . import config

logger = logging.getLogger("knuckledragger")

class Proof(smt.Z3PPObject):
    thm: smt.BoolRef
    reason: list[Any]
    admit: bool = False

    def _repr_html_(self):
        return "⊦" + repr(self.thm)

    def __repr__(self):
        return "|- " + repr(self.thm)

# It is unlikely that users should be accessing the `Proof` constructor directly.
# This is not ironclad. If you really want the Proof constructor, I can't stop you.
__Proof = Proof
Proof = None

def is_proof(p):
    return isinstance(p, __Proof)

class LemmaError(Exception):

def lemma(
    thm: smt.BoolRef,
    by: list[Proof] = [],
) -> Proof:
    """Prove a theorem using a list of previously proved lemmas.

    In essence `prove(Implies(by, thm))`.

    :param thm: The theorem to prove.
        thm (smt.BoolRef): The theorem to prove.
        by (list[Proof]): A list of previously proved lemmas.
        admit     (bool): If True, admit the theorem without proof.

        Proof: A proof object of thm

    >>> lemma(BoolVal(True))

    >>> lemma(RealVal(1) >= RealVal(0))

    if admit:
        logger.warning("Admitting lemma {}".format(thm))
        return __Proof(thm, by, True)
        if solver is None:
            solver = config.solver
        s = solver()
        s.set("timeout", timeout)
        for p in by:
            if not isinstance(p, __Proof):
                raise LemmaError("In by reasons:", p, "is not a Proof object")
        if dump:
        res = s.check()
        if res != smt.unsat:
            if res == smt.sat:
                raise LemmaError(thm, "Countermodel", s.model())
            raise LemmaError("lemma", thm, res)
            return __Proof(thm, by, False)

def axiom(thm: smt.BoolRef, by=[]) -> Proof:
    """Assert an axiom.

    Axioms are necessary and useful. But you must use great care.

        thm: The axiom to assert.
        by: A python object explaining why the axiom should exist. Often a string explaining the axiom.
    return __Proof(thm, by, admit=True)

class Defn:
    A record storing definition. It is useful to record definitions as special axioms because we often must unfold them.

    name: str
    args: list[smt.ExprRef]
    body: smt.ExprRef
    ax: Proof

defns: dict[smt.FuncDecl, Defn] = {}
smt.FuncDeclRef.defn = property(lambda self: defns[self].ax)
smt.ExprRef.defn = property(lambda self: defns[self.decl()].ax)

def fresh_const(q: smt.QuantifierRef):
    """Generate fresh constants of same sort as quantifier."""
    return [
        smt.FreshConst(q.var_sort(i), prefix=q.var_name(i)) for i in range(q.num_vars())

def define(name: str, args: list[smt.ExprRef], body: smt.ExprRef) -> smt.FuncDeclRef:
    Define a non recursive definition. Useful for shorthand and abstraction. Does not currently defend against ill formed definitions.
    TODO: Check for bad circularity, record dependencies

        name: The name of the term to define.
        args: The arguments of the term.
        defn: The definition of the term.

        tuple[smt.FuncDeclRef, __Proof]: A tuple of the defined term and the proof of the definition.
    sorts = [arg.sort() for arg in args] + [body.sort()]
    f = smt.Function(name, *sorts)

    # TODO: This is getting too hairy for the kernel? Reassess. Maybe just a lambda flag? Autolift?
    if smt.is_quantifier(body) and body.is_lambda():
        # It is worth it to avoid having lambdas in definition.
        vs = fresh_const(body)
        # print(vs, f(*args)[tuple(vs)])
        # print(smt.substitute_vars(body.body(), *vs))
        def_ax = axiom(
                args + vs,
                f(*args)[tuple(vs)] == smt.substitute_vars(body.body(), *reversed(vs)),
    elif len(args) == 0:
        def_ax = axiom(f() == body, by="definition")
        def_ax = axiom(smt.ForAll(args, f(*args) == body), by="definition")
    # assert f not in __sig or __sig[f].eq(   def_ax.thm)  # Check for redefinitions. This is kind of painful. Hmm.
    # Soft warning is more pleasant.
    defn = Defn(name, args, body, def_ax)
    if f not in defns or defns[f].ax.thm.eq(def_ax.thm):
        defns[f] = defn
        print("WARNING: Redefining function", f, "from", defns[f].ax, "to", def_ax.thm)
        defns[f] = defn
    if len(args) == 0:
        return f()  # Convenience
        return f

def define_fix(name: str, args: list[smt.ExprRef], retsort, fix_lam) -> smt.FuncDeclRef:
    Define a recursive definition.
    sorts = [arg.sort() for arg in args]
    f = smt.Function(name, *sorts)

    # wrapper to record calls
    calls = set()

    def record_f(*args):
        return f(*args)

    defn = define(name, args, fix_lam(record_f))
    # TODO: check for well foundedness/termination, custom induction principle.
    return defn

def consider(x: smt.ExprRef) -> Proof:
    The purpose of this is to seed the solver with interesting terms.
    Axiom schema. We may give a fresh name to any constant. An "anonymous" form of define.
    Pointing out the interesting terms is sometimes the essence of a proof.
    return axiom(smt.FreshConst(x.sort(), prefix="consider") == x)

def instan(ts: list[smt.ExprRef], pf: Proof) -> Proof:
    Instantiate a universally quantified formula.
    This is forall elimination
    assert is_proof(pf) and pf.thm.is_forall()

    return __Proof(smt.substitute_vars(pf.thm.body(), *reversed(ts)), reason=[pf])

def instan2(ts: list[smt.ExprRef], thm: smt.BoolRef) -> Proof:
    Instantiate a universally quantified formula.
    This is forall elimination
    assert smt.is_quantifier(thm) and thm.is_forall()

    return __Proof(
        smt.Implies(thm, smt.substitute_vars(thm.body(), *reversed(ts))),

def forget(ts: list[smt.ExprRef], pf: Proof) -> Proof:
    "Forget" a term using existentials. This is existential introduction.
    assert is_proof(pf)
    vs = fresh_const(pf.thm)
    return __Proof(smt.Exists(vs, smt.substitute(pf.thm, *zip(ts, vs))), reason=[pf])

def forget2(ts: list[smt.ExprRef], thm: smt.BoolRef) -> Proof:
    "Forget" a term using existentials. This is existential introduction.
    forget easily follows.
    vs = fresh_const(thm)
    return __Proof(
        smt.Implies(thm, smt.Exists(vs, smt.substitute(thm, *zip(ts, vs)))),

def skolem(pf: Proof) -> tuple[list[smt.ExprRef], Proof]:
    Skolemize an existential quantifier.
    # TODO: Hmm. Maybe we don't need to have a Proof? Lessen this to thm.
    assert is_proof(pf) and pf.thm.is_exists()

    skolems = fresh_const(pf.thm)
    return skolems, __Proof(
        smt.substitute_vars(pf.thm.body(), *reversed(skolems)), reason=[pf]

def herb(thm: smt.QuantifierRef) -> tuple[list[smt.ExprRef], Proof]:
    Herbrandize a theorem.
    It is sufficient to prove a theorem for fresh consts to prove a universal.
    Note: Perhaps lambdaized form is better?
    assert smt.is_quantifier(thm) and thm.is_forall()
    herbs = fresh_const(thm)
    return herbs, __Proof(
        smt.Implies(smt.substitute_vars(thm.body(), *reversed(herbs)), thm),

def beta_conv(lam: smt.QuantifierRef, *args) -> Proof:
    Beta conversion for lambda calculus.
    assert len(args) == lam.num_vars()
    assert smt.is_quantifier(lam) and lam.is_lambda()
    return axiom(lam[args] == smt.substitute_vars(lam.body(), *reversed(args)))