diff --git a/Cargo.toml b/Cargo.toml index fc3384c..cca7f95 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,9 @@ wintun = { version = "0.5", default-features = false } [build-dependencies] serde_json = "1" +[features] +udp_packet = [] + [profile.release] opt-level = 's' # Optimize for size (with loop vectorization enabled). lto = true # Enable Link Time Optimization diff --git a/examples/tun.rs b/examples/tun.rs index 37364da..5126259 100644 --- a/examples/tun.rs +++ b/examples/tun.rs @@ -136,6 +136,7 @@ async fn main() -> Result<(), Box> { log::info!("#{number1} TCP closed, session count {c}"); }); } + #[cfg(not(feature = "udp_packet"))] IpStackStream::Udp(mut udp) => { let mut s = match UdpStream::connect(server_addr).await { Ok(s) => s, @@ -159,6 +160,50 @@ async fn main() -> Result<(), Box> { log::info!("#{number2} UDP closed, session count {c}"); }); } + #[cfg(feature = "udp_packet")] + IpStackStream::Udp(mut endpoint) => { + let c = count.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1; + let number2 = number; + log::info!("#{number2} UDP Packet Endpoint starting, session count {c}"); + + tokio::spawn(async move { + loop { + tokio::select! { + res = endpoint.recv() => { + match res { + Some((_src_addr, _dst_addr, _payload)) => { + + + } + None => { + log::info!("#{number2} UDP Packet Endpoint 底层通道已关闭"); + break; + } + } + } + // res = app.readpacket() => { + // match res { + // Ok(Some((remote_player_addr, my_local_addr, payload))) => { + // log::trace!("#{number2} [down] {} -> {} ({} bytes)", remote_player_addr, my_local_addr, payload.len()); + + // + // if let Err(e) = endpoint.send(remote_player_addr, my_local_addr, payload) { + // log::warn!("#{number2} faild to send packet: {}", e); + // } + // } + // Ok(None) | Err(_) => { + // + // break; + // } + // } + // } + + } + } + let c = count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed) - 1; + log::info!("#{number2} UDP Packet Endpoint closed, session count {c}"); + }); + } IpStackStream::UnknownTransport(u) => { let n = number; if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP { diff --git a/src/lib.rs b/src/lib.rs index e67072a..f8ab252 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ use ahash::AHashMap; use packet::{NetworkPacket, NetworkTuple, TransportHeader}; + use std::{sync::Arc, time::Duration}; use tokio::{ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, @@ -23,6 +24,15 @@ pub use self::stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream, IpStac pub use self::stream::{TcpConfig, TcpOptions}; pub use etherparse::IpNumber; +#[cfg(feature = "udp_packet")] +pub use self::stream::IpStackUdpPacketEndpoint; +#[cfg(feature = "udp_packet")] +use std::net::SocketAddr; +#[cfg(feature = "udp_packet")] +use std::sync::atomic::{AtomicU64, Ordering}; +#[cfg(feature = "udp_packet")] +use std::time::{SystemTime, UNIX_EPOCH}; + #[cfg(unix)] const TTL: u8 = 64; @@ -304,23 +314,35 @@ impl Drop for IpStack { } } +#[cfg(feature = "udp_packet")] +type UdpEdpSender = tokio::sync::mpsc::UnboundedSender<(std::net::SocketAddr, std::net::SocketAddr, Vec)>; + +#[cfg(feature = "udp_packet")] +type UdpEdpInfo = (UdpEdpSender, std::sync::Arc); + fn run( config: IpStackConfig, mut device: Device, accept_sender: UnboundedSender, ) -> JoinHandle> { let mut sessions: SessionCollection = AHashMap::new(); + //UDPendpoints + #[cfg(feature = "udp_packet")] + let mut packet_endpoints: AHashMap = AHashMap::new(); let (session_remove_tx, mut session_remove_rx) = mpsc::unbounded_channel::(); + //udp endpoints rm channel + #[cfg(feature = "udp_packet")] + let (edp_remove_tx, mut edp_remove_rx) = mpsc::unbounded_channel::(); let pi = config.packet_information; let offset = if pi && cfg!(unix) { 4 } else { 0 }; let mut buffer = vec![0_u8; config.mtu as usize + offset]; let (up_pkt_sender, mut up_pkt_receiver) = mpsc::unbounded_channel::(); - - tokio::spawn(async move { + #[cfg(feature = "udp_packet")] + return tokio::spawn(async move { loop { select! { Ok(n) = device.read(&mut buffer) => { - if let Err(e) = process_device_read(&buffer[offset..n], &mut sessions, &session_remove_tx, &up_pkt_sender, &config, &accept_sender).await { + if let Err(e) = process_device_read(&buffer[offset..n], &mut sessions, &session_remove_tx, &edp_remove_tx, &up_pkt_sender, &config, &accept_sender, &mut packet_endpoints).await { let io_err: std::io::Error = e.into(); if io_err.kind() == std::io::ErrorKind::ConnectionRefused { log::trace!("Received junk data: {io_err}"); @@ -328,26 +350,56 @@ fn run( log::warn!("process_device_read error: {io_err}"); } } - } + }, + Some(src_addr) = edp_remove_rx.recv() => { + packet_endpoints.remove(&src_addr); + log::debug!("Packet endpoint destroyed and removed: {}", src_addr); + }, Some(network_tuple) = session_remove_rx.recv() => { sessions.remove(&network_tuple); log::debug!("session destroyed: {network_tuple}"); + }, + Some(packet) = up_pkt_receiver.recv() => { + process_upstream_recv(packet, &mut device, #[cfg(unix)]pi).await?; } + } + } + }); + #[cfg(not(feature = "udp_packet"))] + return tokio::spawn(async move { + loop { + select! { + Ok(n) = device.read(&mut buffer) => { + if let Err(e) = process_device_read(&buffer[offset..n], &mut sessions, &session_remove_tx, &up_pkt_sender, &config, &accept_sender).await { + let io_err: std::io::Error = e.into(); + if io_err.kind() == std::io::ErrorKind::ConnectionRefused { + log::trace!("Received junk data: {io_err}"); + } else { + log::warn!("process_device_read error: {io_err}"); + } + } + }, + Some(network_tuple) = session_remove_rx.recv() => { + sessions.remove(&network_tuple); + log::debug!("session destroyed: {network_tuple}"); + }, Some(packet) = up_pkt_receiver.recv() => { process_upstream_recv(packet, &mut device, #[cfg(unix)]pi).await?; } } } - }) + }); } - +#[allow(clippy::too_many_arguments)] async fn process_device_read( data: &[u8], sessions: &mut SessionCollection, session_remove_tx: &UnboundedSender, + #[cfg(feature = "udp_packet")] edp_remove_tx: &tokio::sync::mpsc::UnboundedSender, up_pkt_sender: &PacketSender, config: &IpStackConfig, accept_sender: &UnboundedSender, + #[cfg(feature = "udp_packet")] packet_endpoints: &mut AHashMap, ) -> Result<()> { let Ok(packet) = NetworkPacket::parse(data) else { let stream = IpStackStream::UnknownNetwork(data.to_owned()); @@ -368,6 +420,79 @@ async fn process_device_read( return Ok(()); } + //UDP packet + #[cfg(feature = "udp_packet")] + if let TransportHeader::Udp(_udp_header) = packet.transport_header() { + let src_addr = packet.src_addr(); + let dst_addr = packet.dst_addr(); + let payload = packet.payload.unwrap_or_default(); + + match packet_endpoints.entry(src_addr) { + std::collections::hash_map::Entry::Occupied(entry) => { + let (tx, last_activity) = entry.get(); + + if let Err(e) = tx.send((src_addr, dst_addr, payload)) { + log::warn!("Failed to send to packet endpoint for {}: {}", src_addr, e); + // Receiver was dropped; remove stale endpoint so a new one can be created. + entry.remove(); + } else { + last_activity.store(now_secs(), Ordering::Relaxed); + } + } + + std::collections::hash_map::Entry::Vacant(entry) => { + //announce to destroy the channel when timeout or application layer take out + let (destroy_tx, mut destroy_rx) = tokio::sync::oneshot::channel::<()>(); + + let last_activity = Arc::new(AtomicU64::new(now_secs())); + let last_activity_clone = last_activity.clone(); + + let timeout_secs = config.udp_timeout.as_secs(); + + let edp_remove_tx_clone = edp_remove_tx.clone(); + let src_addr_clone = src_addr; + + tokio::spawn(async move { + loop { + let elapsed = now_secs() - last_activity_clone.load(Ordering::Relaxed); + if elapsed >= timeout_secs { + log::info!("removing channel for {} because no data for {}s", src_addr_clone, elapsed); + break; + } + + let sleep_duration = std::time::Duration::from_secs(timeout_secs - elapsed); + + tokio::select! { + + //sleep until timeout + _ = tokio::time::sleep(sleep_duration) => {} + + // application layer take out + _ = &mut destroy_rx => { + log::debug!("application layer Endpoint:{} removed the channel", src_addr_clone); + break; + } + } + } + + let _ = edp_remove_tx_clone.send(src_addr_clone); + }); + //ipstack to application layer channel + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + + let endpoint = IpStackUdpPacketEndpoint::new(rx, up_pkt_sender.clone(), src_addr, config.mtu, destroy_tx); + + accept_sender.send(IpStackStream::Udp(endpoint)).map_err(std::io::Error::other)?; + + entry.insert((tx.clone(), last_activity)); + if let Err(e) = tx.send((src_addr, dst_addr, payload)) { + log::warn!("Failed to send to packet endpoint: {}", e); + } + } + } + return Ok(()); + } + let network_tuple = packet.network_tuple(); match sessions.entry(network_tuple) { std::collections::hash_map::Entry::Occupied(entry) => { @@ -407,11 +532,14 @@ fn create_stream( let stream = IpStackTcpStream::new(src_addr, dst_addr, h.clone(), up_pkt_sender, cfg.mtu, msgr, cfg.tcp_config.clone())?; Ok(IpStackStream::Tcp(stream)) } + #[cfg(not(feature = "udp_packet"))] TransportHeader::Udp(_) => { let payload = packet.payload.unwrap_or_default(); let stream = IpStackUdpStream::new(src_addr, dst_addr, payload, up_pkt_sender, cfg.mtu, cfg.udp_timeout, msgr); Ok(IpStackStream::Udp(stream)) } + #[cfg(feature = "udp_packet")] + TransportHeader::Udp(_) => Err(IpStackError::UnsupportedTransportProtocol), TransportHeader::Unknown => Err(IpStackError::UnsupportedTransportProtocol), } } @@ -439,3 +567,12 @@ async fn process_upstream_recv( Ok(()) } + +//time +#[cfg(feature = "udp_packet")] +fn now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_else(|_| Duration::from_secs(0)) + .as_secs() +} diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 777053f..d54b5fd 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -2,9 +2,11 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; pub use self::tcp::IpStackTcpStream; pub use self::tcp::{TcpConfig, TcpOptions}; +#[cfg(feature = "udp_packet")] +pub use self::udp::IpStackUdpPacketEndpoint; + pub use self::udp::IpStackUdpStream; pub use self::unknown::IpStackUnknownTransport; - mod seqnum; mod tcb; mod tcp; @@ -26,7 +28,11 @@ pub enum IpStackStream { /// A TCP connection stream. Tcp(IpStackTcpStream), /// A UDP stream. + #[cfg(not(feature = "udp_packet"))] Udp(IpStackUdpStream), + /// UDP PACKET. + #[cfg(feature = "udp_packet")] + Udp(IpStackUdpPacketEndpoint), /// A stream for unknown transport protocols. UnknownTransport(IpStackUnknownTransport), /// Raw network packets that couldn't be parsed. @@ -51,7 +57,10 @@ impl IpStackStream { pub fn local_addr(&self) -> SocketAddr { match self { IpStackStream::Tcp(tcp) => tcp.local_addr(), + #[cfg(not(feature = "udp_packet"))] IpStackStream::Udp(udp) => udp.local_addr(), + #[cfg(feature = "udp_packet")] + IpStackStream::Udp(udp_edp) => udp_edp.local_addr(), IpStackStream::UnknownNetwork(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)), IpStackStream::UnknownTransport(unknown) => match unknown.src_addr() { IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)), @@ -77,7 +86,10 @@ impl IpStackStream { pub fn peer_addr(&self) -> SocketAddr { match self { IpStackStream::Tcp(tcp) => tcp.peer_addr(), + #[cfg(not(feature = "udp_packet"))] IpStackStream::Udp(udp) => udp.peer_addr(), + #[cfg(feature = "udp_packet")] + IpStackStream::Udp(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)), IpStackStream::UnknownNetwork(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)), IpStackStream::UnknownTransport(unknown) => match unknown.dst_addr() { IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)), @@ -89,6 +101,7 @@ impl IpStackStream { pub(crate) fn stream_sender(&self) -> Result { match self { IpStackStream::Tcp(tcp) => Ok(tcp.stream_sender()), + #[cfg(not(feature = "udp_packet"))] IpStackStream::Udp(udp) => Ok(udp.stream_sender()), _ => Err(std::io::Error::other("Unknown transport stream does not have a sender")), } diff --git a/src/stream/udp.rs b/src/stream/udp.rs index 87925b1..a3a0c44 100644 --- a/src/stream/udp.rs +++ b/src/stream/udp.rs @@ -40,6 +40,7 @@ use tokio::{ pub struct IpStackUdpStream { src_addr: SocketAddr, dst_addr: SocketAddr, + #[allow(dead_code)] stream_sender: PacketSender, stream_receiver: PacketReceiver, up_pkt_sender: PacketSender, @@ -75,11 +76,10 @@ impl IpStackUdpStream { destroy_messenger, } } - + #[allow(dead_code)] pub(crate) fn stream_sender(&self) -> PacketSender { self.stream_sender.clone() } - fn create_rev_packet(&self, ttl: u8, mut payload: Vec) -> std::io::Result { const UHS: usize = 8; // udp header size is 8 match (self.dst_addr.ip(), self.src_addr.ip()) { @@ -213,3 +213,101 @@ impl Drop for IpStackUdpStream { } } } +#[cfg(feature = "udp_packet")] +pub struct IpStackUdpPacketEndpoint { + //receive from TUN: (src, dst, payload) + receiver: mpsc::UnboundedReceiver<(SocketAddr, SocketAddr, Vec)>, + + //send to TUN: raw packet + up_pkt_sender: crate::PacketSender, + + local_addr: SocketAddr, + + mtu: u16, + + _destroy_messenger: tokio::sync::oneshot::Sender<()>, +} +#[cfg(feature = "udp_packet")] +impl IpStackUdpPacketEndpoint { + pub fn new( + receiver: mpsc::UnboundedReceiver<(SocketAddr, SocketAddr, Vec)>, + up_pkt_sender: crate::PacketSender, + local_addr: SocketAddr, + mtu: u16, + _destroy_messenger: tokio::sync::oneshot::Sender<()>, + ) -> Self { + Self { + receiver, + up_pkt_sender, + local_addr, + mtu, + _destroy_messenger, + } + } + + /// recv from TUN: (src, dst, payload) + pub async fn recv(&mut self) -> Option<(SocketAddr, SocketAddr, Vec)> { + self.receiver.recv().await + } + + /// send to TUN: raw packet + pub fn send(&self, src: SocketAddr, dst: SocketAddr, payload: Vec) -> std::io::Result<()> { + let raw_packet = build_raw_udp_packet(src, dst, payload, self.mtu)?; + + self.up_pkt_sender + .send(raw_packet) + .map_err(|_| std::io::Error::new(std::io::ErrorKind::BrokenPipe, "ipstack up_pkt_sender closed")) + } + + pub fn local_addr(&self) -> SocketAddr { + self.local_addr + } +} +#[cfg(feature = "udp_packet")] +pub fn build_raw_udp_packet(src_addr: SocketAddr, dst_addr: SocketAddr, mut payload: Vec, mtu: u16) -> std::io::Result { + const UHS: usize = 8; + let ttl = 64; + + match (src_addr.ip(), dst_addr.ip()) { + (std::net::IpAddr::V4(src), std::net::IpAddr::V4(dst)) => { + let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::UDP, src.octets(), dst.octets()).map_err(IpStackError::from)?; + + let line_buffer = mtu.saturating_sub((ip_h.header_len() + UHS) as u16); + payload.truncate(line_buffer as usize); + + ip_h.set_payload_len(payload.len() + UHS).map_err(IpStackError::from)?; + + let udp_header = + UdpHeader::with_ipv4_checksum(src_addr.port(), dst_addr.port(), &ip_h, &payload).map_err(IpStackError::from)?; + + Ok(NetworkPacket { + ip: IpHeader::Ipv4(ip_h), + transport: TransportHeader::Udp(udp_header), + payload: Some(payload), + }) + } + (std::net::IpAddr::V6(src), std::net::IpAddr::V6(dst)) => { + let mut ip_h = Ipv6Header { + traffic_class: 0, + flow_label: Ipv6FlowLabel::ZERO, + payload_length: 0, + next_header: IpNumber::UDP, + hop_limit: ttl, + source: src.octets(), + destination: dst.octets(), + }; + let line_buffer = mtu.saturating_sub((ip_h.header_len() + UHS) as u16); + payload.truncate(line_buffer as usize); + ip_h.payload_length = (payload.len() + UHS) as u16; + let udp_header = + UdpHeader::with_ipv6_checksum(src_addr.port(), dst_addr.port(), &ip_h, &payload).map_err(IpStackError::from)?; + + Ok(NetworkPacket { + ip: IpHeader::Ipv6(ip_h), + transport: TransportHeader::Udp(udp_header), + payload: Some(payload), + }) + } + _ => Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "IP version mismatch")), + } +}