rmap/commands/
generate.rs

1use clap::Parser;
2use futures::Stream;
3use indicatif::{ParallelProgressIterator, ProgressBar, ProgressStyle};
4use rand::{Rng, SeedableRng, rngs::StdRng};
5use rayon::prelude::*;
6use serde::{Deserialize, Serialize};
7use std::collections::HashSet;
8use std::net::IpAddr;
9use std::path::PathBuf;
10use std::sync::Arc;
11use tga::{ModelEnum, TgaGenerator, TgaModel};
12use tokio::fs::File;
13use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::sync::mpsc;
15use tokio_stream::StreamExt;
16use tokio_stream::wrappers::ReceiverStream;
17use tracing::info;
18
19use crate::commands::Command;
20use crate::data::{DataRow, DataStreamInfo, DataStreamResult, stream_from_iter};
21
22#[derive(Parser, Serialize, Deserialize)]
23pub struct GenerateCommand {
24    /// Load trained model from file
25    #[arg(short, long)]
26    pub model: PathBuf,
27
28    /// Number of addresses to generate
29    #[arg(short = 'n', long)]
30    pub count: usize,
31
32    /// Ensure generated addresses are unique
33    #[arg(short = 'u', long)]
34    pub unique: bool,
35
36    /// Maximum number of attempts at generating a unique address
37    #[arg(short = 'a', long, default_value = "1000")]
38    pub max_attempts: usize,
39
40    /// Exclude a list of addresses from generation
41    #[arg(short = 'e', long)]
42    pub exclude: Option<PathBuf>,
43
44    /// Randomize the interface identifier portion of each generated address
45    #[arg(long)]
46    pub random_iid: bool,
47
48    /// Output file to save the generated addresses
49    #[arg(short = 'o', long, default_value = "output.txt")]
50    pub output: PathBuf,
51}
52
53impl Command for GenerateCommand {
54    async fn run(&self) -> Result<(), String> {
55        info!("Loading model from: {}", self.model.display());
56
57        // Load model from file
58        let model = {
59            let model_data = tokio::fs::read(&self.model)
60                .await
61                .map_err(|e| format!("Failed to read model file: {}", e))?;
62            let model: ModelEnum = bincode::deserialize(&model_data)
63                .map_err(|e| format!("Failed to deserialize model: {}", e))?;
64
65            model
66        };
67
68        let mut exclude = HashSet::new();
69        if let Some(path) = self.exclude.clone() {
70            info!("Loading excluded targets");
71            let file = File::open(path).await.unwrap();
72            let reader = BufReader::new(file);
73            let mut lines = reader.lines();
74            while let Some(line) = lines.next_line().await.unwrap() {
75                let line = line.trim();
76                if line.is_empty() {
77                    continue;
78                }
79                exclude.insert(line.parse::<IpAddr>().unwrap());
80            }
81        }
82
83        let mut stream = stream_targets_from_model(model, 16, 16, self.random_iid);
84        let mut targets = HashSet::new();
85
86        info!(
87            "Generating {} addresses (unique: {})",
88            self.count, self.unique
89        );
90
91        // Create progress bar for unique generation
92        let progress_bar = ProgressBar::new(self.count as u64);
93        progress_bar.set_style(
94            ProgressStyle::default_bar()
95                .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} unique addresses ({eta})")
96                .unwrap()
97                .progress_chars("#>-")
98        );
99
100        let mut output_file = File::create(self.output.clone()).await.unwrap();
101
102        let mut attempts = 0;
103        let mut unique_attempts = 0;
104        while let Some(target) = stream.next().await {
105            // Done generating addresses
106            if targets.len() >= self.count {
107                break;
108            }
109
110            if (!self.unique || !targets.contains(&target)) && !exclude.contains(&target) {
111                unique_attempts = 0;
112                targets.insert(target);
113                let s = format!("{}\n", target);
114                output_file.write_all(s.as_bytes()).await.unwrap();
115            } else {
116                unique_attempts += 1;
117                if unique_attempts >= self.max_attempts {
118                    tracing::error!("Exceeded max attempts at generating a unique address");
119                    break;
120                }
121            }
122
123            attempts += 1;
124            progress_bar.set_position(targets.len() as u64);
125        }
126
127        progress_bar.finish_with_message(format!(
128            "Generated {} addresses in {} attempts",
129            targets.len(),
130            attempts
131        ));
132
133        info!(
134            "Generated {} addresses in {} attempts",
135            targets.len(),
136            attempts
137        );
138
139        // Convert addresses to DataRow stream
140        let data_rows: Vec<DataRow> = targets
141            .into_iter()
142            .map(|addr| DataRow::new().with_column("address", addr.to_string()))
143            .collect();
144
145        let headers = vec!["address".to_string()];
146        let info = DataStreamInfo::new(headers)
147            .with_total_rows(data_rows.len())
148            .with_description(format!(
149                "Generated {} addresses from model",
150                data_rows.len()
151            ));
152
153        let stream = stream_from_iter(data_rows);
154        Ok(())
155    }
156}
157
158pub fn stream_targets_from_model(
159    model: ModelEnum,
160    buffer: usize,
161    workers: usize,
162    random_iid: bool,
163) -> impl Stream<Item = IpAddr> {
164    let (target_tx, target_rx) = mpsc::channel(buffer);
165    let generator = Arc::new(model.build(0));
166
167    for _ in 0..workers {
168        let generator = generator.clone();
169        let tx = target_tx.clone();
170        let random_iid = random_iid;
171        tokio::spawn(async move {
172            let mut rng = StdRng::from_entropy();
173            while let Ok(permit) = tx.reserve().await {
174                let mut bytes = generator.generate();
175                if random_iid {
176                    let random_tail: u64 = rng.r#gen();
177                    bytes[8..].copy_from_slice(&random_tail.to_be_bytes());
178                }
179                let addr = IpAddr::from(bytes);
180                permit.send(addr);
181            }
182        });
183    }
184
185    ReceiverStream::new(target_rx)
186}