kdrag.kernel

The kernel hold core proof datatypes and core inference rules. By and large, all proofs must flow through this module.

Module Attributes

defns

defn holds definitional axioms for function symbols.

Functions

axiom(thm[, by])

Assert an axiom.

beta_conv(lam, *args)

Beta conversion for lambda calculus.

consider(x)

The purpose of this is to seed the solver with interesting terms.

define(name, args, body)

Define a non recursive definition.

define_fix(name, args, retsort, fix_lam)

Define a recursive definition.

einstan(thm)

Skolemize an existential quantifier.

forget(ts, pf)

"Forget" a term using existentials.

forget2(ts, thm)

"Forget" a term using existentials.

fresh_const(q)

Generate fresh constants of same sort as quantifier.

herb(thm)

Herbrandize a theorem.

instan(ts, pf)

Instantiate a universally quantified formula.

instan2(ts, thm)

Instantiate a universally quantified formula forall xs, P(xs) -> P(ts) This is forall elimination

is_proof(p)

lemma(thm[, by, admit, timeout, dump, solver])

Prove a theorem using a list of previously proved lemmas.

skolem(pf)

Skolemize an existential quantifier.

Classes

Defn(name, args, body, ax)

A record storing definition.

Exceptions

LemmaError

class kdrag.kernel.Defn(name: str, args: list[ExprRef], body: ExprRef, ax: None)

Bases: object

A record storing definition. It is useful to record definitions as special axioms because we often must unfold them.

args: list[ExprRef]
ax: None
body: ExprRef
name: str
exception kdrag.kernel.LemmaError

Bases: Exception

add_note()

Exception.add_note(note) – add a note to the exception

args
with_traceback()

Exception.with_traceback(tb) – set self.__traceback__ to tb and return self.

kdrag.kernel.axiom(thm: BoolRef, by=[]) Proof

Assert an axiom.

Axioms are necessary and useful. But you must use great care.

Parameters:
  • thm – The axiom to assert.

  • by – A python object explaining why the axiom should exist. Often a string explaining the axiom.

kdrag.kernel.beta_conv(lam: QuantifierRef, *args) None

Beta conversion for lambda calculus.

kdrag.kernel.consider(x: ExprRef) None

The purpose of this is to seed the solver with interesting terms. Axiom schema. We may give a fresh name to any constant. An “anonymous” form of define. Pointing out the interesting terms is sometimes the essence of a proof.

kdrag.kernel.define(name: str, args: list[ExprRef], body: ExprRef) FuncDeclRef

Define a non recursive definition. Useful for shorthand and abstraction. Does not currently defend against ill formed definitions. TODO: Check for bad circularity, record dependencies

Parameters:
  • name – The name of the term to define.

  • args – The arguments of the term.

  • defn – The definition of the term.

Returns:

A tuple of the defined term and the proof of the definition.

Return type:

tuple[smt.FuncDeclRef, __Proof]

kdrag.kernel.define_fix(name: str, args: list[ExprRef], retsort, fix_lam) FuncDeclRef

Define a recursive definition.

