From bfc54bc5a6048cf3ab846045cbe6c9cb70edb573 Mon Sep 17 00:00:00 2001 From: hal8174 Date: Tue, 15 Oct 2024 21:22:34 +0200 Subject: [PATCH] Add direct transformation to normal form --- sat/simple.sat | 3 +- src/expr.rs | 85 +++++++++++++++++++++++++++++++++++++++++++++- src/main.rs | 55 ++++++++++++++++++------------ src/mapping.rs | 32 +++++++++++++++++ src/normal_form.rs | 23 ++++++++++++- 5 files changed, 174 insertions(+), 24 deletions(-) create mode 100644 src/mapping.rs diff --git a/sat/simple.sat b/sat/simple.sat index 044b38a..c5ed155 100644 --- a/sat/simple.sat +++ b/sat/simple.sat @@ -1 +1,2 @@ - dfjsk&(dfjsk| !fjdsk| ((djask& ffadsfj)|fjskd)) + a&(b| !c| ((!b& d)|e)) + a&!b diff --git a/src/expr.rs b/src/expr.rs index 0c65af8..dbdecb4 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -1,4 +1,4 @@ -use crate::dpll::Index; +use crate::{dpll::Index, mapping::Mapping, normal_form::NormalForm}; use miette::{miette, Context, Result}; use std::collections::HashMap; @@ -243,6 +243,79 @@ impl<'s> Expr<'s> { } } + pub fn disjunctiveNormalForm2(self, mapping: &mut Mapping<'s>) -> NormalForm { + match self { + Expr::Literal(l) => { + let mut n = NormalForm::new(); + n.add_clause(&[mapping.forward(l)]); + n + } + Expr::Not(e) => match *e { + Expr::Literal(l) => { + let mut n = NormalForm::new(); + n.add_clause(&[-mapping.forward(l)]); + n + } + Expr::Not(n) => n.disjunctiveNormalForm2(mapping), + Expr::And(v) => { + let mut n = NormalForm::new(); + for e in v + .into_iter() + .map(|e| Expr::Not(Box::new(e)).disjunctiveNormalForm2(mapping)) + { + for (s, _, _) in e.iter() { + n.add_clause(s); + } + } + n + } + + Expr::Or(v) => Expr::And(v.into_iter().map(|e| Expr::Not(Box::new(e))).collect()) + .disjunctiveNormalForm2(mapping), + }, + Expr::And(v) => { + let mut v: Vec<_> = v + .into_iter() + .map(|e| e.disjunctiveNormalForm2(mapping)) + .collect(); + + let mut n = v.pop().unwrap(); + + let mut scratch = Vec::new(); + + for on in v { + let mut new_n = NormalForm::new(); + + for (old_s, _, _) in n.iter() { + for (new_s, _, _) in on.iter() { + scratch.clear(); + scratch.extend_from_slice(old_s); + scratch.extend_from_slice(new_s); + scratch.sort(); + scratch.dedup(); + new_n.add_clause(&scratch); + } + } + + n = new_n; + } + + n + } + Expr::Or(v) => { + let mut n = NormalForm::new(); + for e in v { + let on = e.disjunctiveNormalForm2(mapping); + for (s, _, _) in on.iter() { + n.add_clause(s); + } + } + + n + } + } + } + pub fn conjunctiveNormalForm(self) -> Self { if let Expr::Or(v) = Expr::Not(Box::new(self)).disjunctiveNormalForm() { Expr::And( @@ -275,6 +348,16 @@ impl<'s> Expr<'s> { } } + pub fn to_dpll_cnf_normal_form(self) -> (NormalForm, Mapping<'s>) { + let mut mapping = Mapping::new(); + + let mut cnf = Expr::Not(Box::new(self)).disjunctiveNormalForm2(&mut mapping); + + cnf.negate(); + + (cnf, mapping) + } + pub fn to_dpll_cnf(self) -> (Vec>, HashMap<&'s str, Index>) { let mut map = HashMap::new(); diff --git a/src/main.rs b/src/main.rs index a62a25f..03ec90e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,12 +3,13 @@ use std::collections::HashMap; use clap::Parser; use dpll::dpll; use expr::{Expr, Tokanizer}; +use mapping::Mapping; use miette::{IntoDiagnostic, Result}; use normal_form::NormalForm; -mod dpll_normal_form; -mod normal_form; - mod dpll; +mod dpll_normal_form; +mod mapping; +mod normal_form; mod expr; @@ -28,39 +29,51 @@ fn main() -> Result<()> { // dbg!(&e); // dbg!(Expr::Not(Box::new(e.clone())).disjunctiveNormalForm()); - // dbg!(e.clone().conjunctiveNormalForm()); - let (cnf, map) = e.to_dpll_cnf(); + let (cnf, map) = e.to_dpll_cnf_normal_form(); - // dbg!(&cnf, &map); - println!("cnf clauses: {}", cnf.len()); + println!("Normal form created."); + cnf.stats(); let mut solution = Vec::new(); + dbg!(crate::dpll_normal_form::dpll(&cnf, &mut solution)); - dpll(&cnf, &mut solution); + // dbg!(&solution); + // dbg!(e.clone().conjunctiveNormalForm()); - let inv_map: HashMap<_, _> = map.into_iter().map(|(k, v)| (v, k)).collect(); + // let (cnf, map) = e.to_dpll_cnf(); - dbg!(&solution); - // for i in solution { - // if i > 0 { - // println!("{}: true", inv_map[&i]); - // } else { - // println!("{}: false", inv_map[&i.abs()]); - // } - // } + // dbg!(&cnf, &map); + // println!("cnf clauses: {}", cnf.len()); - let mut nf = NormalForm::from_slices(&cnf); + // let mut solution = Vec::new(); + + // dpll(&cnf, &mut solution); + + // dbg!(solution); + + // let inv_map: HashMap<_, _> = map.into_iter().map(|(k, v)| (v, k)).collect(); + + // dbg!(&solution); + for i in solution { + if i > 0 { + println!("{}: true", map.backward(i)); + } else { + println!("{}: false", map.backward(i.abs())); + } + } + + // let mut nf = NormalForm::from_slices(&cnf); // dbg!(&nf); // nf.remove_literal(-1); // dbg!(&nf); - let mut s2 = Vec::new(); + // let mut s2 = Vec::new(); - dbg!(dpll_normal_form::dpll(&nf, &mut s2)); + // dbg!(dpll_normal_form::dpll(&nf, &mut s2)); - dbg!(s2 == solution); + // dbg!(s2 == solution); Ok(()) } diff --git a/src/mapping.rs b/src/mapping.rs new file mode 100644 index 0000000..e76707c --- /dev/null +++ b/src/mapping.rs @@ -0,0 +1,32 @@ +use std::collections::HashMap; + +use crate::dpll::Index; + +pub struct Mapping<'s> { + forward: HashMap<&'s str, Index>, + backward: Vec<&'s str>, +} + +impl<'s> Mapping<'s> { + pub fn new() -> Self { + Self { + forward: HashMap::new(), + backward: Vec::new(), + } + } + + pub fn forward(&mut self, str: &'s str) -> Index { + if let Some(&i) = self.forward.get(str) { + i + } else { + let i = self.backward.len() as Index + 1; + self.forward.insert(str, i); + self.backward.push(str); + i + } + } + + pub fn backward(&self, i: Index) -> &'s str { + self.backward[(i - 1) as usize] + } +} diff --git a/src/normal_form.rs b/src/normal_form.rs index 2c4603e..a3ce0f6 100644 --- a/src/normal_form.rs +++ b/src/normal_form.rs @@ -1,3 +1,5 @@ +use std::ops::Neg; + use crate::dpll::Index; #[derive(Debug, Clone)] @@ -66,7 +68,7 @@ impl NormalForm { pub fn add_clause(&mut self, s: &[Index]) { let len = s.len(); - if self.data.len() < len { + if self.data.len() <= len { self.data.resize(len + 1, Vec::new()); } @@ -110,6 +112,25 @@ impl NormalForm { pub fn iter_len(&self, len: usize) -> NormalFormIteratorLen<'_> { NormalFormIteratorLen::new(self, len) } + + pub fn negate(&mut self) { + for v in &mut self.data[1..] { + v.reverse(); + for e in v { + *e = -*e; + } + } + } + + pub fn stats(&self) { + println!("Normal form stats:"); + let mut sum = 0; + for (i, v) in self.data.iter().enumerate() { + sum += v.len(); + println!("{i:2}: {:15}", v.len() / usize::max(1, i)); + } + println!("Total size: {}byte", sum * Index::BITS as usize / 8); + } } pub struct NormalFormIterator<'nf> {