#pragma once #include #include #include "random_sampler_wrapper.hpp" #include "math/vec2.h" #include "math/vec3fa.h" class MLTRandomSampler : public embree::RandomSamplerWrapper { private: size_t index; std::vector data; std::vector new_data; std::vector last_changed; size_t time; size_t last_large_step; float small_step_size; bool large_step; float normalize(float x) { if (x < 0.0) { return x + 1.0; } else if (x >= 1.0) { return x - 1.0; } return x; } public: MLTRandomSampler() { } MLTRandomSampler(float small_step_size) : index(0), data({}), last_changed({}), time(0), last_large_step(0), small_step_size(small_step_size) {} void init(int id) override { RandomSampler_init(sampler, id); } void accept() { time++; for (size_t i = 0; i < new_data.size(); i++) { if (i >= data.size()) { data.push_back(new_data[i]); last_changed.push_back(time); } else { data.at(i) = new_data.at(i); last_changed.at(i) = time; } } if (large_step) last_large_step = time; } void new_ray(bool is_large_step) { index = 0; new_data.clear(); large_step = is_large_step; } bool is_large_step() { return large_step; } float get1D() override { float result; if (is_large_step()) { float r = RandomSampler_get1D(sampler); new_data.push_back(r); result = r; } else { if (index >= data.size()) { float r = RandomSampler_get1D(sampler); data.push_back(r); last_changed.push_back(time); new_data.push_back(r); result = r; } // printf("%d, %d\n", index, last_large_step); if (last_changed.at(index) < last_large_step) { float r = RandomSampler_get1D(sampler); data.at(index) = r; last_changed.at(index) = time; new_data.push_back(r); result = r; } else { size_t steps = time - last_changed.at(index); float d = data.at(index); for (size_t i = 0; i < steps; i++) { float r = RandomSampler_get1D(sampler); float o = r * small_step_size - (small_step_size / 2.0); d = normalize(d + o); } data.at(index) = d; last_changed.at(index) = time; float r = RandomSampler_get1D(sampler); float o = r * small_step_size - (small_step_size / 2.0); d = normalize(d + o); new_data.push_back(d); // printf("%zu, %zu, %f, %f\n", index, steps, data.at(index), d); result = d; } } // printf("%zu, %f\n", index, result); index++; return result; } embree::Vec2f get2D() override { return embree::Vec2f(get1D(), get1D()); } embree::Vec3fa get3D() override { return embree::Vec3fa(get1D(), get1D(), get1D()); } };