kdrag.theories.bitvec

Theorems about bitvectors. These are theorems about the built in smtlib bitvector types

Module Attributes

BitVecN

Arbitrary length bitvectors.

Functions

BVNot(x)

BitVecNConst(name, N)

There is a lot of confusion possible with this construct.

BitVecNVal(x, N)

BitVecSort(N)

PopCount(x)

SelectConcat(a, addr, n[, le])

Concat out of an array.

StoreConcat(a, addr, data[, le])

Store multiple bytes into an array.

fromBV(x)

select64(outsize)

toBV(x, N)

kdrag.theories.bitvec.BVNot(x: DatatypeRef) DatatypeRef
>>> smt.simplify(BVNot(BitVecNVal(1, 3)))
BitVecN(Concat(Unit(0), Concat(Unit(1), Unit(1))))
Parameters:

x (DatatypeRef)

Return type:

DatatypeRef

kdrag.theories.bitvec.BitVecN = BitVecN

Arbitrary length bitvectors. Least significant bit comes first (index 0). Concat is unfortunately reversed compared to bitvector convetions.

Fix via Newtype wrapper? I guess I want overloading anyway

Parameters:

args (ExprRef)

Return type:

DatatypeRef

kdrag.theories.bitvec.BitVecNConst(name: str, N: int) DatatypeRef

There is a lot of confusion possible with this construct. Maybe it shouldn’t exist.

>>> BitVecNConst("x", 3)
BitVecN(Concat(Unit(x[0]), Concat(Unit(x[1]), Unit(x[2]))))
Parameters:
  • name (str)

  • N (int)

Return type:

DatatypeRef

kdrag.theories.bitvec.BitVecNVal(x: int, N: int) DatatypeRef
>>> BitVecNVal(6, 3)
BitVecN(Concat(Unit(0), Concat(Unit(1), Unit(1))))
Parameters:
  • x (int)

  • N (int)

Return type:

DatatypeRef

kdrag.theories.bitvec.BitVecSort(N)
>>> BV32 = BitVecSort(32)
>>> BV32.bvadd_comm
|- ForAll([x, y], x + y == y + x)
kdrag.theories.bitvec.PopCount(x: BitVecRef) ArithRef
>>> smt.simplify(PopCount(smt.BitVecVal(6, 3)))
2
Parameters:

x (BitVecRef)

Return type:

ArithRef

kdrag.theories.bitvec.SelectConcat(a: ArrayRef, addr: BitVecRef | int, n: int, le=True) BitVecRef

Concat out of an array. n is number of bytes. Flag is for little endian concatenation vs big endian.

>>> x = smt.Const("x", BitVecSort(8))
>>> a = smt.Lambda([x], x)
>>> smt.simplify(SelectConcat(a, 1, 1))
1
>>> smt.simplify(SelectConcat(a, 0, 2))
256
>>> smt.simplify(SelectConcat(a, 0, 2, le=False))
1
Parameters:
  • a (ArrayRef)

  • addr (BitVecRef | int)

  • n (int)

Return type:

BitVecRef

kdrag.theories.bitvec.StoreConcat(a: ArrayRef, addr: BitVecRef | int, data: BitVecRef, le=True) ArrayRef

Store multiple bytes into an array.

>>> a = smt.Array("a", smt.BitVecSort(8), smt.BitVecSort(8))
>>> smt.simplify(StoreConcat(a, 0, smt.BitVecVal(258, 16)))
Store(Store(a, 0, 2), 1, 1)
>>> smt.simplify(SelectConcat(StoreConcat(a, 6, smt.BitVecVal(258, 16)), 6, 2))
258
Parameters:
  • a (ArrayRef)

  • addr (BitVecRef | int)

  • data (BitVecRef)

Return type:

ArrayRef

kdrag.theories.bitvec.fromBV(x: BitVecRef) DatatypeRef
>>> fromBV(smt.BitVecVal(6, 3))
BitVecN(Concat(Unit(0), Concat(Unit(1), Unit(1))))
Parameters:

x (BitVecRef)

Return type:

DatatypeRef

kdrag.theories.bitvec.select64(outsize: int) FuncDeclRef
Parameters:

outsize (int)

Return type:

FuncDeclRef

kdrag.theories.bitvec.toBV(x: DatatypeRef, N: int) BitVecRef
>>> smt.simplify(toBV(BitVecNVal(6, 3), 3))
6
Parameters:
  • x (DatatypeRef)

  • N (int)

Return type:

BitVecRef

