diff --git a/src/agent.rs b/src/agent.rs index d5cf560..0733113 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -122,7 +122,7 @@ impl Agent { log::info!("CLIENT --> {host}"); if self.check_request_blocked(&request.path) { - log::info!("CLIENT --> PROXY --> {host}"); + log::info!("CLIENT --> PROXY --> TARGET"); let mut outbound = self.io(self.builder.connect(host))?; // forward intercepted request @@ -130,7 +130,9 @@ impl Agent { outbound.flush().await?; log::info!("CLIENT <-> PROXY (connection established)"); - self.tunnel(conn, outbound).await?; + self.tunnel(conn, &mut outbound).await?; + + outbound.shutdown(std::net::Shutdown::Both)?; return Ok(()); } @@ -150,11 +152,13 @@ impl Agent { log::debug!("CLIENT --> (intercepted) --> TARGET"); } - self.tunnel(conn, target).await?; + self.tunnel(conn, &mut target).await?; + target.shutdown(std::net::Shutdown::Both)?; + return Ok(()); } - async fn tunnel(&self, inbound: &mut A, mut outbound: B) -> Result<()> + async fn tunnel(&self, inbound: &mut A, outbound: &mut B) -> Result<()> where A: Read + Write + Send + Sync + Unpin + 'static, B: Read + Write + Send + Sync + Unpin + 'static, diff --git a/src/connection.rs b/src/connection.rs index 50e9da1..580ee27 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -10,7 +10,7 @@ pub enum ConnectionBuilder { } impl ConnectionBuilder { - pub async fn connect(&self, target: &str) -> Result { + pub async fn connect(&self, target: &str) -> std::io::Result { let conn = match self { Self::Http(addr) => { Connection::new(TcpStream::connect(addr)?) @@ -36,9 +36,13 @@ impl Connection { Self { inner: Async::new(conn).unwrap() } } - pub fn into_inner(self) -> Result { + pub fn into_inner(self) -> std::io::Result { self.inner.into_inner() } + + pub fn shutdown(self, how: std::net::Shutdown) -> std::io::Result<()> { + self.into_inner()?.shutdown(how) + } } impl async_std::io::Read for Connection { diff --git a/src/http.rs b/src/http.rs index e86798c..377c502 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,11 +1,12 @@ use crate::error::Error; + macro_rules! impl_http { (pub struct $name:ident($inner:ident) { $($variant:ident = $value:literal,)* }) => { - #[derive(PartialEq)] + #[derive(PartialEq, Eq)] pub struct $name($inner); - #[derive(PartialEq)] + #[derive(PartialEq, Eq)] #[allow(non_camel_case_types)] enum $inner{ $($variant,)* } @@ -99,15 +100,13 @@ impl Response { pub fn from_err(err: Error) -> Self { use async_std::io::ErrorKind::TimedOut; match err { - Error::BadRequest(x) => - Self::make(400, x), + Error::BadRequest(_) => + Self::make(400, "Bad Request"), Error::Io(e) if e.kind() == TimedOut => Self::make(408, "Timeout"), Error::Io(_) => Self::make(503, "Unavailable"), - Error::Parse(_) => - Self::make(500, "Internal Server Error"), - Error::Utf8(_) => + Error::Parse(_) | Error::Utf8(_) => Self::make(500, "Internal Server Error"), } } diff --git a/src/server.rs b/src/server.rs index 9069c98..a59d245 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,18 +1,11 @@ -use std::sync::Arc; - -use async_std::io::WriteExt; - -use crate::http; - - pub struct Server { addrs: std::net::SocketAddr, } impl Server { - pub async fn run(self, agent: crate::agent::Agent) -> Result<(), std::io::Error> { + pub async fn run(self, agent: crate::agent::Agent) -> std::io::Result<()> { let listener = async_std::net::TcpListener::bind(self.addrs).await?; - let agent = Arc::new(agent); + let agent = std::sync::Arc::new(agent); log::info!("IMPOSTER/0.1 HTTP SERVER"); log::info!("Server listening at {}", self.addrs); @@ -26,7 +19,9 @@ impl Server { async_std::task::spawn(async move { if let Err(e) = agent.handle(&mut inbound).await { log::error!("Agent: {e}"); - let resp = http::Response::from_err(e); + + let resp = crate::http::Response::from_err(e); + use async_std::io::WriteExt; inbound.write(resp.to_string().as_bytes()).await.unwrap(); inbound.flush().await.unwrap();