fix: target shutdown
This commit is contained in:
parent
8ebba7271c
commit
a501a30f29
12
src/agent.rs
12
src/agent.rs
@ -122,7 +122,7 @@ impl Agent {
|
|||||||
log::info!("CLIENT --> {host}");
|
log::info!("CLIENT --> {host}");
|
||||||
|
|
||||||
if self.check_request_blocked(&request.path) {
|
if self.check_request_blocked(&request.path) {
|
||||||
log::info!("CLIENT --> PROXY --> {host}");
|
log::info!("CLIENT --> PROXY --> TARGET");
|
||||||
let mut outbound = self.io(self.builder.connect(host))?;
|
let mut outbound = self.io(self.builder.connect(host))?;
|
||||||
|
|
||||||
// forward intercepted request
|
// forward intercepted request
|
||||||
@ -130,7 +130,9 @@ impl Agent {
|
|||||||
outbound.flush().await?;
|
outbound.flush().await?;
|
||||||
|
|
||||||
log::info!("CLIENT <-> PROXY (connection established)");
|
log::info!("CLIENT <-> PROXY (connection established)");
|
||||||
self.tunnel(conn, outbound).await?;
|
self.tunnel(conn, &mut outbound).await?;
|
||||||
|
|
||||||
|
outbound.shutdown(std::net::Shutdown::Both)?;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -150,11 +152,13 @@ impl Agent {
|
|||||||
log::debug!("CLIENT --> (intercepted) --> TARGET");
|
log::debug!("CLIENT --> (intercepted) --> TARGET");
|
||||||
}
|
}
|
||||||
|
|
||||||
self.tunnel(conn, target).await?;
|
self.tunnel(conn, &mut target).await?;
|
||||||
|
target.shutdown(std::net::Shutdown::Both)?;
|
||||||
|
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn tunnel<A, B>(&self, inbound: &mut A, mut outbound: B) -> Result<()>
|
async fn tunnel<A, B>(&self, inbound: &mut A, outbound: &mut B) -> Result<()>
|
||||||
where
|
where
|
||||||
A: Read + Write + Send + Sync + Unpin + 'static,
|
A: Read + Write + Send + Sync + Unpin + 'static,
|
||||||
B: Read + Write + Send + Sync + Unpin + 'static,
|
B: Read + Write + Send + Sync + Unpin + 'static,
|
||||||
|
|||||||
@ -10,7 +10,7 @@ pub enum ConnectionBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ConnectionBuilder {
|
impl ConnectionBuilder {
|
||||||
pub async fn connect(&self, target: &str) -> Result<Connection, std::io::Error> {
|
pub async fn connect(&self, target: &str) -> std::io::Result<Connection> {
|
||||||
let conn = match self {
|
let conn = match self {
|
||||||
Self::Http(addr) => {
|
Self::Http(addr) => {
|
||||||
Connection::new(TcpStream::connect(addr)?)
|
Connection::new(TcpStream::connect(addr)?)
|
||||||
@ -36,9 +36,13 @@ impl Connection {
|
|||||||
Self { inner: Async::new(conn).unwrap() }
|
Self { inner: Async::new(conn).unwrap() }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn into_inner(self) -> Result<TcpStream, std::io::Error> {
|
pub fn into_inner(self) -> std::io::Result<TcpStream> {
|
||||||
self.inner.into_inner()
|
self.inner.into_inner()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn shutdown(self, how: std::net::Shutdown) -> std::io::Result<()> {
|
||||||
|
self.into_inner()?.shutdown(how)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl async_std::io::Read for Connection {
|
impl async_std::io::Read for Connection {
|
||||||
|
|||||||
13
src/http.rs
13
src/http.rs
@ -1,11 +1,12 @@
|
|||||||
use crate::error::Error;
|
use crate::error::Error;
|
||||||
|
|
||||||
|
|
||||||
macro_rules! impl_http {
|
macro_rules! impl_http {
|
||||||
(pub struct $name:ident($inner:ident) { $($variant:ident = $value:literal,)* }) => {
|
(pub struct $name:ident($inner:ident) { $($variant:ident = $value:literal,)* }) => {
|
||||||
#[derive(PartialEq)]
|
#[derive(PartialEq, Eq)]
|
||||||
pub struct $name($inner);
|
pub struct $name($inner);
|
||||||
|
|
||||||
#[derive(PartialEq)]
|
#[derive(PartialEq, Eq)]
|
||||||
#[allow(non_camel_case_types)]
|
#[allow(non_camel_case_types)]
|
||||||
enum $inner{ $($variant,)* }
|
enum $inner{ $($variant,)* }
|
||||||
|
|
||||||
@ -99,15 +100,13 @@ impl Response {
|
|||||||
pub fn from_err(err: Error) -> Self {
|
pub fn from_err(err: Error) -> Self {
|
||||||
use async_std::io::ErrorKind::TimedOut;
|
use async_std::io::ErrorKind::TimedOut;
|
||||||
match err {
|
match err {
|
||||||
Error::BadRequest(x) =>
|
Error::BadRequest(_) =>
|
||||||
Self::make(400, x),
|
Self::make(400, "Bad Request"),
|
||||||
Error::Io(e) if e.kind() == TimedOut =>
|
Error::Io(e) if e.kind() == TimedOut =>
|
||||||
Self::make(408, "Timeout"),
|
Self::make(408, "Timeout"),
|
||||||
Error::Io(_) =>
|
Error::Io(_) =>
|
||||||
Self::make(503, "Unavailable"),
|
Self::make(503, "Unavailable"),
|
||||||
Error::Parse(_) =>
|
Error::Parse(_) | Error::Utf8(_) =>
|
||||||
Self::make(500, "Internal Server Error"),
|
|
||||||
Error::Utf8(_) =>
|
|
||||||
Self::make(500, "Internal Server Error"),
|
Self::make(500, "Internal Server Error"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,18 +1,11 @@
|
|||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use async_std::io::WriteExt;
|
|
||||||
|
|
||||||
use crate::http;
|
|
||||||
|
|
||||||
|
|
||||||
pub struct Server {
|
pub struct Server {
|
||||||
addrs: std::net::SocketAddr,
|
addrs: std::net::SocketAddr,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Server {
|
impl Server {
|
||||||
pub async fn run(self, agent: crate::agent::Agent) -> Result<(), std::io::Error> {
|
pub async fn run(self, agent: crate::agent::Agent) -> std::io::Result<()> {
|
||||||
let listener = async_std::net::TcpListener::bind(self.addrs).await?;
|
let listener = async_std::net::TcpListener::bind(self.addrs).await?;
|
||||||
let agent = Arc::new(agent);
|
let agent = std::sync::Arc::new(agent);
|
||||||
|
|
||||||
log::info!("IMPOSTER/0.1 HTTP SERVER");
|
log::info!("IMPOSTER/0.1 HTTP SERVER");
|
||||||
log::info!("Server listening at {}", self.addrs);
|
log::info!("Server listening at {}", self.addrs);
|
||||||
@ -26,7 +19,9 @@ impl Server {
|
|||||||
async_std::task::spawn(async move {
|
async_std::task::spawn(async move {
|
||||||
if let Err(e) = agent.handle(&mut inbound).await {
|
if let Err(e) = agent.handle(&mut inbound).await {
|
||||||
log::error!("Agent: {e}");
|
log::error!("Agent: {e}");
|
||||||
let resp = http::Response::from_err(e);
|
|
||||||
|
let resp = crate::http::Response::from_err(e);
|
||||||
|
use async_std::io::WriteExt;
|
||||||
|
|
||||||
inbound.write(resp.to_string().as_bytes()).await.unwrap();
|
inbound.write(resp.to_string().as_bytes()).await.unwrap();
|
||||||
inbound.flush().await.unwrap();
|
inbound.flush().await.unwrap();
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user