grafos_observe/
propagation.rs

1//! Trace context propagation helpers for grafos-rpc and grafos-mq.
2//!
3//! This module provides functions to inject and extract trace context from
4//! binary headers (for RPC shared-memory transport) and from message headers
5//! (for MQ message passing).
6//!
7//! The RPC propagation extends the shared-memory request region header with
8//! a 32-byte trace context field. MQ propagation uses the `"traceparent"`
9//! header key in the message's `headers` field.
10
11use crate::trace::{TraceContext, TraceContextError, TRACE_CONTEXT_BYTES};
12
13/// The W3C header name used for MQ trace propagation.
14pub const TRACEPARENT_HEADER: &str = "traceparent";
15
16// ---------------------------------------------------------------------------
17// Binary propagation (for RPC shared-memory header extension)
18// ---------------------------------------------------------------------------
19
20/// Encode a trace context into a 32-byte buffer suitable for embedding in
21/// an RPC shared-memory header.
22///
23/// Returns all zeros if the context is empty (backward-compatible: old
24/// servers ignore zero bytes in the trace context field).
25pub fn encode_binary(ctx: &TraceContext) -> [u8; TRACE_CONTEXT_BYTES] {
26    if ctx.is_empty() {
27        [0u8; TRACE_CONTEXT_BYTES]
28    } else {
29        ctx.encode()
30    }
31}
32
33/// Decode a trace context from a 32-byte buffer read from an RPC header.
34///
35/// Returns `None` if the buffer is all zeros (no trace context was sent,
36/// backward-compatible with old clients).
37pub fn decode_binary(buf: &[u8; TRACE_CONTEXT_BYTES]) -> Option<TraceContext> {
38    // All-zero means no trace context
39    if buf.iter().all(|&b| b == 0) {
40        return None;
41    }
42    TraceContext::decode(buf).ok()
43}
44
45// ---------------------------------------------------------------------------
46// Header-based propagation (for MQ messages)
47// ---------------------------------------------------------------------------
48
49/// Inject a trace context as a W3C traceparent string into a header list.
50///
51/// If a `"traceparent"` header already exists, it is replaced.
52pub fn inject_traceparent(
53    headers: &mut alloc::vec::Vec<(alloc::string::String, alloc::vec::Vec<u8>)>,
54    ctx: &TraceContext,
55) {
56    let value = ctx.to_w3c_string().into_bytes();
57    if let Some(entry) = headers.iter_mut().find(|(k, _)| k == TRACEPARENT_HEADER) {
58        entry.1 = value;
59    } else {
60        headers.push((alloc::string::String::from(TRACEPARENT_HEADER), value));
61    }
62}
63
64/// Extract a trace context from a header list.
65///
66/// Looks for the `"traceparent"` header and parses it as a W3C traceparent
67/// string. Returns `None` if the header is not present or cannot be parsed.
68pub fn extract_traceparent(
69    headers: &[(alloc::string::String, alloc::vec::Vec<u8>)],
70) -> Option<TraceContext> {
71    let entry = headers.iter().find(|(k, _)| k == TRACEPARENT_HEADER)?;
72    let s = core::str::from_utf8(&entry.1).ok()?;
73    TraceContext::from_w3c_string(s).ok()
74}
75
76/// Extract a trace context from a header list, returning the parse error
77/// if the header is present but malformed.
78pub fn extract_traceparent_strict(
79    headers: &[(alloc::string::String, alloc::vec::Vec<u8>)],
80) -> Result<Option<TraceContext>, TraceContextError> {
81    let mut matches = headers.iter().filter(|(k, _)| k == TRACEPARENT_HEADER);
82    let entry = match matches.next() {
83        Some(e) => e,
84        None => return Ok(None),
85    };
86    if matches.next().is_some() {
87        return Err(TraceContextError::InvalidFormat);
88    }
89    let s = core::str::from_utf8(&entry.1).map_err(|_| TraceContextError::InvalidFormat)?;
90    TraceContext::from_w3c_string(s).map(Some)
91}
92
93/// Generate a fresh W3C traceparent string suitable for correlation.
94///
95/// Uses the system clock and a monotonic counter to produce unique trace IDs
96/// without requiring an external entropy source.
97#[cfg(feature = "std")]
98pub fn generate_trace_id() -> alloc::string::String {
99    use core::sync::atomic::{AtomicU64, Ordering};
100    static COUNTER: AtomicU64 = AtomicU64::new(1);
101
102    let cnt = COUNTER.fetch_add(1, Ordering::Relaxed);
103    let nanos = std::time::SystemTime::now()
104        .duration_since(std::time::UNIX_EPOCH)
105        .unwrap_or_default()
106        .as_nanos() as u64;
107
108    let mut seed = [0u8; 24];
109    seed[0..8].copy_from_slice(&nanos.to_be_bytes());
110    seed[8..16].copy_from_slice(&cnt.to_be_bytes());
111    seed[16..24].copy_from_slice(&(nanos ^ cnt).to_be_bytes());
112
113    let ctx = crate::trace::TraceContext::new_root(&seed);
114    ctx.to_w3c_string()
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use crate::trace::{SpanId, TraceContext, TraceId};
121    use alloc::string::String;
122    use alloc::vec;
123    use alloc::vec::Vec;
124
125    fn test_ctx() -> TraceContext {
126        TraceContext {
127            trace_id: TraceId(0x0102030405060708090a0b0c0d0e0f10),
128            span_id: SpanId(0x1112131415161718),
129            parent_span_id: SpanId::INVALID,
130            flags: 0x01,
131        }
132    }
133
134    #[test]
135    fn binary_roundtrip() {
136        let ctx = test_ctx();
137        let encoded = encode_binary(&ctx);
138        let decoded = decode_binary(&encoded).unwrap();
139        assert_eq!(decoded.trace_id, ctx.trace_id);
140        assert_eq!(decoded.span_id, ctx.span_id);
141        assert_eq!(decoded.flags, ctx.flags);
142    }
143
144    #[test]
145    fn binary_empty_context_is_zeros() {
146        let ctx = TraceContext::default();
147        let encoded = encode_binary(&ctx);
148        assert!(encoded.iter().all(|&b| b == 0));
149        assert!(decode_binary(&encoded).is_none());
150    }
151
152    #[test]
153    fn binary_all_zeros_returns_none() {
154        let buf = [0u8; TRACE_CONTEXT_BYTES];
155        assert!(decode_binary(&buf).is_none());
156    }
157
158    #[test]
159    fn mq_inject_extract_roundtrip() {
160        let ctx = test_ctx();
161        let mut headers: Vec<(String, Vec<u8>)> = Vec::new();
162
163        inject_traceparent(&mut headers, &ctx);
164        assert_eq!(headers.len(), 1);
165        assert_eq!(headers[0].0, TRACEPARENT_HEADER);
166
167        let extracted = extract_traceparent(&headers).unwrap();
168        assert_eq!(extracted.trace_id, ctx.trace_id);
169        assert_eq!(extracted.span_id, ctx.span_id);
170        assert_eq!(extracted.flags, ctx.flags);
171    }
172
173    #[test]
174    fn mq_inject_replaces_existing() {
175        let ctx1 = test_ctx();
176        let ctx2 = TraceContext {
177            trace_id: TraceId(0xAAAABBBBCCCCDDDDEEEEFFFF00001111),
178            span_id: SpanId(0x2222333344445555),
179            parent_span_id: SpanId::INVALID,
180            flags: 0x00,
181        };
182
183        let mut headers: Vec<(String, Vec<u8>)> = Vec::new();
184        inject_traceparent(&mut headers, &ctx1);
185        inject_traceparent(&mut headers, &ctx2);
186
187        // Should still be only one traceparent header
188        assert_eq!(headers.len(), 1);
189
190        let extracted = extract_traceparent(&headers).unwrap();
191        assert_eq!(extracted.trace_id, ctx2.trace_id);
192    }
193
194    #[test]
195    fn mq_extract_missing_returns_none() {
196        let headers: Vec<(String, Vec<u8>)> =
197            vec![(String::from("other-header"), b"value".to_vec())];
198        assert!(extract_traceparent(&headers).is_none());
199    }
200
201    #[test]
202    fn mq_extract_malformed_returns_none() {
203        let headers: Vec<(String, Vec<u8>)> = vec![(
204            String::from(TRACEPARENT_HEADER),
205            b"not-a-valid-traceparent".to_vec(),
206        )];
207        assert!(extract_traceparent(&headers).is_none());
208    }
209
210    #[test]
211    fn mq_extract_strict_malformed_returns_error() {
212        let headers: Vec<(String, Vec<u8>)> = vec![(
213            String::from(TRACEPARENT_HEADER),
214            b"not-a-valid-traceparent".to_vec(),
215        )];
216        let result = extract_traceparent_strict(&headers);
217        assert!(result.is_err());
218    }
219
220    #[test]
221    fn mq_extract_strict_duplicate_returns_error() {
222        let ctx = test_ctx().to_w3c_string().into_bytes();
223        let headers: Vec<(String, Vec<u8>)> = vec![
224            (String::from(TRACEPARENT_HEADER), ctx.clone()),
225            (String::from(TRACEPARENT_HEADER), ctx),
226        ];
227        let result = extract_traceparent_strict(&headers);
228        assert!(matches!(result, Err(TraceContextError::InvalidFormat)));
229    }
230
231    #[test]
232    fn mq_extract_strict_missing_returns_ok_none() {
233        let headers: Vec<(String, Vec<u8>)> = Vec::new();
234        assert_eq!(extract_traceparent_strict(&headers).unwrap(), None);
235    }
236
237    #[test]
238    fn mq_preserves_other_headers() {
239        let ctx = test_ctx();
240        let mut headers: Vec<(String, Vec<u8>)> = vec![
241            (String::from("content-type"), b"application/json".to_vec()),
242            (String::from("x-custom"), b"value".to_vec()),
243        ];
244
245        inject_traceparent(&mut headers, &ctx);
246        assert_eq!(headers.len(), 3);
247        assert_eq!(headers[0].0, "content-type");
248        assert_eq!(headers[1].0, "x-custom");
249        assert_eq!(headers[2].0, TRACEPARENT_HEADER);
250    }
251
252    #[test]
253    fn generate_trace_id_length_and_uniqueness() {
254        let id1 = super::generate_trace_id();
255        let id2 = super::generate_trace_id();
256        // W3C traceparent is 55 chars: "00-<32 hex>-<16 hex>-<2 hex>"
257        assert_eq!(id1.len(), 55, "trace_id should be 55 chars");
258        assert_eq!(id2.len(), 55, "trace_id should be 55 chars");
259        assert_ne!(id1, id2, "two calls should produce different trace IDs");
260    }
261}