kdrag.hypothesis

Helper functions for the hypothesis property based testing library.

This can be useful for: - Giving counterexamples to poorly stated theorems before you spend much effort on them - Sanity checking axioms - Connecting formal models to other code - Testing Knuckledragger facilities - Testing Z3 intended meaning

Functions

binop(children, op)

binops(children)

compares(strat)

nitpick(thm[, deadline])

Run a hypothesis test to check that an instantiated forall is equivalent to the original forall.

smt_datatype_val(s)

smt_generic_val(sort[, maxiter])

A hypothesis search strateegy that uses smt model generation to generate a value of a given SMT sort.

smt_seq_val(s)

sort_occurs(s, s2[, visited])

Check if a sort occurs in the datatype.

val_of_sort(s[, knot_tie, slow_generic])

Make a search strategy of values of a given SMT sort.

z3_array_val(dom, ran)

kdrag.hypothesis.binop(children, op) SearchStrategy
Return type:

SearchStrategy

kdrag.hypothesis.binops(children) SearchStrategy
Return type:

SearchStrategy

kdrag.hypothesis.compares(strat) SearchStrategy
Return type:

SearchStrategy

kdrag.hypothesis.nitpick(thm: QuantifierRef, deadline=100, **hyp_settings)

Run a hypothesis test to check that an instantiated forall is equivalent to the original forall.

Parameters:

thm (QuantifierRef)

kdrag.hypothesis.smt_datatype_val(s: DatatypeSortRef) SearchStrategy[DatatypeRef]
Parameters:

s (DatatypeSortRef)

Return type:

SearchStrategy[DatatypeRef]

kdrag.hypothesis.smt_generic_val(sort: SortRef, maxiter=4) SearchStrategy[ExprRef]

A hypothesis search strateegy that uses smt model generation to generate a value of a given SMT sort. It is slower and will have worse shrinkage. To be used as a fallback.

Parameters:

sort (SortRef)

Return type:

SearchStrategy[ExprRef]

kdrag.hypothesis.smt_seq_val(s: SortRef) SearchStrategy[SeqRef]
Parameters:

s (SortRef)

Return type:

SearchStrategy[SeqRef]

kdrag.hypothesis.sort_occurs(s, s2, visited=None)

Check if a sort occurs in the datatype.

>>> import kdrag.theories.list as list
>>> sort_occurs(smt.IntSort(), list.List(smt.IntSort()))
True
>>> sort_occurs(smt.IntSort(), list.List(smt.BoolSort()))
False
>>> sort_occurs(smt.IntSort(), smt.IntSort())
True
kdrag.hypothesis.val_of_sort(s: SortRef, knot_tie: tuple[SortRef, SearchStrategy[ExprRef]] | None = None, slow_generic=False) SearchStrategy[ExprRef]

Make a search strategy of values of a given SMT sort.

Parameters:
  • s (SortRef)

  • knot_tie (tuple[SortRef, SearchStrategy[ExprRef]] | None)

Return type:

SearchStrategy[ExprRef]

kdrag.hypothesis.z3_array_val(dom: SearchStrategy[ExprRef], ran: SearchStrategy[ExprRef]) SearchStrategy[ArrayRef]
Parameters:
  • dom (SearchStrategy[ExprRef])

  • ran (SearchStrategy[ExprRef])

Return type:

SearchStrategy[ArrayRef]

"""
Helper functions for the hypothesis property based testing library.

This can be useful for:
- Giving counterexamples to poorly stated theorems before you spend much effort on them
- Sanity checking axioms
- Connecting formal models to other code
- Testing Knuckledragger facilities
- Testing Z3 intended meaning
"""

import kdrag as kd
import kdrag.smt as smt
import hypothesis
import hypothesis.strategies as st
import operator as op
from typing import Optional

smt_sorts = st.recursive(
    st.sampled_from([smt.BoolSort(), smt.IntSort(), smt.RealSort(), smt.StringSort()]),
    lambda children: st.one_of(
        st.tuples(children, children).map(lambda x: smt.ArraySort(x[0], x[1])),
        children.map(lambda x: smt.SeqSort(x)),
    ),
)

names = st.sampled_from("x y z".split())
# I think we'll get more interesting bugs with more name clashes rather than exploring weird names


def binop(children, op) -> st.SearchStrategy:
    return st.tuples(children, children).map(lambda t: op(t[0], t[1]))


def binops(children) -> st.SearchStrategy:
    return st.one_of(
        binop(children, op.add),
        binop(children, op.sub),
        binop(children, op.mul),
        binop(children, op.truediv),
    )


