diff --git a/lib/cdk-rust/src/transport/mod.rs b/lib/cdk-rust/src/transport/mod.rs index b7f552a12..f075d5487 100644 --- a/lib/cdk-rust/src/transport/mod.rs +++ b/lib/cdk-rust/src/transport/mod.rs @@ -1,3 +1,4 @@ +pub mod tcp; pub mod webtransport; use anyhow::Result; diff --git a/lib/cdk-rust/src/transport/tcp/mod.rs b/lib/cdk-rust/src/transport/tcp/mod.rs new file mode 100644 index 000000000..e5f745783 --- /dev/null +++ b/lib/cdk-rust/src/transport/tcp/mod.rs @@ -0,0 +1,81 @@ +use std::io; +use std::net::SocketAddr; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use async_trait::async_trait; +use bytes::Bytes; +use futures::{Sink, Stream}; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +use tokio::net::TcpStream; +use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; + +use crate::transport::{Message, Transport}; + +struct TcpTransport { + target: SocketAddr, +} + +#[async_trait] +impl Transport for TcpTransport { + type Sender = TcpSender; + type Receiver = TcpReceiver; + + async fn connect(&self) -> anyhow::Result<(Self::Sender, Self::Receiver)> { + let stream = TcpStream::connect(self.target).await?; + let (rx, tx) = stream.into_split(); + Ok(( + TcpSender { + inner: FramedWrite::new(tx, LengthDelimitedCodec::new()), + }, + TcpReceiver { + inner: FramedRead::new(rx, LengthDelimitedCodec::new()), + }, + )) + } +} + +pub struct TcpSender { + inner: FramedWrite, +} + +impl Sink for TcpSender { + type Error = io::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_ready(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { + Pin::new(&mut self.inner).start_send(item) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_close(cx) + } +} + +pub struct TcpReceiver { + inner: FramedRead, +} + +impl Stream for TcpReceiver { + type Item = Message; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner) + .poll_next(cx) + .map(|opt| match opt { + None => None, + Some(Ok(bytes)) => Some(Bytes::from(bytes)), + Some(Err(e)) => { + log::error!("unexpected error in receiving stream: {e:?}"); + None + }, + }) + } +}