Add minesweeper solver

This commit is contained in:
hal8174 2025-02-22 18:56:34 +01:00
parent e5f408aed2
commit d130fd57d9

159
src/bin/minesweeper.rs Normal file
View file

@ -0,0 +1,159 @@
use std::collections::HashMap;
use simple_sat_solver::{
dpll::Index, expr::Expr, normal_form::NormalForm, solver::SolutionIterator,
};
fn solve_minesweeper<F>(f: F, input: &[Vec<u8>], mines: u64) -> (u64, Vec<Vec<u64>>, f64)
where
F: Fn(&NormalForm) -> Option<Vec<Index>>,
{
let mut cnf = Vec::new();
let mut total_flags = 0;
let mut unknown_fields = 0;
for y in 0..input.len() {
for x in 0..input[y].len() {
match input[y][x] {
0 => (),
1..=9 => {
let mut neighbors = Vec::new();
let mut flags = 0;
for (nx, ny) in (-1..=1)
.flat_map(|dx| (-1..=1).map(move |dy| (dx, dy)))
.filter_map(|(dx, dy)| {
x.checked_add_signed(dx)
.filter(|&nx| nx < input[y].len())
.zip(y.checked_add_signed(dy).filter(|&ny| ny < input.len()))
})
{
match input[ny][nx] {
0..10 => (),
10 => neighbors.push((nx, ny)),
11 => flags += 1,
_ => unreachable!(),
}
}
let bombs = (input[y][x] - flags) as usize;
if bombs > 0 {
assert!(!neighbors.is_empty());
cnf.push(Expr::cnf_from_truth_function(
|bits| bits.iter().filter(|&&b| b).count() == bombs,
&neighbors,
));
}
}
10 => unknown_fields += 1,
11 => total_flags += 1,
_ => unreachable!(),
}
}
}
let expr = Expr::And(cnf);
let mut result = vec![vec![u64::MAX; input[0].len()]; input.len()];
let mut total = 0;
let mut interesting_fields = 0;
let mut bomb_distribution = HashMap::new();
for i in SolutionIterator::new(f, expr) {
total += 1;
interesting_fields = i.len();
let mut bombs = 0;
for ((x, y), b) in i {
if result[y][x] == u64::MAX {
result[y][x] = 0;
}
if b {
result[y][x] += 1;
bombs += 1;
}
}
match bomb_distribution.entry(bombs) {
std::collections::hash_map::Entry::Occupied(mut occupied_entry) => {
*occupied_entry.get_mut() += 1;
}
std::collections::hash_map::Entry::Vacant(vacant_entry) => {
vacant_entry.insert(1);
}
}
}
let expected_bombs = bomb_distribution
.iter()
.map(|(&k, &v)| k as f64 * v as f64)
.sum::<f64>()
/ total as f64;
(
total,
result,
(mines as f64 - expected_bombs) / (unknown_fields - interesting_fields) as f64,
)
}
fn print_input(input: &[Vec<u8>]) {
for row in input.iter() {
for elem in row.iter() {
match elem {
0 => print!(" . "),
1..=9 => print!(" {elem} "),
10 => print!(" # "),
11 => print!(" F "),
_ => unreachable!(),
};
}
println!();
}
}
fn print_output(total: u64, input: &[Vec<u8>], map: &[Vec<u64>], default: f64) {
for (input_row, map_row) in input.iter().zip(map.iter()) {
for (input_elem, map_elem) in input_row.iter().zip(map_row.iter()) {
match input_elem {
0 => print!(" . "),
1..=9 => print!(" {input_elem} "),
10 => {
if *map_elem == u64::MAX {
print!("{:3}", (default * 100.0) as u8);
} else {
print!("{:3}", map_elem * 100 / total);
}
}
11 => print!(" F "),
_ => unreachable!(),
};
}
println!();
}
}
fn main() {
// let input = vec![
// vec![0, 0, 1, 10, 10, 10],
// vec![1, 1, 2, 10, 10, 10],
// vec![1, 11, 2, 2, 10, 10],
// vec![1, 1, 2, 10, 10, 10],
// vec![1, 1, 2, 10, 10, 10],
// vec![10, 10, 2, 10, 10, 10],
// vec![10; 6],
// ];
//
let input = vec![
vec![2, 2, 10],
vec![10, 10, 10],
vec![10, 10, 10],
vec![10, 10, 10],
];
print_input(&input);
let (total, result, default) = solve_minesweeper(simple_sat_solver::cdcl::cdcl, &input, 5);
print_output(total, &input, &result, default);
}