diff --git a/src/bin/minesweeper.rs b/src/bin/minesweeper.rs new file mode 100644 index 0000000..0df19bb --- /dev/null +++ b/src/bin/minesweeper.rs @@ -0,0 +1,159 @@ +use std::collections::HashMap; + +use simple_sat_solver::{ + dpll::Index, expr::Expr, normal_form::NormalForm, solver::SolutionIterator, +}; + +fn solve_minesweeper(f: F, input: &[Vec], mines: u64) -> (u64, Vec>, f64) +where + F: Fn(&NormalForm) -> Option>, +{ + let mut cnf = Vec::new(); + + let mut total_flags = 0; + let mut unknown_fields = 0; + + for y in 0..input.len() { + for x in 0..input[y].len() { + match input[y][x] { + 0 => (), + 1..=9 => { + let mut neighbors = Vec::new(); + let mut flags = 0; + for (nx, ny) in (-1..=1) + .flat_map(|dx| (-1..=1).map(move |dy| (dx, dy))) + .filter_map(|(dx, dy)| { + x.checked_add_signed(dx) + .filter(|&nx| nx < input[y].len()) + .zip(y.checked_add_signed(dy).filter(|&ny| ny < input.len())) + }) + { + match input[ny][nx] { + 0..10 => (), + 10 => neighbors.push((nx, ny)), + 11 => flags += 1, + _ => unreachable!(), + } + } + + let bombs = (input[y][x] - flags) as usize; + if bombs > 0 { + assert!(!neighbors.is_empty()); + cnf.push(Expr::cnf_from_truth_function( + |bits| bits.iter().filter(|&&b| b).count() == bombs, + &neighbors, + )); + } + } + 10 => unknown_fields += 1, + 11 => total_flags += 1, + _ => unreachable!(), + } + } + } + + let expr = Expr::And(cnf); + + let mut result = vec![vec![u64::MAX; input[0].len()]; input.len()]; + let mut total = 0; + + let mut interesting_fields = 0; + + let mut bomb_distribution = HashMap::new(); + + for i in SolutionIterator::new(f, expr) { + total += 1; + interesting_fields = i.len(); + let mut bombs = 0; + for ((x, y), b) in i { + if result[y][x] == u64::MAX { + result[y][x] = 0; + } + if b { + result[y][x] += 1; + bombs += 1; + } + } + match bomb_distribution.entry(bombs) { + std::collections::hash_map::Entry::Occupied(mut occupied_entry) => { + *occupied_entry.get_mut() += 1; + } + std::collections::hash_map::Entry::Vacant(vacant_entry) => { + vacant_entry.insert(1); + } + } + } + + let expected_bombs = bomb_distribution + .iter() + .map(|(&k, &v)| k as f64 * v as f64) + .sum::() + / total as f64; + + ( + total, + result, + (mines as f64 - expected_bombs) / (unknown_fields - interesting_fields) as f64, + ) +} + +fn print_input(input: &[Vec]) { + for row in input.iter() { + for elem in row.iter() { + match elem { + 0 => print!(" . "), + 1..=9 => print!(" {elem} "), + 10 => print!(" # "), + 11 => print!(" F "), + _ => unreachable!(), + }; + } + println!(); + } +} + +fn print_output(total: u64, input: &[Vec], map: &[Vec], default: f64) { + for (input_row, map_row) in input.iter().zip(map.iter()) { + for (input_elem, map_elem) in input_row.iter().zip(map_row.iter()) { + match input_elem { + 0 => print!(" . "), + 1..=9 => print!(" {input_elem} "), + 10 => { + if *map_elem == u64::MAX { + print!("{:3}", (default * 100.0) as u8); + } else { + print!("{:3}", map_elem * 100 / total); + } + } + 11 => print!(" F "), + _ => unreachable!(), + }; + } + println!(); + } +} + +fn main() { + // let input = vec![ + // vec![0, 0, 1, 10, 10, 10], + // vec![1, 1, 2, 10, 10, 10], + // vec![1, 11, 2, 2, 10, 10], + // vec![1, 1, 2, 10, 10, 10], + // vec![1, 1, 2, 10, 10, 10], + // vec![10, 10, 2, 10, 10, 10], + // vec![10; 6], + // ]; + // + let input = vec![ + vec![2, 2, 10], + vec![10, 10, 10], + vec![10, 10, 10], + vec![10, 10, 10], + ]; + + print_input(&input); + + let (total, result, default) = solve_minesweeper(simple_sat_solver::cdcl::cdcl, &input, 5); + + print_output(total, &input, &result, default); +}