Library Coq.MSets.MSetAVL


MSetAVL : Implementation of MSetInterface via AVL trees

This module implements finite sets using AVL trees. It follows the implementation from Ocaml's standard library,
All operations given here expect and produce well-balanced trees (in the ocaml sense: heights of subtrees shouldn't differ by more than 2), and hence has low complexities (e.g. add is logarithmic in the size of the set). But proving these balancing preservations is in fact not necessary for ensuring correct operational behavior and hence fulfilling the MSet interface. As a consequence, balancing results are not part of this file anymore, they can now be found in MSetFullAVL.
Four operations (union, subset, compare and equal) have been slightly adapted in order to have only structural recursive calls. The precise ocaml versions of these operations have also been formalized (thanks to Function+measure), see ocaml_union, ocaml_subset, ocaml_compare and ocaml_equal in MSetFullAVL. The structural variants compute faster in Coq, whereas the other variants produce nicer and/or (slightly) faster code after extraction.

Require Import FunInd MSetInterface MSetGenTree BinInt Int.

Set Implicit Arguments.

Ops : the pure functions


Module Ops (Import I:Int)(X:OrderedType) <: MSetInterface.Ops X.
Local Open Scope Int_scope.

Generic trees instantiated with integer height

We reuse a generic definition of trees where the information parameter is a Int.t. Functions like mem or fold are also provided by this generic functor.

Include MSetGenTree.Ops X I.

Definition t := tree.

Height of trees


Definition height (s : t) : int :=
  match s with
  | Leaf => 0
  | Node h _ _ _ => h
  end.

Singleton set


Definition singleton x := Node 1 Leaf x Leaf.

Helper functions

create l x r creates a node, assuming l and r to be balanced and |height l - height r| <= 2.

Definition create l x r :=
   Node (max (height l) (height r) + 1) l x r.

bal l x r acts as create, but performs one step of rebalancing if necessary, i.e. assumes |height l - height r| <= 3.

Definition assert_false := create.

Definition bal l x r :=
  let hl := height l in
  let hr := height r in
  if (hr+2) <? hl then
    match l with
     | Leaf => assert_false l x r
     | Node _ ll lx lr =>
       if (height lr) <=? (height ll) then
         create ll lx (create lr x r)
       else
         match lr with
          | Leaf => assert_false l x r
          | Node _ lrl lrx lrr =>
              create (create ll lx lrl) lrx (create lrr x r)
         end
    end
  else
    if (hl+2) <? hr then
      match r with
       | Leaf => assert_false l x r
       | Node _ rl rx rr =>
         if (height rl) <=? (height rr) then
            create (create l x rl) rx rr
         else
           match rl with
            | Leaf => assert_false l x r
            | Node _ rll rlx rlr =>
                create (create l x rll) rlx (create rlr rx rr)
           end
      end
    else
      create l x r.

Insertion


Fixpoint add x s := match s with
   | Leaf => Node 1 Leaf x Leaf
   | Node h l y r =>
      match X.compare x y with
         | Lt => bal (add x l) y r
         | Eq => Node h l y r
         | Gt => bal l y (add x r)
      end
  end.

Join

Same as bal but does not assume anything regarding heights of l and r.

Fixpoint join l : elt -> t -> t :=
  match l with
    | Leaf => add
    | Node lh ll lx lr => fun x =>
       fix join_aux (r:t) : t := match r with
          | Leaf => add x l
          | Node rh rl rx rr =>
               if (rh+2) <? lh then bal ll lx (join lr x r)
               else if (lh+2) <? rh then bal (join_aux rl) rx rr
               else create l x r
          end
  end.

Extraction of minimum element

Morally, remove_min is to be applied to a non-empty tree t = Node h l x r. Since we can't deal here with assert false for t=Leaf, we pre-unpack t (and forget about h).

Fixpoint remove_min l x r : t*elt :=
  match l with
    | Leaf => (r,x)
    | Node lh ll lx lr =>
       let (l',m) := remove_min ll lx lr in (bal l' x r, m)
  end.

Merging two trees

merge t1 t2 builds the union of t1 and t2 assuming all elements of t1 to be smaller than all elements of t2, and |height t1 - height t2| <= 2.

Definition merge s1 s2 := match s1,s2 with
  | Leaf, _ => s2
  | _, Leaf => s1
  | _, Node _ l2 x2 r2 =>
        let (s2',m) := remove_min l2 x2 r2 in bal s1 m s2'
end.

Deletion


Fixpoint remove x s := match s with
  | Leaf => Leaf
  | Node _ l y r =>
      match X.compare x y with
         | Lt => bal (remove x l) y r
         | Eq => merge l r
         | Gt => bal l y (remove x r)
      end
   end.

Concatenation

Same as merge but does not assume anything about heights.

Definition concat s1 s2 :=
   match s1, s2 with
      | Leaf, _ => s2
      | _, Leaf => s1
      | _, Node _ l2 x2 r2 =>
            let (s2',m) := remove_min l2 x2 r2 in
            join s1 m s2'
   end.

Splitting

split x s returns a triple (l, present, r) where
  • l is the set of elements of s that are < x
  • r is the set of elements of s that are > x
  • present is true if and only if s contains x.

Record triple := mktriple { t_left:t; t_in:bool; t_right:t }.
Notation "<< l , b , r >>" := (mktriple l b r) (at level 9).

Fixpoint split x s : triple := match s with
  | Leaf => << Leaf, false, Leaf >>
  | Node _ l y r =>
     match X.compare x y with
      | Lt => let (ll,b,rl) := split x l in << ll, b, join rl y r >>
      | Eq => << l, true, r >>
      | Gt => let (rl,b,rr) := split x r in << join l y rl, b, rr >>
     end
 end.

Intersection


Fixpoint inter s1 s2 := match s1, s2 with
    | Leaf, _ => Leaf
    | _, Leaf => Leaf
    | Node _ l1 x1 r1, _ =>
            let (l2',pres,r2') := split x1 s2 in
            if pres then join (inter l1 l2') x1 (inter r1 r2')
            else concat (inter l1 l2') (inter r1 r2')
    end.

Difference


Fixpoint diff s1 s2 := match s1, s2 with
 | Leaf, _ => Leaf
 | _, Leaf => s1
 | Node _ l1 x1 r1, _ =>
    let (l2',pres,r2') := split x1 s2 in
    if pres then concat (diff l1 l2') (diff r1 r2')
    else join (diff l1 l2') x1 (diff r1 r2')
end.

Union

In ocaml, heights of s1 and s2 are compared each time in order to recursively perform the split on the smaller set. Unfortunately, this leads to a non-structural algorithm. The following code is a simplification of the ocaml version: no comparison of heights. It might be slightly slower, but experimentally all the tests I've made in ocaml have shown this potential slowdown to be non-significant. Anyway, the exact code of ocaml has also been formalized thanks to Function+measure, see ocaml_union in MSetFullAVL.

Fixpoint union s1 s2 :=
 match s1, s2 with
  | Leaf, _ => s2
  | _, Leaf => s1
  | Node _ l1 x1 r1, _ =>
     let (l2',_,r2') := split x1 s2 in
     join (union l1 l2') x1 (union r1 r2')
 end.

Filter


Fixpoint filter (f:elt->bool) s := match s with
  | Leaf => Leaf
  | Node _ l x r =>
    let l' := filter f l in
    let r' := filter f r in
    if f x then join l' x r' else concat l' r'
 end.

Partition


Fixpoint partition (f:elt->bool)(s : t) : t*t :=
  match s with
   | Leaf => (Leaf, Leaf)
   | Node _ l x r =>
      let (l1,l2) := partition f l in
      let (r1,r2) := partition f r in
      if f x then (join l1 x r1, concat l2 r2)
      else (concat l1 r1, join l2 x r2)
  end.

End Ops.

MakeRaw

Functor of pure functions + a posteriori proofs of invariant preservation

Module MakeRaw (Import I:Int)(X:OrderedType) <: RawSets X.
Include Ops I X.

Generic definition of binary-search-trees and proofs of specifications for generic functions such as mem or fold.

Include MSetGenTree.Props X I.

Automation and dedicated tactics

Local Hint Immediate MX.eq_sym : core.
Local Hint Unfold In lt_tree gt_tree Ok : core.
Local Hint Constructors InT bst : core.
Local Hint Resolve MX.eq_refl MX.eq_trans MX.lt_trans ok : core.
Local Hint Resolve lt_leaf gt_leaf lt_tree_node gt_tree_node : core.
Local Hint Resolve lt_tree_not_in lt_tree_trans gt_tree_not_in gt_tree_trans : core.
Local Hint Resolve elements_spec2 : core.



Tactic Notation "factornode" ident(s) :=
 try clear s;
 match goal with
   | |- context [Node ?l ?x ?r ?h] =>
       set (s:=Node l x r h) in *; clearbody s; clear l x r h
   | _ : context [Node ?l ?x ?r ?h] |- _ =>
       set (s:=Node l x r h) in *; clearbody s; clear l x r h
 end.

Inductions principles for some of the set operators

Functional Scheme bal_ind := Induction for bal Sort Prop.
Functional Scheme remove_min_ind := Induction for remove_min Sort Prop.
Functional Scheme merge_ind := Induction for merge Sort Prop.
Functional Scheme concat_ind := Induction for concat Sort Prop.
Functional Scheme inter_ind := Induction for inter Sort Prop.
Functional Scheme diff_ind := Induction for diff Sort Prop.
Functional Scheme union_ind := Induction for union Sort Prop.

Notations and helper lemma about pairs and triples


Notation "s #1" := (fst s) (at level 9, format "s '#1'") : pair_scope.
Notation "s #2" := (snd s) (at level 9, format "s '#2'") : pair_scope.
Notation "t #l" := (t_left t) (at level 9, format "t '#l'") : pair_scope.
Notation "t #b" := (t_in t) (at level 9, format "t '#b'") : pair_scope.
Notation "t #r" := (t_right t) (at level 9, format "t '#r'") : pair_scope.

Local Open Scope pair_scope.

Singleton set


Lemma singleton_spec : forall x y, InT y (singleton x) <-> X.eq y x.

Instance singleton_ok x : Ok (singleton x).

Helper functions


Lemma create_spec :
 forall l x r y, InT y (create l x r) <-> X.eq y x \/ InT y l \/ InT y r.

Instance create_ok l x r `(Ok l, Ok r, lt_tree x l, gt_tree x r) :
 Ok (create l x r).

Lemma bal_spec : forall l x r y,
 InT y (bal l x r) <-> X.eq y x \/ InT y l \/ InT y r.

Instance bal_ok l x r `(Ok l, Ok r, lt_tree x l, gt_tree x r) :
 Ok (bal l x r).

Insertion


Lemma add_spec' : forall s x y,
 InT y (add x s) <-> X.eq y x \/ InT y s.

Lemma add_spec : forall s x y `{Ok s},
 InT y (add x s) <-> X.eq y x \/ InT y s.

Instance add_ok s x `(Ok s) : Ok (add x s).

Local Open Scope Int_scope.

Join

Function/Functional Scheme can't deal with internal fix. Let's do its job by hand:

Ltac join_tac :=
 let l := fresh "l" in
 intro l; induction l as [| lh ll _ lx lr Hlr];
   [ | intros x r; induction r as [| rh rl Hrl rx rr _]; unfold join;
     [ | destruct ((rh+2) <? lh) eqn:LT;
       [ match goal with |- context b [ bal ?a ?b ?c] =>
           replace (bal a b c)
           with (bal ll lx (join lr x (Node rh rl rx rr))); [ | auto]
         end
       | destruct ((lh+2) <? rh) eqn:LT';
         [ match goal with |- context b [ bal ?a ?b ?c] =>
             replace (bal a b c)
             with (bal (join (Node lh ll lx lr) x rl) rx rr); [ | auto]
           end
         | ] ] ] ]; intros.

Lemma join_spec : forall l x r y,
 InT y (join l x r) <-> X.eq y x \/ InT y l \/ InT y r.

Instance join_ok : forall l x r `(Ok l, Ok r, lt_tree x l, gt_tree x r),
 Ok (join l x r).

Extraction of minimum element


Lemma remove_min_spec : forall l x r y h,
 InT y (Node h l x r) <->
  X.eq y (remove_min l x r)#2 \/ InT y (remove_min l x r)#1.

Instance remove_min_ok l x r : forall h `(Ok (Node h l x r)),
 Ok (remove_min l x r)#1.

Lemma remove_min_gt_tree : forall l x r h `{Ok (Node h l x r)},
 gt_tree (remove_min l x r)#2 (remove_min l x r)#1.
Local Hint Resolve remove_min_gt_tree : core.

Merging two trees


Lemma merge_spec : forall s1 s2 y,
 InT y (merge s1 s2) <-> InT y s1 \/ InT y s2.

Instance merge_ok s1 s2 : forall `(Ok s1, Ok s2)
 `(forall y1 y2 : elt, InT y1 s1 -> InT y2 s2 -> X.lt y1 y2),
 Ok (merge s1 s2).

Deletion


Lemma remove_spec : forall s x y `{Ok s},
 (InT y (remove x s) <-> InT y s /\ ~ X.eq y x).

Instance remove_ok s x `(Ok s) : Ok (remove x s).

Concatenation


Lemma concat_spec : forall s1 s2 y,
 InT y (concat s1 s2) <-> InT y s1 \/ InT y s2.

Instance concat_ok s1 s2 : forall `(Ok s1, Ok s2)
 `(forall y1 y2 : elt, InT y1 s1 -> InT y2 s2 -> X.lt y1 y2),
 Ok (concat s1 s2).

Splitting


Lemma split_spec1 : forall s x y `{Ok s},
 (InT y (split x s)#l <-> InT y s /\ X.lt y x).

Lemma split_spec2 : forall s x y `{Ok s},
 (InT y (split x s)#r <-> InT y s /\ X.lt x y).

Lemma split_spec3 : forall s x `{Ok s},
 ((split x s)#b = true <-> InT x s).

Lemma split_ok : forall s x `{Ok s}, Ok (split x s)#l /\ Ok (split x s)#r.

Instance split_ok1 s x `(Ok s) : Ok (split x s)#l.

Instance split_ok2 s x `(Ok s) : Ok (split x s)#r.

Intersection


Ltac destruct_split := match goal with
 | H : split ?x ?s = << ?u, ?v, ?w >> |- _ =>
   assert ((split x s)#l = u) by (rewrite H; auto);
   assert ((split x s)#b = v) by (rewrite H; auto);
   assert ((split x s)#r = w) by (rewrite H; auto);
   clear H; subst u w
 end.

Lemma inter_spec_ok : forall s1 s2 `{Ok s1, Ok s2},
 Ok (inter s1 s2) /\ (forall y, InT y (inter s1 s2) <-> InT y s1 /\ InT y s2).

Lemma inter_spec : forall s1 s2 y `{Ok s1, Ok s2},
 (InT y (inter s1 s2) <-> InT y s1 /\ InT y s2).

Instance inter_ok s1 s2 `(Ok s1, Ok s2) : Ok (inter s1 s2).

Difference


Lemma diff_spec_ok : forall s1 s2 `{Ok s1, Ok s2},
 Ok (diff s1 s2) /\ (forall y, InT y (diff s1 s2) <-> InT y s1 /\ ~InT y s2).

Lemma diff_spec : forall s1 s2 y `{Ok s1, Ok s2},
 (InT y (diff s1 s2) <-> InT y s1 /\ ~InT y s2).

Instance diff_ok s1 s2 `(Ok s1, Ok s2) : Ok (diff s1 s2).

Union


Lemma union_spec : forall s1 s2 y `{Ok s1, Ok s2},
 (InT y (union s1 s2) <-> InT y s1 \/ InT y s2).

Instance union_ok s1 s2 : forall `(Ok s1, Ok s2), Ok (union s1 s2).

Filter


Lemma filter_spec : forall s x f,
 Proper (X.eq==>Logic.eq) f ->
 (InT x (filter f s) <-> InT x s /\ f x = true).

Lemma filter_weak_spec : forall s x f,
 InT x (filter f s) -> InT x s.

Instance filter_ok s f `(H : Ok s) : Ok (filter f s).

Partition


Lemma partition_spec1' s f : (partition f s)#1 = filter f s.

Lemma partition_spec2' s f :
 (partition f s)#2 = filter (fun x => negb (f x)) s.

Lemma partition_spec1 s f :
 Proper (X.eq==>Logic.eq) f ->
 Equal (partition f s)#1 (filter f s).

Lemma partition_spec2 s f :
 Proper (X.eq==>Logic.eq) f ->
 Equal (partition f s)#2 (filter (fun x => negb (f x)) s).

Instance partition_ok1 s f `(Ok s) : Ok (partition f s)#1.

Instance partition_ok2 s f `(Ok s) : Ok (partition f s)#2.

End MakeRaw.

Encapsulation

Now, in order to really provide a functor implementing S, we need to encapsulate everything into a type of binary search trees. They also happen to be well-balanced, but this has no influence on the correctness of operations, so we won't state this here, see MSetFullAVL if you need more than just the MSet interface.

Module IntMake (I:Int)(X: OrderedType) <: S with Module E := X.
 Module Raw := MakeRaw I X.
 Include Raw2Sets X Raw.
End IntMake.


Module Make (X: OrderedType) <: S with Module E := X
 :=IntMake(Z_as_Int)(X).