rmap/commands/
train.rs

1use clap::{Parser, Subcommand};
2use serde::{Deserialize, Serialize};
3use std::path::PathBuf;
4use tga::{TgaEnum, TgaModel};
5use tokio::fs;
6use tracing::info;
7
8use crate::commands::Command;
9use crate::data::{DataRow, DataStreamResult};
10use crate::loaders::ip_list::load_ipv6_addresses_from_file;
11
12#[derive(Parser, Serialize, Deserialize)]
13pub struct TrainCommand {
14    /// Path to file containing seed addresses (one per line)
15    #[arg(short, long)]
16    pub seeds: PathBuf,
17
18    /// Output file to save the trained model
19    #[arg(short, long)]
20    pub output: PathBuf,
21
22    /// TGA algorithm to use for training
23    #[command(subcommand)]
24    pub tga: TgaEnum,
25}
26
27impl Command for TrainCommand {
28    async fn run(&self) -> Result<(), String> {
29        info!("Reading seed addresses from file: {}", self.seeds.display());
30
31        // Read seed addresses from file using async loader
32        let seed_bytes = load_ipv6_addresses_from_file(&self.seeds).await?;
33
34        // Create TGA enum based on subcommand
35        info!("Training TGA: {}", self.tga.name());
36        let model = self
37            .tga
38            .train(seed_bytes)
39            .await
40            .map_err(|e| format!("Failed to train TGA: {}", e))?;
41
42        // Save model to file asynchronously
43        info!("Writing model to file: {}", self.output.display());
44        let model_data =
45            bincode::serialize(&model).map_err(|e| format!("Failed to serialize model: {}", e))?;
46        fs::write(&self.output, model_data)
47            .await
48            .map_err(|e| format!("Failed to write model to file: {}", e))?;
49
50        // Create a simple result with model info
51        let model_info = format!("{}", model);
52        let row = DataRow::new()
53            .with_column("tga_name", self.tga.name())
54            .with_column("model_info", model_info)
55            .with_column("output_file", self.output.display().to_string());
56
57        Ok(())
58    }
59}