"""
Theorems about bitvectors. These are theorems about the built in smtlib bitvector types
"""

import kdrag as kd
import kdrag.smt as smt
import kdrag.theories.seq as seq
import functools


@functools.cache
def BitVecSort(N):
    """

    >>> BV32 = BitVecSort(32)
    >>> BV32.bvadd_comm
    |- ForAll([x, y], x + y == y + x)
    """
    S = smt.BitVecSort(N)
    x, y, z = smt.BitVecs("x y z", N)
    zero = smt.BitVecVal(0, N)
    S.BVNot = (~x).decl()
    S.zero = zero
    one = smt.BitVecVal(1, N)
    S.one = one

    S.bvadd_comm = kd.prove(smt.ForAll([x, y], x + y == y + x))
    S.bvadd_assoc = kd.prove(smt.ForAll([x, y, z], (x + y) + z == x + (y + z)))
    S.bvadd_id = kd.prove(smt.ForAll([x], x + zero == x))
    S.bvadd_neg = kd.prove(smt.ForAll([x], x + (-x) == zero))

    S.bvsub_self = kd.prove(smt.ForAll([x], x - x == zero))
    S.bvsub_def = kd.prove(smt.ForAll([x, y], x - y == x + (-y)))

    S.bvmul_comm = kd.prove(smt.ForAll([x, y], x * y == y * x))
    S.bvmul_assoc = kd.prove(smt.ForAll([x, y, z], (x * y) * z == x * (y * z)))
    S.bvmul_id = kd.prove(smt.ForAll([x], x * smt.BitVecVal(1, N) == x))
    S.bvmul_zero = kd.prove(smt.ForAll([x], x * zero == zero))

    S.bvand_comm = kd.prove(smt.ForAll([x, y], x & y == y & x))
    S.bvand_assoc = kd.prove(smt.ForAll([x, y, z], (x & y) & z == x & (y & z)))
    S.bvand_id = kd.prove(smt.ForAll([x], x & smt.BitVecVal(-1, N) == x))
    S.bvand_zero = kd.prove(smt.ForAll([x], x & zero == zero))

    S.bvor_comm = kd.prove(smt.ForAll([x, y], x | y == y | x))
    S.bvor_assoc = kd.prove(smt.ForAll([x, y, z], (x | y) | z == x | (y | z)))
    S.bvor_id = kd.prove(smt.ForAll([x], x | zero == x))
    S.bvor_neg = kd.prove(smt.ForAll([x], x | ~x == smt.BitVecVal(-1, N)))

    S.bvxor_comm = kd.prove(smt.ForAll([x, y], x ^ y == y ^ x))
    S.bvxor_assoc = kd.prove(smt.ForAll([x, y, z], (x ^ y) ^ z == x ^ (y ^ z)))
    S.bvxor_id = kd.prove(smt.ForAll([x], x ^ zero == x))
    S.bvxor_self = kd.prove(smt.ForAll([x], x ^ x == zero))

    S.bvshl_zero = kd.prove(smt.ForAll([x], x << zero == x))
    S.bvshr_zero = kd.prove(smt.ForAll([x], smt.LShR(x, zero) == x))

    # Bitwise simplification rules
    S.bvand_self = kd.prove(smt.ForAll([x], x & x == x))
    S.bvor_self = kd.prove(smt.ForAll([x], x | x == x))
    S.bvxor_zero = kd.prove(smt.ForAll([x], x ^ zero == x))
    S.bvnot_self = kd.prove(smt.ForAll([x], ~x == -x - 1))

    # Rules for shifting and rotating
    S.bvshl_self = kd.prove(
        smt.ForAll([x, y], x << y == x * (one << y))
    )  # Left shift as multiplication
    # bvshr_self = kd.prove(smt.ForAll([x, y], smt.LShR(x, y) == x / (one << y)))  # Logical right shift as division
    # bvashr_self = kd.prove(smt.ForAll([x, y], smt.AShr(x, y) == smt.If(x >> 31 == 0, smt.LShR(x, y), ~smt.LShR(~x, y))))  # Arithmetic right shift rule

    # Simplification with negation and subtraction
    S.bvsub_zero = kd.prove(smt.ForAll([x], x - zero == x))
    S.bvsub_id = kd.prove(smt.ForAll([x], zero - x == -x))
    S.bvadd_sub = kd.prove(smt.ForAll([x, y], x + (-y) == x - y))
    S.bvsub_add = kd.prove(smt.ForAll([x, y], x - (-y) == x + y))

    # Bitwise AND, OR, and XOR with constants
    S.bvand_allones = kd.prove(smt.ForAll([x], x & smt.BitVecVal(-1, N) == x))
    S.bvor_allzeros = kd.prove(smt.ForAll([x], x | zero == x))
    S.bvxor_allzeros = kd.prove(smt.ForAll([x], x ^ zero == x))

    # Distribution and absorption laws
    S.bvand_or = kd.prove(smt.ForAll([x, y, z], x & (y | z) == (x & y) | (x & z)))
    S.bvor_and = kd.prove(smt.ForAll([x, y, z], x | (y & z) == (x | y) & (x | z)))
    S.bvand_absorb = kd.prove(smt.ForAll([x, y], x & (x | y) == x))
    S.bvor_absorb = kd.prove(smt.ForAll([x, y], x | (x & y) == x))

    # Shifting rules with zero and identity
    S.bvshl_zero_shift = kd.prove(smt.ForAll([x], x << zero == x))
    S.bvshr_zero_shift = kd.prove(smt.ForAll([x], smt.LShR(x, zero) == x))
    # bvashr_zero_shift = kd.prove(smt.ForAll([x], smt.AShr(x, zero) == x))  # Arithmetic right shift by zero is identity
    S.bvshl_allzeros = kd.prove(smt.ForAll([y], zero << y == zero))
    S.bvshr_allzeros = kd.prove(smt.ForAll([y], smt.LShR(zero, y) == zero))
    # bvashr_allzeros = kd.prove(smt.ForAll([y], smt.AShr(zero, y) == zero))  # Arithmetic right shift of zero is zero

    # Additional rules for combining operations
    # bvadd_and = kd.prove(smt.ForAll([x, y, z], (x & y) + (x & z) == x & (y + z)))  # AND distribution over addition
    S.bvor_and_not = kd.prove(smt.ForAll([x, y], (x & y) | (x & ~y) == x))
    # bvxor_and_not = kd.prove(smt.ForAll([x, y], (x & y) ^ (x & ~y) == y))  # Distribution of XOR and AND with negation

    # Properties involving shifts and bit manipulations
    S.bvshl_and = kd.prove(smt.ForAll([x, y, z], (x & y) << z == (x << z) & (y << z)))
    S.bvshr_and = kd.prove(
        smt.ForAll([x, y, z], smt.LShR(x & y, z) == smt.LShR(x, z) & smt.LShR(y, z))
    )
    return S


