From 2699f702b1d04c9dad9ed469cf1a859f47f4e03f Mon Sep 17 00:00:00 2001 From: Break27 Date: Sun, 9 Jun 2024 01:10:30 +0800 Subject: [PATCH] handle concurrently (again) --- src/agent.rs | 40 +++++++++++++++++++--------------------- src/engine.rs | 9 +++++---- src/server.rs | 10 +++++++--- 3 files changed, 31 insertions(+), 28 deletions(-) diff --git a/src/agent.rs b/src/agent.rs index afdd8df..4bf8930 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -25,35 +25,32 @@ impl Agent { Self { builder, config, engine } } - pub fn spawn(self: &std::sync::Arc, mut conn: TcpStream) { + pub async fn handle(&self, mut conn: TcpStream) { let req = match self.read(&mut conn) { Ok(x) => x, - Err(e) => return log::error!("Read Error: {e}") + Err(e) => return log::error!("Read: {e}") }; let mat = self.engine.check_request_blocked(&req.path); - let agent = self.clone(); log::info!("CLIENT --> {} ({})", req.host, mat.then_some("tunnel").unwrap_or("direct")); - async_std::task::spawn(async move { - let res = if mat { - agent.remote(req, &mut conn).await - } else { - agent.direct(req, &mut conn).await - }; + let res = if mat { + self.remote(req, &mut conn).await + } else { + self.direct(req, &mut conn).await + }; - if let Err(e) = res { - log::error!("Agent: {e}"); - let resp = http::Response::from_err(e); + if let Err(e) = res { + log::error!("Agent: {e}"); + let resp = http::Response::from_err(e); - conn.write(resp.to_string().as_bytes()).await.unwrap(); - conn.flush().await.unwrap(); - } + conn.write(resp.to_string().as_bytes()).await.unwrap(); + conn.flush().await.unwrap(); + } - let _ = conn.shutdown(std::net::Shutdown::Both); - }); + let _ = conn.shutdown(std::net::Shutdown::Both); } async fn remote(&self, req: http::Request, inbound: &mut S) -> Result<()> @@ -61,7 +58,7 @@ impl Agent { S: Read + Write + Send + Sync + Unpin + 'static { let mut outbound = self.io(self.builder.connect(&req.host))?; - log::info!("CLIENT --> PROXY --> TARGET"); + log::info!("CLIENT --> PROXY (pending)"); // forward intercepted request outbound.write_all(req.as_bytes()).await?; @@ -79,14 +76,14 @@ impl Agent { S: Read + Write + Send + Sync + Unpin + 'static { let mut outbound = self.io(TcpStream::connect(&req.host))?; - log::info!("CLIENT <=> TARGET (direct)"); + log::info!("CLIENT --> TARGET (pending)"); if let http::Method::CONNECT = req.method { let resp = http::Response::default(); // respond to client with code 200 and an EMPTY body inbound.write_all(resp.to_string().as_bytes()).await?; inbound.flush().await?; - log::debug!("Received CONNECT (200 OK)"); + log::debug!("Agent: received CONNECT (200 OK)"); } else { // forward intercepted request outbound.write_all(req.as_bytes()).await?; @@ -94,6 +91,7 @@ impl Agent { log::debug!("CLIENT --> (intercepted) --> TARGET"); } + log::info!("CLIENT <=> TARGET (connection established)"); self.tunnel(inbound, &mut outbound).await; let _ = outbound.shutdown(std::net::Shutdown::Both); @@ -148,7 +146,7 @@ impl Agent { Some("80") | _ => "http", }; - path = format!("{}://{}", scheme, path); + path = format!("{scheme}://{path}"); } let version = match request.version { diff --git a/src/engine.rs b/src/engine.rs index 3f5ef8e..34db8c1 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -1,23 +1,24 @@ +use std::sync::Mutex; use adblock::request::Request; pub struct Engine { - inner: Option + inner: Option> } impl Engine { pub fn new(inner: Option) -> Self { - Self { inner } + Self { inner: inner.map(|x| x.into()) } } pub fn check_request_blocked(&self, url: &str) -> bool { let inner = match &self.inner { - Some(x) => x, + Some(x) => x.lock().unwrap(), None => return true // always use tunnel when without rules }; Request::new(url, url, "fetch") - .map(|x| inner.check_network_request(&x).matched) + .map(|req| inner.check_network_request(&req).matched) .unwrap_or(true) } } diff --git a/src/server.rs b/src/server.rs index d799a96..4677021 100644 --- a/src/server.rs +++ b/src/server.rs @@ -69,9 +69,9 @@ impl ServerBuilder { } let engine = Engine::new(ruleset); - let agent = Agent::new(builder, config, engine).into(); + let agent = Agent::new(builder, config, engine); - Ok(Server { agent }) + Ok(Server { agent: agent.into() }) } fn build_rules(&self, mut text: String) -> BuildResult { @@ -117,9 +117,13 @@ impl Server { let listener = TcpListener::bind(addrs).await?; loop { let (inbound, addr) = listener.accept().await?; + let agent = self.agent.clone(); log::info!("*** Incoming connection from {addr}"); - self.agent.spawn(inbound); + + async_std::task::spawn(async move { + agent.handle(inbound).await; + }); } } }