A Simple, Probably-Not-Exp-Time Disjoint Set in Coq
A data structure that I've been more and more interested in recently is the disjoint set data structure or union-find. https://en.wikipedia.org/wiki/Disjoint-set_data_structure It's used in egraphs, unification, prolog, and graph connectivity.
I realized a cute representation that is easy to use and prove stuff about. It is however very inefficient compared to the usual version of disjoint set (outrageously so since the inverse Ackermann complexity of union find is a crown jewel of algorithms), hence the expectation lowering title of the post. It's linear time in the number of unions I believe. Oh well. It uses a simple functional representation based off the observation that the preimages of a function form disjoint sets. This representation is somewhat analogous to using functions to represent Maps in Coq like in Software Foundations. This is another inefficient but very convenient representation. https://softwarefoundations.cis.upenn.edu/lf-current/Maps.html
The nice thing about this is that it really avoids any termination stickiness. Termination of the find_root
operation of a union find in an explicit array or tree is not at all obvious and requires side proof or a refined type.
The termination ultimately comes from the fact that you used a finite number of unions to construct the disjoint sets.
Code of post here https://gist.github.com/philzook58/82f4a22c194587e3f63904d33c0f2b3d
Definitions
Require Import Arith.
For convenience we define disjoint sets of nats. See comments at end of post for some other options
Definition ds := nat -> nat.
The completely disjoint set is the identity function
Definition init_ds : ds := fun x => x.
find_root
is just application
Definition find_root (g : ds) x := g x.
in_same_set
is just nat
equality checking
Definition in_same_set (g : ds) x y :=
Nat.eq_dec (g x) (g y).
And finally the only interesting operation is union
. Some comments: It is useful to lift the find_root
operations out of the body of the returned function to compute them eagerly and share the result. This version unfortunately means that the cost of a find_root
operation becomes proportional to the number of union
operations used to construct the ds
Definition union (g : ds) x y : ds :=
let px := find_root g x in
let py := find_root g y in
fun z =>
let pz := find_root g z in
if Nat.eq_dec px pz
then py
else pz.
Some proofs
I couldn't find this useful lemma in the standard library. Maybe I just missed it?
forall x : nat, exists p : x = x, Nat.eq_dec x x = left px:natexists p : x = x, Nat.eq_dec x x = left px:nate:x = xexists p : x = x, left e = left px:natn:x <> xexists p : x = x, right n = left px:nate:x = xleft e = left ?px:natn:x <> xexists p : x = x, right n = left px:natn:x <> xexists p : x = x, right n = left px:natn:x = x -> Falseexists p : x = x, right n = left px:natn:x = x -> FalseFalsereflexivity. Qed.x:natn:x = x -> Falsex = x
A useful definition is In_Same_Set
, which states that two nats x
and y
are in the same set if the in_same_set
function returns a left
Definition In_Same_Set g x y := exists p,
in_same_set g x y = left p.
An element is always in the same set with itself
forall (g : ds) (x : nat), In_Same_Set g x xg:dsx:natIn_Same_Set g x xg:dsx:natexists p : g x = g x, Nat.eq_dec (g x) (g x) = left pg:dsx:nate:g x = g xexists p : g x = g x, left e = left pg:dsx:natn:g x <> g xexists p : g x = g x, right n = left pg:dsx:nate:g x = g xexists p : g x = g x, left e = left preflexivity.g:dsx:nate:g x = g xleft e = left ?pg:dsx:natn:g x <> g xexists p : g x = g x, right n = left pg:dsx:natn:g x <> g xFalsereflexivity. Qed.g:dsx:natn:g x <> g xg x = g x
The In_Same_Set g
relation is symmettric.
forall (g : ds) (x y : nat), In_Same_Set g x y -> In_Same_Set g y xforall (g : ds) (x y : nat), (exists p : g x = g y, Nat.eq_dec (g x) (g y) = left p) -> exists p : g y = g x, Nat.eq_dec (g y) (g x) = left pg:dsx, y:natH:exists p : g x = g y, Nat.eq_dec (g x) (g y) = left pexists p : g y = g x, Nat.eq_dec (g y) (g x) = left pg:dsx, y:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0exists p : g y = g x, Nat.eq_dec (g y) (g x) = left pg:dsx, y:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0e:g y = g xE:Nat.eq_dec (g y) (g x) = left eexists p : g y = g x, left e = left pg:dsx, y:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0n:g y <> g xE:Nat.eq_dec (g y) (g x) = right nexists p : g y = g x, right n = left pg:dsx, y:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0e:g y = g xE:Nat.eq_dec (g y) (g x) = left eexists p : g y = g x, left e = left preflexivity.g:dsx, y:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0e:g y = g xE:Nat.eq_dec (g y) (g x) = left eleft e = left ?pg:dsx, y:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0n:g y <> g xE:Nat.eq_dec (g y) (g x) = right nexists p : g y = g x, right n = left pg:dsx, y:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0n:g y <> g xE:Nat.eq_dec (g y) (g x) = right nFalseg:dsx, y:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0n:g y <> g xE:Nat.eq_dec (g y) (g x) = right ng y = g xreflexivity. Qed.g:dsx, y:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0n:g y <> g xE:Nat.eq_dec (g y) (g x) = right ng y = g y
The In_Same_Set g
relation is transtive.
forall (g : ds) (x y z : nat), In_Same_Set g x y -> In_Same_Set g y z -> In_Same_Set g x zforall (g : ds) (x y z : nat), (exists p : g x = g y, Nat.eq_dec (g x) (g y) = left p) -> (exists p : g y = g z, Nat.eq_dec (g y) (g z) = left p) -> exists p : g x = g z, Nat.eq_dec (g x) (g z) = left pg:dsx, y, z:natH:exists p : g x = g y, Nat.eq_dec (g x) (g y) = left pH0:exists p : g y = g z, Nat.eq_dec (g y) (g z) = left pexists p : g x = g z, Nat.eq_dec (g x) (g z) = left pg:dsx, y, z:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0H0:exists p : g y = g z, Nat.eq_dec (g y) (g z) = left pexists p : g x = g z, Nat.eq_dec (g x) (g z) = left pg:dsx, y, z:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0x1:g y = g zH0:Nat.eq_dec (g y) (g z) = left x1exists p : g x = g z, Nat.eq_dec (g x) (g z) = left pg:dsx, y, z:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0x1:g y = g zH0:Nat.eq_dec (g y) (g z) = left x1e:g x = g zE:Nat.eq_dec (g x) (g z) = left eexists p : g x = g z, left e = left pg:dsx, y, z:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0x1:g y = g zH0:Nat.eq_dec (g y) (g z) = left x1n:g x <> g zE:Nat.eq_dec (g x) (g z) = right nexists p : g x = g z, right n = left pg:dsx, y, z:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0x1:g y = g zH0:Nat.eq_dec (g y) (g z) = left x1e:g x = g zE:Nat.eq_dec (g x) (g z) = left eexists p : g x = g z, left e = left preflexivity.g:dsx, y, z:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0x1:g y = g zH0:Nat.eq_dec (g y) (g z) = left x1e:g x = g zE:Nat.eq_dec (g x) (g z) = left eleft e = left ?pg:dsx, y, z:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0x1:g y = g zH0:Nat.eq_dec (g y) (g z) = left x1n:g x <> g zE:Nat.eq_dec (g x) (g z) = right nexists p : g x = g z, right n = left pg:dsx, y, z:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0x1:g y = g zH0:Nat.eq_dec (g y) (g z) = left x1n:g x <> g zE:Nat.eq_dec (g x) (g z) = right nFalseg:dsx, y, z:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0x1:g y = g zH0:Nat.eq_dec (g y) (g z) = left x1n:g x <> g zE:Nat.eq_dec (g x) (g z) = right ng x = g zg:dsx, y, z:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0x1:g y = g zH0:Nat.eq_dec (g y) (g z) = left x1n:g x <> g zE:Nat.eq_dec (g x) (g z) = right ng y = g zreflexivity. Qed.g:dsx, y, z:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0x1:g y = g zH0:Nat.eq_dec (g y) (g z) = left x1n:g x <> g zE:Nat.eq_dec (g x) (g z) = right ng z = g z
The initial disjoint set only has equal elements in the same set
forall x y : nat, x <> y -> exists p : init_ds x <> init_ds y, in_same_set init_ds x y = right px, y:natH:x <> yexists p : init_ds x <> init_ds y, in_same_set init_ds x y = right px, y:natH:x <> ye:init_ds x = init_ds yexists p : init_ds x <> init_ds y, left e = right px, y:natH:x <> yn:init_ds x <> init_ds yexists p : init_ds x <> init_ds y, right n = right px, y:natH:x <> ye:init_ds x = init_ds yFalsex, y:natH:x <> yn:init_ds x <> init_ds yexists p : init_ds x <> init_ds y, right n = right px, y:natH:x <> ye:init_ds x = init_ds yx = yx, y:natH:x <> yn:init_ds x <> init_ds yexists p : init_ds x <> init_ds y, right n = right px, y:natH:x <> ye:x = yx = yx, y:natH:x <> yn:init_ds x <> init_ds yexists p : init_ds x <> init_ds y, right n = right px, y:natH:x <> yn:init_ds x <> init_ds yexists p : init_ds x <> init_ds y, right n = right preflexivity. Qed.x, y:natH:x <> yn:init_ds x <> init_ds yright n = right ?p
After unioning x and y, they are now in the same set. This proof could be tightened up and automated more.
forall (g : ds) (x y : nat), In_Same_Set (union g x y) x yg:dsx, y:natIn_Same_Set (union g x y) x yg:dsx, y:natexists p : (if Nat.eq_dec (g x) (g x) then g y else g x) = (if Nat.eq_dec (g x) (g y) then g y else g y), Nat.eq_dec (if Nat.eq_dec (g x) (g x) then g y else g x) (if Nat.eq_dec (g x) (g y) then g y else g y) = left pg:dsx, y:nate:g x = g yE1:Nat.eq_dec (g x) (g y) = left ee0:g x = g xE2:Nat.eq_dec (g x) (g x) = left e0e1:g y = g yE3:Nat.eq_dec (g y) (g y) = left e1exists p : g y = g y, left e1 = left pg:dsx, y:nate:g x = g yE1:Nat.eq_dec (g x) (g y) = left ee0:g x = g xE2:Nat.eq_dec (g x) (g x) = left e0n:g y <> g yE3:Nat.eq_dec (g y) (g y) = right nexists p : g y = g y, right n = left pg:dsx, y:nate:g x = g yE1:Nat.eq_dec (g x) (g y) = left en:g x <> g xE2:Nat.eq_dec (g x) (g x) = right ne0:g y = g yE3:Nat.eq_dec (g y) (g y) = left e0exists p : g x = g y, Nat.eq_dec (g x) (g y) = left pg:dsx, y:nate:g x = g yE1:Nat.eq_dec (g x) (g y) = left en:g x <> g xE2:Nat.eq_dec (g x) (g x) = right nn0:g y <> g yE3:Nat.eq_dec (g y) (g y) = right n0exists p : g x = g y, Nat.eq_dec (g x) (g y) = left pg:dsx, y:natn:g x <> g yE1:Nat.eq_dec (g x) (g y) = right ne:g x = g xE2:Nat.eq_dec (g x) (g x) = left ee0:g y = g yE3:Nat.eq_dec (g y) (g y) = left e0exists p : g y = g y, left e0 = left pg:dsx, y:natn:g x <> g yE1:Nat.eq_dec (g x) (g y) = right ne:g x = g xE2:Nat.eq_dec (g x) (g x) = left en0:g y <> g yE3:Nat.eq_dec (g y) (g y) = right n0exists p : g y = g y, right n0 = left pg:dsx, y:natn:g x <> g yE1:Nat.eq_dec (g x) (g y) = right nn0:g x <> g xE2:Nat.eq_dec (g x) (g x) = right n0e:g y = g yE3:Nat.eq_dec (g y) (g y) = left eexists p : g x = g y, Nat.eq_dec (g x) (g y) = left pg:dsx, y:natn:g x <> g yE1:Nat.eq_dec (g x) (g y) = right nn0:g x <> g xE2:Nat.eq_dec (g x) (g x) = right n0n1:g y <> g yE3:Nat.eq_dec (g y) (g y) = right n1exists p : g x = g y, Nat.eq_dec (g x) (g y) = left pg:dsx, y:nate:g x = g yE1:Nat.eq_dec (g x) (g y) = left ee0:g x = g xE2:Nat.eq_dec (g x) (g x) = left e0e1:g y = g yE3:Nat.eq_dec (g y) (g y) = left e1exists p : g y = g y, left e1 = left preflexivity.g:dsx, y:nate:g x = g yE1:Nat.eq_dec (g x) (g y) = left ee0:g x = g xE2:Nat.eq_dec (g x) (g x) = left e0e1:g y = g yE3:Nat.eq_dec (g y) (g y) = left e1left e1 = left ?pg:dsx, y:nate:g x = g yE1:Nat.eq_dec (g x) (g y) = left ee0:g x = g xE2:Nat.eq_dec (g x) (g x) = left e0n:g y <> g yE3:Nat.eq_dec (g y) (g y) = right nexists p : g y = g y, right n = left pauto.g:dsx, y:nate:g x = g yE1:Nat.eq_dec (g x) (g y) = left ee0:g x = g xE2:Nat.eq_dec (g x) (g x) = left e0n:g y <> g yE3:Nat.eq_dec (g y) (g y) = right nFalsefirstorder.g:dsx, y:nate:g x = g yE1:Nat.eq_dec (g x) (g y) = left en:g x <> g xE2:Nat.eq_dec (g x) (g x) = right ne0:g y = g yE3:Nat.eq_dec (g y) (g y) = left e0exists p : g x = g y, Nat.eq_dec (g x) (g y) = left pfirstorder.g:dsx, y:nate:g x = g yE1:Nat.eq_dec (g x) (g y) = left en:g x <> g xE2:Nat.eq_dec (g x) (g x) = right nn0:g y <> g yE3:Nat.eq_dec (g y) (g y) = right n0exists p : g x = g y, Nat.eq_dec (g x) (g y) = left pg:dsx, y:natn:g x <> g yE1:Nat.eq_dec (g x) (g y) = right ne:g x = g xE2:Nat.eq_dec (g x) (g x) = left ee0:g y = g yE3:Nat.eq_dec (g y) (g y) = left e0exists p : g y = g y, left e0 = left preflexivity.g:dsx, y:natn:g x <> g yE1:Nat.eq_dec (g x) (g y) = right ne:g x = g xE2:Nat.eq_dec (g x) (g x) = left ee0:g y = g yE3:Nat.eq_dec (g y) (g y) = left e0left e0 = left ?pg:dsx, y:natn:g x <> g yE1:Nat.eq_dec (g x) (g y) = right ne:g x = g xE2:Nat.eq_dec (g x) (g x) = left en0:g y <> g yE3:Nat.eq_dec (g y) (g y) = right n0exists p : g y = g y, right n0 = left pauto.g:dsx, y:natn:g x <> g yE1:Nat.eq_dec (g x) (g y) = right ne:g x = g xE2:Nat.eq_dec (g x) (g x) = left en0:g y <> g yE3:Nat.eq_dec (g y) (g y) = right n0Falseg:dsx, y:natn:g x <> g yE1:Nat.eq_dec (g x) (g y) = right nn0:g x <> g xE2:Nat.eq_dec (g x) (g x) = right n0e:g y = g yE3:Nat.eq_dec (g y) (g y) = left eexists p : g x = g y, Nat.eq_dec (g x) (g y) = left pauto.g:dsx, y:natn:g x <> g yE1:Nat.eq_dec (g x) (g y) = right nn0:g x <> g xE2:Nat.eq_dec (g x) (g x) = right n0e:g y = g yE3:Nat.eq_dec (g y) (g y) = left eFalseg:dsx, y:natn:g x <> g yE1:Nat.eq_dec (g x) (g y) = right nn0:g x <> g xE2:Nat.eq_dec (g x) (g x) = right n0n1:g y <> g yE3:Nat.eq_dec (g y) (g y) = right n1exists p : g x = g y, Nat.eq_dec (g x) (g y) = left pauto. Qed.g:dsx, y:natn:g x <> g yE1:Nat.eq_dec (g x) (g y) = right nn0:g x <> g xE2:Nat.eq_dec (g x) (g x) = right n0n1:g y <> g yE3:Nat.eq_dec (g y) (g y) = right n1False
Things remain in the same set after unioning any other elements
forall (g : ds) (x y z w : nat), In_Same_Set g x y -> In_Same_Set (union g z w) x yg:dsx, y, z, w:natH:In_Same_Set g x yIn_Same_Set (union g z w) x yg:dsx, y, z, w:natx0:g x = g yH:in_same_set g x y = left x0In_Same_Set (union g z w) x yg:dsx, y, z, w:natx0:g x = g yH:in_same_set g x y = left x0exists p : (if Nat.eq_dec (find_root g z) (find_root g x) then find_root g w else find_root g x) = (if Nat.eq_dec (find_root g z) (find_root g y) then find_root g w else find_root g y), in_same_set (fun z0 : nat => if Nat.eq_dec (find_root g z) (find_root g z0) then find_root g w else find_root g z0) x y = left pg:dsx, y, z, w:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0exists p : (if Nat.eq_dec (find_root g z) (find_root g x) then find_root g w else find_root g x) = (if Nat.eq_dec (find_root g z) (find_root g y) then find_root g w else find_root g y), in_same_set (fun z0 : nat => if Nat.eq_dec (find_root g z) (find_root g z0) then find_root g w else find_root g z0) x y = left pg:dsx, y, z, w:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0exists p : (if Nat.eq_dec (find_root g z) (find_root g x) then find_root g w else find_root g x) = (if Nat.eq_dec (find_root g z) (find_root g y) then find_root g w else find_root g y), Nat.eq_dec (if Nat.eq_dec (find_root g z) (find_root g x) then find_root g w else find_root g x) (if Nat.eq_dec (find_root g z) (find_root g y) then find_root g w else find_root g y) = left pg:dsx, y, z, w:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0exists p : (if Nat.eq_dec (g z) (g x) then g w else g x) = (if Nat.eq_dec (g z) (g y) then g w else g y), Nat.eq_dec (if Nat.eq_dec (g z) (g x) then g w else g x) (if Nat.eq_dec (g z) (g y) then g w else g y) = left papply nat_eq_refl. Qed.g:dsx, y, z, w:natx0:g x = g yH:Nat.eq_dec (g x) (g y) = left x0exists p : (if Nat.eq_dec (g z) (g y) then g w else g y) = (if Nat.eq_dec (g z) (g y) then g w else g y), Nat.eq_dec (if Nat.eq_dec (g z) (g y) then g w else g y) (if Nat.eq_dec (g z) (g y) then g w else g y) = left p
Here's a small enhancement. We can define a new version of union that checks if the two elements being unioned are already in the same set.
Definition union2 (g : ds) x y : ds := let px := find_root g x in let py := find_root g y in if Nat.eq_dec px py then g else fun z => let pz := find_root g z in if Nat.eq_dec px pz then py else pz.forall (g : ds) (x y z : nat), union2 g x y z = union g x y zg:dsx, y, z:natunion2 g x y z = union g x y zdestruct (Nat.eq_dec (g x) (g y)) eqn:E2; destruct (Nat.eq_dec (g x) (g z)) eqn:E1; congruence. Qed.g:dsx, y, z:nat(if Nat.eq_dec (g x) (g y) then g else fun z : nat => if Nat.eq_dec (g x) (g z) then g y else g z) z = (if Nat.eq_dec (g x) (g z) then g y else g z)
Possible Refinements
You could abstract over the things that you're forming disjoint sets over This means you need an enumeration to define init_ds
Definition ds' (a : Type) := a -> nat.
Or you can use a self map. This requires decidable equality in for checking canonical elements in the codomain
Definition ds'' (a : Type) := a -> a.
Or you could do a fun proof relevant version for equivalence relations R
. Now to perform a union you'll need to supply a proof R x y
Definition ds_pf (a : Type) (R : a -> a -> Prop) :=
forall x : a, {y : a | R x y}.
Another possibility is to use the fairly new feature of Coq, persistent arrays
- https://coq.inria.fr/refman/language/core/primitive.html#primitive-arrays
- https://coq.github.io/doc/master/stdlib/Coq.Array.PArray.html
This may indeed why these were added. They do require dealing with termination though. Here's one suggestion of an encoding. I dunno if this will bite you in the butt.
Require Import PArray. Open Scope array_scope. Fixpoint find_root' gas (a : array (sum nat Int63.int)) (i : Int63.int) := match gas with | O => None | S gas' => match a .[i] with | inl x => Some i | inr j => find_root' gas' a j end end. Record disjoint_set := { gas : nat; (* unions, or height *) parents : array (sum nat Int63.int) ; term_pf : forall i, exists n, (find_root' gas parents i = Some n) }.
Built using Alectryon https://github.com/cpitclaudel/alectryon
A stack overflow question https://stackoverflow.com/questions/66630519/how-to-implement-a-union-find-disjoint-set-data-structure-in-coq/66875872#66875872