BV1 = BitVecSort(1)
BV8 = BitVecSort(8)
# Annoyingly slow ~ 1s
# BV16 = BitVecSort(16)
# BV32 = BitVecSort(32)
# Annoyingly slow ~ 1s
# BV64 = BitVecSort(64)


BitVecN = kd.NewType("BitVecN", seq.Seq(BV1))
"""
Arbitrary length bitvectors. Least significant bit comes first (index 0). Concat is unfortunately reversed compared to bitvector convetions.

Fix via Newtype wrapper? I guess I want overloading anyway

"""

BVN = BitVecN
BVN.empty = BVN(smt.Empty(seq.Seq(BV1)))  # type: ignore
x, y, z = smt.Consts("x y z", BitVecN)
to_int = smt.Function("to_int", BitVecN, smt.IntSort())
to_int = kd.notation.to_int.define(
    [x],
    smt.If(
        smt.Length(x.val) == 0,
        smt.IntVal(0),
        smt.BV2Int(x.val[0], is_signed=False)
        + 2 * to_int(BitVecN(smt.SubSeq(x.val, 1, smt.Length(x.val) - 1))),
    ),
)


def BitVecNVal(x: int, N: int) -> smt.DatatypeRef:
    """
    >>> BitVecNVal(6, 3)
    BitVecN(Concat(Unit(0), Concat(Unit(1), Unit(1))))
    """
    if N == 0:
        return BVN.empty
    elif N == 1:
        return BVN(smt.Unit(smt.BitVecVal(x, 1)))
    else:
        return BVN(
            smt.Concat([smt.Unit(smt.BitVecVal((x >> i) & 1, 1)) for i in range(N)])
        )


to_int_empty = kd.prove(to_int(BitVecNVal(0, 0)) == smt.IntVal(0), unfold=1)
to_int_false = kd.prove(BitVecNVal(0, 1).to_int() == smt.IntVal(0), by=[to_int.defn])
to_int_true = kd.prove(BitVecNVal(1, 1).to_int() == smt.IntVal(1), by=[to_int.defn])
# (x + y).to_int() == x.to_int() + 2**(smt.Length(x)) * y.to_int()


