crux_core/capability/
shell_request.rs

1//! Async support for implementing capabilities
2//!
3use std::{
4    sync::{Arc, Mutex},
5    task::{Poll, Waker},
6};
7
8use futures::Future;
9
10use crate::Request;
11
12pub struct ShellRequest<T> {
13    shared_state: Arc<Mutex<SharedState<T>>>,
14}
15
16#[cfg(test)]
17impl ShellRequest<()> {
18    pub(crate) fn new() -> Self {
19        Self {
20            shared_state: Arc::new(Mutex::new(SharedState {
21                result: None,
22                waker: None,
23                send_request: None,
24            })),
25        }
26    }
27}
28
29// State shared between the ShellRequest future itself and the
30// Request's resolve callback. The resolve callback is responsible
31// for advancing the state from Pending to Complete
32//
33// FIXME this should be a tri-state enum instead:
34// - ReadyToSend(Box<dyn FnOnce() + Send + 'static>)
35// - Pending(Waker)
36// - Complete(T)
37struct SharedState<T> {
38    // the effect's output
39    result: Option<T>,
40    send_request: Option<Box<dyn FnOnce() + Send + 'static>>,
41    waker: Option<Waker>,
42}
43
44impl<T> Future for ShellRequest<T> {
45    type Output = T;
46
47    fn poll(
48        self: std::pin::Pin<&mut Self>,
49        cx: &mut std::task::Context<'_>,
50    ) -> std::task::Poll<Self::Output> {
51        let mut shared_state = self.shared_state.lock().unwrap();
52
53        // If there's still a request to send, take it and send it
54        if let Some(send_request) = shared_state.send_request.take() {
55            send_request();
56        }
57
58        // If a result has been delivered, we're ready to continue
59        // Else we're pending with the waker from context
60        match shared_state.result.take() {
61            Some(result) => Poll::Ready(result),
62            None => {
63                let cloned_waker = cx.waker().clone();
64                shared_state.waker = Some(cloned_waker);
65                Poll::Pending
66            }
67        }
68    }
69}
70
71impl<Op, Ev> crate::capability::CapabilityContext<Op, Ev>
72where
73    Op: crate::capability::Operation,
74    Ev: 'static,
75{
76    /// Send an effect request to the shell, expecting an output. The
77    /// provided `operation` describes the effect input in a serialisable fashion,
78    /// and must implement the [`Operation`](crate::capability::Operation) trait to declare the expected
79    /// output type.
80    ///
81    /// `request_from_shell` returns a future of the output, which can be
82    /// `await`ed. You should only call this method inside an async task
83    /// created with [`CapabilityContext::spawn`](crate::capability::CapabilityContext::spawn).
84    pub fn request_from_shell(&self, operation: Op) -> ShellRequest<Op::Output> {
85        let shared_state = Arc::new(Mutex::new(SharedState {
86            result: None,
87            waker: None,
88            send_request: None,
89        }));
90
91        // Our callback holds a weak pointer to avoid circular references
92        // from shared_state -> send_request -> request -> shared_state
93        let callback_shared_state = Arc::downgrade(&shared_state);
94
95        // used in docs/internals/runtime.md
96        // ANCHOR: resolve
97        let request = Request::resolves_once(operation, move |result| {
98            let Some(shared_state) = callback_shared_state.upgrade() else {
99                // The ShellRequest was dropped before we were called, so just
100                // do nothing.
101                return;
102            };
103
104            let mut shared_state = shared_state.lock().unwrap();
105
106            // Attach the result to the shared state of the future
107            shared_state.result = Some(result);
108            // Signal the executor to wake the task holding this future
109            if let Some(waker) = shared_state.waker.take() {
110                waker.wake()
111            }
112        });
113        // ANCHOR_END: resolve
114
115        // Send the request on the next poll of the ShellRequest future
116        let send_req_context = self.clone();
117        let send_request = move || send_req_context.send_request(request);
118
119        shared_state.lock().unwrap().send_request = Some(Box::new(send_request));
120
121        ShellRequest { shared_state }
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use assert_matches::assert_matches;
128
129    use crate::capability::{channel, executor_and_spawner, CapabilityContext, Operation};
130
131    #[derive(serde::Serialize, Clone, PartialEq, Eq, Debug)]
132    struct TestOperation;
133
134    impl Operation for TestOperation {
135        type Output = ();
136    }
137
138    #[test]
139    fn test_effect_future() {
140        let (request_sender, requests) = channel();
141        let (event_sender, events) = channel::<()>();
142        let (executor, spawner) = executor_and_spawner();
143        let capability_context =
144            CapabilityContext::new(request_sender, event_sender.clone(), spawner.clone());
145
146        let future = capability_context.request_from_shell(TestOperation);
147
148        // The future hasn't been awaited so we shouldn't have any requests.
149        assert_matches!(requests.receive(), None);
150        assert_matches!(events.receive(), None);
151
152        // It also shouldn't have spawned anything so check that
153        executor.run_all();
154        assert_matches!(requests.receive(), None);
155        assert_matches!(events.receive(), None);
156
157        spawner.spawn(async move {
158            future.await;
159            event_sender.send(());
160        });
161
162        // We still shouldn't have any requests
163        assert_matches!(requests.receive(), None);
164        assert_matches!(events.receive(), None);
165
166        executor.run_all();
167        let mut request = requests.receive().expect("we should have a request here");
168        assert_matches!(requests.receive(), None);
169        assert_matches!(events.receive(), None);
170
171        request.resolve(()).expect("request should resolve");
172
173        assert_matches!(requests.receive(), None);
174        assert_matches!(events.receive(), None);
175
176        executor.run_all();
177        assert_matches!(requests.receive(), None);
178        assert_matches!(events.receive(), Some(()));
179        assert_matches!(events.receive(), None);
180    }
181}