Add improved representation for normal forms

This commit is contained in:
hal8174 2024-10-13 01:02:27 +02:00
parent c23697d386
commit 6b9785dc1f
3 changed files with 268 additions and 7 deletions

59
src/dpll_normal_form.rs Normal file
View file

@ -0,0 +1,59 @@
use std::collections::HashSet;
use crate::{dpll::Index, normal_form::NormalForm};
fn unit_propagate(cnf: &mut NormalForm, solution: &mut Vec<Index>) {
if let Some((s, _)) = cnf.iter_len(1).next() {
let l = s[0];
solution.push(l);
cnf.remove_literal(l);
unit_propagate(cnf, solution);
}
}
fn pure_literal_assign(cnf: &mut NormalForm, solution: &mut Vec<Index>) {
let mut literals: HashSet<Index> = HashSet::new();
for (s, _, _) in cnf.iter() {
literals.extend(s.iter());
}
for &l in &literals {
if !literals.contains(&(-l)) {
solution.push(l);
cnf.remove_literal(l);
}
}
}
pub fn dpll(cnf: &NormalForm, solution: &mut Vec<Index>) -> bool {
let prev_solution = solution.len();
let mut own_cnf = cnf.to_owned();
unit_propagate(&mut own_cnf, solution);
pure_literal_assign(&mut own_cnf, solution);
if own_cnf.is_empty() {
return true;
}
if own_cnf.empty_clauses() {
solution.truncate(prev_solution);
return false;
}
let literal = own_cnf.iter().next().unwrap().0[0];
own_cnf.push_litearl_clause(literal);
if dpll(&own_cnf, solution) {
return true;
}
own_cnf.remove_literal_clause();
own_cnf.push_litearl_clause(-literal);
if dpll(&own_cnf, solution) {
return true;
}
solution.truncate(prev_solution);
false
}

View file

@ -4,6 +4,9 @@ use clap::Parser;
use dpll::dpll; use dpll::dpll;
use expr::{Expr, Tokanizer}; use expr::{Expr, Tokanizer};
use miette::{IntoDiagnostic, Result}; use miette::{IntoDiagnostic, Result};
use normal_form::NormalForm;
mod dpll_normal_form;
mod normal_form;
mod dpll; mod dpll;
@ -38,13 +41,26 @@ fn main() -> Result<()> {
let inv_map: HashMap<_, _> = map.into_iter().map(|(k, v)| (v, k)).collect(); let inv_map: HashMap<_, _> = map.into_iter().map(|(k, v)| (v, k)).collect();
for i in solution { dbg!(&solution);
if i > 0 { // for i in solution {
println!("{}: true", inv_map[&i]); // if i > 0 {
} else { // println!("{}: true", inv_map[&i]);
println!("{}: false", inv_map[&i.abs()]); // } else {
} // println!("{}: false", inv_map[&i.abs()]);
} // }
// }
let mut nf = NormalForm::from_slices(&cnf);
// dbg!(&nf);
// nf.remove_literal(-1);
// dbg!(&nf);
let mut s2 = Vec::new();
dbg!(dpll_normal_form::dpll(&nf, &mut s2));
dbg!(s2 == solution);
Ok(()) Ok(())
} }

186
src/normal_form.rs Normal file
View file

@ -0,0 +1,186 @@
use crate::dpll::Index;
#[derive(Debug, Clone)]
pub struct NormalForm {
data: Vec<Vec<Index>>,
}
pub struct Internal<'nf> {
nf: &'nf mut NormalForm,
len: usize,
index: usize,
}
impl NormalForm {
pub fn new() -> Self {
Self { data: Vec::new() }
}
pub fn remove_literal(&mut self, l: Index) {
// dbg!(l);
for len in 1..self.data.len() {
let mut retained = 0;
let (left, right) = self.data.split_at_mut(len);
let smaller = left.last_mut().unwrap();
let current = right.first_mut().unwrap();
for index in 0..current.len() / len {
let s = &current[(index * len)..((index + 1) * len)];
if s.binary_search(&l).is_ok() {
} else if let Ok(i) = s.binary_search(&(-l)) {
if len == 1 {
smaller.push(0);
} else {
smaller.extend_from_slice(&s[0..i]);
smaller.extend_from_slice(&s[(i + 1)..]);
}
} else {
if retained < index {
current.copy_within((index * len)..((index + 1) * len), retained * len);
}
retained += 1;
}
}
current.truncate(retained * len);
}
// dbg!(&self);
}
pub fn from_slices(slices: &[Vec<Index>]) -> Self {
let mut nf = Self::new();
for s in slices {
nf.add_clause(s.as_slice());
}
nf
}
pub fn is_empty(&self) -> bool {
self.data.iter().all(|v| v.is_empty())
}
pub fn empty_clauses(&self) -> bool {
!self.data[0].is_empty()
}
pub fn add_clause(&mut self, s: &[Index]) {
let len = s.len();
if self.data.len() < len {
self.data.resize(len + 1, Vec::new());
}
if len > 0 {
self.data[len].extend_from_slice(s);
} else {
self.data[0].push(0);
}
}
pub fn push_litearl_clause(&mut self, l: Index) {
self.data[1].push(l);
}
pub fn remove_literal_clause(&mut self) {
self.data[1].pop();
}
pub fn get_mut(&mut self, len: usize, index: usize) -> Internal<'_> {
Internal {
nf: self,
len,
index,
}
}
pub fn get_clause(&self, len: usize, index: usize) -> &[Index] {
&self.data[len][(index * len)..((index + 1) * len)]
}
pub fn remove_clause(&mut self, len: usize, index: usize) {
let r = self.data[len].len() - len;
self.data[len].copy_within(r.., index * len);
self.data[len].truncate(r);
}
pub fn iter(&self) -> NormalFormIterator<'_> {
NormalFormIterator::new(self)
}
pub fn iter_len(&self, len: usize) -> NormalFormIteratorLen<'_> {
NormalFormIteratorLen::new(self, len)
}
}
pub struct NormalFormIterator<'nf> {
nf: &'nf NormalForm,
len: usize,
index: usize,
}
impl<'nf> NormalFormIterator<'nf> {
fn new(nf: &'nf NormalForm) -> Self {
Self {
nf,
len: 0,
index: 0,
}
}
}
impl<'nf> Iterator for NormalFormIterator<'nf> {
type Item = (&'nf [Index], usize, usize);
fn next(&mut self) -> Option<Self::Item> {
while self.len < self.nf.data.len()
&& self.index * usize::max(self.len, 1) >= self.nf.data[self.len].len()
{
self.len += 1;
self.index = 0;
}
if self.nf.data.len() > self.len {
let s = self.nf.get_clause(self.len, self.index);
self.index += 1;
Some((s, self.len, self.index - 1))
} else {
None
}
}
}
pub struct NormalFormIteratorLen<'nf> {
nf: &'nf NormalForm,
len: usize,
index: usize,
}
impl<'nf> NormalFormIteratorLen<'nf> {
fn new(nf: &'nf NormalForm, len: usize) -> Self {
Self { nf, len, index: 0 }
}
}
impl<'nf> Iterator for NormalFormIteratorLen<'nf> {
type Item = (&'nf [Index], usize);
fn next(&mut self) -> Option<Self::Item> {
if self.len < self.nf.data.len()
&& self.nf.data[self.len].len() > self.index * usize::max(self.len, 1)
{
let s = self.nf.get_clause(self.len, self.index);
self.index += 1;
Some((s, self.index - 1))
} else {
None
}
}
}
impl Default for NormalForm {
fn default() -> Self {
Self::new()
}
}