diff --git a/futures-util/src/stream/try_stream/mod.rs b/futures-util/src/stream/try_stream/mod.rs index fc547237c..ba4abab30 100644 --- a/futures-util/src/stream/try_stream/mod.rs +++ b/futures-util/src/stream/try_stream/mod.rs @@ -164,6 +164,10 @@ mod into_async_read; #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 pub use self::into_async_read::IntoAsyncRead; +mod try_all; +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use self::try_all::TryAll; + impl TryStreamExt for S {} /// Adapters specific to `Result`-returning streams @@ -1071,4 +1075,33 @@ pub trait TryStreamExt: TryStream { { crate::io::assert_read(IntoAsyncRead::new(self)) } + + /// Attempt to execute a predicate over an asynchronous stream and evaluate if all items + /// satisfy the predicate. Exits early if an `Err` is encountered or if an `Ok` item is found + /// that does not satisfy the predicate. + /// + /// # Examples + /// + /// ``` + /// # futures::executor::block_on(async { + /// use futures::stream::{self, StreamExt, TryStreamExt}; + /// use std::convert::Infallible; + /// + /// let number_stream = stream::iter(1..10).map(Ok::<_, Infallible>); + /// let positive = number_stream.try_all(|i| async move { i > 0 }); + /// assert_eq!(positive.await, Ok(true)); + /// + /// let stream_with_errors = stream::iter([Ok(1), Err("err"), Ok(3)]); + /// let positive = stream_with_errors.try_all(|i| async move { i > 0 }); + /// assert_eq!(positive.await, Err("err")); + /// # }); + /// ``` + fn try_all(self, f: F) -> TryAll + where + Self: Sized, + F: FnMut(Self::Ok) -> Fut, + Fut: Future, + { + assert_future::, _>(TryAll::new(self, f)) + } } diff --git a/futures-util/src/stream/try_stream/try_all.rs b/futures-util/src/stream/try_stream/try_all.rs new file mode 100644 index 000000000..8179f86af --- /dev/null +++ b/futures-util/src/stream/try_stream/try_all.rs @@ -0,0 +1,98 @@ +use core::fmt; +use core::pin::Pin; +use futures_core::future::{FusedFuture, Future}; +use futures_core::ready; +use futures_core::stream::TryStream; +use futures_core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Future for the [`try_all`](super::TryStreamExt::try_all) method. + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct TryAll { + #[pin] + stream: St, + f: F, + done: bool, + #[pin] + future: Option, + } +} + +impl fmt::Debug for TryAll +where + St: fmt::Debug, + Fut: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TryAll") + .field("stream", &self.stream) + .field("done", &self.done) + .field("future", &self.future) + .finish() + } +} + +impl TryAll +where + St: TryStream, + F: FnMut(St::Ok) -> Fut, + Fut: Future, +{ + pub(super) fn new(stream: St, f: F) -> Self { + Self { stream, f, done: false, future: None } + } +} + +impl FusedFuture for TryAll +where + St: TryStream, + F: FnMut(St::Ok) -> Fut, + Fut: Future, +{ + fn is_terminated(&self) -> bool { + self.done && self.future.is_none() + } +} + +impl Future for TryAll +where + St: TryStream, + F: FnMut(St::Ok) -> Fut, + Fut: Future, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + Poll::Ready(loop { + if let Some(fut) = this.future.as_mut().as_pin_mut() { + // we're currently processing a future to produce a new value + let acc = ready!(fut.poll(cx)); + this.future.set(None); + if !acc { + *this.done = true; + break Ok(false); + } // early exit + } else if !*this.done { + // we're waiting on a new item from the stream + match ready!(this.stream.as_mut().try_poll_next(cx)) { + Some(Ok(item)) => { + this.future.set(Some((this.f)(item))); + } + Some(Err(err)) => { + *this.done = true; + break Err(err); + } + None => { + *this.done = true; + break Ok(true); + } + } + } else { + panic!("TryAll polled after completion") + } + }) + } +} diff --git a/futures/tests/stream_try_stream.rs b/futures/tests/stream_try_stream.rs index f67adfc08..97cb99d91 100644 --- a/futures/tests/stream_try_stream.rs +++ b/futures/tests/stream_try_stream.rs @@ -1,4 +1,5 @@ use core::pin::Pin; +use std::convert::Infallible; use futures::{ stream::{self, repeat, Repeat, StreamExt, TryStreamExt}, @@ -132,3 +133,29 @@ fn try_flatten_unordered() { assert_eq!(taken, 31); }) } + +async fn is_even(number: u8) -> bool { + number % 2 == 0 +} + +#[test] +fn try_all() { + block_on(async { + let empty: [Result; 0] = []; + let st = stream::iter(empty); + let all = st.try_all(is_even).await; + assert_eq!(Ok(true), all); + + let st = stream::iter([Ok::<_, Infallible>(2), Ok(4), Ok(6), Ok(8)]); + let all = st.try_all(is_even).await; + assert_eq!(Ok(true), all); + + let st = stream::iter([Ok::<_, Infallible>(2), Ok(3), Ok(4)]); + let all = st.try_all(is_even).await; + assert_eq!(Ok(false), all); + + let st = stream::iter([Ok(2), Ok(4), Err("err"), Ok(8)]); + let all = st.try_all(is_even).await; + assert_eq!(Err("err"), all); + }); +}