1use core::future::Future;
4use core::mem;
5use core::pin::pin;
6use core::time::Duration;
7
8use anyhow::Context as _;
9use bytes::{Bytes, BytesMut};
10use futures::TryStreamExt as _;
11use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _};
12use tokio::{select, try_join};
13use tokio_util::codec::{Encoder as _, FramedRead};
14use tracing::{debug, instrument, trace, Instrument as _};
15
16use crate::{Deferred as _, Incoming, Index, TupleDecode, TupleEncode};
17
18pub trait Invoke: Send + Sync {
20 type Context: Send + Sync;
22
23 type Outgoing: AsyncWrite + Index<Self::Outgoing> + Send + Sync + Unpin + 'static;
25
26 type Incoming: AsyncRead + Index<Self::Incoming> + Send + Sync + Unpin + 'static;
28
29 fn invoke<P>(
69 &self,
70 cx: Self::Context,
71 instance: &str,
72 func: &str,
73 params: Bytes,
74 paths: impl AsRef<[P]> + Send,
75 ) -> impl Future<Output = anyhow::Result<(Self::Outgoing, Self::Incoming)>> + Send
76 where
77 P: AsRef<[Option<usize>]> + Send + Sync;
78}
79
80#[derive(Clone, Copy, Debug, Eq, PartialEq)]
82pub struct Timeout<'a, T: ?Sized> {
83 pub inner: &'a T,
85 pub timeout: Duration,
87}
88
89impl<T: Invoke> Invoke for Timeout<'_, T> {
90 type Context = T::Context;
91 type Outgoing = T::Outgoing;
92 type Incoming = T::Incoming;
93
94 #[instrument(level = "trace", skip(self, cx, params, paths))]
95 async fn invoke<P>(
96 &self,
97 cx: Self::Context,
98 instance: &str,
99 func: &str,
100 params: Bytes,
101 paths: impl AsRef<[P]> + Send,
102 ) -> anyhow::Result<(Self::Outgoing, Self::Incoming)>
103 where
104 P: AsRef<[Option<usize>]> + Send + Sync,
105 {
106 tokio::time::timeout(
107 self.timeout,
108 self.inner.invoke(cx, instance, func, params, paths),
109 )
110 .await
111 .context("invocation timed out")?
112 }
113}
114
115#[derive(Clone, Copy, Debug, Eq, PartialEq)]
117pub struct TimeoutOwned<T> {
118 pub inner: T,
120 pub timeout: Duration,
122}
123
124impl<T: Invoke> Invoke for TimeoutOwned<T> {
125 type Context = T::Context;
126 type Outgoing = T::Outgoing;
127 type Incoming = T::Incoming;
128
129 #[instrument(level = "trace", skip(self, cx, params, paths))]
130 async fn invoke<P>(
131 &self,
132 cx: Self::Context,
133 instance: &str,
134 func: &str,
135 params: Bytes,
136 paths: impl AsRef<[P]> + Send,
137 ) -> anyhow::Result<(Self::Outgoing, Self::Incoming)>
138 where
139 P: AsRef<[Option<usize>]> + Send + Sync,
140 {
141 self.inner
142 .timeout(self.timeout)
143 .invoke(cx, instance, func, params, paths)
144 .await
145 }
146}
147
148pub trait InvokeExt: Invoke {
150 #[instrument(level = "trace", skip(self, cx, params, paths))]
152 fn invoke_values<P, Params, Results>(
153 &self,
154 cx: Self::Context,
155 instance: &str,
156 func: &str,
157 params: Params,
158 paths: impl AsRef<[P]> + Send,
159 ) -> impl Future<
160 Output = anyhow::Result<(
161 Results,
162 Option<impl Future<Output = anyhow::Result<()>> + Send + 'static>,
163 )>,
164 > + Send
165 where
166 P: AsRef<[Option<usize>]> + Send + Sync,
167 Params: TupleEncode<Self::Outgoing> + Send,
168 Results: TupleDecode<Self::Incoming> + Send,
169 <Params::Encoder as tokio_util::codec::Encoder<Params>>::Error:
170 std::error::Error + Send + Sync + 'static,
171 <Results::Decoder as tokio_util::codec::Decoder>::Error:
172 std::error::Error + Send + Sync + 'static,
173 {
174 async {
175 let mut buf = BytesMut::default();
176 let mut enc = Params::Encoder::default();
177 trace!("encoding parameters");
178 enc.encode(params, &mut buf)
179 .context("failed to encode parameters")?;
180 debug!("invoking function");
181 let (mut outgoing, incoming) = self
182 .invoke(cx, instance, func, buf.freeze(), paths)
183 .await
184 .context("failed to invoke function")?;
185 trace!("shutdown synchronous parameter channel");
186 outgoing
187 .shutdown()
188 .await
189 .context("failed to shutdown synchronous parameter channel")?;
190 let mut tx = enc.take_deferred().map(|tx| {
191 tokio::spawn(
192 async {
193 debug!("transmitting async parameters");
194 tx(outgoing, Vec::default())
195 .await
196 .context("failed to write async parameters")
197 }
198 .in_current_span(),
199 )
200 });
201
202 let mut dec = FramedRead::new(incoming, Results::Decoder::default());
203 let results = async {
204 debug!("receiving sync results");
205 dec.try_next()
206 .await
207 .context("failed to receive sync results")?
208 .context("incomplete results")
209 };
210 let results = if let Some(mut fut) = tx.take() {
211 let mut results = pin!(results);
212 select! {
213 res = &mut results => {
214 tx = Some(fut);
215 res?
216 }
217 res = &mut fut => {
218 res??;
219 results.await?
220 }
221 }
222 } else {
223 results.await?
224 };
225 trace!("received sync results");
226 let buffer = mem::take(dec.read_buffer_mut());
227 let rx = dec.decoder_mut().take_deferred();
228 let incoming = Incoming {
229 buffer,
230 inner: dec.into_inner(),
231 };
232 Ok((
233 results,
234 (tx.is_some() || rx.is_some()).then_some(
235 async {
236 match (tx, rx) {
237 (Some(tx), Some(rx)) => {
238 try_join!(
239 async {
240 debug!("receiving async results");
241 rx(incoming, Vec::default())
242 .await
243 .context("receiving async results failed")
244 },
245 async {
246 tx.await.context("transmitting async parameters failed")?
247 }
248 )?;
249 }
250 (Some(tx), None) => {
251 tx.await.context("transmitting async parameters failed")??;
252 }
253 (None, Some(rx)) => {
254 debug!("receiving async results");
255 rx(incoming, Vec::default())
256 .await
257 .context("receiving async results failed")?;
258 }
259 _ => {}
260 }
261 Ok(())
262 }
263 .in_current_span(),
264 ),
265 ))
266 }
267 }
268
269 #[instrument(level = "trace", skip_all)]
272 fn invoke_values_blocking<P, Params, Results>(
273 &self,
274 cx: Self::Context,
275 instance: &str,
276 func: &str,
277 params: Params,
278 paths: impl AsRef<[P]> + Send,
279 ) -> impl Future<Output = anyhow::Result<Results>> + Send
280 where
281 P: AsRef<[Option<usize>]> + Send + Sync,
282 Params: TupleEncode<Self::Outgoing> + Send,
283 Results: TupleDecode<Self::Incoming> + Send,
284 <Params::Encoder as tokio_util::codec::Encoder<Params>>::Error:
285 std::error::Error + Send + Sync + 'static,
286 <Results::Decoder as tokio_util::codec::Decoder>::Error:
287 std::error::Error + Send + Sync + 'static,
288 {
289 async {
290 let (ret, io) = self
291 .invoke_values(cx, instance, func, params, paths)
292 .await?;
293 if let Some(io) = io {
294 trace!("awaiting I/O completion");
295 io.await?;
296 }
297 Ok(ret)
298 }
299 }
300
301 fn timeout(&self, timeout: Duration) -> Timeout<'_, Self> {
304 Timeout {
305 inner: self,
306 timeout,
307 }
308 }
309
310 fn timeout_owned(self, timeout: Duration) -> TimeoutOwned<Self>
312 where
313 Self: Sized,
314 {
315 TimeoutOwned {
316 inner: self,
317 timeout,
318 }
319 }
320}
321
322impl<T: Invoke> InvokeExt for T {}
323
324#[allow(dead_code)]
325#[cfg(test)]
326mod tests {
327 use core::future::Future;
328 use core::pin::Pin;
329
330 use std::sync::Arc;
331
332 use bytes::Bytes;
333 use futures::{Stream, StreamExt as _};
334 use send_future::SendFuture as _;
335
336 use super::*;
337
338 #[allow(clippy::manual_async_fn)]
339 fn invoke_values_send<T>() -> impl Future<
340 Output = anyhow::Result<(
341 Pin<Box<dyn Stream<Item = Vec<Pin<Box<dyn Future<Output = String> + Send>>>> + Send>>,
342 )>,
343 > + Send
344 where
345 T: Invoke<Context = ()> + Default,
346 {
347 async {
348 let wrpc = T::default();
349 let ((r0,), _) = wrpc
350 .invoke_values(
351 (),
352 "wrpc-test:integration/async",
353 "with-streams",
354 (),
355 [[None].as_slice()],
356 )
357 .send()
358 .await?;
359 Ok(r0)
360 }
361 }
362
363 async fn call_invoke<T: Invoke>(
364 i: &T,
365 cx: T::Context,
366 paths: Arc<[Arc<[Option<usize>]>]>,
367 ) -> anyhow::Result<(T::Outgoing, T::Incoming)> {
368 i.invoke(cx, "foo", "bar", Bytes::default(), &paths).await
369 }
370
371 async fn call_invoke_async<T>() -> anyhow::Result<(Pin<Box<dyn Stream<Item = Bytes> + Send>>,)>
372 where
373 T: Invoke<Context = ()> + Default,
374 {
375 let wrpc = T::default();
376 let ((r0,), _) = wrpc
377 .invoke_values(
378 (),
379 "wrpc-test:integration/async",
380 "with-streams",
381 (),
382 [
383 [Some(1), Some(2)].as_slice(),
384 [None].as_slice(),
385 [Some(42)].as_slice(),
386 ],
387 )
388 .await?;
389 Ok(r0)
390 }
391
392 trait Handler {
393 fn foo() -> impl Future<Output = anyhow::Result<()>>;
394 }
395
396 impl<T> Handler for T
397 where
398 T: Invoke<Context = ()> + Default,
399 {
400 async fn foo() -> anyhow::Result<()> {
401 let (st,) = call_invoke_async::<Self>().await?;
402 st.collect::<Vec<_>>().await;
403 Ok(())
404 }
405 }
406}