diff --git a/src/agent.rs b/src/agent.rs index bce337e..d5cf560 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -112,11 +112,11 @@ unsafe impl Send for Agent {} unsafe impl Sync for Agent {} impl Agent { - pub async fn handle(&self, mut conn: S) -> Result<()> + pub async fn handle(&self, conn: &mut S) -> Result<()> where S: Read + Write + Send + Sync + Unpin + 'static { - let request = self.read(&mut conn)?; + let request = self.read(conn)?; let host = request.host(); log::info!("CLIENT --> {host}"); @@ -138,9 +138,9 @@ impl Agent { log::info!("CLIENT <-> TARGET (direct)"); if let http::Method::CONNECT = request.method { - let resp = b"HTTP/1.1 200 OK\r\n\r\n"; + let resp = http::Response::default(); // send response to client with code 200 and an EMPTY body - conn.write_all(resp).await?; + conn.write_all(resp.to_string().as_bytes()).await?; conn.flush().await?; log::debug!("Received CONNECT (200 OK)"); } else { @@ -154,7 +154,7 @@ impl Agent { return Ok(()); } - async fn tunnel(&self, mut inbound: A, mut outbound: B) -> Result<()> + async fn tunnel(&self, inbound: &mut A, mut outbound: B) -> Result<()> where A: Read + Write + Send + Sync + Unpin + 'static, B: Read + Write + Send + Sync + Unpin + 'static, diff --git a/src/error.rs b/src/error.rs index ef94638..7138b35 100644 --- a/src/error.rs +++ b/src/error.rs @@ -31,7 +31,6 @@ impl_error! { Io(std::io::Error), Parse(httparse::Error), Utf8(std::str::Utf8Error), - Timeout(async_std::future::TimeoutError), } } diff --git a/src/http.rs b/src/http.rs index 772aee9..e86798c 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,3 +1,5 @@ +use crate::error::Error; + macro_rules! impl_http { (pub struct $name:ident($inner:ident) { $($variant:ident = $value:literal,)* }) => { #[derive(PartialEq)] @@ -72,3 +74,54 @@ impl Request { &self.host } } + +pub struct Response { + pub version: Version, + pub status: u16, + pub message: String +} + +impl Response { + pub fn new(version: Version, status: u16, message: T) -> Self + where + T: ToString + { + Self { version, status, message: message.to_string() } + } + + pub fn make(status: u16, message: T) -> Self + where + T: ToString + { + Self { version: Version::HTTP_11, status, message: message.to_string() } + } + + pub fn from_err(err: Error) -> Self { + use async_std::io::ErrorKind::TimedOut; + match err { + Error::BadRequest(x) => + Self::make(400, x), + Error::Io(e) if e.kind() == TimedOut => + Self::make(408, "Timeout"), + Error::Io(_) => + Self::make(503, "Unavailable"), + Error::Parse(_) => + Self::make(500, "Internal Server Error"), + Error::Utf8(_) => + Self::make(500, "Internal Server Error"), + } + } +} + +impl Default for Response { + fn default() -> Self { + Self::make(200, "OK") + } +} + +impl ToString for Response { + fn to_string(&self) -> String { + format!("{} {} {}\r\n\r\n", + self.version.to_string(), self.status, self.message) + } +} diff --git a/src/main.rs b/src/main.rs index d7a8895..c51e485 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ use clap::Parser; -mod http; +pub mod http; pub mod agent; pub mod error; pub mod connection; diff --git a/src/server.rs b/src/server.rs index c902bf2..453d593 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,5 +1,9 @@ use std::sync::Arc; +use async_std::io::WriteExt; + +use crate::http; + pub struct Server { addrs: std::net::SocketAddr, @@ -15,13 +19,17 @@ impl Server { loop { let agent = agent.clone(); - let (inbound, addr) = listener.accept().await?; + let (mut inbound, addr) = listener.accept().await?; log::info!("*** Incoming connection from {}", addr); async_std::task::spawn(async move { - if let Err(e) = agent.handle(inbound).await { + if let Err(e) = agent.handle(&mut inbound).await { log::error!("Agent: {}", e); + let resp = http::Response::from_err(e); + + inbound.write(resp.to_string().as_bytes()).await.unwrap(); + inbound.flush().await.unwrap(); } }); }