From 312ba329e5dbbdc48f6f75a4dadc9dcf7e450e5f Mon Sep 17 00:00:00 2001 From: break27 Date: Fri, 7 Jun 2024 16:24:33 +0800 Subject: [PATCH] add http.rs --- Cargo.lock | 24 ------------- Cargo.toml | 1 - src/agent.rs | 98 +++++++++++++++++++++++++--------------------------- src/error.rs | 5 --- src/http.rs | 74 +++++++++++++++++++++++++++++++++++++++ src/main.rs | 1 + 6 files changed, 122 insertions(+), 81 deletions(-) create mode 100644 src/http.rs diff --git a/Cargo.lock b/Cargo.lock index 2cb00f7..b31a918 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -526,12 +526,6 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" -[[package]] -name = "fnv" -version = "1.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" - [[package]] name = "foreign-types" version = "0.3.2" @@ -635,17 +629,6 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" -[[package]] -name = "http" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" -dependencies = [ - "bytes", - "fnv", - "itoa", -] - [[package]] name = "httparse" version = "1.8.0" @@ -690,7 +673,6 @@ dependencies = [ "base64 0.22.1", "clap", "env_logger", - "http", "httparse", "log", "native-tls", @@ -735,12 +717,6 @@ dependencies = [ "either", ] -[[package]] -name = "itoa" -version = "1.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" - [[package]] name = "js-sys" version = "0.3.69" diff --git a/Cargo.toml b/Cargo.toml index 8563f8b..2bf8781 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,6 @@ async-std = { version = "1.12.0", features = ["attributes"] } base64 = "0.22.1" clap = { version = "4.5.4", features = ["derive"] } env_logger = "0.11.3" -http = "1.1.0" httparse = "1.8.0" log = "0.4.21" native-tls = "0.2.12" diff --git a/src/agent.rs b/src/agent.rs index c9631d5..5e056fb 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -3,6 +3,7 @@ use async_std::net::TcpStream; use crate::connection::ConnectionBuilder; use crate::error::{Result, Error, BuildError, BuildResult}; +use crate::http; pub struct AgentBuilder { filter_url: Option, @@ -58,7 +59,7 @@ impl AgentBuilder { }; if let Some(ref url) = self.filter_url { - log::info!(target: "builder", "Try downloading rule list from '{}'", url); + log::info!(target: "builder", "Try downloading rule list from '{url}'"); let https = native_tls::TlsConnector::new()?; let client = ureq::AgentBuilder::new() @@ -68,9 +69,9 @@ impl AgentBuilder { .build(); let resp = client.get(url.as_str()).call()?; let text = resp.into_string()?; - let kbs = text.len() as f32 / 1000f32; + let len = text.len() as f32 / 1000f32; - log::info!(target: "builder", "Successfully downloaded data ({}/kB transmitted)", kbs); + log::info!(target: "builder", "Successfully downloaded data ({len}/kB transmitted)"); ruleset = Some(self.build_rules(text)?); } @@ -115,25 +116,17 @@ impl Agent { where S: Read + Write + Send + Sync + Unpin + 'static { - let (request, payload) = self.read(&mut conn)?; + let request = self.read(&mut conn)?; + let host = request.host(); - let value = request.headers.get("host").unwrap(); - let mut host = value.to_str()?.to_string(); + log::info!("CLIENT --> {host}"); - if ! host.ends_with(char::is_numeric) { - // append a port number when without one - host += ":80"; - } - - log::info!("CLIENT --> {} ({}/bit request intercepted)", - host, payload.len()); - - if self.check_request_blocked(&request.uri.to_string()) { - log::info!("CLIENT --> PROXY --> {}", host); + if self.check_request_blocked(&request.path) { + log::info!("CLIENT --> PROXY --> {host}"); let mut outbound = self.io(self.builder.connect(&host))?; // forward intercepted request - outbound.write_all(&payload).await?; + outbound.write_all(request.as_bytes()).await?; outbound.flush().await?; log::info!("CLIENT <-> PROXY (connection established)"); @@ -173,7 +166,7 @@ impl Agent { Ok(()) } - fn read(&self, conn: &mut S) -> Result<(http::request::Parts, Vec)> + fn read(&self, conn: &mut S) -> Result where S: Read + Write + Send + Unpin + 'static { @@ -187,33 +180,31 @@ impl Agent { let payload = buf[..offset].to_vec(); let method = match request.method { - Some(x) => x, + Some(x) => x.parse::().unwrap(), None => return Err(Error::BadRequest("METHOD".to_string())) }; - let path = match request.path { - Some(x) => { - let mut text = x.to_string(); - if text.find("://").is_none() { - // in case of an cannot-be-a-base url - // find a port number, if any - let port = text - .rfind(":") - .and_then(|x| text.get(x + 1..)); - - let scheme = match port { - Some("443") => "https", - Some("21") => "ftp", - Some("80") | _ => "http", - }; - - text = format!("{}://{}", scheme, text); - } - text.parse::()? - }, + let mut path = match request.path { + Some(x) => x.to_string(), None => return Err(Error::BadRequest("PATH".to_string())) }; + if path.find("://").is_none() { + // in case of an cannot-be-a-base url + // find a port number, if any + let port = path + .rfind(":") + .and_then(|x| path.get(x + 1..)); + + let scheme = match port { + Some("443") => "https", + Some("21") => "ftp", + Some("80") | _ => "http", + }; + + path = format!("{}://{}", scheme, path); + } + let version = match request.version { Some(3) => http::Version::HTTP_3, Some(2) => http::Version::HTTP_2, @@ -223,21 +214,26 @@ impl Agent { None => return Err(Error::BadRequest("VERSION".to_string())) }; - let (mut parts, _) = http::Request::builder() - .method(method) - .uri(path) - .version(version) - .body(())? - .into_parts(); + let mut host = headers.iter() + .find_map(|x: _| (x.name == "Host").then_some(x.value)) + .map(|x| std::str::from_utf8(x)) + .ok_or(Error::BadRequest("Host".to_string()))?? + .to_string(); - for (k, v) in headers.map(|x: _| (x.name, x.value)) { - if k.is_empty() { break } - let key = k.parse::()?; - let value = std::str::from_utf8(v)?.parse::()?; - parts.headers.insert(key, value); + if host.find(":").is_none() { + // append a port number when without one + host += ":80"; } - Ok((parts, payload)) + let request = crate::http::Request { + method, + path, + version, + host, + payload: payload.into(), + }; + + Ok(request) } fn check_request_blocked(&self, url: &str) -> bool { diff --git a/src/error.rs b/src/error.rs index 4d0b42e..ef94638 100644 --- a/src/error.rs +++ b/src/error.rs @@ -30,11 +30,6 @@ impl_error! { BadRequest("Missing part '{}'"), Io(std::io::Error), Parse(httparse::Error), - Http(http::Error), - Uri(http::uri::InvalidUri), - HeaderName(http::header::InvalidHeaderName), - HeaderValue(http::header::InvalidHeaderValue), - ToStr(http::header::ToStrError), Utf8(std::str::Utf8Error), Timeout(async_std::future::TimeoutError), } diff --git a/src/http.rs b/src/http.rs new file mode 100644 index 0000000..772aee9 --- /dev/null +++ b/src/http.rs @@ -0,0 +1,74 @@ +macro_rules! impl_http { + (pub struct $name:ident($inner:ident) { $($variant:ident = $value:literal,)* }) => { + #[derive(PartialEq)] + pub struct $name($inner); + + #[derive(PartialEq)] + #[allow(non_camel_case_types)] + enum $inner{ $($variant,)* } + + impl $name { + $( + pub const $variant: Self = Self($inner::$variant); + )* + } + impl ToString for $name { + fn to_string(&self) -> String { + match self.0 { + $($inner::$variant => $value.to_string(),)* + } + } + } + impl std::str::FromStr for $name { + type Err = (); + fn from_str(s: &str) -> Result { + match s { + $($value => Ok(Self::$variant),)* + _ => Err(()) + } + } + } + }; +} + +impl_http! { + pub struct Method(InnerMethod) { + OPTIONS = "OPTIONS", + GET = "GET", + POST = "POST", + PUT = "PUT", + DELETE = "DELETE", + HEAD = "HEAD", + TRACE = "TRACE", + CONNECT = "CONNECT", + PATCH = "PATCH", + } +} + +impl_http! { + pub struct Version(InnerVersion) { + HTTP_09 = "HTTP/0.9", + HTTP_10 = "HTTP/1.0", + HTTP_11 = "HTTP/1.1", + HTTP_2 = "HTTP/2.0", + HTTP_3 = "HTTP/3.0", + } +} + +pub struct Request { + pub method: Method, + pub path: String, + pub version: Version, + pub(crate) host: String, + pub(crate) payload: Box<[u8]>, +} + +impl Request { + pub fn as_bytes(&self) -> &[u8] { + &self.payload + } + + pub fn host(&self) -> &str { + &self.host + } +} diff --git a/src/main.rs b/src/main.rs index 5bd112b..d7a8895 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ use clap::Parser; +mod http; pub mod agent; pub mod error; pub mod connection;