1#![doc = include_str!("../README.md")]
2
3use std::task::Context;
4use std::task::Poll;
5
6use fastrace::prelude::*;
7use http::HeaderValue;
8use http::Request;
9use tower_layer::Layer;
10use tower_service::Service;
11
12pub const TRACEPARENT_HEADER: &str = "traceparent";
17
18#[derive(Clone)]
24pub struct FastraceServerLayer;
25
26impl<S> Layer<S> for FastraceServerLayer {
27 type Service = FastraceServerService<S>;
28
29 fn layer(&self, service: S) -> Self::Service {
30 FastraceServerService { service }
31 }
32}
33
34#[derive(Clone)]
40pub struct FastraceServerService<S> {
41 service: S,
42}
43
44impl<S, Body> Service<Request<Body>> for FastraceServerService<S>
45where S: Service<Request<Body>>
46{
47 type Response = S::Response;
48 type Error = S::Error;
49 type Future = fastrace::future::InSpan<S::Future>;
50
51 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
52 self.service.poll_ready(cx)
53 }
54
55 fn call(&mut self, req: Request<Body>) -> Self::Future {
56 let headers = req.headers();
57 let parent = headers
58 .get(TRACEPARENT_HEADER)
59 .and_then(|traceparent| SpanContext::decode_w3c_traceparent(traceparent.to_str().ok()?))
60 .unwrap_or(SpanContext::random());
61 let root = Span::root(req.uri().to_string(), parent);
62 self.service.call(req).in_span(root)
63 }
64}
65
66#[derive(Clone)]
72pub struct FastraceClientLayer;
73
74impl<S> Layer<S> for FastraceClientLayer {
75 type Service = FastraceClientService<S>;
76
77 fn layer(&self, service: S) -> Self::Service {
78 FastraceClientService { service }
79 }
80}
81
82#[derive(Clone)]
87pub struct FastraceClientService<S> {
88 service: S,
89}
90
91impl<S, Body> Service<Request<Body>> for FastraceClientService<S>
92where S: Service<Request<Body>>
93{
94 type Response = S::Response;
95 type Error = S::Error;
96 type Future = S::Future;
97
98 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
99 self.service.poll_ready(cx)
100 }
101
102 fn call(&mut self, mut req: Request<Body>) -> Self::Future {
103 if let Some(current) = SpanContext::current_local_parent() {
104 req.headers_mut().insert(
105 TRACEPARENT_HEADER,
106 HeaderValue::from_str(¤t.encode_w3c_traceparent()).unwrap(),
107 );
108 }
109 self.service.call(req)
110 }
111}