From 3a153acec6c4d189eb5de501b2155b4484b8651b Mon Sep 17 00:00:00 2001 From: Paolo Barbolini Date: Sun, 17 Mar 2024 03:47:40 +0100 Subject: [PATCH] Forward vectored writes (#45) * Migrate early-data test to rustls * Replace `match` with `if let` on `TlsState::EarlyData` * Extract client early data handling * Forward vectored writes --- src/client.rs | 180 +++++++++++++++++++++++++------------- src/common/mod.rs | 37 ++++++++ src/common/test_stream.rs | 11 ++- src/lib.rs | 20 +++++ src/server.rs | 18 ++++ tests/badssl.rs | 28 +++++- tests/early-data.rs | 174 +++++++++++++++--------------------- tests/utils.rs | 28 +++++- 8 files changed, 328 insertions(+), 168 deletions(-) diff --git a/src/client.rs b/src/client.rs index f03448fe..1e57647e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,6 +4,8 @@ use std::os::unix::io::{AsRawFd, RawFd}; #[cfg(windows)] use std::os::windows::io::{AsRawSocket, RawSocket}; use std::pin::Pin; +#[cfg(feature = "early-data")] +use std::task::Waker; use std::task::{Context, Poll}; use rustls::ClientConnection; @@ -20,7 +22,7 @@ pub struct TlsStream { pub(crate) state: TlsState, #[cfg(feature = "early-data")] - pub(crate) early_waker: Option, + pub(crate) early_waker: Option, } impl TlsStream { @@ -152,78 +154,70 @@ where let mut stream = Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); - #[allow(clippy::match_single_binding)] - match this.state { - #[cfg(feature = "early-data")] - TlsState::EarlyData(ref mut pos, ref mut data) => { - use std::io::Write; - - // write early data - if let Some(mut early_data) = stream.session.early_data() { - let len = match early_data.write(buf) { - Ok(n) => n, - Err(err) => return Poll::Ready(Err(err)), - }; - if len != 0 { - data.extend_from_slice(&buf[..len]); - return Poll::Ready(Ok(len)); - } - } - - // complete handshake - while stream.session.is_handshaking() { - ready!(stream.handshake(cx))?; - } - - // write early data (fallback) - if !stream.session.is_early_data_accepted() { - while *pos < data.len() { - let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; - *pos += len; - } - } - - // end - this.state = TlsState::Stream; - - if let Some(waker) = this.early_waker.take() { - waker.wake(); - } - - stream.as_mut_pin().poll_write(cx, buf) + #[cfg(feature = "early-data")] + { + let bufs = [io::IoSlice::new(buf)]; + let written = ready!(poll_handle_early_data( + &mut this.state, + &mut stream, + &mut this.early_waker, + cx, + &bufs + ))?; + if written != 0 { + return Poll::Ready(Ok(written)); } - _ => stream.as_mut_pin().poll_write(cx, buf), } + + stream.as_mut_pin().poll_write(cx, buf) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + /// Note: that it does not guarantee the final data to be sent. + /// To be cautious, you must manually call `flush`. + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { let this = self.get_mut(); let mut stream = Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); #[cfg(feature = "early-data")] { - if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state { - // complete handshake - while stream.session.is_handshaking() { - ready!(stream.handshake(cx))?; - } + let written = ready!(poll_handle_early_data( + &mut this.state, + &mut stream, + &mut this.early_waker, + cx, + bufs + ))?; + if written != 0 { + return Poll::Ready(Ok(written)); + } + } - // write early data (fallback) - if !stream.session.is_early_data_accepted() { - while *pos < data.len() { - let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; - *pos += len; - } - } + stream.as_mut_pin().poll_write_vectored(cx, bufs) + } - this.state = TlsState::Stream; + #[inline] + fn is_write_vectored(&self) -> bool { + true + } - if let Some(waker) = this.early_waker.take() { - waker.wake(); - } - } - } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + + #[cfg(feature = "early-data")] + ready!(poll_handle_early_data( + &mut this.state, + &mut stream, + &mut this.early_waker, + cx, + &[] + ))?; stream.as_mut_pin().poll_flush(cx) } @@ -248,3 +242,69 @@ where stream.as_mut_pin().poll_shutdown(cx) } } + +#[cfg(feature = "early-data")] +fn poll_handle_early_data( + state: &mut TlsState, + stream: &mut Stream, + early_waker: &mut Option, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], +) -> Poll> +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + if let TlsState::EarlyData(pos, data) = state { + use std::io::Write; + + // write early data + if let Some(mut early_data) = stream.session.early_data() { + let mut written = 0; + + for buf in bufs { + if buf.is_empty() { + continue; + } + + let len = match early_data.write(buf) { + Ok(0) => break, + Ok(n) => n, + Err(err) => return Poll::Ready(Err(err)), + }; + + written += len; + data.extend_from_slice(&buf[..len]); + + if len < buf.len() { + break; + } + } + + if written != 0 { + return Poll::Ready(Ok(written)); + } + } + + // complete handshake + while stream.session.is_handshaking() { + ready!(stream.handshake(cx))?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; + *pos += len; + } + } + + // end + *state = TlsState::Stream; + + if let Some(waker) = early_waker.take() { + waker.wake(); + } + } + + Poll::Ready(Ok(0)) +} diff --git a/src/common/mod.rs b/src/common/mod.rs index 15c76c5b..18e9b94c 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -289,6 +289,43 @@ where Poll::Ready(Ok(pos)) } + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + if bufs.iter().all(|buf| buf.is_empty()) { + return Poll::Ready(Ok(0)); + } + + loop { + let mut would_block = false; + let written = self.session.writer().write_vectored(bufs)?; + + while self.session.wants_write() { + match self.write_io(cx) { + Poll::Ready(Ok(0)) | Poll::Pending => { + would_block = true; + break; + } + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + } + } + + return match (written, would_block) { + (0, true) => Poll::Pending, + (0, false) => continue, + (n, _) => Poll::Ready(Ok(n)), + }; + } + } + + #[inline] + fn is_write_vectored(&self) -> bool { + true + } + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { self.session.writer().flush()?; while self.session.wants_write() { diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index dbd9a233..149c461d 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -122,6 +122,15 @@ impl AsyncWrite for Expected { #[tokio::test] async fn stream_good() -> io::Result<()> { + stream_good_impl(false).await +} + +#[tokio::test] +async fn stream_good_vectored() -> io::Result<()> { + stream_good_impl(true).await +} + +async fn stream_good_impl(vectored: bool) -> io::Result<()> { const FILE: &[u8] = include_bytes!("../../README.md"); let (server, mut client) = make_pair(); @@ -139,7 +148,7 @@ async fn stream_good() -> io::Result<()> { dbg!(stream.read_to_end(&mut buf).await)?; assert_eq!(buf, FILE); - dbg!(stream.write_all(b"Hello World!").await)?; + dbg!(utils::write(&mut stream, b"Hello World!", vectored).await)?; stream.session.send_close_notify(); dbg!(stream.shutdown().await)?; diff --git a/src/lib.rs b/src/lib.rs index ea121a75..619368b1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -564,6 +564,26 @@ where } } + #[inline] + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + match self.get_mut() { + TlsStream::Client(x) => Pin::new(x).poll_write_vectored(cx, bufs), + TlsStream::Server(x) => Pin::new(x).poll_write_vectored(cx, bufs), + } + } + + #[inline] + fn is_write_vectored(&self) -> bool { + match self { + TlsStream::Client(x) => x.is_write_vectored(), + TlsStream::Server(x) => x.is_write_vectored(), + } + } + #[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { diff --git a/src/server.rs b/src/server.rs index 9444a625..02debac3 100644 --- a/src/server.rs +++ b/src/server.rs @@ -113,6 +113,24 @@ where stream.as_mut_pin().poll_write(cx, buf) } + /// Note: that it does not guarantee the final data to be sent. + /// To be cautious, you must manually call `flush`. + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + stream.as_mut_pin().poll_write_vectored(cx, bufs) + } + + #[inline] + fn is_write_vectored(&self) -> bool { + true + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); let mut stream = diff --git a/tests/badssl.rs b/tests/badssl.rs index 34da3293..87e1744d 100644 --- a/tests/badssl.rs +++ b/tests/badssl.rs @@ -14,6 +14,7 @@ async fn get( config: Arc, domain: &str, port: u16, + vectored: bool, ) -> io::Result<(TlsStream, String)> { let connector = TlsConnector::from(config); let input = format!("GET / HTTP/1.0\r\nHost: {}\r\n\r\n", domain); @@ -24,7 +25,7 @@ async fn get( let stream = TcpStream::connect(&addr).await?; let mut stream = connector.connect(domain, stream).await?; - stream.write_all(input.as_bytes()).await?; + utils::write(&mut stream, input.as_bytes(), vectored).await?; stream.flush().await?; stream.read_to_end(&mut buf).await?; @@ -33,6 +34,15 @@ async fn get( #[tokio::test] async fn test_tls12() -> io::Result<()> { + test_tls12_impl(false).await +} + +#[tokio::test] +async fn test_tls12_vectored() -> io::Result<()> { + test_tls12_impl(true).await +} + +async fn test_tls12_impl(vectored: bool) -> io::Result<()> { let mut root_store = rustls::RootCertStore::empty(); root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); let config = rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS12]) @@ -42,7 +52,7 @@ async fn test_tls12() -> io::Result<()> { let config = Arc::new(config); let domain = "tls-v1-2.badssl.com"; - let (_, output) = get(config.clone(), domain, 1012).await?; + let (_, output) = get(config.clone(), domain, 1012, vectored).await?; assert!( output.contains("tls-v1-2.badssl.com"), "failed badssl test, output: {}", @@ -61,6 +71,15 @@ fn test_tls13() { #[tokio::test] async fn test_modern() -> io::Result<()> { + test_modern_impl(false).await +} + +#[tokio::test] +async fn test_modern_vectored() -> io::Result<()> { + test_modern_impl(true).await +} + +async fn test_modern_impl(vectored: bool) -> io::Result<()> { let mut root_store = rustls::RootCertStore::empty(); root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); let config = rustls::ClientConfig::builder() @@ -69,7 +88,7 @@ async fn test_modern() -> io::Result<()> { let config = Arc::new(config); let domain = "mozilla-modern.badssl.com"; - let (_, output) = get(config.clone(), domain, 443).await?; + let (_, output) = get(config.clone(), domain, 443, vectored).await?; assert!( output.contains("mozilla-modern.badssl.com"), "failed badssl test, output: {}", @@ -78,3 +97,6 @@ async fn test_modern() -> io::Result<()> { Ok(()) } + +// Include `utils` module +include!("utils.rs"); diff --git a/tests/early-data.rs b/tests/early-data.rs index f01189d5..42faad32 100644 --- a/tests/early-data.rs +++ b/tests/early-data.rs @@ -1,20 +1,16 @@ #![cfg(feature = "early-data")] -use std::io::{self, BufRead, BufReader, Cursor}; -use std::net::SocketAddr; +use std::io::{self, BufReader, Cursor, Read, Write}; +use std::net::{SocketAddr, TcpListener}; use std::pin::Pin; -use std::process::{Child, Command, Stdio}; use std::sync::Arc; use std::task::{Context, Poll}; use std::thread; -use std::time::Duration; -use futures_util::{future, future::Future, ready}; -use rustls::{self, ClientConfig, RootCertStore}; -use tokio::io::{split, AsyncRead, AsyncWriteExt, ReadBuf}; +use futures_util::{future::Future, ready}; +use rustls::{self, ClientConfig, RootCertStore, ServerConfig, ServerConnection, Stream}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf}; use tokio::net::TcpStream; -use tokio::sync::oneshot; -use tokio::time::sleep; use tokio_rustls::{client::TlsStream, TlsConnector}; struct Read1(T); @@ -41,91 +37,77 @@ async fn send( config: Arc, addr: SocketAddr, data: &[u8], -) -> io::Result> { + vectored: bool, +) -> io::Result<(TlsStream, Vec)> { let connector = TlsConnector::from(config).early_data(true); let stream = TcpStream::connect(&addr).await?; let domain = pki_types::ServerName::try_from("foobar.com").unwrap(); - let stream = connector.connect(domain, stream).await?; - let (mut rd, mut wd) = split(stream); - let (notify, wait) = oneshot::channel(); - - let j = tokio::spawn(async move { - // read to eof - // - // see https://www.mail-archive.com/openssl-users@openssl.org/msg84451.html - let mut read_task = Read1(&mut rd); - let mut notify = Some(notify); - - // read once, then write - // - // this is a regression test, see https://github.com/tokio-rs/tls/issues/54 - future::poll_fn(|cx| { - let ret = Pin::new(&mut read_task).poll(cx)?; - assert_eq!(ret, Poll::Pending); - - notify.take().unwrap().send(()).unwrap(); - - Poll::Ready(Ok(())) as Poll> - }) - .await?; - - match read_task.await { - Ok(()) => (), - Err(ref err) if err.kind() == io::ErrorKind::UnexpectedEof => (), - Err(err) => return Err(err), - } - - Ok(rd) as io::Result<_> - }); + let mut stream = connector.connect(domain, stream).await?; + utils::write(&mut stream, data, vectored).await?; + stream.flush().await?; + stream.shutdown().await?; - wait.await.unwrap(); + let mut buf = Vec::new(); + stream.read_to_end(&mut buf).await?; - wd.write_all(data).await?; - wd.flush().await?; - wd.shutdown().await?; - - let rd: tokio::io::ReadHalf<_> = j.await??; - - Ok(rd.unsplit(wd)) + Ok((stream, buf)) } -struct DropKill(Child); - -impl Drop for DropKill { - fn drop(&mut self) { - self.0.kill().unwrap(); - } +#[tokio::test] +async fn test_0rtt() -> io::Result<()> { + test_0rtt_impl(false).await } -async fn wait_for_server(addr: &str) { - let tries = 10; - for i in 0..tries { - if let Ok(_) = TcpStream::connect(addr).await { - return; - } - sleep(Duration::from_millis(i * 100)).await; - } - panic!("failed to connect to {:?} after {} tries", addr, tries) +#[tokio::test] +async fn test_0rtt_vectored() -> io::Result<()> { + test_0rtt_impl(true).await } -#[tokio::test] -async fn test_0rtt() -> io::Result<()> { - let server_port = 12354; - let mut handle = Command::new("openssl") - .arg("s_server") - .arg("-early_data") - .arg("-tls1_3") - .args(["-cert", "./tests/end.cert"]) - .args(["-key", "./tests/end.rsa"]) - .args(["-port", &server_port.to_string()]) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .spawn() - .map(DropKill)?; - - // wait openssl server - wait_for_server(format!("127.0.0.1:{}", server_port).as_str()).await; +async fn test_0rtt_impl(vectored: bool) -> io::Result<()> { + let cert_chain = rustls_pemfile::certs(&mut Cursor::new(include_bytes!("end.cert"))) + .collect::>>()?; + let key_der = + rustls_pemfile::private_key(&mut Cursor::new(include_bytes!("end.rsa")))?.unwrap(); + let mut server = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(cert_chain, key_der) + .unwrap(); + server.max_early_data_size = 8192; + let server = Arc::new(server); + + let listener = TcpListener::bind("127.0.0.1:0")?; + let server_port = listener.local_addr().unwrap().port(); + thread::spawn(move || loop { + let (mut sock, _addr) = listener.accept().unwrap(); + + let server = Arc::clone(&server); + thread::spawn(move || { + let mut conn = ServerConnection::new(server).unwrap(); + conn.complete_io(&mut sock).unwrap(); + + if let Some(mut early_data) = conn.early_data() { + let mut buf = Vec::new(); + early_data.read_to_end(&mut buf).unwrap(); + let mut stream = Stream::new(&mut conn, &mut sock); + stream.write_all(b"EARLY:").unwrap(); + stream.write_all(&buf).unwrap(); + } + + let mut stream = Stream::new(&mut conn, &mut sock); + stream.write_all(b"LATE:").unwrap(); + loop { + let mut buf = [0; 1024]; + let n = stream.read(&mut buf).unwrap(); + if n == 0 { + conn.send_close_notify(); + conn.complete_io(&mut sock).unwrap(); + break; + } + stream.write_all(&buf[..n]).unwrap(); + } + }); + }); let mut chain = BufReader::new(Cursor::new(include_str!("end.chain"))); let mut root_store = RootCertStore::empty(); @@ -141,30 +123,16 @@ async fn test_0rtt() -> io::Result<()> { let config = Arc::new(config); let addr = SocketAddr::from(([127, 0, 0, 1], server_port)); - // workaround: write to openssl s_server standard input periodically, to - // get it unstuck on Windows - let stdin = handle.0.stdin.take().unwrap(); - thread::spawn(move || { - let mut stdin = stdin; - loop { - thread::sleep(std::time::Duration::from_secs(5)); - std::io::Write::write_all(&mut stdin, b"\n").unwrap(); - } - }); - - let io = send(config.clone(), addr, b"hello").await?; + let (io, buf) = send(config.clone(), addr, b"hello", vectored).await?; assert!(!io.get_ref().1.is_early_data_accepted()); + assert_eq!("LATE:hello", String::from_utf8_lossy(&buf)); - let io = send(config, addr, b"world!").await?; + let (io, buf) = send(config, addr, b"world!", vectored).await?; assert!(io.get_ref().1.is_early_data_accepted()); - - let stdout = handle.0.stdout.as_mut().unwrap(); - let mut lines = BufReader::new(stdout).lines(); - - let has_msg1 = lines.by_ref().any(|line| line.unwrap().contains("hello")); - let has_msg2 = lines.by_ref().any(|line| line.unwrap().contains("world!")); - - assert!(has_msg1 && has_msg2); + assert_eq!("EARLY:world!LATE:", String::from_utf8_lossy(&buf)); Ok(()) } + +// Include `utils` module +include!("utils.rs"); diff --git a/tests/utils.rs b/tests/utils.rs index 8282a03a..94f7f12c 100644 --- a/tests/utils.rs +++ b/tests/utils.rs @@ -1,9 +1,10 @@ mod utils { - use std::io::{BufReader, Cursor}; + use std::io::{BufReader, Cursor, IoSlice}; use std::sync::Arc; use rustls::{ClientConfig, RootCertStore, ServerConfig}; use rustls_pemfile::{certs, rsa_private_keys}; + use tokio::io::{self, AsyncWrite, AsyncWriteExt}; #[allow(dead_code)] pub fn make_configs() -> (Arc, Arc) { @@ -35,4 +36,29 @@ mod utils { (Arc::new(sconfig), Arc::new(cconfig)) } + + #[allow(dead_code)] + pub async fn write( + w: &mut W, + data: &[u8], + vectored: bool, + ) -> io::Result<()> { + if !vectored { + return w.write_all(data).await; + } + + let mut data = data; + + while !data.is_empty() { + let chunk_size = (data.len() / 4).max(1); + let vectors = data + .chunks(chunk_size) + .map(IoSlice::new) + .collect::>(); + let written = w.write_vectored(&vectors).await?; + data = &data[written..]; + } + + Ok(()) + } }