From 4f09baaa1c159a486cbede8dfdc9aa568d8974b8 Mon Sep 17 00:00:00 2001 From: Break27 Date: Sat, 8 Jun 2024 23:34:43 +0800 Subject: [PATCH] fix: adblock::Engine BorrowMut panic --- Cargo.lock | 1 + Cargo.toml | 2 +- src/agent.rs | 225 +++++++++++++++------------------------------- src/connection.rs | 3 +- src/engine.rs | 23 +++++ src/http.rs | 4 - src/main.rs | 38 ++++---- src/server.rs | 148 ++++++++++++++++++++++-------- 8 files changed, 235 insertions(+), 209 deletions(-) create mode 100644 src/engine.rs diff --git a/Cargo.lock b/Cargo.lock index b31a918..301e84c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1290,6 +1290,7 @@ dependencies = [ "log", "native-tls", "once_cell", + "socks", "url", ] diff --git a/Cargo.toml b/Cargo.toml index 2bf8781..770c8ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,5 +16,5 @@ log = "0.4.21" native-tls = "0.2.12" socks = "0.3.4" tokio = { version = "1.38.0", features = ["io-util"] } -ureq = { version = "2.9.7", default-features = false, features = ["native-tls"] } +ureq = { version = "2.9.7", default-features = false, features = ["native-tls", "socks-proxy"] } url = "2.5.0" diff --git a/src/agent.rs b/src/agent.rs index 0733113..afdd8df 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -1,164 +1,106 @@ -use async_std::io::{Read, Write, ReadExt, WriteExt}; +use async_std::io::{Read, ReadExt, Write, WriteExt}; use async_std::net::TcpStream; -use crate::connection::ConnectionBuilder; -use crate::error::{Result, Error, BuildError, BuildResult}; +use crate::connection::ConnectionBuilder as Builder; +use crate::engine::Engine; +use crate::error::{Error, Result}; use crate::http; -pub struct AgentBuilder { - filter_url: Option, - buf_size: Option, - timeout: Option, - decode: bool -} - -impl AgentBuilder { - pub fn new() -> Self { - Self { - filter_url: None, - buf_size: None, - timeout: None, - decode: true - } - } - - pub fn filter(mut self, url: url::Url) -> Self { - let _ = self.filter_url.insert(url); - self - } - - pub fn timeout(mut self, timeout: u64) -> Self { - let _ = self.timeout.insert(timeout); - self - } - - pub fn buffer(mut self, size: usize) -> Self { - let _ = self.buf_size.insert(size); - self - } - - pub fn decode(mut self, decode: bool) -> Self { - self.decode = decode; - self - } - - pub fn build(self, remote: url::Url) -> BuildResult { - use ConnectionBuilder as CB; - let builder = match remote.scheme() { - "http" | "" => CB::Http(remote.authority().to_string()), - "socks" | "socks5" => CB::Socks5(remote.authority().to_string()), - other => return Err(BuildError::Unsupported(other.to_string())) - }; - - let mut ruleset = None; - let time = self.timeout.unwrap_or(u64::MAX); - - let config = AgentConfig { - buf_size: self.buf_size.unwrap_or(1024), - timeout: std::time::Duration::from_secs(time) - }; - - if let Some(ref url) = self.filter_url { - log::info!(target: "builder", "Try downloading rule list from '{url}'"); - let https = native_tls::TlsConnector::new()?; - - let client = ureq::AgentBuilder::new() - .proxy(ureq::Proxy::new(remote)?) - .tls_connector(https.into()) - .timeout(config.timeout) - .build(); - let resp = client.get(url.as_str()).call()?; - let text = resp.into_string()?; - let len = text.len() as f32 / 1000f32; - - log::info!(target: "builder", "Successfully downloaded data ({len}/kB transmitted)"); - ruleset = Some(self.build_rules(text)?); - } - - Ok(Agent { builder, ruleset, config }) - } - - fn build_rules(&self, mut text: String) -> BuildResult { - if self.decode { - log::info!(target: "builder", "Try decoding raw textual data (base64 encoded)"); - use base64::{Engine, engine::general_purpose::STANDARD}; - let line = text.split_whitespace().collect::(); - let decoded = STANDARD.decode(line)?; - - text = String::from_utf8(decoded)?; - } - - let mut filters = adblock::FilterSet::new(false); - let opts = adblock::lists::ParseOptions::default(); - filters.add_filter_list(&text, opts); - - log::info!(target: "builder", "Rule data parsed successfully"); - Ok(adblock::Engine::from_filter_set(filters, true)) - } -} - pub struct AgentConfig { - pub buf_size: usize, + pub bufsize: usize, pub timeout: std::time::Duration, } pub struct Agent { - ruleset: Option, - builder: ConnectionBuilder, + builder: Builder, config: AgentConfig, + engine: Engine } unsafe impl Send for Agent {} unsafe impl Sync for Agent {} impl Agent { - pub async fn handle(&self, conn: &mut S) -> Result<()> + pub fn new(builder: Builder, config: AgentConfig, engine: Engine) -> Self { + Self { builder, config, engine } + } + + pub fn spawn(self: &std::sync::Arc, mut conn: TcpStream) { + let req = match self.read(&mut conn) { + Ok(x) => x, + Err(e) => return log::error!("Read Error: {e}") + }; + + let mat = self.engine.check_request_blocked(&req.path); + let agent = self.clone(); + + log::info!("CLIENT --> {} ({})", + req.host, mat.then_some("tunnel").unwrap_or("direct")); + + async_std::task::spawn(async move { + let res = if mat { + agent.remote(req, &mut conn).await + } else { + agent.direct(req, &mut conn).await + }; + + if let Err(e) = res { + log::error!("Agent: {e}"); + let resp = http::Response::from_err(e); + + conn.write(resp.to_string().as_bytes()).await.unwrap(); + conn.flush().await.unwrap(); + } + + let _ = conn.shutdown(std::net::Shutdown::Both); + }); + } + + async fn remote(&self, req: http::Request, inbound: &mut S) -> Result<()> where S: Read + Write + Send + Sync + Unpin + 'static { - let request = self.read(conn)?; - let host = request.host(); + let mut outbound = self.io(self.builder.connect(&req.host))?; + log::info!("CLIENT --> PROXY --> TARGET"); - log::info!("CLIENT --> {host}"); + // forward intercepted request + outbound.write_all(req.as_bytes()).await?; + outbound.flush().await?; - if self.check_request_blocked(&request.path) { - log::info!("CLIENT --> PROXY --> TARGET"); - let mut outbound = self.io(self.builder.connect(host))?; + log::info!("CLIENT <=> PROXY (connection established)"); + self.tunnel(inbound, &mut outbound).await; - // forward intercepted request - outbound.write_all(request.as_bytes()).await?; - outbound.flush().await?; + let _ = outbound.shutdown(std::net::Shutdown::Both); + return Ok(()); + } - log::info!("CLIENT <-> PROXY (connection established)"); - self.tunnel(conn, &mut outbound).await?; + async fn direct(&self, req: http::Request, inbound: &mut S) -> Result<()> + where + S: Read + Write + Send + Sync + Unpin + 'static + { + let mut outbound = self.io(TcpStream::connect(&req.host))?; + log::info!("CLIENT <=> TARGET (direct)"); - outbound.shutdown(std::net::Shutdown::Both)?; - return Ok(()); - } - - let mut target = self.io(TcpStream::connect(host))?; - log::info!("CLIENT <-> TARGET (direct)"); - - if let http::Method::CONNECT = request.method { + if let http::Method::CONNECT = req.method { let resp = http::Response::default(); - // send response to client with code 200 and an EMPTY body - conn.write_all(resp.to_string().as_bytes()).await?; - conn.flush().await?; + // respond to client with code 200 and an EMPTY body + inbound.write_all(resp.to_string().as_bytes()).await?; + inbound.flush().await?; log::debug!("Received CONNECT (200 OK)"); } else { // forward intercepted request - target.write_all(request.as_bytes()).await?; - target.flush().await?; + outbound.write_all(req.as_bytes()).await?; + outbound.flush().await?; log::debug!("CLIENT --> (intercepted) --> TARGET"); } - self.tunnel(conn, &mut target).await?; - target.shutdown(std::net::Shutdown::Both)?; + self.tunnel(inbound, &mut outbound).await; + let _ = outbound.shutdown(std::net::Shutdown::Both); return Ok(()); } - async fn tunnel(&self, inbound: &mut A, outbound: &mut B) -> Result<()> + async fn tunnel(&self, inbound: &mut A, outbound: &mut B) where A: Read + Write + Send + Sync + Unpin + 'static, B: Read + Write + Send + Sync + Unpin + 'static, @@ -169,27 +111,22 @@ impl Agent { if let Err(e) = copy( &mut outbound.compat_mut(), &mut inbound.compat_mut()).await { - log::warn!("{}", e); + log::warn!("{e}"); } - - Ok(()) } - fn read(&self, conn: &mut S) -> Result - where - S: Read + Write + Send + Unpin + 'static - { + fn read(&self, conn: &mut TcpStream) -> Result { let mut headers = [httparse::EMPTY_HEADER; 64]; let mut request = httparse::Request::new(&mut headers); - let mut buf = vec![0; self.config.buf_size]; + let mut buf = vec![0; self.config.bufsize]; self.io(conn.read(&mut buf))?; let offset = request.parse(&buf)?.unwrap(); let payload = buf[..offset].to_vec(); let method = match request.method { - Some(x) => x.parse::().unwrap(), + Some(x) => x.parse::().unwrap(), None => return Err(Error::BadRequest("METHOD".to_string())) }; @@ -234,7 +171,7 @@ impl Agent { host += ":80"; } - let request = crate::http::Request { + let request = http::Request { method, path, version, @@ -245,22 +182,6 @@ impl Agent { Ok(request) } - fn check_request_blocked(&self, url: &str) -> bool { - let attempt: _ = adblock::request::Request::new( - url, url, "fetch" - ); - - let req = match attempt { - Ok(x) => x, - Err(_) => return true - }; - - match &self.ruleset { - Some(x) => x.check_network_request(&req).matched, - None => true // always use tunnel when without rules - } - } - fn io(&self, f: F) -> Result where F: std::future::Future>, diff --git a/src/connection.rs b/src/connection.rs index 580ee27..0bebe13 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -31,8 +31,7 @@ unsafe impl Send for Connection {} unsafe impl Sync for Connection {} impl Connection { - pub fn new(conn: TcpStream) -> Self - { + pub fn new(conn: TcpStream) -> Self { Self { inner: Async::new(conn).unwrap() } } diff --git a/src/engine.rs b/src/engine.rs new file mode 100644 index 0000000..3f5ef8e --- /dev/null +++ b/src/engine.rs @@ -0,0 +1,23 @@ +use adblock::request::Request; + + +pub struct Engine { + inner: Option +} + +impl Engine { + pub fn new(inner: Option) -> Self { + Self { inner } + } + + pub fn check_request_blocked(&self, url: &str) -> bool { + let inner = match &self.inner { + Some(x) => x, + None => return true // always use tunnel when without rules + }; + + Request::new(url, url, "fetch") + .map(|x| inner.check_network_request(&x).matched) + .unwrap_or(true) + } +} diff --git a/src/http.rs b/src/http.rs index 377c502..1e8cbde 100644 --- a/src/http.rs +++ b/src/http.rs @@ -70,10 +70,6 @@ impl Request { pub fn as_bytes(&self) -> &[u8] { &self.payload } - - pub fn host(&self) -> &str { - &self.host - } } pub struct Response { diff --git a/src/main.rs b/src/main.rs index c51e485..c84e226 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +use std::ops::Not; use clap::Parser; pub mod http; @@ -5,23 +6,33 @@ pub mod agent; pub mod error; pub mod connection; pub mod server; +pub mod engine; #[derive(Parser)] #[command(version, about, long_about = None)] struct Cli { + /// Server port number #[arg(short, long)] port: Option, + /// Rule data URL #[arg(short, long, value_name = "URL")] filter_url: Option, + /// Buffer size #[arg(long, value_name = "SIZE")] buf_size: Option, + /// Seconds before server return a timeout response (408) #[arg(short, long, value_name = "SEC")] timeout: Option, + /// Load downloaded rule data without decoding (base64) + #[arg(long)] + plain_text: bool, + + /// Proxy server URL #[arg(value_name = "URL")] remote: url::Url } @@ -32,32 +43,29 @@ const PORT: u16 = 9000; const BUF_SIZE:usize = 1024; const TIMEOUT: u64 = 15; -async fn try_launch(agent: Result, - server: server::Server) -> Result<(), Box> -{ - Ok(server.run(agent?).await?) -} - -fn main() { +async fn launch() -> Result<(), Box> { if std::env::var("RUST_LOG").ok().is_none() { - std::env::set_var("RUST_LOG", "info"); + unsafe { std::env::set_var("RUST_LOG", "info") } } let cli = Cli::parse(); env_logger::init(); let port = cli.port.unwrap_or(PORT); - let server = server::Server::bind((LOCALHOST, port)); - - let agent = agent::AgentBuilder::new() + let server: _ = server::Server::builder() .buffer(cli.buf_size.unwrap_or(BUF_SIZE)) .filter(cli.filter_url.unwrap_or(URL.parse().unwrap())) .timeout(cli.timeout.unwrap_or(TIMEOUT)) - .build(cli.remote); + .encoded(cli.plain_text.not()) + .build(cli.remote)?; - if let Err(e) = async_std::task::block_on( - try_launch(agent, server)) + Ok(server.bind((LOCALHOST, port)).await?) +} + +fn main() { + if let Err(e) = + async_std::task::block_on(launch()) { - eprintln!("Error: {}", e); + eprintln!("Error: {e}"); } } diff --git a/src/server.rs b/src/server.rs index a59d245..d799a96 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,47 +1,125 @@ +use async_std::net::TcpListener; + +use crate::agent::{Agent, AgentConfig}; +use crate::connection::ConnectionBuilder; +use crate::error::{BuildError, BuildResult}; +use crate::engine::Engine; + + +#[derive(Default)] +pub struct ServerBuilder { + filter_url: Option, + buf_size: Option, + timeout: Option, + encoded: bool +} + +impl ServerBuilder { + pub fn filter(mut self, url: url::Url) -> Self { + let _ = self.filter_url.insert(url); + self + } + + pub fn buffer(mut self, size: usize) -> Self { + let _ = self.buf_size.insert(size); + self + } + + pub fn timeout(mut self, timeout: u64) -> Self { + let _ = self.timeout.insert(timeout); + self + } + + pub fn encoded(mut self, encoded: bool) -> Self { + self.encoded = encoded; + self + } + + pub fn build(self, remote: url::Url) -> BuildResult { + use ConnectionBuilder as CB; + let builder = match remote.scheme() { + "http" | "" => CB::Http(remote.authority().to_string()), + "socks" | "socks5" => CB::Socks5(remote.authority().to_string()), + other => return Err(BuildError::Unsupported(other.to_string())) + }; + + let mut ruleset = None; + let time = self.timeout.unwrap_or(u64::MAX); + + let config = AgentConfig { + bufsize: self.buf_size.unwrap_or(1024), + timeout: std::time::Duration::from_secs(time), + }; + + if let Some(ref url) = self.filter_url { + log::info!("Try downloading rule list from '{url}'"); + let https = native_tls::TlsConnector::new()?; + + let client = ureq::AgentBuilder::new() + .proxy(ureq::Proxy::new(remote)?) + .tls_connector(https.into()) + .timeout(config.timeout) + .build(); + let resp = client.get(url.as_str()).call()?; + let text = resp.into_string()?; + let len = text.len() as f32 / 1000f32; + + log::info!("Successfully downloaded data ({len}/kB transmitted)"); + ruleset = Some(self.build_rules(text)?); + } + + let engine = Engine::new(ruleset); + let agent = Agent::new(builder, config, engine).into(); + + Ok(Server { agent }) + } + + fn build_rules(&self, mut text: String) -> BuildResult { + if self.encoded { + log::info!("Try decoding raw textual data (base64 encoded)"); + use base64::{Engine, engine::general_purpose::STANDARD}; + let line = text.split_whitespace().collect::(); + let decoded = STANDARD.decode(line)?; + + text = String::from_utf8(decoded)?; + } + + let mut filters = adblock::FilterSet::new(false); + let opts = adblock::lists::ParseOptions::default(); + filters.add_filter_list(&text, opts); + + log::info!("Rule data parsed successfully"); + Ok(adblock::Engine::from_filter_set(filters, true)) + } +} + pub struct Server { - addrs: std::net::SocketAddr, + agent: std::sync::Arc } impl Server { - pub async fn run(self, agent: crate::agent::Agent) -> std::io::Result<()> { - let listener = async_std::net::TcpListener::bind(self.addrs).await?; - let agent = std::sync::Arc::new(agent); - - log::info!("IMPOSTER/0.1 HTTP SERVER"); - log::info!("Server listening at {}", self.addrs); - - loop { - let agent = agent.clone(); - let (mut inbound, addr) = listener.accept().await?; - - log::info!("*** Incoming connection from {addr}"); - - async_std::task::spawn(async move { - if let Err(e) = agent.handle(&mut inbound).await { - log::error!("Agent: {e}"); - - let resp = crate::http::Response::from_err(e); - use async_std::io::WriteExt; - - inbound.write(resp.to_string().as_bytes()).await.unwrap(); - inbound.flush().await.unwrap(); - } - - let _ = inbound.shutdown(std::net::Shutdown::Both); - }); - } + pub fn builder() -> ServerBuilder { + ServerBuilder::default() } - pub fn bind(addrs: A) -> Self + pub async fn bind(self, addrs: A) -> std::io::Result<()> where A: std::net::ToSocketAddrs { - let addrs = addrs.to_socket_addrs() - .expect("Bind Error") - .collect::>() - .pop() - .expect("Bind Error"); + let addrs = addrs.to_socket_addrs()? + .collect::>() + .pop() + .expect("Bind Error"); - Self { addrs } + log::info!("IMPOSTER/0.1 HTTP SERVER"); + log::info!("Server listening at {addrs}"); + + let listener = TcpListener::bind(addrs).await?; + loop { + let (inbound, addr) = listener.accept().await?; + + log::info!("*** Incoming connection from {addr}"); + self.agent.spawn(inbound); + } } }