Skip to main content

tower/util/call_all/
common.rs

1use futures_core::Stream;
2use pin_project_lite::pin_project;
3use std::{
4    fmt,
5    future::Future,
6    pin::Pin,
7    task::{ready, Context, Poll},
8};
9use tower_service::Service;
10
11pin_project! {
12    /// The [`Future`] returned by the [`ServiceExt::call_all`] combinator.
13    pub(crate) struct CallAll<Svc, S, Q>
14    where
15        S: Stream,
16    {
17        service: Option<Svc>,
18        #[pin]
19        stream: S,
20        queue: Q,
21        eof: bool,
22        curr_req: Option<S::Item>
23    }
24}
25
26impl<Svc, S, Q> fmt::Debug for CallAll<Svc, S, Q>
27where
28    Svc: fmt::Debug,
29    S: Stream + fmt::Debug,
30{
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        f.debug_struct("CallAll")
33            .field("service", &self.service)
34            .field("stream", &self.stream)
35            .field("eof", &self.eof)
36            .finish()
37    }
38}
39
40pub(crate) trait Drive<F: Future> {
41    fn is_empty(&self) -> bool;
42
43    fn push(&mut self, future: F);
44
45    fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<F::Output>>;
46}
47
48impl<Svc, S, Q> CallAll<Svc, S, Q>
49where
50    Svc: Service<S::Item>,
51    S: Stream,
52    Q: Drive<Svc::Future>,
53{
54    pub(crate) const fn new(service: Svc, stream: S, queue: Q) -> CallAll<Svc, S, Q> {
55        CallAll {
56            service: Some(service),
57            stream,
58            queue,
59            eof: false,
60            curr_req: None,
61        }
62    }
63
64    /// Extract the wrapped [`Service`].
65    pub(crate) fn into_inner(mut self) -> Svc {
66        self.service.take().expect("Service already taken")
67    }
68
69    /// Extract the wrapped [`Service`].
70    pub(crate) fn take_service(self: Pin<&mut Self>) -> Svc {
71        self.project()
72            .service
73            .take()
74            .expect("Service already taken")
75    }
76
77    /// Transition this `CallAll` instance into a `CallAllUnordered` stream.
78    ///
79    /// This conversion preserves the internal stream, backpressure flags (`eof`),
80    /// and any pulled but un-submitted request currently sitting in `curr_req`.
81    pub(crate) fn unordered(self) -> super::CallAllUnordered<Svc, S> {
82        // Ensure we don't discard any active, in-flight response futures.
83        assert!(self.queue.is_empty());
84
85        let CallAll {
86            service,
87            stream,
88            queue: _,
89            eof,
90            curr_req,
91        } = self;
92
93        // Reassemble the internal state machine while transitioning
94        // to an unordered concurrency driver.
95        let inner = CallAll {
96            service,
97            stream,
98            queue: futures_util::stream::FuturesUnordered::new(),
99            eof,
100            curr_req,
101        };
102
103        super::CallAllUnordered::from_inner(inner)
104    }
105}
106
107impl<Svc, S, Q> Stream for CallAll<Svc, S, Q>
108where
109    Svc: Service<S::Item>,
110    S: Stream,
111    Q: Drive<Svc::Future>,
112{
113    type Item = Result<Svc::Response, Svc::Error>;
114
115    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
116        let mut this = self.project();
117
118        loop {
119            // First, see if we have any responses to yield
120            if let Poll::Ready(r) = this.queue.poll(cx) {
121                if let Some(rsp) = r.transpose()? {
122                    return Poll::Ready(Some(Ok(rsp)));
123                }
124            }
125
126            // If there are no more requests coming, check if we're done
127            if *this.eof {
128                if this.queue.is_empty() {
129                    return Poll::Ready(None);
130                } else {
131                    return Poll::Pending;
132                }
133            }
134
135            // If not done, and we don't have a stored request, gather the next request from the
136            // stream (if there is one), or return `Pending` if the stream is not ready.
137            if this.curr_req.is_none() {
138                *this.curr_req = match ready!(this.stream.as_mut().poll_next(cx)) {
139                    Some(next_req) => Some(next_req),
140                    None => {
141                        // Mark that there will be no more requests.
142                        *this.eof = true;
143                        continue;
144                    }
145                };
146            }
147
148            // Then, see that the service is ready for another request
149            let svc = this
150                .service
151                .as_mut()
152                .expect("Using CallAll after extracting inner Service");
153
154            if let Err(e) = ready!(svc.poll_ready(cx)) {
155                // Set eof to prevent the service from being called again after a `poll_ready` error
156                *this.eof = true;
157                return Poll::Ready(Some(Err(e)));
158            }
159
160            // Unwrap: The check above always sets `this.curr_req` if none.
161            this.queue.push(svc.call(this.curr_req.take().unwrap()));
162        }
163    }
164}