handle concurrently (again)
This commit is contained in:
parent
4f09baaa1c
commit
2699f702b1
40
src/agent.rs
40
src/agent.rs
@ -25,35 +25,32 @@ impl Agent {
|
||||
Self { builder, config, engine }
|
||||
}
|
||||
|
||||
pub fn spawn(self: &std::sync::Arc<Self>, mut conn: TcpStream) {
|
||||
pub async fn handle(&self, mut conn: TcpStream) {
|
||||
let req = match self.read(&mut conn) {
|
||||
Ok(x) => x,
|
||||
Err(e) => return log::error!("Read Error: {e}")
|
||||
Err(e) => return log::error!("Read: {e}")
|
||||
};
|
||||
|
||||
let mat = self.engine.check_request_blocked(&req.path);
|
||||
let agent = self.clone();
|
||||
|
||||
log::info!("CLIENT --> {} ({})",
|
||||
req.host, mat.then_some("tunnel").unwrap_or("direct"));
|
||||
|
||||
async_std::task::spawn(async move {
|
||||
let res = if mat {
|
||||
agent.remote(req, &mut conn).await
|
||||
} else {
|
||||
agent.direct(req, &mut conn).await
|
||||
};
|
||||
let res = if mat {
|
||||
self.remote(req, &mut conn).await
|
||||
} else {
|
||||
self.direct(req, &mut conn).await
|
||||
};
|
||||
|
||||
if let Err(e) = res {
|
||||
log::error!("Agent: {e}");
|
||||
let resp = http::Response::from_err(e);
|
||||
if let Err(e) = res {
|
||||
log::error!("Agent: {e}");
|
||||
let resp = http::Response::from_err(e);
|
||||
|
||||
conn.write(resp.to_string().as_bytes()).await.unwrap();
|
||||
conn.flush().await.unwrap();
|
||||
}
|
||||
conn.write(resp.to_string().as_bytes()).await.unwrap();
|
||||
conn.flush().await.unwrap();
|
||||
}
|
||||
|
||||
let _ = conn.shutdown(std::net::Shutdown::Both);
|
||||
});
|
||||
let _ = conn.shutdown(std::net::Shutdown::Both);
|
||||
}
|
||||
|
||||
async fn remote<S>(&self, req: http::Request, inbound: &mut S) -> Result<()>
|
||||
@ -61,7 +58,7 @@ impl Agent {
|
||||
S: Read + Write + Send + Sync + Unpin + 'static
|
||||
{
|
||||
let mut outbound = self.io(self.builder.connect(&req.host))?;
|
||||
log::info!("CLIENT --> PROXY --> TARGET");
|
||||
log::info!("CLIENT --> PROXY (pending)");
|
||||
|
||||
// forward intercepted request
|
||||
outbound.write_all(req.as_bytes()).await?;
|
||||
@ -79,14 +76,14 @@ impl Agent {
|
||||
S: Read + Write + Send + Sync + Unpin + 'static
|
||||
{
|
||||
let mut outbound = self.io(TcpStream::connect(&req.host))?;
|
||||
log::info!("CLIENT <=> TARGET (direct)");
|
||||
log::info!("CLIENT --> TARGET (pending)");
|
||||
|
||||
if let http::Method::CONNECT = req.method {
|
||||
let resp = http::Response::default();
|
||||
// respond to client with code 200 and an EMPTY body
|
||||
inbound.write_all(resp.to_string().as_bytes()).await?;
|
||||
inbound.flush().await?;
|
||||
log::debug!("Received CONNECT (200 OK)");
|
||||
log::debug!("Agent: received CONNECT (200 OK)");
|
||||
} else {
|
||||
// forward intercepted request
|
||||
outbound.write_all(req.as_bytes()).await?;
|
||||
@ -94,6 +91,7 @@ impl Agent {
|
||||
log::debug!("CLIENT --> (intercepted) --> TARGET");
|
||||
}
|
||||
|
||||
log::info!("CLIENT <=> TARGET (connection established)");
|
||||
self.tunnel(inbound, &mut outbound).await;
|
||||
let _ = outbound.shutdown(std::net::Shutdown::Both);
|
||||
|
||||
@ -148,7 +146,7 @@ impl Agent {
|
||||
Some("80") | _ => "http",
|
||||
};
|
||||
|
||||
path = format!("{}://{}", scheme, path);
|
||||
path = format!("{scheme}://{path}");
|
||||
}
|
||||
|
||||
let version = match request.version {
|
||||
|
||||
@ -1,23 +1,24 @@
|
||||
use std::sync::Mutex;
|
||||
use adblock::request::Request;
|
||||
|
||||
|
||||
pub struct Engine {
|
||||
inner: Option<adblock::Engine>
|
||||
inner: Option<Mutex<adblock::Engine>>
|
||||
}
|
||||
|
||||
impl Engine {
|
||||
pub fn new(inner: Option<adblock::Engine>) -> Self {
|
||||
Self { inner }
|
||||
Self { inner: inner.map(|x| x.into()) }
|
||||
}
|
||||
|
||||
pub fn check_request_blocked(&self, url: &str) -> bool {
|
||||
let inner = match &self.inner {
|
||||
Some(x) => x,
|
||||
Some(x) => x.lock().unwrap(),
|
||||
None => return true // always use tunnel when without rules
|
||||
};
|
||||
|
||||
Request::new(url, url, "fetch")
|
||||
.map(|x| inner.check_network_request(&x).matched)
|
||||
.map(|req| inner.check_network_request(&req).matched)
|
||||
.unwrap_or(true)
|
||||
}
|
||||
}
|
||||
|
||||
@ -69,9 +69,9 @@ impl ServerBuilder {
|
||||
}
|
||||
|
||||
let engine = Engine::new(ruleset);
|
||||
let agent = Agent::new(builder, config, engine).into();
|
||||
let agent = Agent::new(builder, config, engine);
|
||||
|
||||
Ok(Server { agent })
|
||||
Ok(Server { agent: agent.into() })
|
||||
}
|
||||
|
||||
fn build_rules(&self, mut text: String) -> BuildResult<adblock::Engine> {
|
||||
@ -117,9 +117,13 @@ impl Server {
|
||||
let listener = TcpListener::bind(addrs).await?;
|
||||
loop {
|
||||
let (inbound, addr) = listener.accept().await?;
|
||||
let agent = self.agent.clone();
|
||||
|
||||
log::info!("*** Incoming connection from {addr}");
|
||||
self.agent.spawn(inbound);
|
||||
|
||||
async_std::task::spawn(async move {
|
||||
agent.handle(inbound).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user