wrpc_transport/
invoke.rs

1//! wRPC transport client handle
2
3use core::future::Future;
4use core::mem;
5use core::pin::pin;
6use core::time::Duration;
7
8use anyhow::Context as _;
9use bytes::{Bytes, BytesMut};
10use futures::TryStreamExt as _;
11use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _};
12use tokio::{select, try_join};
13use tokio_util::codec::{Encoder as _, FramedRead};
14use tracing::{debug, instrument, trace, Instrument as _};
15
16use crate::{Deferred as _, Incoming, Index, TupleDecode, TupleEncode};
17
18/// Client-side handle to a wRPC transport
19pub trait Invoke: Send + Sync {
20    /// Transport-specific invocation context
21    type Context: Send + Sync;
22
23    /// Outgoing multiplexed byte stream
24    type Outgoing: AsyncWrite + Index<Self::Outgoing> + Send + Sync + Unpin + 'static;
25
26    /// Incoming multiplexed byte stream
27    type Incoming: AsyncRead + Index<Self::Incoming> + Send + Sync + Unpin + 'static;
28
29    /// Invoke function `func` on instance `instance`
30    ///
31    /// Note, that compilation of code calling methods on [`Invoke`] implementations within [`Send`] async functions
32    /// may fail with hard-to-debug errors due to a compiler bug:
33    /// [http://github.com/rust-lang/rust/issues/96865](http://github.com/rust-lang/rust/issues/96865)
34    ///
35    /// The following fails to compile with rustc 1.78.0:
36    ///
37    /// ```compile_fail
38    /// use core::future::Future;
39    ///
40    /// fn invoke_send<T>() -> impl Future<Output = anyhow::Result<(T::Outgoing, T::Incoming)>> + Send
41    /// where
42    ///     T: wrpc_transport::Invoke<Context = ()> + Default,
43    /// {
44    ///     async { T::default().invoke((), "compiler-bug", "free", "since".into(), [[Some(2024)].as_slice(); 0]).await }
45    /// }
46    /// ```
47    ///
48    /// ```text
49    /// async { T::default().invoke((), "compiler-bug", "free", "since".into(), [[Some(2024)].as_slice(); 0]).await }
50    /// |     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ implementation of `AsRef` is not general enough
51    ///  |
52    ///  = note: `[&'0 [Option<usize>]; 0]` must implement `AsRef<[&'1 [Option<usize>]]>`, for any two lifetimes `'0` and `'1`...
53    ///  = note: ...but it actually implements `AsRef<[&[Option<usize>]]>`
54    /// ```
55    ///
56    /// The fix is to call [`send`](send_future::SendFuture::send) provided by [`send_future::SendFuture`], re-exported by this crate, on the future before awaiting:
57    /// ```
58    /// use core::future::Future;
59    /// use wrpc_transport::SendFuture as _;
60    ///
61    /// fn invoke_send<T>() -> impl Future<Output = anyhow::Result<(T::Outgoing, T::Incoming)>> + Send
62    /// where
63    ///     T: wrpc_transport::Invoke<Context = ()> + Default,
64    /// {
65    ///     async { T::default().invoke((), "compiler-bug", "free", "since".into(), [[Some(2024)].as_slice(); 0]).send().await }
66    /// }
67    /// ```
68    fn invoke<P>(
69        &self,
70        cx: Self::Context,
71        instance: &str,
72        func: &str,
73        params: Bytes,
74        paths: impl AsRef<[P]> + Send,
75    ) -> impl Future<Output = anyhow::Result<(Self::Outgoing, Self::Incoming)>> + Send
76    where
77        P: AsRef<[Option<usize>]> + Send + Sync;
78}
79
80/// Wrapper struct returned by [`InvokeExt::timeout`]
81#[derive(Clone, Copy, Debug, Eq, PartialEq)]
82pub struct Timeout<'a, T: ?Sized> {
83    /// Inner [Invoke]
84    pub inner: &'a T,
85    /// Invocation timeout
86    pub timeout: Duration,
87}
88
89impl<T: Invoke> Invoke for Timeout<'_, T> {
90    type Context = T::Context;
91    type Outgoing = T::Outgoing;
92    type Incoming = T::Incoming;
93
94    #[instrument(level = "trace", skip(self, cx, params, paths))]
95    async fn invoke<P>(
96        &self,
97        cx: Self::Context,
98        instance: &str,
99        func: &str,
100        params: Bytes,
101        paths: impl AsRef<[P]> + Send,
102    ) -> anyhow::Result<(Self::Outgoing, Self::Incoming)>
103    where
104        P: AsRef<[Option<usize>]> + Send + Sync,
105    {
106        tokio::time::timeout(
107            self.timeout,
108            self.inner.invoke(cx, instance, func, params, paths),
109        )
110        .await
111        .context("invocation timed out")?
112    }
113}
114
115/// Wrapper struct returned by [`InvokeExt::timeout_owned`]
116#[derive(Clone, Copy, Debug, Eq, PartialEq)]
117pub struct TimeoutOwned<T> {
118    /// Inner [Invoke]
119    pub inner: T,
120    /// Invocation timeout
121    pub timeout: Duration,
122}
123
124impl<T: Invoke> Invoke for TimeoutOwned<T> {
125    type Context = T::Context;
126    type Outgoing = T::Outgoing;
127    type Incoming = T::Incoming;
128
129    #[instrument(level = "trace", skip(self, cx, params, paths))]
130    async fn invoke<P>(
131        &self,
132        cx: Self::Context,
133        instance: &str,
134        func: &str,
135        params: Bytes,
136        paths: impl AsRef<[P]> + Send,
137    ) -> anyhow::Result<(Self::Outgoing, Self::Incoming)>
138    where
139        P: AsRef<[Option<usize>]> + Send + Sync,
140    {
141        self.inner
142            .timeout(self.timeout)
143            .invoke(cx, instance, func, params, paths)
144            .await
145    }
146}
147
148/// Extension trait for [Invoke]
149pub trait InvokeExt: Invoke {
150    /// Invoke function `func` on instance `instance` using typed `Params` and `Results`
151    #[instrument(level = "trace", skip(self, cx, params, paths))]
152    fn invoke_values<P, Params, Results>(
153        &self,
154        cx: Self::Context,
155        instance: &str,
156        func: &str,
157        params: Params,
158        paths: impl AsRef<[P]> + Send,
159    ) -> impl Future<
160        Output = anyhow::Result<(
161            Results,
162            Option<impl Future<Output = anyhow::Result<()>> + Send + 'static>,
163        )>,
164    > + Send
165    where
166        P: AsRef<[Option<usize>]> + Send + Sync,
167        Params: TupleEncode<Self::Outgoing> + Send,
168        Results: TupleDecode<Self::Incoming> + Send,
169        <Params::Encoder as tokio_util::codec::Encoder<Params>>::Error:
170            std::error::Error + Send + Sync + 'static,
171        <Results::Decoder as tokio_util::codec::Decoder>::Error:
172            std::error::Error + Send + Sync + 'static,
173    {
174        async {
175            let mut buf = BytesMut::default();
176            let mut enc = Params::Encoder::default();
177            trace!("encoding parameters");
178            enc.encode(params, &mut buf)
179                .context("failed to encode parameters")?;
180            debug!("invoking function");
181            let (mut outgoing, incoming) = self
182                .invoke(cx, instance, func, buf.freeze(), paths)
183                .await
184                .context("failed to invoke function")?;
185            trace!("shutdown synchronous parameter channel");
186            outgoing
187                .shutdown()
188                .await
189                .context("failed to shutdown synchronous parameter channel")?;
190            let mut tx = enc.take_deferred().map(|tx| {
191                tokio::spawn(
192                    async {
193                        debug!("transmitting async parameters");
194                        tx(outgoing, Vec::default())
195                            .await
196                            .context("failed to write async parameters")
197                    }
198                    .in_current_span(),
199                )
200            });
201
202            let mut dec = FramedRead::new(incoming, Results::Decoder::default());
203            let results = async {
204                debug!("receiving sync results");
205                dec.try_next()
206                    .await
207                    .context("failed to receive sync results")?
208                    .context("incomplete results")
209            };
210            let results = if let Some(mut fut) = tx.take() {
211                let mut results = pin!(results);
212                select! {
213                    res = &mut results => {
214                        tx = Some(fut);
215                        res?
216                    }
217                    res = &mut fut => {
218                        res??;
219                        results.await?
220                    }
221                }
222            } else {
223                results.await?
224            };
225            trace!("received sync results");
226            let buffer = mem::take(dec.read_buffer_mut());
227            let rx = dec.decoder_mut().take_deferred();
228            let incoming = Incoming {
229                buffer,
230                inner: dec.into_inner(),
231            };
232            Ok((
233                results,
234                (tx.is_some() || rx.is_some()).then_some(
235                    async {
236                        match (tx, rx) {
237                            (Some(tx), Some(rx)) => {
238                                try_join!(
239                                    async {
240                                        debug!("receiving async results");
241                                        rx(incoming, Vec::default())
242                                            .await
243                                            .context("receiving async results failed")
244                                    },
245                                    async {
246                                        tx.await.context("transmitting async parameters failed")?
247                                    }
248                                )?;
249                            }
250                            (Some(tx), None) => {
251                                tx.await.context("transmitting async parameters failed")??;
252                            }
253                            (None, Some(rx)) => {
254                                debug!("receiving async results");
255                                rx(incoming, Vec::default())
256                                    .await
257                                    .context("receiving async results failed")?;
258                            }
259                            _ => {}
260                        }
261                        Ok(())
262                    }
263                    .in_current_span(),
264                ),
265            ))
266        }
267    }
268
269    /// Invoke function `func` on instance `instance` using typed `Params` and `Results`
270    /// This is like [`Self::invoke_values`], but it only results once all I/O is done
271    #[instrument(level = "trace", skip_all)]
272    fn invoke_values_blocking<P, Params, Results>(
273        &self,
274        cx: Self::Context,
275        instance: &str,
276        func: &str,
277        params: Params,
278        paths: impl AsRef<[P]> + Send,
279    ) -> impl Future<Output = anyhow::Result<Results>> + Send
280    where
281        P: AsRef<[Option<usize>]> + Send + Sync,
282        Params: TupleEncode<Self::Outgoing> + Send,
283        Results: TupleDecode<Self::Incoming> + Send,
284        <Params::Encoder as tokio_util::codec::Encoder<Params>>::Error:
285            std::error::Error + Send + Sync + 'static,
286        <Results::Decoder as tokio_util::codec::Decoder>::Error:
287            std::error::Error + Send + Sync + 'static,
288    {
289        async {
290            let (ret, io) = self
291                .invoke_values(cx, instance, func, params, paths)
292                .await?;
293            if let Some(io) = io {
294                trace!("awaiting I/O completion");
295                io.await?;
296            }
297            Ok(ret)
298        }
299    }
300
301    /// Returns a [`Timeout`], wrapping [Self] with an implementation of [Invoke], which will
302    /// error, if call to [`Invoke::invoke`] does not return within a supplied `timeout`
303    fn timeout(&self, timeout: Duration) -> Timeout<'_, Self> {
304        Timeout {
305            inner: self,
306            timeout,
307        }
308    }
309
310    /// This is like [`InvokeExt::timeout`], but moves [Self] and returns corresponding [`TimeoutOwned`]
311    fn timeout_owned(self, timeout: Duration) -> TimeoutOwned<Self>
312    where
313        Self: Sized,
314    {
315        TimeoutOwned {
316            inner: self,
317            timeout,
318        }
319    }
320}
321
322impl<T: Invoke> InvokeExt for T {}
323
324#[allow(dead_code)]
325#[cfg(test)]
326mod tests {
327    use core::future::Future;
328    use core::pin::Pin;
329
330    use std::sync::Arc;
331
332    use bytes::Bytes;
333    use futures::{Stream, StreamExt as _};
334    use send_future::SendFuture as _;
335
336    use super::*;
337
338    #[allow(clippy::manual_async_fn)]
339    fn invoke_values_send<T>() -> impl Future<
340        Output = anyhow::Result<(
341            Pin<Box<dyn Stream<Item = Vec<Pin<Box<dyn Future<Output = String> + Send>>>> + Send>>,
342        )>,
343    > + Send
344    where
345        T: Invoke<Context = ()> + Default,
346    {
347        async {
348            let wrpc = T::default();
349            let ((r0,), _) = wrpc
350                .invoke_values(
351                    (),
352                    "wrpc-test:integration/async",
353                    "with-streams",
354                    (),
355                    [[None].as_slice()],
356                )
357                .send()
358                .await?;
359            Ok(r0)
360        }
361    }
362
363    async fn call_invoke<T: Invoke>(
364        i: &T,
365        cx: T::Context,
366        paths: Arc<[Arc<[Option<usize>]>]>,
367    ) -> anyhow::Result<(T::Outgoing, T::Incoming)> {
368        i.invoke(cx, "foo", "bar", Bytes::default(), &paths).await
369    }
370
371    async fn call_invoke_async<T>() -> anyhow::Result<(Pin<Box<dyn Stream<Item = Bytes> + Send>>,)>
372    where
373        T: Invoke<Context = ()> + Default,
374    {
375        let wrpc = T::default();
376        let ((r0,), _) = wrpc
377            .invoke_values(
378                (),
379                "wrpc-test:integration/async",
380                "with-streams",
381                (),
382                [
383                    [Some(1), Some(2)].as_slice(),
384                    [None].as_slice(),
385                    [Some(42)].as_slice(),
386                ],
387            )
388            .await?;
389        Ok(r0)
390    }
391
392    trait Handler {
393        fn foo() -> impl Future<Output = anyhow::Result<()>>;
394    }
395
396    impl<T> Handler for T
397    where
398        T: Invoke<Context = ()> + Default,
399    {
400        async fn foo() -> anyhow::Result<()> {
401            let (st,) = call_invoke_async::<Self>().await?;
402            st.collect::<Vec<_>>().await;
403            Ok(())
404        }
405    }
406}