Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP) feat: add support for processing handshake packets async via vacation #99

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ jobs:
run: |
cargo test --locked --all
cargo test --locked -p tokio-rustls --features early-data --test early-data
# we run all test suites against this feature
# to capture any regressions that come from changes to the handshake future state machine
cargo test --locked -p tokio-rustls --features vacation

lints:
name: Lints
Expand Down
25 changes: 23 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@ rust-version = "1.70"
exclude = ["/.github", "/examples", "/scripts"]

[dependencies]
vacation = { version = "0.1", optional = true, default-features = false }
pin-project-lite = { version = "0.2.15", optional = true }
rustls = { version = "0.23.15", default-features = false, features = ["std"] }
tokio = "1.0"

[features]
default = ["logging", "tls12", "aws_lc_rs"]
aws_lc_rs = ["rustls/aws_lc_rs"]
aws-lc-rs = ["aws_lc_rs"] # Alias because Cargo features commonly use `-`
vacation = ["dep:vacation", "pin-project-lite"]
early-data = []
fips = ["rustls/fips"]
logging = ["rustls/logging"]
Expand All @@ -32,4 +35,5 @@ futures-util = "0.3.1"
lazy_static = "1.1"
rcgen = { version = "0.13", features = ["pem"] }
tokio = { version = "1.0", features = ["full"] }
vacation = { version = "0.1", features = ["tokio"] }
webpki-roots = "0.26"
39 changes: 38 additions & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,15 @@ where
self.get_ref().0.as_raw_socket()
}
}
#[cfg(feature = "early-data")]
type TlsStreamExtras = Option<Waker>;
#[cfg(not(feature = "early-data"))]
type TlsStreamExtras = ();

impl<IO> IoSession for TlsStream<IO> {
type Io = IO;
type Session = ClientConnection;
type Extras = TlsStreamExtras;

#[inline]
fn skip_handshake(&self) -> bool {
Expand All @@ -80,6 +85,35 @@ impl<IO> IoSession for TlsStream<IO> {
fn into_io(self) -> Self::Io {
self.io
}

#[inline]
fn into_inner(self) -> (TlsState, Self::Io, Self::Session, Self::Extras) {
#[cfg(feature = "early-data")]
return (self.state, self.io, self.session, self.early_waker);

#[cfg(not(feature = "early-data"))]
(self.state, self.io, self.session, ())
}

#[inline]
#[allow(unused_variables)]
fn from_inner(
state: TlsState,
io: Self::Io,
session: Self::Session,
extras: Self::Extras,
) -> Self {
#[cfg(feature = "early-data")]
return Self {
io,
session,
state,
early_waker: extras,
};

#[cfg(not(feature = "early-data"))]
Self { io, session, state }
}
}

impl<IO> AsyncRead for TlsStream<IO>
Expand Down Expand Up @@ -254,6 +288,8 @@ fn poll_handle_early_data<IO>(
where
IO: AsyncRead + AsyncWrite + Unpin,
{
use crate::common::PacketProcessingMode;

if let TlsState::EarlyData(pos, data) = state {
use std::io::Write;

Expand Down Expand Up @@ -287,7 +323,8 @@ where

// complete handshake
while stream.session.is_handshaking() {
ready!(stream.handshake(cx))?;
// TODO: also model as using `vacation` executor
ready!(stream.handshake(cx, PacketProcessingMode::Sync))?;
}

// write early data (fallback)
Expand Down
130 changes: 130 additions & 0 deletions src/common/async_session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
use std::{
future::Future,
io,
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
};

use pin_project_lite::pin_project;
use rustls::{ConnectionCommon, SideData};
use tokio::io::{AsyncRead, AsyncWrite};

use crate::common::IoSession;

use super::{Stream, TlsState};

/// Full result of sync closure
type SessionResult<S> = Result<S, (Option<S>, io::Error)>;
/// Executor result wrapping sync closure result
type ExecutorResult<S> = Result<SessionResult<S>, vacation::Error>;
/// Future wrapping waiting on executor
type SessionFuture<S> = Box<dyn Future<Output = ExecutorResult<S>> + Unpin + Send>;

pin_project! {
/// Session is off doing compute-heavy sync work, such as initializing the session or processing handshake packets.
/// Might be on another thread / external threadpool.
///
/// This future sleeps on it in current worker thread until it completes.
pub(crate) struct AsyncSession<IS: IoSession> {
#[pin]
future: SessionFuture<IS::Session>,
io: IS::Io,
state: TlsState,
extras: IS::Extras,
}
}

impl<IS, SD> AsyncSession<IS>
where
IS: IoSession + Unpin,
IS::Io: AsyncRead + AsyncWrite + Unpin,
IS::Session: DerefMut + Deref<Target = ConnectionCommon<SD>> + Unpin + Send + 'static,
SD: SideData,
{
pub(crate) fn process_packets(stream: IS) -> Self {
let (state, io, mut session, extras) = stream.into_inner();

let closure = move || match session.process_new_packets() {
Ok(_) => Ok(session),
Err(err) => Err((
Some(session),
io::Error::new(io::ErrorKind::InvalidData, err),
)),
};

// TODO: if we ever start also delegating non-handshake byte processing, make this chance of blocking
// variable and set by caller
let future = vacation::execute(closure, vacation::ChanceOfBlocking::High);

Self {
future: Box::new(Box::pin(future)),
io,
state,
extras,
}
}

pub(crate) fn into_stream(
mut self,
session_result: Result<IS::Session, (Option<IS::Session>, io::Error)>,
cx: &mut Context<'_>,
) -> Result<IS, (io::Error, IS::Io)> {
match session_result {
Ok(session) => Ok(IS::from_inner(self.state, self.io, session, self.extras)),
Err((Some(mut session), err)) => {
// In case we have an alert to send describing this error,
// try a last-gasp write -- but don't predate the primary
// error.
let mut tls_stream: Stream<'_, <IS as IoSession>::Io, <IS as IoSession>::Session> =
Stream::new(&mut self.io, &mut session).set_eof(!self.state.readable());
let _ = tls_stream.write_io(cx);

// still drop the tls session and return the io error only
Err((err, self.io))
}
Err((None, err)) => Err((err, self.io)),
}
}

#[inline]
pub fn get_ref(&self) -> &IS::Io {
&self.io
}

#[inline]
pub fn get_mut(&mut self) -> &mut IS::Io {
&mut self.io
}
}

impl<IS, SD> Future for AsyncSession<IS>
where
IS: IoSession + Unpin,
IS::Session: DerefMut + Deref<Target = ConnectionCommon<SD>> + Unpin + Send + 'static,
SD: SideData,
{
type Output = Result<IS::Session, (Option<IS::Session>, io::Error)>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();

match ready!(this.future.as_mut().poll(cx)) {
Ok(session_res) => match session_res {
Ok(res) => Poll::Ready(Ok(res)),
// return any session along with the error,
// so the caller can flush any remaining alerts in buffer to i/o
Err((session, err)) => Poll::Ready(Err((
session,
io::Error::new(io::ErrorKind::InvalidData, err),
))),
},
// We don't have a session to flush here because the executor ate it
// TODO: not all errors should be modeled as io
Err(executor_error) => Poll::Ready(Err((
None,
io::Error::new(io::ErrorKind::Other, executor_error),
))),
}
}
}
Loading
Loading