diff --git a/src/agent.rs b/src/agent.rs
index d5cf560..0733113 100644
--- a/src/agent.rs
+++ b/src/agent.rs
@@ -122,7 +122,7 @@ impl Agent {
log::info!("CLIENT --> {host}");
if self.check_request_blocked(&request.path) {
- log::info!("CLIENT --> PROXY --> {host}");
+ log::info!("CLIENT --> PROXY --> TARGET");
let mut outbound = self.io(self.builder.connect(host))?;
// forward intercepted request
@@ -130,7 +130,9 @@ impl Agent {
outbound.flush().await?;
log::info!("CLIENT <-> PROXY (connection established)");
- self.tunnel(conn, outbound).await?;
+ self.tunnel(conn, &mut outbound).await?;
+
+ outbound.shutdown(std::net::Shutdown::Both)?;
return Ok(());
}
@@ -150,11 +152,13 @@ impl Agent {
log::debug!("CLIENT --> (intercepted) --> TARGET");
}
- self.tunnel(conn, target).await?;
+ self.tunnel(conn, &mut target).await?;
+ target.shutdown(std::net::Shutdown::Both)?;
+
return Ok(());
}
- async fn tunnel(&self, inbound: &mut A, mut outbound: B) -> Result<()>
+ async fn tunnel(&self, inbound: &mut A, outbound: &mut B) -> Result<()>
where
A: Read + Write + Send + Sync + Unpin + 'static,
B: Read + Write + Send + Sync + Unpin + 'static,
diff --git a/src/connection.rs b/src/connection.rs
index 50e9da1..580ee27 100644
--- a/src/connection.rs
+++ b/src/connection.rs
@@ -10,7 +10,7 @@ pub enum ConnectionBuilder {
}
impl ConnectionBuilder {
- pub async fn connect(&self, target: &str) -> Result {
+ pub async fn connect(&self, target: &str) -> std::io::Result {
let conn = match self {
Self::Http(addr) => {
Connection::new(TcpStream::connect(addr)?)
@@ -36,9 +36,13 @@ impl Connection {
Self { inner: Async::new(conn).unwrap() }
}
- pub fn into_inner(self) -> Result {
+ pub fn into_inner(self) -> std::io::Result {
self.inner.into_inner()
}
+
+ pub fn shutdown(self, how: std::net::Shutdown) -> std::io::Result<()> {
+ self.into_inner()?.shutdown(how)
+ }
}
impl async_std::io::Read for Connection {
diff --git a/src/http.rs b/src/http.rs
index e86798c..377c502 100644
--- a/src/http.rs
+++ b/src/http.rs
@@ -1,11 +1,12 @@
use crate::error::Error;
+
macro_rules! impl_http {
(pub struct $name:ident($inner:ident) { $($variant:ident = $value:literal,)* }) => {
- #[derive(PartialEq)]
+ #[derive(PartialEq, Eq)]
pub struct $name($inner);
- #[derive(PartialEq)]
+ #[derive(PartialEq, Eq)]
#[allow(non_camel_case_types)]
enum $inner{ $($variant,)* }
@@ -99,15 +100,13 @@ impl Response {
pub fn from_err(err: Error) -> Self {
use async_std::io::ErrorKind::TimedOut;
match err {
- Error::BadRequest(x) =>
- Self::make(400, x),
+ Error::BadRequest(_) =>
+ Self::make(400, "Bad Request"),
Error::Io(e) if e.kind() == TimedOut =>
Self::make(408, "Timeout"),
Error::Io(_) =>
Self::make(503, "Unavailable"),
- Error::Parse(_) =>
- Self::make(500, "Internal Server Error"),
- Error::Utf8(_) =>
+ Error::Parse(_) | Error::Utf8(_) =>
Self::make(500, "Internal Server Error"),
}
}
diff --git a/src/server.rs b/src/server.rs
index 9069c98..a59d245 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -1,18 +1,11 @@
-use std::sync::Arc;
-
-use async_std::io::WriteExt;
-
-use crate::http;
-
-
pub struct Server {
addrs: std::net::SocketAddr,
}
impl Server {
- pub async fn run(self, agent: crate::agent::Agent) -> Result<(), std::io::Error> {
+ pub async fn run(self, agent: crate::agent::Agent) -> std::io::Result<()> {
let listener = async_std::net::TcpListener::bind(self.addrs).await?;
- let agent = Arc::new(agent);
+ let agent = std::sync::Arc::new(agent);
log::info!("IMPOSTER/0.1 HTTP SERVER");
log::info!("Server listening at {}", self.addrs);
@@ -26,7 +19,9 @@ impl Server {
async_std::task::spawn(async move {
if let Err(e) = agent.handle(&mut inbound).await {
log::error!("Agent: {e}");
- let resp = http::Response::from_err(e);
+
+ let resp = crate::http::Response::from_err(e);
+ use async_std::io::WriteExt;
inbound.write(resp.to_string().as_bytes()).await.unwrap();
inbound.flush().await.unwrap();