kdrag.kernel.defns: dict[FuncDecl, Defn] = {absR: Defn(name='absR', args=[x], body=If(x >= 0, x, -x), ax=|- ForAll(x, absR(x) == If(x >= 0, x, -x))), add: Defn(name='add', args=[f, g], body=Lambda(x, f[x] + g[x]), ax=|- ForAll([f, g, x!0], add(f, g)[x!0] == f[x!0] + g[x!0])), add: Defn(name='add', args=[x, y], body=x + y, ax=|- ForAll([x, y], add(x, y) == x + y)), add: Defn(name='add', args=[z1, z2], body=C(re(z1) + re(z2), im(z1) + im(z2)), ax=|- ForAll([z1, z2],        add(z1, z2) == C(re(z1) + re(z2), im(z1) + im(z2)))), add: Defn(name='add', args=[i, j], body=Interval(lo(i) + lo(j), hi(i) + hi(j)), ax=|- ForAll([i, j],        add(i, j) == Interval(lo(i) + lo(j), hi(i) + hi(j)))), add: Defn(name='add', args=[u, v], body=Vec2(x(u) + x(v), y(u) + y(v)), ax=|- ForAll([u, v], add(u, v) == Vec2(x(u) + x(v), y(u) + y(v)))), add: Defn(name='add', args=[u, v], body=Vec3(x0(u) + x0(v), x1(u) + x1(v), x2(u) + x2(v)), ax=|- ForAll([u, v],        add(u, v) ==        Vec3(x0(u) + x0(v), x1(u) + x1(v), x2(u) + x2(v)))), add: Defn(name='add', args=[x, y], body=If(is(Z, x), y, S(add(pred(x), y))), ax=|- ForAll([x, y],        add(x, y) == If(is(Z, x), y, S(add(pred(x), y))))), cauchy_mod: Defn(name='cauchy_mod', args=[a, mod], body=ForAll(eps,        Implies(eps > 0,                ForAll([m, k],                       Implies(And(m > mod[eps],                                   k > mod[eps]),                               absR(a[m] - a[k]) < eps)))), ax=|- ForAll([a, mod],        cauchy_mod(a, mod) ==        (ForAll(eps,                Implies(eps > 0,                        ForAll([m, k],                               Implies(And(m > mod[eps],                                         k > mod[eps]),                                       absR(a[m] - a[k]) <                                       eps))))))), comp: Defn(name='comp', args=[f, g], body=Lambda(x, f[g[x]]), ax=|- ForAll([f, g, x!148], comp(f, g)[x!148] == f[g[x!148]])), conj: Defn(name='conj', args=[z], body=C(re(z), -im(z)), ax=|- ForAll(z, conj(z) == C(re(z), -im(z)))), const: Defn(name='const', args=[x], body=K(Real, x), ax=|- ForAll(x, const(x) == K(Real, x))), cont_at: Defn(name='cont_at', args=[f, x], body=ForAll(eps,        Implies(eps > 0,                Exists(delta,                       And(delta > 0,                           ForAll(y,                                  Implies(absR(x - y) < delta,                                         absR(f[x] - f[y]) <                                         eps)))))), ax=|- ForAll([f, x],        cont_at(f, x) ==        (ForAll(eps,                Implies(eps > 0,                        Exists(delta,                               And(delta > 0,                                   ForAll(y,                                         Implies(absR(x - y) <                                         delta,                                         absR(f[x] - f[y]) <                                         eps))))))))), cross: Defn(name='cross', args=[u, v], body=Vec3(x1(u)*x2(v) - x2(u)*x1(v),      x2(u)*x0(v) - x0(u)*x2(v),      x0(u)*x1(v) - x1(u)*x0(v)), ax=|- ForAll([u, v],        cross(u, v) ==        Vec3(x1(u)*x2(v) - x2(u)*x1(v),             x2(u)*x0(v) - x0(u)*x2(v),             x0(u)*x1(v) - x1(u)*x0(v)))), diff_at: Defn(name='diff_at', args=[f, x], body=Exists(y, has_diff_at(f, x, y)), ax=|- ForAll([f, x],        diff_at(f, x) == (Exists(y, has_diff_at(f, x, y))))), dist: Defn(name='dist', args=[u, v], body=sqrt(norm2(sub(u, v))), ax=|- ForAll([u, v], dist(u, v) == sqrt(norm2(sub(u, v))))), div_: Defn(name='div_', args=[f, g], body=Lambda(x, f[x]/g[x]), ax=|- ForAll([f, g, x!3], div_(f, g)[x!3] == f[x!3]/g[x!3])), div_: Defn(name='div_', args=[z1, z2], body=C((re(z1)*re(z2) + im(z1)*im(z2))/(re(z2)**2 + im(z2)**2),   (im(z1)*re(z2) - re(z1)*im(z2))/(re(z2)**2 + im(z2)**2)), ax=|- ForAll([z1, z2],        div_(z1, z2) ==        C((re(z1)*re(z2) + im(z1)*im(z2))/          (re(z2)**2 + im(z2)**2),          (im(z1)*re(z2) - re(z1)*im(z2))/          (re(z2)**2 + im(z2)**2)))), div_: Defn(name='div_', args=[u, v], body=Vec3(x0(u)/x0(v), x1(u)/x1(v), x2(u)/x2(v)), ax=|- ForAll([u, v],        div_(u, v) ==        Vec3(x0(u)/x0(v), x1(u)/x1(v), x2(u)/x2(v)))), dot: Defn(name='dot', args=[u, v], body=x(u)*x(v) + y(u)*y(v), ax=|- ForAll([u, v], dot(u, v) == x(u)*x(v) + y(u)*y(v))), dot: Defn(name='dot', args=[u, v], body=0 + x0(u)*x0(v) + x1(u)*x1(v) + x2(u)*x2(v), ax=|- ForAll([u, v],        dot(u, v) ==        0 + x0(u)*x0(v) + x1(u)*x1(v) + x2(u)*x2(v))), even: Defn(name='even', args=[x], body=Exists(y, x == 2*y), ax=|- ForAll(x, even(x) == (Exists(y, x == 2*y)))), expi: Defn(name='expi', args=[t], body=C(cos[t], sin[t]), ax=|- ForAll(t, expi(t) == C(cos[t], sin[t]))), floor: Defn(name='floor', args=[x], body=ToReal(ToInt(x)), ax=|- ForAll(x, floor(x) == ToReal(ToInt(x)))), has_lim_at: Defn(name='has_lim_at', args=[f, p, L], body=ForAll(eps,        Implies(0 < eps,                Exists(delta,                       And(delta > 0,                           ForAll(x,                                  Implies(And(0 < absR(x - p),                                         absR(x - p) < delta),                                         absR(f[x] - L) < eps)))))), ax=|- ForAll([f, p, L],        has_lim_at(f, p, L) ==        (ForAll(eps,                Implies(0 < eps,                        Exists(delta,                               And(delta > 0,                                   ForAll(x,                                         Implies(And(0 <                                         absR(x - p),                                         absR(x - p) < delta),                                         absR(f[x] - L) < eps))))))))), ident: Defn(name='ident', args=[], body=Lambda(x, x), ax=|- ForAll(x!149, ident[x!149] == x!149)), is_cauchy: Defn(name='is_cauchy', args=[a], body=ForAll(eps,        Implies(eps > 0,                Exists(N,                       ForAll([m, k],                              Implies(And(m > N, k > N),                                      absR(a[m] - a[k]) < eps))))), ax=|- ForAll(a,        is_cauchy(a) ==        (ForAll(eps,                Implies(eps > 0,                        Exists(N,                               ForAll([m, k],                                      Implies(And(m > N,                                         k > N),                                         absR(a[m] - a[k]) <                                         eps)))))))), is_cont: Defn(name='is_cont', args=[f], body=ForAll(x, cont_at(f, x)), ax=|- ForAll(f, is_cont(f) == (ForAll(x, cont_at(f, x))))), is_convergent: Defn(name='is_convergent', args=[a], body=ForAll(eps,        Implies(eps > 0,                Exists(N,                       ForAll(m,                              Implies(m > N,                                      Exists(x,                                         absR(a[m] - x) < eps)))))), ax=|- ForAll(a,        is_convergent(a) ==        (ForAll(eps,                Implies(eps > 0,                        Exists(N,                               ForAll(m,                                      Implies(m > N,                                         Exists(x,                                         absR(a[m] - x) < eps))))))))), is_diff: Defn(name='is_diff', args=[f], body=ForAll(x, diff_at(f, x)), ax=|- ForAll(f, is_diff(f) == (ForAll(x, diff_at(f, x))))), join: Defn(name='join', args=[i, j], body=Interval(min(lo(i), lo(j)), max(hi(i), hi(j))), ax=|- ForAll([i, j],        join(i, j) ==        Interval(min(lo(i), lo(j)), max(hi(i), hi(j))))), max: Defn(name='max', args=[x, y], body=If(x >= y, x, y), ax=|- ForAll([x, y], max(x, y) == If(x >= y, x, y))), meet: Defn(name='meet', args=[i, j], body=Interval(max(lo(i), lo(j)), min(hi(i), hi(j))), ax=|- ForAll([i, j],        meet(i, j) ==        Interval(max(lo(i), lo(j)), min(hi(i), hi(j))))), mid: Defn(name='mid', args=[i], body=(lo(i) + hi(i))/2, ax=|- ForAll(i, mid(i) == (lo(i) + hi(i))/2)), min: Defn(name='min', args=[x, y], body=If(x <= y, x, y), ax=|- ForAll([x, y], min(x, y) == If(x <= y, x, y))), mul: Defn(name='mul', args=[f, g], body=Lambda(x, f[x]*g[x]), ax=|- ForAll([f, g, x!2], mul(f, g)[x!2] == f[x!2]*g[x!2])), mul: Defn(name='mul', args=[x, y], body=x*y, ax=|- ForAll([x, y], mul(x, y) == x*y)), mul: Defn(name='mul', args=[z1, z2], body=C(re(z1)*re(z2) - im(z1)*im(z2),   re(z1)*im(z2) + im(z1)*re(z2)), ax=|- ForAll([z1, z2],        mul(z1, z2) ==        C(re(z1)*re(z2) - im(z1)*im(z2),          re(z1)*im(z2) + im(z1)*re(z2)))), mul: Defn(name='mul', args=[u, v], body=Vec3(x0(u)*x0(v), x1(u)*x1(v), x2(u)*x2(v)), ax=|- ForAll([u, v],        mul(u, v) ==        Vec3(x0(u)*x0(v), x1(u)*x1(v), x2(u)*x2(v)))), neg: Defn(name='neg', args=[u], body=Vec3(-x0(u), -x1(u), -x2(u)), ax=|- ForAll(u, neg(u) == Vec3(-x0(u), -x1(u), -x2(u)))), nonneg: Defn(name='nonneg', args=[x], body=absR(x) == x, ax=|- ForAll(x, nonneg(x) == (absR(x) == x))), norm2: Defn(name='norm2', args=[z], body=mul(z, conj(z)), ax=|- ForAll(z, norm2(z) == mul(z, conj(z)))), norm2: Defn(name='norm2', args=[u], body=dot(u, u), ax=|- ForAll(u, norm2(u) == dot(u, u))), norm2: Defn(name='norm2', args=[u], body=x0(u)*x0(u) + x1(u)*x1(u) + x2(u)*x2(u), ax=|- ForAll(u,        norm2(u) == x0(u)*x0(u) + x1(u)*x1(u) + x2(u)*x2(u))), odd: Defn(name='odd', args=[x], body=Exists(y, x == 2*y + 1), ax=|- ForAll(x, odd(x) == (Exists(y, x == 2*y + 1)))), pow: Defn(name='pow', args=[x, y], body=x**y, ax=|- ForAll([x, y], pow(x, y) == x**y)), setof: Defn(name='setof', args=[i], body=Lambda(x, And(lo(i) <= x, x <= hi(i))), ax=|- ForAll([i, x!174],        setof(i)[x!174] ==        And(lo(i) <= x!174, x!174 <= hi(i)))), sgn: Defn(name='sgn', args=[x], body=If(x > 0, 1, If(x < 0, -1, 0)), ax=|- ForAll(x, sgn(x) == If(x > 0, 1, If(x < 0, -1, 0)))), sqr: Defn(name='sqr', args=[x], body=x*x, ax=|- ForAll(x, sqr(x) == x*x)), sqrt: Defn(name='sqrt', args=[x], body=x**(1/2), ax=|- ForAll(x, sqrt(x) == x**(1/2))), sub: Defn(name='sub', args=[f, g], body=Lambda(x, f[x] - g[x]), ax=|- ForAll([f, g, x!1], sub(f, g)[x!1] == f[x!1] - g[x!1])), sub: Defn(name='sub', args=[x, y], body=x - y, ax=|- ForAll([x, y], sub(x, y) == x - y)), sub: Defn(name='sub', args=[i, j], body=Interval(lo(i) - hi(j), hi(i) - lo(j)), ax=|- ForAll([i, j],        sub(i, j) == Interval(lo(i) - hi(j), hi(i) - lo(j)))), sub: Defn(name='sub', args=[u, v], body=Vec2(x(u) - x(v), y(u) - y(v)), ax=|- ForAll([u, v], sub(u, v) == Vec2(x(u) - x(v), y(u) - y(v)))), sub: Defn(name='sub', args=[u, v], body=Vec3(x0(u) - x0(v), x1(u) - x1(v), x2(u) - x2(v)), ax=|- ForAll([u, v],        sub(u, v) ==        Vec3(x0(u) - x0(v), x1(u) - x1(v), x2(u) - x2(v)))), tan: Defn(name='tan', args=[x], body=sin[x]/cos[x], ax=|- ForAll(x, tan(x) == sin[x]/cos[x])), wf: Defn(name='wf', args=[x], body=Implies(is(real, x), val(x) >= 0), ax=|- ForAll(x, wf(x) == Implies(is(real, x), val(x) >= 0))), width: Defn(name='width', args=[i], body=hi(i) - lo(i), ax=|- ForAll(i, width(i) == hi(i) - lo(i)))}

