fix: target shutdown

This commit is contained in:
Break27 2024-06-08 18:26:18 +08:00
parent 8ebba7271c
commit a501a30f29
4 changed files with 25 additions and 23 deletions

View File

@ -122,7 +122,7 @@ impl Agent {
log::info!("CLIENT --> {host}"); log::info!("CLIENT --> {host}");
if self.check_request_blocked(&request.path) { if self.check_request_blocked(&request.path) {
log::info!("CLIENT --> PROXY --> {host}"); log::info!("CLIENT --> PROXY --> TARGET");
let mut outbound = self.io(self.builder.connect(host))?; let mut outbound = self.io(self.builder.connect(host))?;
// forward intercepted request // forward intercepted request
@ -130,7 +130,9 @@ impl Agent {
outbound.flush().await?; outbound.flush().await?;
log::info!("CLIENT <-> PROXY (connection established)"); log::info!("CLIENT <-> PROXY (connection established)");
self.tunnel(conn, outbound).await?; self.tunnel(conn, &mut outbound).await?;
outbound.shutdown(std::net::Shutdown::Both)?;
return Ok(()); return Ok(());
} }
@ -150,11 +152,13 @@ impl Agent {
log::debug!("CLIENT --> (intercepted) --> TARGET"); log::debug!("CLIENT --> (intercepted) --> TARGET");
} }
self.tunnel(conn, target).await?; self.tunnel(conn, &mut target).await?;
target.shutdown(std::net::Shutdown::Both)?;
return Ok(()); return Ok(());
} }
async fn tunnel<A, B>(&self, inbound: &mut A, mut outbound: B) -> Result<()> async fn tunnel<A, B>(&self, inbound: &mut A, outbound: &mut B) -> Result<()>
where where
A: Read + Write + Send + Sync + Unpin + 'static, A: Read + Write + Send + Sync + Unpin + 'static,
B: Read + Write + Send + Sync + Unpin + 'static, B: Read + Write + Send + Sync + Unpin + 'static,

View File

@ -10,7 +10,7 @@ pub enum ConnectionBuilder {
} }
impl ConnectionBuilder { impl ConnectionBuilder {
pub async fn connect(&self, target: &str) -> Result<Connection, std::io::Error> { pub async fn connect(&self, target: &str) -> std::io::Result<Connection> {
let conn = match self { let conn = match self {
Self::Http(addr) => { Self::Http(addr) => {
Connection::new(TcpStream::connect(addr)?) Connection::new(TcpStream::connect(addr)?)
@ -36,9 +36,13 @@ impl Connection {
Self { inner: Async::new(conn).unwrap() } Self { inner: Async::new(conn).unwrap() }
} }
pub fn into_inner(self) -> Result<TcpStream, std::io::Error> { pub fn into_inner(self) -> std::io::Result<TcpStream> {
self.inner.into_inner() self.inner.into_inner()
} }
pub fn shutdown(self, how: std::net::Shutdown) -> std::io::Result<()> {
self.into_inner()?.shutdown(how)
}
} }
impl async_std::io::Read for Connection { impl async_std::io::Read for Connection {

View File

@ -1,11 +1,12 @@
use crate::error::Error; use crate::error::Error;
macro_rules! impl_http { macro_rules! impl_http {
(pub struct $name:ident($inner:ident) { $($variant:ident = $value:literal,)* }) => { (pub struct $name:ident($inner:ident) { $($variant:ident = $value:literal,)* }) => {
#[derive(PartialEq)] #[derive(PartialEq, Eq)]
pub struct $name($inner); pub struct $name($inner);
#[derive(PartialEq)] #[derive(PartialEq, Eq)]
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
enum $inner{ $($variant,)* } enum $inner{ $($variant,)* }
@ -99,15 +100,13 @@ impl Response {
pub fn from_err(err: Error) -> Self { pub fn from_err(err: Error) -> Self {
use async_std::io::ErrorKind::TimedOut; use async_std::io::ErrorKind::TimedOut;
match err { match err {
Error::BadRequest(x) => Error::BadRequest(_) =>
Self::make(400, x), Self::make(400, "Bad Request"),
Error::Io(e) if e.kind() == TimedOut => Error::Io(e) if e.kind() == TimedOut =>
Self::make(408, "Timeout"), Self::make(408, "Timeout"),
Error::Io(_) => Error::Io(_) =>
Self::make(503, "Unavailable"), Self::make(503, "Unavailable"),
Error::Parse(_) => Error::Parse(_) | Error::Utf8(_) =>
Self::make(500, "Internal Server Error"),
Error::Utf8(_) =>
Self::make(500, "Internal Server Error"), Self::make(500, "Internal Server Error"),
} }
} }

View File

@ -1,18 +1,11 @@
use std::sync::Arc;
use async_std::io::WriteExt;
use crate::http;
pub struct Server { pub struct Server {
addrs: std::net::SocketAddr, addrs: std::net::SocketAddr,
} }
impl Server { impl Server {
pub async fn run(self, agent: crate::agent::Agent) -> Result<(), std::io::Error> { pub async fn run(self, agent: crate::agent::Agent) -> std::io::Result<()> {
let listener = async_std::net::TcpListener::bind(self.addrs).await?; let listener = async_std::net::TcpListener::bind(self.addrs).await?;
let agent = Arc::new(agent); let agent = std::sync::Arc::new(agent);
log::info!("IMPOSTER/0.1 HTTP SERVER"); log::info!("IMPOSTER/0.1 HTTP SERVER");
log::info!("Server listening at {}", self.addrs); log::info!("Server listening at {}", self.addrs);
@ -26,7 +19,9 @@ impl Server {
async_std::task::spawn(async move { async_std::task::spawn(async move {
if let Err(e) = agent.handle(&mut inbound).await { if let Err(e) = agent.handle(&mut inbound).await {
log::error!("Agent: {e}"); log::error!("Agent: {e}");
let resp = http::Response::from_err(e);
let resp = crate::http::Response::from_err(e);
use async_std::io::WriteExt;
inbound.write(resp.to_string().as_bytes()).await.unwrap(); inbound.write(resp.to_string().as_bytes()).await.unwrap();
inbound.flush().await.unwrap(); inbound.flush().await.unwrap();