| #include "training.hpp" |
| #include "utils.hpp" |
| #include "fungi_Paremetres.hpp" |
| #include <iostream> |
| #include <vector> |
| #include <string> |
| #include <numeric> |
| #include <algorithm> |
| #include <random> |
| #include <iomanip> |
|
|
| void train_model(const FashionMNISTSet& train, const FashionMNISTSet& test, TrainConfig& cfg) { |
| const int N_train = train.N; |
| const int N_test = test.N; |
|
|
| OpticalParams params; |
| init_params(params, cfg.seed); |
|
|
| FungiSoA fungi; |
| fungi.resize(cfg.fungi_count, IMG_H, IMG_W); |
| fungi.init_random(cfg.seed); |
|
|
| DeviceBuffers db; |
| allocate_device_buffers(db, cfg.batch); |
|
|
| |
| upload_params_to_gpu(params, db); |
|
|
| FFTPlan fft; |
| create_fft_plan(fft, cfg.batch); |
|
|
| std::vector<int> train_indices(N_train); |
| std::iota(train_indices.begin(), train_indices.end(), 0); |
| std::mt19937 rng(cfg.seed); |
|
|
| int adam_step = 0; |
| double prev_accuracy = -1.0; |
|
|
| for (int ep = 1; ep <= cfg.epochs; ++ep) { |
| std::shuffle(train_indices.begin(), train_indices.end(), rng); |
| double epoch_loss = 0.0; |
| int samples_seen = 0; |
|
|
| |
| for (int start = 0; start < N_train; start += cfg.batch) { |
| int current_B = std::min(cfg.batch, N_train - start); |
|
|
| std::vector<float> h_batch_in(current_B * IMG_SIZE); |
| std::vector<uint8_t> h_batch_lbl(current_B); |
|
|
| for (int i = 0; i < current_B; ++i) { |
| int idx = train_indices[start + i]; |
| memcpy(&h_batch_in[i * IMG_SIZE], &train.images[idx * IMG_SIZE], IMG_SIZE * sizeof(float)); |
| h_batch_lbl[i] = train.labels[idx]; |
| } |
|
|
| adam_step++; |
| float loss = train_batch(h_batch_in.data(), h_batch_lbl.data(), current_B, fungi, params, db, fft, cfg.lr, cfg.wd, adam_step); |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| epoch_loss += loss * current_B; |
| samples_seen += current_B; |
| std::cout << "\r[Epoch " << ep << "] Progress: " << samples_seen << "/" << N_train |
| << " Avg Loss: " << std::fixed << std::setprecision(5) << (epoch_loss / samples_seen) |
| << std::flush; |
| } |
| std::cout << "\n"; |
|
|
| |
| std::cout << "[INFO] Evaluating on test set for epoch " << ep << "...\n"; |
| int correct_predictions = 0; |
| for (int start = 0; start < N_test; start += cfg.batch) { |
| int current_B = std::min(cfg.batch, N_test - start); |
|
|
| std::vector<float> h_batch_in(current_B * IMG_SIZE); |
| for (int i = 0; i < current_B; ++i) { |
| memcpy(&h_batch_in[i * IMG_SIZE], &test.images[(start + i) * IMG_SIZE], IMG_SIZE * sizeof(float)); |
| } |
|
|
| std::vector<int> predictions; |
| infer_batch(h_batch_in.data(), current_B, fungi, params, db, fft, predictions); |
|
|
| for (int i = 0; i < current_B; ++i) { |
| if (predictions[i] == test.labels[start + i]) { |
| correct_predictions++; |
| } |
| } |
| } |
| double accuracy = static_cast<double>(correct_predictions) / N_test; |
| std::cout << "[Epoch " << ep << " RESULT] Test Accuracy: " |
| << std::fixed << std::setprecision(4) << (accuracy * 100.0) << "%\n"; |
|
|
| if (prev_accuracy >= 0.0) { |
| double delta = accuracy - prev_accuracy; |
| if (delta > cfg.accuracy_tolerance) { |
| int target_fungi = static_cast<int>(std::ceil(static_cast<double>(fungi.F) * cfg.fungi_growth_rate)); |
| target_fungi = std::max(cfg.fungi_min, std::min(cfg.fungi_max, target_fungi)); |
| if (target_fungi > fungi.F) { |
| fungi.adjust_population(target_fungi, cfg.seed + static_cast<unsigned>(ep * 17)); |
| cfg.fungi_count = fungi.F; |
| std::cout << "[ADAPT] Accuracy improved by " << delta * 100.0 |
| << "% -> fungi population " << fungi.F << "\n"; |
| } |
| } else if (delta < -cfg.accuracy_tolerance) { |
| int target_fungi = static_cast<int>(std::floor(static_cast<double>(fungi.F) * cfg.fungi_decay_rate)); |
| target_fungi = std::max(cfg.fungi_min, std::min(cfg.fungi_max, target_fungi)); |
| if (target_fungi < fungi.F) { |
| fungi.adjust_population(target_fungi, cfg.seed + static_cast<unsigned>(ep * 23)); |
| cfg.fungi_count = fungi.F; |
| std::cout << "[ADAPT] Accuracy decreased by " << -delta * 100.0 |
| << "% -> fungi population " << fungi.F << "\n"; |
| } |
| } |
| } |
| prev_accuracy = accuracy; |
| } |
|
|
| free_device_buffers(db); |
| destroy_fft_plan(fft); |
| } |
|
|