smt_int_val: st.SearchStrategy[smt.ArithRef] = st.integers().map(smt.IntVal)
smt_int_expr = st.recursive(
    st.one_of(smt_int_val, names.map(smt.Int)),
    lambda children: st.one_of(
        binop(children, op.add),
        binop(children, op.sub),
        binop(children, op.mul),
        binop(children, op.truediv),
        st.deferred(
            lambda: st.tuples(smt_bool_expr, children, children).map(
                lambda x: smt.If(x[0], x[1], x[2])
            )
        ),
    ),
)


smt_bool_val: st.SearchStrategy[smt.BoolRef] = st.sampled_from(
    [smt.BoolVal(True), smt.BoolVal(False)]
)

smt_real_val = st.fractions().map(smt.RealVal)
smt_real_expr = st.recursive(
    st.one_of(
        smt_real_val,
        names.map(smt.Real),
    ),
    lambda children: st.one_of(
        binop(children, op.add),
        binop(children, op.sub),
        binop(children, op.mul),
        binop(children, op.truediv),
        st.deferred(
            lambda: st.tuples(smt_bool_expr, children, children).map(
                lambda x: smt.If(x[0], x[1], x[2])
            )
        ),
    ),
)


def compares(strat) -> st.SearchStrategy:
    return st.one_of(
        binop(strat, op.eq),
        binop(strat, op.ne),
        binop(strat, op.lt),
        binop(strat, op.le),
        binop(strat, op.gt),
        binop(strat, op.ge),
    )


smt_bool_expr = st.recursive(
    st.one_of(
        smt_bool_val,
        names.map(smt.Bool),
        compares(smt_int_expr),
        compares(smt_real_expr),
    ),
    lambda children: st.one_of(
        binop(children, smt.And),
        binop(children, smt.Or),
        binop(children, smt.Xor),
        st.tuples(children, children).map(lambda x: x[0] == x[1]),
        st.tuples(children, children).map(lambda x: smt.Implies(x[0], x[1])),
    ),
)


smt_string_val = st.text().map(smt.StringVal)


def sort_occurs(s, s2, visited=None):
    """
    Check if a sort occurs in the datatype.

    >>> import kdrag.theories.list as list
    >>> sort_occurs(smt.IntSort(), list.List(smt.IntSort()))
    True
    >>> sort_occurs(smt.IntSort(), list.List(smt.BoolSort()))
    False
    >>> sort_occurs(smt.IntSort(), smt.IntSort())
    True
    """
    if visited is None:
        visited = set()
    if s2 in visited:
        return False
    elif s == s2:
        return True
    elif isinstance(s2, smt.ArraySortRef):
        visited.add(s2)
        return sort_occurs(s, s2.domain(), visited=visited) or sort_occurs(
            s, s2.range(), visited=visited
        )
    elif isinstance(s2, smt.SeqSortRef):
        visited.add(s2)
        return sort_occurs(s, s2.basis(), visited=visited)
    elif isinstance(s2, smt.DatatypeSortRef):
        visited.add(s2)
        for i in range(s2.num_constructors()):
            cons = s2.constructor(i)
            for j in range(cons.arity()):
                field_sort = cons.domain(j)
                if sort_occurs(s, field_sort, visited=visited):
                    return True
        return False
    else:
        return False


def smt_datatype_val(s: smt.DatatypeSortRef) -> st.SearchStrategy[smt.DatatypeRef]:
    # TODO: with a lot of muscle grease, we could probably do better than a big deferred
    bases = []
    for i in range(s.num_constructors()):
        cons = s.constructor(i)
        if cons.arity() == 0:
            bases.append(st.just(cons()))  # optimization
        elif all(not sort_occurs(s, cons.domain(j)) for j in range(cons.arity())):
            args = []
            for j in range(cons.arity()):
                field_sort = cons.domain(j)
                args.append(val_of_sort(field_sort))
            bases.append(st.tuples(*args).map(lambda args: cons(*args)))
    base = st.one_of(bases)

    def rec(children):
        cases = []
        for i in range(s.num_constructors()):
            cons = s.constructor(i)

            def conswrap(*args):
                return cons(*args)

            conswrap.__name__ = cons.name()  # hack to get slightly better output
            if any(sort_occurs(s, cons.domain(j)) for j in range(cons.arity())):
                args = []
                for j in range(cons.arity()):
                    field_sort = cons.domain(j)
                    args.append(val_of_sort(field_sort, knot_tie=(s, children)))
                cases.append(st.tuples(*args).map(conswrap))

        return st.one_of(*cases)

    return st.recursive(base, rec)


