grafos_observe/
propagation.rs

1//! Trace context propagation helpers for grafos-rpc and grafos-mq.
2//!
3//! This module provides functions to inject and extract trace context from
4//! binary headers (for RPC shared-memory transport) and from message headers
5//! (for MQ message passing).
6//!
7//! The RPC propagation extends the shared-memory request region header with
8//! a 32-byte trace context field. MQ propagation uses the `"traceparent"`
9//! header key in the message's `headers` field.
10
11use crate::trace::{TraceContext, TraceContextError, TRACE_CONTEXT_BYTES};
12
13/// The W3C header name used for MQ trace propagation.
14pub const TRACEPARENT_HEADER: &str = "traceparent";
15
16// ---------------------------------------------------------------------------
17// Binary propagation (for RPC shared-memory header extension)
18// ---------------------------------------------------------------------------
19
20/// Encode a trace context into a 32-byte buffer suitable for embedding in
21/// an RPC shared-memory header.
22///
23/// Returns all zeros if the context is empty (backward-compatible: old
24/// servers ignore zero bytes in the trace context field).
25pub fn encode_binary(ctx: &TraceContext) -> [u8; TRACE_CONTEXT_BYTES] {
26    if ctx.is_empty() {
27        [0u8; TRACE_CONTEXT_BYTES]
28    } else {
29        ctx.encode()
30    }
31}
32
33/// Decode a trace context from a 32-byte buffer read from an RPC header.
34///
35/// Returns `None` if the buffer is all zeros (no trace context was sent,
36/// backward-compatible with old clients).
37pub fn decode_binary(buf: &[u8; TRACE_CONTEXT_BYTES]) -> Option<TraceContext> {
38    // All-zero means no trace context
39    if buf.iter().all(|&b| b == 0) {
40        return None;
41    }
42    TraceContext::decode(buf).ok()
43}
44
45// ---------------------------------------------------------------------------
46// Header-based propagation (for MQ messages)
47// ---------------------------------------------------------------------------
48
49/// Inject a trace context as a W3C traceparent string into a header list.
50///
51/// If a `"traceparent"` header already exists, it is replaced.
52pub fn inject_traceparent(
53    headers: &mut alloc::vec::Vec<(alloc::string::String, alloc::vec::Vec<u8>)>,
54    ctx: &TraceContext,
55) {
56    let value = ctx.to_w3c_string().into_bytes();
57    if let Some(entry) = headers.iter_mut().find(|(k, _)| k == TRACEPARENT_HEADER) {
58        entry.1 = value;
59    } else {
60        headers.push((alloc::string::String::from(TRACEPARENT_HEADER), value));
61    }
62}
63
64/// Extract a trace context from a header list.
65///
66/// Looks for the `"traceparent"` header and parses it as a W3C traceparent
67/// string. Returns `None` if the header is not present or cannot be parsed.
68pub fn extract_traceparent(
69    headers: &[(alloc::string::String, alloc::vec::Vec<u8>)],
70) -> Option<TraceContext> {
71    let entry = headers.iter().find(|(k, _)| k == TRACEPARENT_HEADER)?;
72    let s = core::str::from_utf8(&entry.1).ok()?;
73    TraceContext::from_w3c_string(s).ok()
74}
75
76/// Extract a trace context from a header list, returning the parse error
77/// if the header is present but malformed.
78pub fn extract_traceparent_strict(
79    headers: &[(alloc::string::String, alloc::vec::Vec<u8>)],
80) -> Result<Option<TraceContext>, TraceContextError> {
81    let entry = match headers.iter().find(|(k, _)| k == TRACEPARENT_HEADER) {
82        Some(e) => e,
83        None => return Ok(None),
84    };
85    let s = core::str::from_utf8(&entry.1).map_err(|_| TraceContextError::InvalidFormat)?;
86    TraceContext::from_w3c_string(s).map(Some)
87}
88
89/// Generate a fresh W3C traceparent string suitable for correlation.
90///
91/// Uses the system clock and a monotonic counter to produce unique trace IDs
92/// without requiring an external entropy source.
93#[cfg(feature = "std")]
94pub fn generate_trace_id() -> alloc::string::String {
95    use core::sync::atomic::{AtomicU64, Ordering};
96    static COUNTER: AtomicU64 = AtomicU64::new(1);
97
98    let cnt = COUNTER.fetch_add(1, Ordering::Relaxed);
99    let nanos = std::time::SystemTime::now()
100        .duration_since(std::time::UNIX_EPOCH)
101        .unwrap_or_default()
102        .as_nanos() as u64;
103
104    let mut seed = [0u8; 24];
105    seed[0..8].copy_from_slice(&nanos.to_be_bytes());
106    seed[8..16].copy_from_slice(&cnt.to_be_bytes());
107    seed[16..24].copy_from_slice(&(nanos ^ cnt).to_be_bytes());
108
109    let ctx = crate::trace::TraceContext::new_root(&seed);
110    ctx.to_w3c_string()
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use crate::trace::{SpanId, TraceContext, TraceId};
117    use alloc::string::String;
118    use alloc::vec;
119    use alloc::vec::Vec;
120
121    fn test_ctx() -> TraceContext {
122        TraceContext {
123            trace_id: TraceId(0x0102030405060708090a0b0c0d0e0f10),
124            span_id: SpanId(0x1112131415161718),
125            parent_span_id: SpanId::INVALID,
126            flags: 0x01,
127        }
128    }
129
130    #[test]
131    fn binary_roundtrip() {
132        let ctx = test_ctx();
133        let encoded = encode_binary(&ctx);
134        let decoded = decode_binary(&encoded).unwrap();
135        assert_eq!(decoded.trace_id, ctx.trace_id);
136        assert_eq!(decoded.span_id, ctx.span_id);
137        assert_eq!(decoded.flags, ctx.flags);
138    }
139
140    #[test]
141    fn binary_empty_context_is_zeros() {
142        let ctx = TraceContext::default();
143        let encoded = encode_binary(&ctx);
144        assert!(encoded.iter().all(|&b| b == 0));
145        assert!(decode_binary(&encoded).is_none());
146    }
147
148    #[test]
149    fn binary_all_zeros_returns_none() {
150        let buf = [0u8; TRACE_CONTEXT_BYTES];
151        assert!(decode_binary(&buf).is_none());
152    }
153
154    #[test]
155    fn mq_inject_extract_roundtrip() {
156        let ctx = test_ctx();
157        let mut headers: Vec<(String, Vec<u8>)> = Vec::new();
158
159        inject_traceparent(&mut headers, &ctx);
160        assert_eq!(headers.len(), 1);
161        assert_eq!(headers[0].0, TRACEPARENT_HEADER);
162
163        let extracted = extract_traceparent(&headers).unwrap();
164        assert_eq!(extracted.trace_id, ctx.trace_id);
165        assert_eq!(extracted.span_id, ctx.span_id);
166        assert_eq!(extracted.flags, ctx.flags);
167    }
168
169    #[test]
170    fn mq_inject_replaces_existing() {
171        let ctx1 = test_ctx();
172        let ctx2 = TraceContext {
173            trace_id: TraceId(0xAAAABBBBCCCCDDDDEEEEFFFF00001111),
174            span_id: SpanId(0x2222333344445555),
175            parent_span_id: SpanId::INVALID,
176            flags: 0x00,
177        };
178
179        let mut headers: Vec<(String, Vec<u8>)> = Vec::new();
180        inject_traceparent(&mut headers, &ctx1);
181        inject_traceparent(&mut headers, &ctx2);
182
183        // Should still be only one traceparent header
184        assert_eq!(headers.len(), 1);
185
186        let extracted = extract_traceparent(&headers).unwrap();
187        assert_eq!(extracted.trace_id, ctx2.trace_id);
188    }
189
190    #[test]
191    fn mq_extract_missing_returns_none() {
192        let headers: Vec<(String, Vec<u8>)> =
193            vec![(String::from("other-header"), b"value".to_vec())];
194        assert!(extract_traceparent(&headers).is_none());
195    }
196
197    #[test]
198    fn mq_extract_malformed_returns_none() {
199        let headers: Vec<(String, Vec<u8>)> = vec![(
200            String::from(TRACEPARENT_HEADER),
201            b"not-a-valid-traceparent".to_vec(),
202        )];
203        assert!(extract_traceparent(&headers).is_none());
204    }
205
206    #[test]
207    fn mq_extract_strict_malformed_returns_error() {
208        let headers: Vec<(String, Vec<u8>)> = vec![(
209            String::from(TRACEPARENT_HEADER),
210            b"not-a-valid-traceparent".to_vec(),
211        )];
212        let result = extract_traceparent_strict(&headers);
213        assert!(result.is_err());
214    }
215
216    #[test]
217    fn mq_extract_strict_missing_returns_ok_none() {
218        let headers: Vec<(String, Vec<u8>)> = Vec::new();
219        assert_eq!(extract_traceparent_strict(&headers).unwrap(), None);
220    }
221
222    #[test]
223    fn mq_preserves_other_headers() {
224        let ctx = test_ctx();
225        let mut headers: Vec<(String, Vec<u8>)> = vec![
226            (String::from("content-type"), b"application/json".to_vec()),
227            (String::from("x-custom"), b"value".to_vec()),
228        ];
229
230        inject_traceparent(&mut headers, &ctx);
231        assert_eq!(headers.len(), 3);
232        assert_eq!(headers[0].0, "content-type");
233        assert_eq!(headers[1].0, "x-custom");
234        assert_eq!(headers[2].0, TRACEPARENT_HEADER);
235    }
236
237    #[test]
238    fn generate_trace_id_length_and_uniqueness() {
239        let id1 = super::generate_trace_id();
240        let id2 = super::generate_trace_id();
241        // W3C traceparent is 55 chars: "00-<32 hex>-<16 hex>-<2 hex>"
242        assert_eq!(id1.len(), 55, "trace_id should be 55 chars");
243        assert_eq!(id2.len(), 55, "trace_id should be 55 chars");
244        assert_ne!(id1, id2, "two calls should produce different trace IDs");
245    }
246}