defn holds definitional axioms for function symbols.

kdrag.kernel.einstan(thm: QuantifierRef) tuple[list[ExprRef], None]

Skolemize an existential quantifier. exists xs, P(xs) -> P(cs) for fresh cs https://en.wikipedia.org/wiki/Existential_instantiation

kdrag.kernel.forget(ts: list[ExprRef], pf: None) None

“Forget” a term using existentials. This is existential introduction. This could be derived from forget2

kdrag.kernel.forget2(ts: list[ExprRef], thm: QuantifierRef) None

“Forget” a term using existentials. This is existential introduction. P(ts) -> exists xs, P(xs) thm is an existential formula, and ts are terms to substitute those variables with. forget easily follows. https://en.wikipedia.org/wiki/Existential_generalization

kdrag.kernel.fresh_const(q: QuantifierRef)

Generate fresh constants of same sort as quantifier.

kdrag.kernel.herb(thm: QuantifierRef) tuple[list[ExprRef], None]

Herbrandize a theorem. It is sufficient to prove a theorem for fresh consts to prove a universal. Note: Perhaps lambdaized form is better? Return vars and lamda that could receive |- P[vars]

kdrag.kernel.instan(ts: list[ExprRef], pf: None) None

Instantiate a universally quantified formula. This is forall elimination