def smt_seq_val(s: smt.SortRef) -> st.SearchStrategy[smt.SeqRef]:
    vsort = val_of_sort(s)
    return st.one_of(
        st.just(smt.Empty(smt.SeqSort(s))),
        vsort.map(lambda v: smt.Unit(v)),
        st.lists(vsort, min_size=2).map(
            lambda l: smt.Concat(*[smt.Unit(x) for x in l])
        ),
    )


def z3_array_val(
    dom: st.SearchStrategy[smt.ExprRef], ran: st.SearchStrategy[smt.ExprRef]
) -> st.SearchStrategy[smt.ArrayRef]:
    def of_list(l: list[tuple[smt.ExprRef, smt.ExprRef]]) -> smt.ArrayRef:
        k, v = l.pop()
        acc = smt.K(k.sort(), v)
        for k, v in l:
            acc = smt.Store(acc, k, v)
        return acc

    return st.lists(st.tuples(dom, ran), min_size=1).map(of_list)


def val_of_sort(
    s: smt.SortRef,
    knot_tie: Optional[tuple[smt.SortRef, st.SearchStrategy[smt.ExprRef]]] = None,
    slow_generic=False,
) -> st.SearchStrategy[smt.ExprRef]:
    """
    Make a search strategy of values of a given SMT sort.
    """
    if knot_tie is not None and knot_tie[0] == s:
        return knot_tie[1]
    if s == smt.BoolSort():
        return smt_bool_val
    elif s == smt.IntSort():
        return smt_int_val
    elif s == smt.RealSort():
        return smt_real_val
    elif s == smt.StringSort():
        return smt_string_val
    elif isinstance(s, smt.ArraySortRef):
        return z3_array_val(val_of_sort(s.domain()), val_of_sort(s.range()))
    elif isinstance(s, smt.SeqSortRef):
        return smt_seq_val(s.basis())
    elif isinstance(s, smt.DatatypeSortRef):
        return smt_datatype_val(s)
    else:
        # return smt_generic_val(s) # This is really slow. We're better off just throwing an error
        if slow_generic:
            return smt_generic_val(s)
        else:
            raise NotImplementedError(f"Don't know how to generate values for {s}")


# def expr_of_sort(s: smt.SortRef):


@st.composite
def smt_generic_val(draw: st.DrawFn, sort: smt.SortRef, maxiter=4) -> smt.ExprRef:
    """
    A hypothesis search strateegy that uses smt model generation to generate a value of a given SMT sort. It is slower
    and will have worse shrinkage. To be used as a fallback.
    """
    x, y = smt.Consts("x y", sort)
    s = smt.Solver()
    # s.set("random_seed", draw(st.integers())) # Did not seem to work.
    # According to Z3 docs, Solver is deterministic.
    s.add(x == y)
    res = s.check()
    assert res == smt.sat
    v = s.model()[x]
    for j in range(draw(st.integers(min_value=0, max_value=maxiter))):
        if res == smt.sat:
            v = s.model()[x]
            s.add(v != x)
            res = s.check()
        elif res == smt.unsat:
            break
    return v


def nitpick(thm: smt.QuantifierRef, deadline=100, **hyp_settings):
    """
    Run a hypothesis test to check that an instantiated forall is equivalent to the original forall.
    """
    assert isinstance(thm, smt.QuantifierRef) and thm.is_forall()
    sorts = [val_of_sort(thm.var_sort(i)) for i in range(thm.num_vars())]
    body = thm.body()
    N = len(sorts)

    # Todo: could specialize to arity of the quantifier. Probably not worth it.
    @hypothesis.settings(deadline=deadline, **hyp_settings)
    @hypothesis.given(**{str(i): sort for i, sort in enumerate(sorts)})
    def nitpick(**kwargs):
        t0 = smt.substitute_vars(body, *[kwargs[str(i)] for i in range(N - 1, -1, -1)])
        hypothesis.note(("Starting point: ", t0))
        t1 = kd.rewrite.simp(t0, max_iter=1000000000)
        hypothesis.note(("Simplifies to: ", t1))
        if not smt.is_true(t1):
            s = smt.Solver()
            s.set("timeout", 100)
            s.add(smt.Not(t1))
            res = s.check()
            if res == smt.sat:
                model = s.model()
                hypothesis.note(("Counterexample: ", model))
                raise AssertionError("Found a counterexample", model)
            elif res == smt.unsat:
                pass
            else:
                raise AssertionError("Could not find a counterexample")

    nitpick()