1use clap::{Parser, Subcommand};
2use serde::{Deserialize, Serialize};
3use std::path::PathBuf;
4use tga::{TgaEnum, TgaModel};
5use tokio::fs;
6use tracing::info;
7
8use crate::commands::Command;
9use crate::data::{DataRow, DataStreamResult};
10use crate::loaders::ip_list::load_ipv6_addresses_from_file;
11
12#[derive(Parser, Serialize, Deserialize)]
13pub struct TrainCommand {
14 #[arg(short, long)]
16 pub seeds: PathBuf,
17
18 #[arg(short, long)]
20 pub output: PathBuf,
21
22 #[command(subcommand)]
24 pub tga: TgaEnum,
25}
26
27impl Command for TrainCommand {
28 async fn run(&self) -> Result<(), String> {
29 info!("Reading seed addresses from file: {}", self.seeds.display());
30
31 let seed_bytes = load_ipv6_addresses_from_file(&self.seeds).await?;
33
34 info!("Training TGA: {}", self.tga.name());
36 let model = self
37 .tga
38 .train(seed_bytes)
39 .await
40 .map_err(|e| format!("Failed to train TGA: {}", e))?;
41
42 info!("Writing model to file: {}", self.output.display());
44 let model_data =
45 bincode::serialize(&model).map_err(|e| format!("Failed to serialize model: {}", e))?;
46 fs::write(&self.output, model_data)
47 .await
48 .map_err(|e| format!("Failed to write model to file: {}", e))?;
49
50 let model_info = format!("{}", model);
52 let row = DataRow::new()
53 .with_column("tga_name", self.tga.name())
54 .with_column("model_info", model_info)
55 .with_column("output_file", self.output.display().to_string());
56
57 Ok(())
58 }
59}