imposter/src/agent.rs
2024-06-07 17:30:45 +08:00

269 lines
8.1 KiB
Rust

use async_std::io::{Read, Write, ReadExt, WriteExt};
use async_std::net::TcpStream;
use crate::connection::ConnectionBuilder;
use crate::error::{Result, Error, BuildError, BuildResult};
use crate::http;
pub struct AgentBuilder {
filter_url: Option<url::Url>,
buf_size: Option<usize>,
timeout: Option<u64>,
decode: bool
}
impl AgentBuilder {
pub fn new() -> Self {
Self {
filter_url: None,
buf_size: None,
timeout: None,
decode: true
}
}
pub fn filter(mut self, url: url::Url) -> Self {
let _ = self.filter_url.insert(url);
self
}
pub fn timeout(mut self, timeout: u64) -> Self {
let _ = self.timeout.insert(timeout);
self
}
pub fn buffer(mut self, size: usize) -> Self {
let _ = self.buf_size.insert(size);
self
}
pub fn decode(mut self, decode: bool) -> Self {
self.decode = decode;
self
}
pub fn build(self, remote: url::Url) -> BuildResult<Agent> {
use ConnectionBuilder as CB;
let builder = match remote.scheme() {
"http" | "" => CB::Http(remote.authority().to_string()),
"socks" | "socks5" => CB::Socks5(remote.authority().to_string()),
other => return Err(BuildError::Unsupported(other.to_string()))
};
let mut ruleset = None;
let time = self.timeout.unwrap_or(u64::MAX);
let config = AgentConfig {
buf_size: self.buf_size.unwrap_or(1024),
timeout: std::time::Duration::from_secs(time)
};
if let Some(ref url) = self.filter_url {
log::info!(target: "builder", "Try downloading rule list from '{url}'");
let https = native_tls::TlsConnector::new()?;
let client = ureq::AgentBuilder::new()
.proxy(ureq::Proxy::new(remote)?)
.tls_connector(https.into())
.timeout(config.timeout)
.build();
let resp = client.get(url.as_str()).call()?;
let text = resp.into_string()?;
let len = text.len() as f32 / 1000f32;
log::info!(target: "builder", "Successfully downloaded data ({len}/kB transmitted)");
ruleset = Some(self.build_rules(text)?);
}
Ok(Agent { builder, ruleset, config })
}
fn build_rules(&self, mut text: String) -> BuildResult<adblock::Engine> {
if self.decode {
log::info!(target: "builder", "Try decoding raw textual data (base64 encoded)");
use base64::{Engine, engine::general_purpose::STANDARD};
let line = text.split_whitespace().collect::<String>();
let decoded = STANDARD.decode(line)?;
text = String::from_utf8(decoded)?;
}
let mut filters = adblock::FilterSet::new(false);
let opts = adblock::lists::ParseOptions::default();
filters.add_filter_list(&text, opts);
log::info!(target: "builder", "Rule data parsed successfully");
Ok(adblock::Engine::from_filter_set(filters, true))
}
}
pub struct AgentConfig {
pub buf_size: usize,
pub timeout: std::time::Duration,
}
pub struct Agent {
ruleset: Option<adblock::Engine>,
builder: ConnectionBuilder,
config: AgentConfig,
}
unsafe impl Send for Agent {}
unsafe impl Sync for Agent {}
impl Agent {
pub async fn handle<S>(&self, conn: &mut S) -> Result<()>
where
S: Read + Write + Send + Sync + Unpin + 'static
{
let request = self.read(conn)?;
let host = request.host();
log::info!("CLIENT --> {host}");
if self.check_request_blocked(&request.path) {
log::info!("CLIENT --> PROXY --> {host}");
let mut outbound = self.io(self.builder.connect(host))?;
// forward intercepted request
outbound.write_all(request.as_bytes()).await?;
outbound.flush().await?;
log::info!("CLIENT <-> PROXY (connection established)");
self.tunnel(conn, outbound).await?;
return Ok(());
}
let mut target = self.io(TcpStream::connect(host))?;
log::info!("CLIENT <-> TARGET (direct)");
if let http::Method::CONNECT = request.method {
let resp = http::Response::default();
// send response to client with code 200 and an EMPTY body
conn.write_all(resp.to_string().as_bytes()).await?;
conn.flush().await?;
log::debug!("Received CONNECT (200 OK)");
} else {
// forward intercepted request
target.write_all(request.as_bytes()).await?;
target.flush().await?;
log::debug!("CLIENT --> (intercepted) --> TARGET");
}
self.tunnel(conn, target).await?;
return Ok(());
}
async fn tunnel<A, B>(&self, inbound: &mut A, mut outbound: B) -> Result<()>
where
A: Read + Write + Send + Sync + Unpin + 'static,
B: Read + Write + Send + Sync + Unpin + 'static,
{
use async_compat::CompatExt;
use tokio::io::copy_bidirectional as copy;
if let Err(e) = copy(
&mut outbound.compat_mut(), &mut inbound.compat_mut()).await
{
log::warn!("{}", e);
}
Ok(())
}
fn read<S>(&self, conn: &mut S) -> Result<http::Request>
where
S: Read + Write + Send + Unpin + 'static
{
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut request = httparse::Request::new(&mut headers);
let mut buf = vec![0; self.config.buf_size];
self.io(conn.read(&mut buf))?;
let offset = request.parse(&buf)?.unwrap();
let payload = buf[..offset].to_vec();
let method = match request.method {
Some(x) => x.parse::<crate::http::Method>().unwrap(),
None => return Err(Error::BadRequest("METHOD".to_string()))
};
let mut path = match request.path {
Some(x) => x.to_string(),
None => return Err(Error::BadRequest("PATH".to_string()))
};
if path.find("://").is_none() {
// in case of an cannot-be-a-base url
// find a port number, if any
let port = path
.rfind(":")
.and_then(|x| path.get(x + 1..));
let scheme = match port {
Some("443") => "https",
Some("21") => "ftp",
Some("80") | _ => "http",
};
path = format!("{}://{}", scheme, path);
}
let version = match request.version {
Some(3) => http::Version::HTTP_3,
Some(2) => http::Version::HTTP_2,
Some(11) => http::Version::HTTP_11,
Some(1) => http::Version::HTTP_10,
Some(_) => http::Version::HTTP_09,
None => return Err(Error::BadRequest("VERSION".to_string()))
};
let mut host = headers.iter()
.find_map(|x: _| (x.name == "Host").then_some(x.value))
.map(|x| std::str::from_utf8(x))
.ok_or(Error::BadRequest("Host".to_string()))??
.to_string();
if host.find(":").is_none() {
// append a port number when without one
host += ":80";
}
let request = crate::http::Request {
method,
path,
version,
host,
payload: payload.into(),
};
Ok(request)
}
fn check_request_blocked(&self, url: &str) -> bool {
let attempt: _ = adblock::request::Request::new(
url, url, "fetch"
);
let req = match attempt {
Ok(x) => x,
Err(_) => return true
};
match &self.ruleset {
Some(x) => x.check_network_request(&req).matched,
None => true // always use tunnel when without rules
}
}
fn io<T, F>(&self, f: F) -> Result<T>
where
F: std::future::Future<Output=std::result::Result<T, std::io::Error>>,
{
async_std::task::block_on(async {
Ok(async_std::io::timeout(self.config.timeout, f).await?)
})
}
}