kdrag.kernel.instan2(ts: list[ExprRef], thm: BoolRef) None

Instantiate a universally quantified formula forall xs, P(xs) -> P(ts) This is forall elimination

kdrag.kernel.is_proof(p: Proof) bool
kdrag.kernel.lemma(thm: BoolRef, by: list[None] = [], admit=False, timeout=1000, dump=False, solver=None) Proof

Prove a theorem using a list of previously proved lemmas.

In essence prove(Implies(by, thm)).

Parameters:
  • thm (smt.BoolRef) – The theorem to prove.

  • thm – The theorem to prove.

  • by (list[Proof]) – A list of previously proved lemmas.

  • admit (bool) – If True, admit the theorem without proof.

Returns:

A proof object of thm

Return type:

Proof

>>> lemma(smt.BoolVal(True))
|- True
>>> lemma(smt.RealVal(1) >= smt.RealVal(0))
|- 1 >= 0
kdrag.kernel.skolem(pf: None) tuple[list[ExprRef], None]

Skolemize an existential quantifier.

"""
The kernel hold core proof datatypes and core inference rules. By and large, all proofs must flow through this module.
"""

import kdrag.smt as smt
from dataclasses import dataclass
from typing import Any
import logging
from . import config

logger = logging.getLogger("knuckledragger")