def fromBV(x: smt.BitVecRef) -> smt.DatatypeRef:
    """
    >>> fromBV(smt.BitVecVal(6, 3))
    BitVecN(Concat(Unit(0), Concat(Unit(1), Unit(1))))
    """
    return smt.simplify(
        BitVecN(smt.Concat([smt.Unit(smt.Extract(i, i, x)) for i in range(x.size())]))
    )


def toBV(x: smt.DatatypeRef, N: int) -> smt.BitVecRef:
    """
    >>> smt.simplify(toBV(BitVecNVal(6, 3), 3))
    6
    """
    BV = BitVecSort(N)
    undef = smt.Function("toBV_undef", BitVecN, BV)
    unpack = smt.Concat(*reversed([x.val[i] for i in range(N)]))
    # could possibly raise error here if we _know_ you've reduced to undef
    # assert not smt.simplify(expr).eq(smt.simplify(undef(x)))
    return smt.If(
        smt.Length(x.val) == N,
        unpack,  # smt.Int2BV(to_int(x), N),  # Or could do full unpack
        undef(x),
    )


def BitVecNConst(name: str, N: int) -> smt.DatatypeRef:
    """
    There is a lot of confusion possible with this construct. Maybe it shouldn't exist.

    >>> BitVecNConst("x", 3)
    BitVecN(Concat(Unit(x[0]), Concat(Unit(x[1]), Unit(x[2]))))
    """
    x = smt.Array(name, smt.IntSort(), BV1)  # array vs function vs seq?
    return BitVecN(smt.Concat([smt.Unit(x[i]) for i in range(N)]))


def BVNot(x: smt.DatatypeRef) -> smt.DatatypeRef:
    """
    >>> smt.simplify(BVNot(BitVecNVal(1, 3)))
    BitVecN(Concat(Unit(0), Concat(Unit(1), Unit(1))))
    """
    z = smt.Const("z", smt.BitVecSort(1))
    return BitVecN(smt.SeqMap(smt.Lambda([z], BV1.BVNot(z)), x.val))


# BVAdd SeqFold


def SelectConcat(
    a: smt.ArrayRef, addr: smt.BitVecRef | int, n: int, le=True
) -> smt.BitVecRef:
    """
    Concat out of an array.
    n is number of bytes.
    Flag is for little endian concatenation vs big endian.

    >>> x = smt.Const("x", BitVecSort(8))
    >>> a = smt.Lambda([x], x)
    >>> smt.simplify(SelectConcat(a, 1, 1))
    1
    >>> smt.simplify(SelectConcat(a, 0, 2))
    256
    >>> smt.simplify(SelectConcat(a, 0, 2, le=False))
    1
    """
    assert n > 0
    if n == 1:
        return a[addr]
    elif le:
        return smt.Concat([a[addr + n - i - 1] for i in range(n)])
    else:
        return smt.Concat([a[addr + i] for i in range(n)])


def StoreConcat(
    a: smt.ArrayRef, addr: smt.BitVecRef | int, data: smt.BitVecRef, le=True
) -> smt.ArrayRef:
    """
    Store multiple bytes into an array.

    >>> a = smt.Array("a", smt.BitVecSort(8), smt.BitVecSort(8))
    >>> smt.simplify(StoreConcat(a, 0, smt.BitVecVal(258, 16)))
    Store(Store(a, 0, 2), 1, 1)
    >>> smt.simplify(SelectConcat(StoreConcat(a, 6, smt.BitVecVal(258, 16)), 6, 2))
    258
    """
    n = data.size()
    assert n % 8 == 0
    for offset in range(n // 8):
        if le:
            a = smt.Store(
                a, addr + offset, smt.Extract(8 * offset + 7, 8 * offset, data)
            )
        else:
            a = smt.Store(
                a, addr + offset, smt.Extract(8 * offset, 8 * offset + 7, data)
            )
    return a


@functools.cache
def select64(outsize: int) -> smt.FuncDeclRef:
    addr = smt.BitVec("addr", 64)
    a = smt.Array("a", smt.BitVecSort(64), smt.BitVecSort(8))
    return kd.define("select64", [a, addr], SelectConcat(a, addr, outsize))


def PopCount(x: smt.BitVecRef) -> smt.ArithRef:
    """
    >>> smt.simplify(PopCount(smt.BitVecVal(6, 3)))
    2
    """
    return smt.Sum([smt.BV2Int(smt.Extract(i, i, x)) for i in range(x.size())])