toy-smt/sat.ml

236 lines
7.6 KiB
OCaml

open Common
type literal = Solver.atom
type cnf = literal list list
type known = Solver.state
(* The literal we used to get here, and the state before we added it *)
type backjumps = (literal * known) list
exception Solved of known
exception Unsat
let add_known lit known =
Solver.add_atom known lit
let print_lit ff = function
| Solver.Equal (x, y) -> Format.fprintf ff "%d = %d" x y
| Solver.Different (x, y) -> Format.fprintf ff "%d <> %d" x y
let literal_negate = Solver.negate
let literal_known_true known lit = Solver.known_true known lit
let literal_known_false known lit = Solver.known_false known lit
let disj_known_true known l =
List.exists (literal_known_true known) l
let disj_known_false known l =
List.for_all (literal_known_false known) l
let cnf_known_true known cnf =
List.for_all (disj_known_true known) cnf
let cnf_known_false known cnf =
List.exists (disj_known_false known) cnf
(*
let next_state cnf (known, backjumps) =
if cnf_known_true known cnf then
raise (Solved known)
else if cnf_known_false known cnf then
match backjumps with
| (lit, kn) :: bcjs ->
(* Format.printf "Backjump@."; *)
(add_known (literal_negate lit) kn, bcjs)
| [] -> raise Unsat
else
(* Unit *)
let rec unit_decision cl =
match cl with
| [] -> None
| lt :: cls when literal_known_true known lt -> None
| lt :: cls when literal_known_false known lt -> unit_decision cls
| lt :: cls when disj_known_false known cls -> Some lt
| _ -> None
in
let rec find_decide_lit l =
match l with
| [] -> raise Not_found
| cl :: ls -> try
literal_negate (List.find (fun lit ->
not ((literal_known_true known lit) ||
(literal_known_false known lit))) cl)
with Not_found -> find_decide_lit ls
in
let rec unit l =
match l with
| cls :: ls -> begin match unit_decision cls with
| None -> unit ls
| Some lit ->
(* Format.printf "Unit %a@." print_lit lit; *)
(add_known lit known, backjumps)
end
| [] -> let lit = find_decide_lit cnf in
(* Format.printf "Decide %a@." print_lit lit; *)
(add_known lit known, (lit, known) :: backjumps)
in unit cnf
let cnf_solve nvars cnf =
let rec loop state = loop (next_state cnf state) in
try loop (Solver.create nvars, []) with
| Unsat -> None
| Solved known -> Some (Solver.build_model known)
*)
module UMap = Map.Make (struct
type t = int * int
let compare (a, b) (c, d) = compare (min a b, max a b) (min c d, max c d)
end)
let get_pair = function
| Solver.Equal (x, y) | Solver.Different (x, y) -> (min x y, max x y)
type state = ISet.t * (literal * literal) Parray.t * ISet.t UMap.t * Solver.state * backjump
and backjump = Backjump of literal * state | Empty
let get lit m =
try UMap.find (get_pair lit) m
with Not_found -> ISet.empty
let add lit i m =
UMap.add (get_pair lit) (ISet.add i (get lit m)) m
let rec recheck acnf ((remaining_clauses, watched, watchset, known, backjumps) as st) found = function
| [] -> (found, st)
| rc :: rcs when not (ISet.mem rc remaining_clauses) -> recheck acnf st found rcs
| rc :: rcs ->
let lit1, lit2 = Parray.get watched rc in
let mark_solved () =
(* Format.printf "%d marked solved@." rc; *)
let c1 = get lit1 watchset in
let c2 = get lit2 watchset in
let nw = UMap.add (get_pair lit1) (ISet.remove rc c1)
(UMap.add (get_pair lit2) (ISet.remove rc c2) watchset) in
(ISet.remove rc remaining_clauses, watched, nw, known, backjumps)
in
if literal_known_true known lit1 || literal_known_true known lit2 then
recheck acnf (mark_solved ()) true rcs
else
let l1, l2 = literal_known_false known lit1, literal_known_false known lit2 in
if l1 || l2 || lit1 = lit2 then
let pair1, pair2 = get_pair lit1, get_pair lit2 in
let al, w = if l1 then if l2 then [], 2 else [lit2], 1 else [lit1], 1 in
let rec find found rem dj =
if rem = 0 then (false, found)
else match dj with
| [] -> (false, found)
| lit :: djs ->
if literal_known_true known lit then
(true, [])
else if literal_known_false known lit then
find found rem djs
else
let p = get_pair lit in
if not (List.mem p (List.map get_pair found)) then
find (lit :: found) (rem - 1) djs
else
find found rem djs
in
let (one_true, new_watched) = find al w acnf.(rc) in
if one_true then
recheck acnf (mark_solved ()) true rcs
else
let c1 = get lit1 watchset in
let c2 = get lit2 watchset in
let nw = UMap.add pair1 (ISet.remove rc c1)
(UMap.add pair2 (ISet.remove rc c2) watchset)
in
match new_watched with
| [] -> begin match backjumps with
| Backjump (lt, (remaining_clauses, watched, watchset, known, backjumps)) ->
(* Format.printf "Backjump@."; *)
recheck acnf (remaining_clauses, watched,
watchset, add_known (literal_negate lt) known, backjumps) true
((ISet.elements (get lt watchset)) @ rcs)
| Empty -> raise Unsat
end
| [lt] -> (* Unit *)
(* Format.printf "Unit %d@." rc; *)
recheck acnf (ISet.remove rc remaining_clauses, watched,
nw, add_known lt known, backjumps) true
((ISet.elements (get lt nw)) @ rcs)
| [lt1; lt2] -> (* Watch changed *)
let cc1 = get lt1 nw in
let cc2 = get lt2 nw in
let nnw = UMap.add (get_pair lt1) (ISet.add rc cc1)
(UMap.add (get_pair lt2) (ISet.add rc cc2) nw)
in
recheck acnf (remaining_clauses, Parray.set watched rc (lt1, lt2),
nnw, known, backjumps) true rcs
| _ -> assert false (* Impossible *)
else
recheck acnf st found rcs
let next_state acnf ((remaining_clauses, watched, watchset, known, backjumps) as st) =
if ISet.is_empty remaining_clauses then
raise (Solved known)
else
let (found, newstate) = recheck acnf st false (ISet.elements remaining_clauses) in
if found then
newstate
else (* Nothing new: decide *)
let cl = ISet.choose remaining_clauses in
let lit1, lit2 = Parray.get watched cl in
let l = literal_negate lit1 in
snd (recheck acnf ((remaining_clauses, watched, watchset, add_known l known,
Backjump (l, st))) false [cl])
let cnf_solve nvars cnf =
let rec simplify_disj d cur =
match d with
| [] -> Some cur
| Solver.Equal (x, y) :: ds when x = y -> None
| Solver.Different (x, y) :: ds when x = y -> simplify_disj ds cur
| ((Solver.Equal (x, y) | Solver.Different (x, y)) as d) :: ds ->
let w = match d with
| Solver.Equal (x, y) -> Solver.Equal (min x y, max x y)
| Solver.Different (x, y) ->Solver.Different (min x y, max x y)
in
let nw = literal_negate w in
if List.mem nw cur then None
else if List.mem w cur then simplify_disj ds cur
else simplify_disj ds (w :: cur)
in
let simp = List.map (fun d -> simplify_disj d []) cnf in
let acnf = Array.of_list (List.extract simp) in
if List.mem (Some []) simp then
None
else
let rec build_initial ((remaining_clauses, watched, watchset, known, backjumps) as st) i =
if i < Array.length acnf then
build_initial (
match acnf.(i) with
| [] -> assert false
| [lit] -> (ISet.add i remaining_clauses, Parray.set watched i (lit, lit),
add lit i watchset, add_known lit known, backjumps)
| lit1 :: lit2 :: _ ->
(ISet.add i remaining_clauses, Parray.set watched i (lit1, lit2),
add lit1 i (add lit2 i watchset),
known, backjumps)
) (i + 1)
else
st
in
let rec loop state = loop (next_state acnf state) in
let known = Solver.create nvars in
let watched = Parray.create (Array.length acnf) (Solver.Equal (-1, -1), Solver.Equal (-1, -1)) in
try loop (build_initial (ISet.empty, watched, UMap.empty, known, Empty) 0) with
| Unsat -> None
| Solved known -> Some (Solver.build_model known)