@dataclass(frozen=True)
class Proof(smt.Z3PPObject):
    thm: smt.BoolRef
    reason: list[Any]
    admit: bool = False

    def _repr_html_(self):
        return "&#8870;" + repr(self.thm)

    def __repr__(self):
        return "|- " + repr(self.thm)


# It is unlikely that users should be accessing the `Proof` constructor directly.
# This is not ironclad. If you really want the Proof constructor, I can't stop you.
__Proof = Proof
Proof = None


def is_proof(p: __Proof) -> bool:
    return isinstance(p, __Proof)


class LemmaError(Exception):
    pass


def lemma(
    thm: smt.BoolRef,
    by: list[Proof] = [],
    admit=False,
    timeout=1000,
    dump=False,
    solver=None,
) -> __Proof:
    """Prove a theorem using a list of previously proved lemmas.

    In essence `prove(Implies(by, thm))`.

    :param thm: The theorem to prove.
    Args:
        thm (smt.BoolRef): The theorem to prove.
        by (list[Proof]): A list of previously proved lemmas.
        admit     (bool): If True, admit the theorem without proof.

    Returns:
        Proof: A proof object of thm

    >>> lemma(smt.BoolVal(True))
    |- True
    >>> lemma(smt.RealVal(1) >= smt.RealVal(0))
    |- 1 >= 0
    """
    if admit:
        logger.warning("Admitting lemma {}".format(thm))
        return __Proof(thm, by, True)
    else:
        if solver is None:
            solver = config.solver
        s = solver()
        s.set("timeout", timeout)
        for p in by:
            if not isinstance(p, __Proof):
                raise LemmaError("In by reasons:", p, "is not a Proof object")
            s.add(p.thm)
        s.add(smt.Not(thm))
        if dump:
            print(s.sexpr())
        res = s.check()
        if res != smt.unsat:
            if res == smt.sat:
                raise LemmaError(thm, "Countermodel", s.model())
            raise LemmaError("lemma", thm, res)
        else:
            return __Proof(thm, by, False)


