add Response

This commit is contained in:
break27 2024-06-07 17:30:45 +08:00
parent c5e30d916d
commit c464fb1e34
5 changed files with 69 additions and 9 deletions

View File

@ -112,11 +112,11 @@ unsafe impl Send for Agent {}
unsafe impl Sync for Agent {} unsafe impl Sync for Agent {}
impl Agent { impl Agent {
pub async fn handle<S>(&self, mut conn: S) -> Result<()> pub async fn handle<S>(&self, conn: &mut S) -> Result<()>
where where
S: Read + Write + Send + Sync + Unpin + 'static S: Read + Write + Send + Sync + Unpin + 'static
{ {
let request = self.read(&mut conn)?; let request = self.read(conn)?;
let host = request.host(); let host = request.host();
log::info!("CLIENT --> {host}"); log::info!("CLIENT --> {host}");
@ -138,9 +138,9 @@ impl Agent {
log::info!("CLIENT <-> TARGET (direct)"); log::info!("CLIENT <-> TARGET (direct)");
if let http::Method::CONNECT = request.method { if let http::Method::CONNECT = request.method {
let resp = b"HTTP/1.1 200 OK\r\n\r\n"; let resp = http::Response::default();
// send response to client with code 200 and an EMPTY body // send response to client with code 200 and an EMPTY body
conn.write_all(resp).await?; conn.write_all(resp.to_string().as_bytes()).await?;
conn.flush().await?; conn.flush().await?;
log::debug!("Received CONNECT (200 OK)"); log::debug!("Received CONNECT (200 OK)");
} else { } else {
@ -154,7 +154,7 @@ impl Agent {
return Ok(()); return Ok(());
} }
async fn tunnel<A, B>(&self, mut inbound: A, mut outbound: B) -> Result<()> async fn tunnel<A, B>(&self, inbound: &mut A, mut outbound: B) -> Result<()>
where where
A: Read + Write + Send + Sync + Unpin + 'static, A: Read + Write + Send + Sync + Unpin + 'static,
B: Read + Write + Send + Sync + Unpin + 'static, B: Read + Write + Send + Sync + Unpin + 'static,

View File

@ -31,7 +31,6 @@ impl_error! {
Io(std::io::Error), Io(std::io::Error),
Parse(httparse::Error), Parse(httparse::Error),
Utf8(std::str::Utf8Error), Utf8(std::str::Utf8Error),
Timeout(async_std::future::TimeoutError),
} }
} }

View File

@ -1,3 +1,5 @@
use crate::error::Error;
macro_rules! impl_http { macro_rules! impl_http {
(pub struct $name:ident($inner:ident) { $($variant:ident = $value:literal,)* }) => { (pub struct $name:ident($inner:ident) { $($variant:ident = $value:literal,)* }) => {
#[derive(PartialEq)] #[derive(PartialEq)]
@ -72,3 +74,54 @@ impl Request {
&self.host &self.host
} }
} }
pub struct Response {
pub version: Version,
pub status: u16,
pub message: String
}
impl Response {
pub fn new<T>(version: Version, status: u16, message: T) -> Self
where
T: ToString
{
Self { version, status, message: message.to_string() }
}
pub fn make<T>(status: u16, message: T) -> Self
where
T: ToString
{
Self { version: Version::HTTP_11, status, message: message.to_string() }
}
pub fn from_err(err: Error) -> Self {
use async_std::io::ErrorKind::TimedOut;
match err {
Error::BadRequest(x) =>
Self::make(400, x),
Error::Io(e) if e.kind() == TimedOut =>
Self::make(408, "Timeout"),
Error::Io(_) =>
Self::make(503, "Unavailable"),
Error::Parse(_) =>
Self::make(500, "Internal Server Error"),
Error::Utf8(_) =>
Self::make(500, "Internal Server Error"),
}
}
}
impl Default for Response {
fn default() -> Self {
Self::make(200, "OK")
}
}
impl ToString for Response {
fn to_string(&self) -> String {
format!("{} {} {}\r\n\r\n",
self.version.to_string(), self.status, self.message)
}
}

View File

@ -1,6 +1,6 @@
use clap::Parser; use clap::Parser;
mod http; pub mod http;
pub mod agent; pub mod agent;
pub mod error; pub mod error;
pub mod connection; pub mod connection;

View File

@ -1,5 +1,9 @@
use std::sync::Arc; use std::sync::Arc;
use async_std::io::WriteExt;
use crate::http;
pub struct Server { pub struct Server {
addrs: std::net::SocketAddr, addrs: std::net::SocketAddr,
@ -15,13 +19,17 @@ impl Server {
loop { loop {
let agent = agent.clone(); let agent = agent.clone();
let (inbound, addr) = listener.accept().await?; let (mut inbound, addr) = listener.accept().await?;
log::info!("*** Incoming connection from {}", addr); log::info!("*** Incoming connection from {}", addr);
async_std::task::spawn(async move { async_std::task::spawn(async move {
if let Err(e) = agent.handle(inbound).await { if let Err(e) = agent.handle(&mut inbound).await {
log::error!("Agent: {}", e); log::error!("Agent: {}", e);
let resp = http::Response::from_err(e);
inbound.write(resp.to_string().as_bytes()).await.unwrap();
inbound.flush().await.unwrap();
} }
}); });
} }