1extern crate alloc;
11
12use alloc::boxed::Box;
13use alloc::vec::Vec;
14use core::sync::atomic::{AtomicU64, Ordering};
15
16use grafos_std::error::Result;
17use grafos_std::mem::MemLease;
18
19use crate::mux::{RpcMuxClient, DEFAULT_NUM_SLOTS, DEFAULT_SLOT_PAYLOAD_SIZE};
20
21pub trait RpcTransport {
27 fn call(&self, method_id: u32, payload: &[u8]) -> Result<Vec<u8>>;
30}
31
32pub struct SharedMemoryTransport<'a> {
40 mux: RpcMuxClient<'a>,
41}
42
43impl<'a> SharedMemoryTransport<'a> {
44 pub fn new(lease: &'a MemLease) -> Self {
47 SharedMemoryTransport {
48 mux: RpcMuxClient::new(lease, DEFAULT_NUM_SLOTS, DEFAULT_SLOT_PAYLOAD_SIZE),
49 }
50 }
51
52 pub fn with_slots(lease: &'a MemLease, num_slots: usize, slot_payload_size: usize) -> Self {
54 SharedMemoryTransport {
55 mux: RpcMuxClient::new(lease, num_slots, slot_payload_size),
56 }
57 }
58}
59
60impl RpcTransport for SharedMemoryTransport<'_> {
61 fn call(&self, method_id: u32, payload: &[u8]) -> Result<Vec<u8>> {
62 self.mux.call(method_id, payload)
63 }
64}
65
66#[allow(clippy::type_complexity)]
89pub struct QuicTransport {
90 sender: Box<dyn Fn(u32, u64, &[u8]) -> Result<Vec<u8>> + Send + Sync>,
92 next_request_id: AtomicU64,
93}
94
95impl QuicTransport {
96 #[allow(clippy::type_complexity)]
101 pub fn new(sender: Box<dyn Fn(u32, u64, &[u8]) -> Result<Vec<u8>> + Send + Sync>) -> Self {
102 QuicTransport {
103 sender,
104 next_request_id: AtomicU64::new(1),
105 }
106 }
107}
108
109impl RpcTransport for QuicTransport {
110 fn call(&self, method_id: u32, payload: &[u8]) -> Result<Vec<u8>> {
111 let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
112 (self.sender)(method_id, request_id, payload)
113 }
114}
115
116pub enum AutoTransport<'a> {
125 SharedMemory(SharedMemoryTransport<'a>),
126 Quic(QuicTransport),
127}
128
129impl RpcTransport for AutoTransport<'_> {
130 fn call(&self, method_id: u32, payload: &[u8]) -> Result<Vec<u8>> {
131 match self {
132 AutoTransport::SharedMemory(t) => t.call(method_id, payload),
133 AutoTransport::Quic(t) => t.call(method_id, payload),
134 }
135 }
136}
137
138pub struct ServiceHandlerAdapter<F> {
154 dispatch: F,
155}
156
157impl<F> ServiceHandlerAdapter<F>
158where
159 F: Fn(u32, &[u8]) -> Result<Vec<u8>>,
160{
161 pub fn new(dispatch: F) -> Self {
162 ServiceHandlerAdapter { dispatch }
163 }
164}
165
166impl<F> crate::RpcHandler for ServiceHandlerAdapter<F>
167where
168 F: Fn(u32, &[u8]) -> Result<Vec<u8>>,
169{
170 fn handle(&self, method_id: u32, payload: &[u8]) -> Result<Vec<u8>> {
171 (self.dispatch)(method_id, payload)
172 }
173}
174
175#[cfg(test)]
180mod tests {
181 use super::*;
182 use grafos_std::error::FabricError;
183 use grafos_std::host;
184 use grafos_std::mem::MemBuilder;
185
186 #[test]
187 fn shared_memory_transport_call() {
188 host::reset_mock();
189 host::mock_set_fbmu_arena_size(65536);
190
191 let lease = MemBuilder::new().min_bytes(65536).acquire().unwrap();
192 let _default = SharedMemoryTransport::new(&lease);
193
194 let transport = SharedMemoryTransport {
197 mux: RpcMuxClient::new(&lease, 8, 4096).with_max_poll_iterations(5),
198 };
199 let result = transport.call(0, b"hello");
200 assert_eq!(result.unwrap_err(), FabricError::LeaseExpired);
201 }
202
203 #[test]
204 fn quic_transport_call() {
205 let transport = QuicTransport::new(Box::new(|method_id, _req_id, payload| {
206 let mut resp = method_id.to_le_bytes().to_vec();
208 resp.extend_from_slice(payload);
209 Ok(resp)
210 }));
211
212 let resp = transport.call(42, b"test").unwrap();
213 assert_eq!(&resp[..4], &42u32.to_le_bytes());
214 assert_eq!(&resp[4..], b"test");
215 }
216
217 #[test]
218 fn quic_transport_increments_request_id() {
219 use alloc::sync::Arc;
220
221 let seen_ids = Arc::new(std::sync::Mutex::new(Vec::new()));
222 let ids = seen_ids.clone();
223 let transport = QuicTransport::new(Box::new(move |_mid, req_id, _payload| {
224 ids.lock().unwrap().push(req_id);
225 Ok(Vec::new())
226 }));
227
228 transport.call(0, b"a").unwrap();
229 transport.call(0, b"b").unwrap();
230 transport.call(0, b"c").unwrap();
231
232 let ids = seen_ids.lock().unwrap();
233 assert_eq!(*ids, vec![1, 2, 3]);
234 }
235
236 #[test]
237 fn quic_transport_error_propagation() {
238 let transport = QuicTransport::new(Box::new(|_mid, _req_id, _payload| {
239 Err(FabricError::IoError(-500))
240 }));
241
242 let result = transport.call(0, b"fail");
243 assert_eq!(result.unwrap_err(), FabricError::IoError(-500));
244 }
245
246 #[test]
247 fn auto_transport_quic_variant() {
248 let quic = QuicTransport::new(Box::new(|_mid, _req_id, payload| Ok(payload.to_vec())));
249 let auto = AutoTransport::Quic(quic);
250
251 let resp = auto.call(0, b"echo").unwrap();
252 assert_eq!(resp, b"echo");
253 }
254
255 #[test]
256 fn service_handler_adapter() {
257 use crate::RpcHandler;
258
259 let adapter = ServiceHandlerAdapter::new(|method_id: u32, payload: &[u8]| {
260 if method_id == 0 {
261 Ok(payload.to_vec())
262 } else {
263 Err(FabricError::Unsupported)
264 }
265 });
266
267 let resp = adapter.handle(0, b"hello").unwrap();
268 assert_eq!(resp, b"hello");
269
270 let err = adapter.handle(99, b"fail").unwrap_err();
271 assert_eq!(err, FabricError::Unsupported);
272 }
273}