def axiom(thm: smt.BoolRef, by=[]) -> __Proof:
    """Assert an axiom.

    Axioms are necessary and useful. But you must use great care.

    Args:
        thm: The axiom to assert.
        by: A python object explaining why the axiom should exist. Often a string explaining the axiom.
    """
    return __Proof(thm, by, admit=True)


@dataclass(frozen=True)
class Defn:
    """
    A record storing definition. It is useful to record definitions as special axioms because we often must unfold them.
    """

    name: str
    args: list[smt.ExprRef]
    body: smt.ExprRef
    ax: Proof


defns: dict[smt.FuncDecl, Defn] = {}
"""
defn holds definitional axioms for function symbols.
"""
smt.FuncDeclRef.defn = property(lambda self: defns[self].ax)
smt.ExprRef.defn = property(lambda self: defns[self.decl()].ax)


def fresh_const(q: smt.QuantifierRef):
    """Generate fresh constants of same sort as quantifier."""
    # .split("!") is to remove ugly multiple freshness from names
    return [
        smt.FreshConst(q.var_sort(i), prefix=q.var_name(i).split("!")[0])
        for i in range(q.num_vars())
    ]


def define(name: str, args: list[smt.ExprRef], body: smt.ExprRef) -> smt.FuncDeclRef:
    """
    Define a non recursive definition. Useful for shorthand and abstraction. Does not currently defend against ill formed definitions.
    TODO: Check for bad circularity, record dependencies

    Args:
        name: The name of the term to define.
        args: The arguments of the term.
        defn: The definition of the term.

    Returns:
        tuple[smt.FuncDeclRef, __Proof]: A tuple of the defined term and the proof of the definition.
    """
    sorts = [arg.sort() for arg in args] + [body.sort()]
    f = smt.Function(name, *sorts)

    # TODO: This is getting too hairy for the kernel? Reassess. Maybe just a lambda flag? Autolift?
    if smt.is_quantifier(body) and body.is_lambda():
        # It is worth it to avoid having lambdas in definition.
        vs = fresh_const(body)
        # print(vs, f(*args)[tuple(vs)])
        # print(smt.substitute_vars(body.body(), *vs))
        def_ax = axiom(
            smt.ForAll(
                args + vs,
                smt.Eq(
                    f(*args)[tuple(vs)], smt.substitute_vars(body.body(), *reversed(vs))
                ),
            ),
            by="definition",
        )
    elif len(args) == 0:
        def_ax = axiom(smt.Eq(f(), body), by="definition")
    else:
        def_ax = axiom(smt.ForAll(args, smt.Eq(f(*args), body)), by="definition")
    # assert f not in __sig or __sig[f].eq(   def_ax.thm)  # Check for redefinitions. This is kind of painful. Hmm.
    # Soft warning is more pleasant.
    defn = Defn(name, args, body, def_ax)
    if f not in defns or defns[f].ax.thm.eq(def_ax.thm):
        defns[f] = defn
    else:
        print("WARNING: Redefining function", f, "from", defns[f].ax, "to", def_ax.thm)
        defns[f] = defn
    if len(args) == 0:
        return f()  # Convenience
    else:
        return f


def define_fix(name: str, args: list[smt.ExprRef], retsort, fix_lam) -> smt.FuncDeclRef:
    """
    Define a recursive definition.
    """
    sorts = [arg.sort() for arg in args]
    sorts.append(retsort)
    f = smt.Function(name, *sorts)

    # wrapper to record calls
    calls = set()

    def record_f(*args):
        calls.add(args)
        return f(*args)

    defn = define(name, args, fix_lam(record_f))
    # TODO: check for well foundedness/termination, custom induction principle.
    return defn


