rmap/commands/
generate.rs1use clap::Parser;
2use futures::Stream;
3use indicatif::{ParallelProgressIterator, ProgressBar, ProgressStyle};
4use rand::{Rng, SeedableRng, rngs::StdRng};
5use rayon::prelude::*;
6use serde::{Deserialize, Serialize};
7use std::collections::HashSet;
8use std::net::IpAddr;
9use std::path::PathBuf;
10use std::sync::Arc;
11use tga::{ModelEnum, TgaGenerator, TgaModel};
12use tokio::fs::File;
13use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::sync::mpsc;
15use tokio_stream::StreamExt;
16use tokio_stream::wrappers::ReceiverStream;
17use tracing::info;
18
19use crate::commands::Command;
20use crate::data::{DataRow, DataStreamInfo, DataStreamResult, stream_from_iter};
21
22#[derive(Parser, Serialize, Deserialize)]
23pub struct GenerateCommand {
24 #[arg(short, long)]
26 pub model: PathBuf,
27
28 #[arg(short = 'n', long)]
30 pub count: usize,
31
32 #[arg(short = 'u', long)]
34 pub unique: bool,
35
36 #[arg(short = 'a', long, default_value = "1000")]
38 pub max_attempts: usize,
39
40 #[arg(short = 'e', long)]
42 pub exclude: Option<PathBuf>,
43
44 #[arg(long)]
46 pub random_iid: bool,
47
48 #[arg(short = 'o', long, default_value = "output.txt")]
50 pub output: PathBuf,
51}
52
53impl Command for GenerateCommand {
54 async fn run(&self) -> Result<(), String> {
55 info!("Loading model from: {}", self.model.display());
56
57 let model = {
59 let model_data = tokio::fs::read(&self.model)
60 .await
61 .map_err(|e| format!("Failed to read model file: {}", e))?;
62 let model: ModelEnum = bincode::deserialize(&model_data)
63 .map_err(|e| format!("Failed to deserialize model: {}", e))?;
64
65 model
66 };
67
68 let mut exclude = HashSet::new();
69 if let Some(path) = self.exclude.clone() {
70 info!("Loading excluded targets");
71 let file = File::open(path).await.unwrap();
72 let reader = BufReader::new(file);
73 let mut lines = reader.lines();
74 while let Some(line) = lines.next_line().await.unwrap() {
75 let line = line.trim();
76 if line.is_empty() {
77 continue;
78 }
79 exclude.insert(line.parse::<IpAddr>().unwrap());
80 }
81 }
82
83 let mut stream = stream_targets_from_model(model, 16, 16, self.random_iid);
84 let mut targets = HashSet::new();
85
86 info!(
87 "Generating {} addresses (unique: {})",
88 self.count, self.unique
89 );
90
91 let progress_bar = ProgressBar::new(self.count as u64);
93 progress_bar.set_style(
94 ProgressStyle::default_bar()
95 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} unique addresses ({eta})")
96 .unwrap()
97 .progress_chars("#>-")
98 );
99
100 let mut output_file = File::create(self.output.clone()).await.unwrap();
101
102 let mut attempts = 0;
103 let mut unique_attempts = 0;
104 while let Some(target) = stream.next().await {
105 if targets.len() >= self.count {
107 break;
108 }
109
110 if (!self.unique || !targets.contains(&target)) && !exclude.contains(&target) {
111 unique_attempts = 0;
112 targets.insert(target);
113 let s = format!("{}\n", target);
114 output_file.write_all(s.as_bytes()).await.unwrap();
115 } else {
116 unique_attempts += 1;
117 if unique_attempts >= self.max_attempts {
118 tracing::error!("Exceeded max attempts at generating a unique address");
119 break;
120 }
121 }
122
123 attempts += 1;
124 progress_bar.set_position(targets.len() as u64);
125 }
126
127 progress_bar.finish_with_message(format!(
128 "Generated {} addresses in {} attempts",
129 targets.len(),
130 attempts
131 ));
132
133 info!(
134 "Generated {} addresses in {} attempts",
135 targets.len(),
136 attempts
137 );
138
139 let data_rows: Vec<DataRow> = targets
141 .into_iter()
142 .map(|addr| DataRow::new().with_column("address", addr.to_string()))
143 .collect();
144
145 let headers = vec!["address".to_string()];
146 let info = DataStreamInfo::new(headers)
147 .with_total_rows(data_rows.len())
148 .with_description(format!(
149 "Generated {} addresses from model",
150 data_rows.len()
151 ));
152
153 let stream = stream_from_iter(data_rows);
154 Ok(())
155 }
156}
157
158pub fn stream_targets_from_model(
159 model: ModelEnum,
160 buffer: usize,
161 workers: usize,
162 random_iid: bool,
163) -> impl Stream<Item = IpAddr> {
164 let (target_tx, target_rx) = mpsc::channel(buffer);
165 let generator = Arc::new(model.build(0));
166
167 for _ in 0..workers {
168 let generator = generator.clone();
169 let tx = target_tx.clone();
170 let random_iid = random_iid;
171 tokio::spawn(async move {
172 let mut rng = StdRng::from_entropy();
173 while let Ok(permit) = tx.reserve().await {
174 let mut bytes = generator.generate();
175 if random_iid {
176 let random_tail: u64 = rng.r#gen();
177 bytes[8..].copy_from_slice(&random_tail.to_be_bytes());
178 }
179 let addr = IpAddr::from(bytes);
180 permit.send(addr);
181 }
182 });
183 }
184
185 ReceiverStream::new(target_rx)
186}