fix: target shutdown

This commit is contained in:
Break27 2024-06-08 18:26:18 +08:00
parent 8ebba7271c
commit a501a30f29
4 changed files with 25 additions and 23 deletions

View File

@ -122,7 +122,7 @@ impl Agent {
log::info!("CLIENT --> {host}");
if self.check_request_blocked(&request.path) {
log::info!("CLIENT --> PROXY --> {host}");
log::info!("CLIENT --> PROXY --> TARGET");
let mut outbound = self.io(self.builder.connect(host))?;
// forward intercepted request
@ -130,7 +130,9 @@ impl Agent {
outbound.flush().await?;
log::info!("CLIENT <-> PROXY (connection established)");
self.tunnel(conn, outbound).await?;
self.tunnel(conn, &mut outbound).await?;
outbound.shutdown(std::net::Shutdown::Both)?;
return Ok(());
}
@ -150,11 +152,13 @@ impl Agent {
log::debug!("CLIENT --> (intercepted) --> TARGET");
}
self.tunnel(conn, target).await?;
self.tunnel(conn, &mut target).await?;
target.shutdown(std::net::Shutdown::Both)?;
return Ok(());
}
async fn tunnel<A, B>(&self, inbound: &mut A, mut outbound: B) -> Result<()>
async fn tunnel<A, B>(&self, inbound: &mut A, outbound: &mut B) -> Result<()>
where
A: Read + Write + Send + Sync + Unpin + 'static,
B: Read + Write + Send + Sync + Unpin + 'static,

View File

@ -10,7 +10,7 @@ pub enum ConnectionBuilder {
}
impl ConnectionBuilder {
pub async fn connect(&self, target: &str) -> Result<Connection, std::io::Error> {
pub async fn connect(&self, target: &str) -> std::io::Result<Connection> {
let conn = match self {
Self::Http(addr) => {
Connection::new(TcpStream::connect(addr)?)
@ -36,9 +36,13 @@ impl Connection {
Self { inner: Async::new(conn).unwrap() }
}
pub fn into_inner(self) -> Result<TcpStream, std::io::Error> {
pub fn into_inner(self) -> std::io::Result<TcpStream> {
self.inner.into_inner()
}
pub fn shutdown(self, how: std::net::Shutdown) -> std::io::Result<()> {
self.into_inner()?.shutdown(how)
}
}
impl async_std::io::Read for Connection {

View File

@ -1,11 +1,12 @@
use crate::error::Error;
macro_rules! impl_http {
(pub struct $name:ident($inner:ident) { $($variant:ident = $value:literal,)* }) => {
#[derive(PartialEq)]
#[derive(PartialEq, Eq)]
pub struct $name($inner);
#[derive(PartialEq)]
#[derive(PartialEq, Eq)]
#[allow(non_camel_case_types)]
enum $inner{ $($variant,)* }
@ -99,15 +100,13 @@ impl Response {
pub fn from_err(err: Error) -> Self {
use async_std::io::ErrorKind::TimedOut;
match err {
Error::BadRequest(x) =>
Self::make(400, x),
Error::BadRequest(_) =>
Self::make(400, "Bad Request"),
Error::Io(e) if e.kind() == TimedOut =>
Self::make(408, "Timeout"),
Error::Io(_) =>
Self::make(503, "Unavailable"),
Error::Parse(_) =>
Self::make(500, "Internal Server Error"),
Error::Utf8(_) =>
Error::Parse(_) | Error::Utf8(_) =>
Self::make(500, "Internal Server Error"),
}
}

View File

@ -1,18 +1,11 @@
use std::sync::Arc;
use async_std::io::WriteExt;
use crate::http;
pub struct Server {
addrs: std::net::SocketAddr,
}
impl Server {
pub async fn run(self, agent: crate::agent::Agent) -> Result<(), std::io::Error> {
pub async fn run(self, agent: crate::agent::Agent) -> std::io::Result<()> {
let listener = async_std::net::TcpListener::bind(self.addrs).await?;
let agent = Arc::new(agent);
let agent = std::sync::Arc::new(agent);
log::info!("IMPOSTER/0.1 HTTP SERVER");
log::info!("Server listening at {}", self.addrs);
@ -26,7 +19,9 @@ impl Server {
async_std::task::spawn(async move {
if let Err(e) = agent.handle(&mut inbound).await {
log::error!("Agent: {e}");
let resp = http::Response::from_err(e);
let resp = crate::http::Response::from_err(e);
use async_std::io::WriteExt;
inbound.write(resp.to_string().as_bytes()).await.unwrap();
inbound.flush().await.unwrap();