def consider(x: smt.ExprRef) -> Proof:
    """
    The purpose of this is to seed the solver with interesting terms.
    Axiom schema. We may give a fresh name to any constant. An "anonymous" form of define.
    Pointing out the interesting terms is sometimes the essence of a proof.
    """
    return axiom(smt.Eq(smt.FreshConst(x.sort(), prefix="consider"), x))


def instan(ts: list[smt.ExprRef], pf: Proof) -> Proof:
    """
    Instantiate a universally quantified formula.
    This is forall elimination
    """
    assert is_proof(pf) and pf.thm.is_forall() and len(ts) == pf.thm.num_vars()

    return __Proof(smt.substitute_vars(pf.thm.body(), *reversed(ts)), reason=[pf])


def instan2(ts: list[smt.ExprRef], thm: smt.BoolRef) -> Proof:
    """
    Instantiate a universally quantified formula
    `forall xs, P(xs) -> P(ts)`
    This is forall elimination
    """
    assert smt.is_quantifier(thm) and thm.is_forall() and len(ts) == thm.num_vars()

    return __Proof(
        smt.Implies(thm, smt.substitute_vars(thm.body(), *reversed(ts))),
        reason="forall_elim",
    )


def forget(ts: list[smt.ExprRef], pf: Proof) -> Proof:
    """
    "Forget" a term using existentials. This is existential introduction.
    This could be derived from forget2
    """
    assert is_proof(pf)
    vs = [smt.FreshConst(t.sort()) for t in ts]
    return __Proof(smt.Exists(vs, smt.substitute(pf.thm, *zip(ts, vs))), reason=[pf])


def forget2(ts: list[smt.ExprRef], thm: smt.QuantifierRef) -> Proof:
    """
    "Forget" a term using existentials. This is existential introduction.
    `P(ts) -> exists xs, P(xs)`
    `thm` is an existential formula, and `ts` are terms to substitute those variables with.
    forget easily follows.
    https://en.wikipedia.org/wiki/Existential_generalization
    """
    assert smt.is_quantifier(thm) and thm.is_exists() and len(ts) == thm.num_vars()
    return __Proof(
        smt.Implies(smt.substitute_vars(thm.body(), *reversed(ts)), thm),
        reason="exists_intro",
    )


def einstan(thm: smt.QuantifierRef) -> tuple[list[smt.ExprRef], Proof]:
    """
    Skolemize an existential quantifier.
    `exists xs, P(xs) -> P(cs)` for fresh cs
    https://en.wikipedia.org/wiki/Existential_instantiation
    """
    # TODO: Hmm. Maybe we don't need to have a Proof? Lessen this to thm.
    assert smt.is_quantifier(thm) and thm.is_exists()

    skolems = fresh_const(thm)
    return skolems, __Proof(
        smt.Implies(thm, smt.substitute_vars(thm.body(), *reversed(skolems))),
        reason=["einstan"],
    )


def skolem(pf: Proof) -> tuple[list[smt.ExprRef], Proof]:
    """
    Skolemize an existential quantifier.
    """
    # TODO: Hmm. Maybe we don't need to have a Proof? Lessen this to thm.
    assert is_proof(pf) and pf.thm.is_exists()

    skolems = fresh_const(pf.thm)
    return skolems, __Proof(
        smt.substitute_vars(pf.thm.body(), *reversed(skolems)), reason=[pf]
    )


def herb(thm: smt.QuantifierRef) -> tuple[list[smt.ExprRef], Proof]:
    """
    Herbrandize a theorem.
    It is sufficient to prove a theorem for fresh consts to prove a universal.
    Note: Perhaps lambdaized form is better? Return vars and lamda that could receive `|- P[vars]`
    """
    assert smt.is_quantifier(thm) and thm.is_forall()
    herbs = fresh_const(thm)
    return herbs, __Proof(
        smt.Implies(smt.substitute_vars(thm.body(), *reversed(herbs)), thm),
        reason="herband",
    )


def beta_conv(lam: smt.QuantifierRef, *args) -> Proof:
    """
    Beta conversion for lambda calculus.
    """
    assert len(args) == lam.num_vars()
    assert smt.is_quantifier(lam) and lam.is_lambda()
    return axiom(smt.Eq(lam[args], smt.substitute_vars(lam.body(), *reversed(args))))