fix: adblock::Engine BorrowMut panic

This commit is contained in:
Break27 2024-06-08 23:34:43 +08:00
parent a501a30f29
commit 4f09baaa1c
8 changed files with 235 additions and 209 deletions

1
Cargo.lock generated
View File

@ -1290,6 +1290,7 @@ dependencies = [
"log", "log",
"native-tls", "native-tls",
"once_cell", "once_cell",
"socks",
"url", "url",
] ]

View File

@ -16,5 +16,5 @@ log = "0.4.21"
native-tls = "0.2.12" native-tls = "0.2.12"
socks = "0.3.4" socks = "0.3.4"
tokio = { version = "1.38.0", features = ["io-util"] } tokio = { version = "1.38.0", features = ["io-util"] }
ureq = { version = "2.9.7", default-features = false, features = ["native-tls"] } ureq = { version = "2.9.7", default-features = false, features = ["native-tls", "socks-proxy"] }
url = "2.5.0" url = "2.5.0"

View File

@ -1,164 +1,106 @@
use async_std::io::{Read, Write, ReadExt, WriteExt}; use async_std::io::{Read, ReadExt, Write, WriteExt};
use async_std::net::TcpStream; use async_std::net::TcpStream;
use crate::connection::ConnectionBuilder; use crate::connection::ConnectionBuilder as Builder;
use crate::error::{Result, Error, BuildError, BuildResult}; use crate::engine::Engine;
use crate::error::{Error, Result};
use crate::http; use crate::http;
pub struct AgentBuilder {
filter_url: Option<url::Url>,
buf_size: Option<usize>,
timeout: Option<u64>,
decode: bool
}
impl AgentBuilder {
pub fn new() -> Self {
Self {
filter_url: None,
buf_size: None,
timeout: None,
decode: true
}
}
pub fn filter(mut self, url: url::Url) -> Self {
let _ = self.filter_url.insert(url);
self
}
pub fn timeout(mut self, timeout: u64) -> Self {
let _ = self.timeout.insert(timeout);
self
}
pub fn buffer(mut self, size: usize) -> Self {
let _ = self.buf_size.insert(size);
self
}
pub fn decode(mut self, decode: bool) -> Self {
self.decode = decode;
self
}
pub fn build(self, remote: url::Url) -> BuildResult<Agent> {
use ConnectionBuilder as CB;
let builder = match remote.scheme() {
"http" | "" => CB::Http(remote.authority().to_string()),
"socks" | "socks5" => CB::Socks5(remote.authority().to_string()),
other => return Err(BuildError::Unsupported(other.to_string()))
};
let mut ruleset = None;
let time = self.timeout.unwrap_or(u64::MAX);
let config = AgentConfig {
buf_size: self.buf_size.unwrap_or(1024),
timeout: std::time::Duration::from_secs(time)
};
if let Some(ref url) = self.filter_url {
log::info!(target: "builder", "Try downloading rule list from '{url}'");
let https = native_tls::TlsConnector::new()?;
let client = ureq::AgentBuilder::new()
.proxy(ureq::Proxy::new(remote)?)
.tls_connector(https.into())
.timeout(config.timeout)
.build();
let resp = client.get(url.as_str()).call()?;
let text = resp.into_string()?;
let len = text.len() as f32 / 1000f32;
log::info!(target: "builder", "Successfully downloaded data ({len}/kB transmitted)");
ruleset = Some(self.build_rules(text)?);
}
Ok(Agent { builder, ruleset, config })
}
fn build_rules(&self, mut text: String) -> BuildResult<adblock::Engine> {
if self.decode {
log::info!(target: "builder", "Try decoding raw textual data (base64 encoded)");
use base64::{Engine, engine::general_purpose::STANDARD};
let line = text.split_whitespace().collect::<String>();
let decoded = STANDARD.decode(line)?;
text = String::from_utf8(decoded)?;
}
let mut filters = adblock::FilterSet::new(false);
let opts = adblock::lists::ParseOptions::default();
filters.add_filter_list(&text, opts);
log::info!(target: "builder", "Rule data parsed successfully");
Ok(adblock::Engine::from_filter_set(filters, true))
}
}
pub struct AgentConfig { pub struct AgentConfig {
pub buf_size: usize, pub bufsize: usize,
pub timeout: std::time::Duration, pub timeout: std::time::Duration,
} }
pub struct Agent { pub struct Agent {
ruleset: Option<adblock::Engine>, builder: Builder,
builder: ConnectionBuilder,
config: AgentConfig, config: AgentConfig,
engine: Engine
} }
unsafe impl Send for Agent {} unsafe impl Send for Agent {}
unsafe impl Sync for Agent {} unsafe impl Sync for Agent {}
impl Agent { impl Agent {
pub async fn handle<S>(&self, conn: &mut S) -> Result<()> pub fn new(builder: Builder, config: AgentConfig, engine: Engine) -> Self {
Self { builder, config, engine }
}
pub fn spawn(self: &std::sync::Arc<Self>, mut conn: TcpStream) {
let req = match self.read(&mut conn) {
Ok(x) => x,
Err(e) => return log::error!("Read Error: {e}")
};
let mat = self.engine.check_request_blocked(&req.path);
let agent = self.clone();
log::info!("CLIENT --> {} ({})",
req.host, mat.then_some("tunnel").unwrap_or("direct"));
async_std::task::spawn(async move {
let res = if mat {
agent.remote(req, &mut conn).await
} else {
agent.direct(req, &mut conn).await
};
if let Err(e) = res {
log::error!("Agent: {e}");
let resp = http::Response::from_err(e);
conn.write(resp.to_string().as_bytes()).await.unwrap();
conn.flush().await.unwrap();
}
let _ = conn.shutdown(std::net::Shutdown::Both);
});
}
async fn remote<S>(&self, req: http::Request, inbound: &mut S) -> Result<()>
where where
S: Read + Write + Send + Sync + Unpin + 'static S: Read + Write + Send + Sync + Unpin + 'static
{ {
let request = self.read(conn)?; let mut outbound = self.io(self.builder.connect(&req.host))?;
let host = request.host();
log::info!("CLIENT --> {host}");
if self.check_request_blocked(&request.path) {
log::info!("CLIENT --> PROXY --> TARGET"); log::info!("CLIENT --> PROXY --> TARGET");
let mut outbound = self.io(self.builder.connect(host))?;
// forward intercepted request // forward intercepted request
outbound.write_all(request.as_bytes()).await?; outbound.write_all(req.as_bytes()).await?;
outbound.flush().await?; outbound.flush().await?;
log::info!("CLIENT <-> PROXY (connection established)"); log::info!("CLIENT <=> PROXY (connection established)");
self.tunnel(conn, &mut outbound).await?; self.tunnel(inbound, &mut outbound).await;
outbound.shutdown(std::net::Shutdown::Both)?; let _ = outbound.shutdown(std::net::Shutdown::Both);
return Ok(()); return Ok(());
} }
let mut target = self.io(TcpStream::connect(host))?; async fn direct<S>(&self, req: http::Request, inbound: &mut S) -> Result<()>
log::info!("CLIENT <-> TARGET (direct)"); where
S: Read + Write + Send + Sync + Unpin + 'static
{
let mut outbound = self.io(TcpStream::connect(&req.host))?;
log::info!("CLIENT <=> TARGET (direct)");
if let http::Method::CONNECT = request.method { if let http::Method::CONNECT = req.method {
let resp = http::Response::default(); let resp = http::Response::default();
// send response to client with code 200 and an EMPTY body // respond to client with code 200 and an EMPTY body
conn.write_all(resp.to_string().as_bytes()).await?; inbound.write_all(resp.to_string().as_bytes()).await?;
conn.flush().await?; inbound.flush().await?;
log::debug!("Received CONNECT (200 OK)"); log::debug!("Received CONNECT (200 OK)");
} else { } else {
// forward intercepted request // forward intercepted request
target.write_all(request.as_bytes()).await?; outbound.write_all(req.as_bytes()).await?;
target.flush().await?; outbound.flush().await?;
log::debug!("CLIENT --> (intercepted) --> TARGET"); log::debug!("CLIENT --> (intercepted) --> TARGET");
} }
self.tunnel(conn, &mut target).await?; self.tunnel(inbound, &mut outbound).await;
target.shutdown(std::net::Shutdown::Both)?; let _ = outbound.shutdown(std::net::Shutdown::Both);
return Ok(()); return Ok(());
} }
async fn tunnel<A, B>(&self, inbound: &mut A, outbound: &mut B) -> Result<()> async fn tunnel<A, B>(&self, inbound: &mut A, outbound: &mut B)
where where
A: Read + Write + Send + Sync + Unpin + 'static, A: Read + Write + Send + Sync + Unpin + 'static,
B: Read + Write + Send + Sync + Unpin + 'static, B: Read + Write + Send + Sync + Unpin + 'static,
@ -169,27 +111,22 @@ impl Agent {
if let Err(e) = copy( if let Err(e) = copy(
&mut outbound.compat_mut(), &mut inbound.compat_mut()).await &mut outbound.compat_mut(), &mut inbound.compat_mut()).await
{ {
log::warn!("{}", e); log::warn!("{e}");
}
} }
Ok(()) fn read(&self, conn: &mut TcpStream) -> Result<http::Request> {
}
fn read<S>(&self, conn: &mut S) -> Result<http::Request>
where
S: Read + Write + Send + Unpin + 'static
{
let mut headers = [httparse::EMPTY_HEADER; 64]; let mut headers = [httparse::EMPTY_HEADER; 64];
let mut request = httparse::Request::new(&mut headers); let mut request = httparse::Request::new(&mut headers);
let mut buf = vec![0; self.config.buf_size]; let mut buf = vec![0; self.config.bufsize];
self.io(conn.read(&mut buf))?; self.io(conn.read(&mut buf))?;
let offset = request.parse(&buf)?.unwrap(); let offset = request.parse(&buf)?.unwrap();
let payload = buf[..offset].to_vec(); let payload = buf[..offset].to_vec();
let method = match request.method { let method = match request.method {
Some(x) => x.parse::<crate::http::Method>().unwrap(), Some(x) => x.parse::<http::Method>().unwrap(),
None => return Err(Error::BadRequest("METHOD".to_string())) None => return Err(Error::BadRequest("METHOD".to_string()))
}; };
@ -234,7 +171,7 @@ impl Agent {
host += ":80"; host += ":80";
} }
let request = crate::http::Request { let request = http::Request {
method, method,
path, path,
version, version,
@ -245,22 +182,6 @@ impl Agent {
Ok(request) Ok(request)
} }
fn check_request_blocked(&self, url: &str) -> bool {
let attempt: _ = adblock::request::Request::new(
url, url, "fetch"
);
let req = match attempt {
Ok(x) => x,
Err(_) => return true
};
match &self.ruleset {
Some(x) => x.check_network_request(&req).matched,
None => true // always use tunnel when without rules
}
}
fn io<T, F>(&self, f: F) -> Result<T> fn io<T, F>(&self, f: F) -> Result<T>
where where
F: std::future::Future<Output=std::result::Result<T, std::io::Error>>, F: std::future::Future<Output=std::result::Result<T, std::io::Error>>,

View File

@ -31,8 +31,7 @@ unsafe impl Send for Connection {}
unsafe impl Sync for Connection {} unsafe impl Sync for Connection {}
impl Connection { impl Connection {
pub fn new(conn: TcpStream) -> Self pub fn new(conn: TcpStream) -> Self {
{
Self { inner: Async::new(conn).unwrap() } Self { inner: Async::new(conn).unwrap() }
} }

23
src/engine.rs Normal file
View File

@ -0,0 +1,23 @@
use adblock::request::Request;
pub struct Engine {
inner: Option<adblock::Engine>
}
impl Engine {
pub fn new(inner: Option<adblock::Engine>) -> Self {
Self { inner }
}
pub fn check_request_blocked(&self, url: &str) -> bool {
let inner = match &self.inner {
Some(x) => x,
None => return true // always use tunnel when without rules
};
Request::new(url, url, "fetch")
.map(|x| inner.check_network_request(&x).matched)
.unwrap_or(true)
}
}

View File

@ -70,10 +70,6 @@ impl Request {
pub fn as_bytes(&self) -> &[u8] { pub fn as_bytes(&self) -> &[u8] {
&self.payload &self.payload
} }
pub fn host(&self) -> &str {
&self.host
}
} }
pub struct Response { pub struct Response {

View File

@ -1,3 +1,4 @@
use std::ops::Not;
use clap::Parser; use clap::Parser;
pub mod http; pub mod http;
@ -5,23 +6,33 @@ pub mod agent;
pub mod error; pub mod error;
pub mod connection; pub mod connection;
pub mod server; pub mod server;
pub mod engine;
#[derive(Parser)] #[derive(Parser)]
#[command(version, about, long_about = None)] #[command(version, about, long_about = None)]
struct Cli { struct Cli {
/// Server port number
#[arg(short, long)] #[arg(short, long)]
port: Option<u16>, port: Option<u16>,
/// Rule data URL
#[arg(short, long, value_name = "URL")] #[arg(short, long, value_name = "URL")]
filter_url: Option<url::Url>, filter_url: Option<url::Url>,
/// Buffer size
#[arg(long, value_name = "SIZE")] #[arg(long, value_name = "SIZE")]
buf_size: Option<usize>, buf_size: Option<usize>,
/// Seconds before server return a timeout response (408)
#[arg(short, long, value_name = "SEC")] #[arg(short, long, value_name = "SEC")]
timeout: Option<u64>, timeout: Option<u64>,
/// Load downloaded rule data without decoding (base64)
#[arg(long)]
plain_text: bool,
/// Proxy server URL
#[arg(value_name = "URL")] #[arg(value_name = "URL")]
remote: url::Url remote: url::Url
} }
@ -32,32 +43,29 @@ const PORT: u16 = 9000;
const BUF_SIZE:usize = 1024; const BUF_SIZE:usize = 1024;
const TIMEOUT: u64 = 15; const TIMEOUT: u64 = 15;
async fn try_launch(agent: Result<agent::Agent, error::BuildError>, async fn launch() -> Result<(), Box<dyn std::error::Error>> {
server: server::Server) -> Result<(), Box<dyn std::error::Error>>
{
Ok(server.run(agent?).await?)
}
fn main() {
if std::env::var("RUST_LOG").ok().is_none() { if std::env::var("RUST_LOG").ok().is_none() {
std::env::set_var("RUST_LOG", "info"); unsafe { std::env::set_var("RUST_LOG", "info") }
} }
let cli = Cli::parse(); let cli = Cli::parse();
env_logger::init(); env_logger::init();
let port = cli.port.unwrap_or(PORT); let port = cli.port.unwrap_or(PORT);
let server = server::Server::bind((LOCALHOST, port)); let server: _ = server::Server::builder()
let agent = agent::AgentBuilder::new()
.buffer(cli.buf_size.unwrap_or(BUF_SIZE)) .buffer(cli.buf_size.unwrap_or(BUF_SIZE))
.filter(cli.filter_url.unwrap_or(URL.parse().unwrap())) .filter(cli.filter_url.unwrap_or(URL.parse().unwrap()))
.timeout(cli.timeout.unwrap_or(TIMEOUT)) .timeout(cli.timeout.unwrap_or(TIMEOUT))
.build(cli.remote); .encoded(cli.plain_text.not())
.build(cli.remote)?;
if let Err(e) = async_std::task::block_on( Ok(server.bind((LOCALHOST, port)).await?)
try_launch(agent, server)) }
fn main() {
if let Err(e) =
async_std::task::block_on(launch())
{ {
eprintln!("Error: {}", e); eprintln!("Error: {e}");
} }
} }

View File

@ -1,47 +1,125 @@
use async_std::net::TcpListener;
use crate::agent::{Agent, AgentConfig};
use crate::connection::ConnectionBuilder;
use crate::error::{BuildError, BuildResult};
use crate::engine::Engine;
#[derive(Default)]
pub struct ServerBuilder {
filter_url: Option<url::Url>,
buf_size: Option<usize>,
timeout: Option<u64>,
encoded: bool
}
impl ServerBuilder {
pub fn filter(mut self, url: url::Url) -> Self {
let _ = self.filter_url.insert(url);
self
}
pub fn buffer(mut self, size: usize) -> Self {
let _ = self.buf_size.insert(size);
self
}
pub fn timeout(mut self, timeout: u64) -> Self {
let _ = self.timeout.insert(timeout);
self
}
pub fn encoded(mut self, encoded: bool) -> Self {
self.encoded = encoded;
self
}
pub fn build(self, remote: url::Url) -> BuildResult<Server> {
use ConnectionBuilder as CB;
let builder = match remote.scheme() {
"http" | "" => CB::Http(remote.authority().to_string()),
"socks" | "socks5" => CB::Socks5(remote.authority().to_string()),
other => return Err(BuildError::Unsupported(other.to_string()))
};
let mut ruleset = None;
let time = self.timeout.unwrap_or(u64::MAX);
let config = AgentConfig {
bufsize: self.buf_size.unwrap_or(1024),
timeout: std::time::Duration::from_secs(time),
};
if let Some(ref url) = self.filter_url {
log::info!("Try downloading rule list from '{url}'");
let https = native_tls::TlsConnector::new()?;
let client = ureq::AgentBuilder::new()
.proxy(ureq::Proxy::new(remote)?)
.tls_connector(https.into())
.timeout(config.timeout)
.build();
let resp = client.get(url.as_str()).call()?;
let text = resp.into_string()?;
let len = text.len() as f32 / 1000f32;
log::info!("Successfully downloaded data ({len}/kB transmitted)");
ruleset = Some(self.build_rules(text)?);
}
let engine = Engine::new(ruleset);
let agent = Agent::new(builder, config, engine).into();
Ok(Server { agent })
}
fn build_rules(&self, mut text: String) -> BuildResult<adblock::Engine> {
if self.encoded {
log::info!("Try decoding raw textual data (base64 encoded)");
use base64::{Engine, engine::general_purpose::STANDARD};
let line = text.split_whitespace().collect::<String>();
let decoded = STANDARD.decode(line)?;
text = String::from_utf8(decoded)?;
}
let mut filters = adblock::FilterSet::new(false);
let opts = adblock::lists::ParseOptions::default();
filters.add_filter_list(&text, opts);
log::info!("Rule data parsed successfully");
Ok(adblock::Engine::from_filter_set(filters, true))
}
}
pub struct Server { pub struct Server {
addrs: std::net::SocketAddr, agent: std::sync::Arc<Agent>
} }
impl Server { impl Server {
pub async fn run(self, agent: crate::agent::Agent) -> std::io::Result<()> { pub fn builder() -> ServerBuilder {
let listener = async_std::net::TcpListener::bind(self.addrs).await?; ServerBuilder::default()
let agent = std::sync::Arc::new(agent);
log::info!("IMPOSTER/0.1 HTTP SERVER");
log::info!("Server listening at {}", self.addrs);
loop {
let agent = agent.clone();
let (mut inbound, addr) = listener.accept().await?;
log::info!("*** Incoming connection from {addr}");
async_std::task::spawn(async move {
if let Err(e) = agent.handle(&mut inbound).await {
log::error!("Agent: {e}");
let resp = crate::http::Response::from_err(e);
use async_std::io::WriteExt;
inbound.write(resp.to_string().as_bytes()).await.unwrap();
inbound.flush().await.unwrap();
} }
let _ = inbound.shutdown(std::net::Shutdown::Both); pub async fn bind<A>(self, addrs: A) -> std::io::Result<()>
});
}
}
pub fn bind<A>(addrs: A) -> Self
where where
A: std::net::ToSocketAddrs A: std::net::ToSocketAddrs
{ {
let addrs = addrs.to_socket_addrs() let addrs = addrs.to_socket_addrs()?
.expect("Bind Error")
.collect::<Vec<std::net::SocketAddr>>() .collect::<Vec<std::net::SocketAddr>>()
.pop() .pop()
.expect("Bind Error"); .expect("Bind Error");
Self { addrs } log::info!("IMPOSTER/0.1 HTTP SERVER");
log::info!("Server listening at {addrs}");
let listener = TcpListener::bind(addrs).await?;
loop {
let (inbound, addr) = listener.accept().await?;
log::info!("*** Incoming connection from {addr}");
self.agent.spawn(inbound);
}
} }
} }