crux_core/capability/
shell_request.rs

1//! Async support for implementing capabilities
2//!
3use std::{
4    sync::{Arc, Mutex},
5    task::{Poll, Waker},
6};
7
8use futures::Future;
9
10use crate::Request;
11
12pub struct ShellRequest<T> {
13    shared_state: Arc<Mutex<SharedState<T>>>,
14}
15
16#[cfg(test)]
17impl ShellRequest<()> {
18    pub(crate) fn new() -> Self {
19        Self {
20            shared_state: Arc::new(Mutex::new(SharedState {
21                result: None,
22                waker: None,
23                send_request: None,
24            })),
25        }
26    }
27}
28
29// State shared between the ShellRequest future itself and the
30// Request's resolve callback. The resolve callback is responsible
31// for advancing the state from Pending to Complete
32//
33// FIXME this should be a tri-state enum instead:
34// - ReadyToSend(Box<dyn FnOnce() + Send + 'static>)
35// - Pending(Waker)
36// - Complete(T)
37struct SharedState<T> {
38    // the effect's output
39    result: Option<T>,
40    send_request: Option<Box<dyn FnOnce() + Send + 'static>>,
41    waker: Option<Waker>,
42}
43
44impl<T> Future for ShellRequest<T> {
45    type Output = T;
46
47    fn poll(
48        self: std::pin::Pin<&mut Self>,
49        cx: &mut std::task::Context<'_>,
50    ) -> std::task::Poll<Self::Output> {
51        let mut shared_state = self.shared_state.lock().unwrap();
52
53        // If there's still a request to send, take it and send it
54        if let Some(send_request) = shared_state.send_request.take() {
55            send_request();
56        }
57
58        // If a result has been delivered, we're ready to continue
59        // Else we're pending with the waker from context
60        if let Some(result) = shared_state.result.take() {
61            Poll::Ready(result)
62        } else {
63            let cloned_waker = cx.waker().clone();
64            shared_state.waker = Some(cloned_waker);
65            Poll::Pending
66        }
67    }
68}
69
70impl<Op, Ev> crate::capability::CapabilityContext<Op, Ev>
71where
72    Op: crate::capability::Operation,
73    Ev: 'static,
74{
75    /// Send an effect request to the shell, expecting an output. The
76    /// provided `operation` describes the effect input in a serialisable fashion,
77    /// and must implement the [`Operation`](crate::capability::Operation) trait to declare the expected
78    /// output type.
79    ///
80    /// `request_from_shell` returns a future of the output, which can be
81    /// `await`ed. You should only call this method inside an async task
82    /// created with [`CapabilityContext::spawn`](crate::capability::CapabilityContext::spawn).
83    ///
84    /// # Panics
85    ///
86    /// Panics if we can't acquire the lock on the shared state.
87    pub fn request_from_shell(&self, operation: Op) -> ShellRequest<Op::Output> {
88        let shared_state = Arc::new(Mutex::new(SharedState {
89            result: None,
90            waker: None,
91            send_request: None,
92        }));
93
94        // Our callback holds a weak pointer to avoid circular references
95        // from shared_state -> send_request -> request -> shared_state
96        let callback_shared_state = Arc::downgrade(&shared_state);
97
98        // used in docs/internals/runtime.md
99        // ANCHOR: resolve
100        let request = Request::resolves_once(operation, move |result| {
101            let Some(shared_state) = callback_shared_state.upgrade() else {
102                // The ShellRequest was dropped before we were called, so just
103                // do nothing.
104                return;
105            };
106
107            let mut shared_state = shared_state.lock().unwrap();
108
109            // Attach the result to the shared state of the future
110            shared_state.result = Some(result);
111            // Signal the executor to wake the task holding this future
112            if let Some(waker) = shared_state.waker.take() {
113                waker.wake();
114            }
115        });
116        // ANCHOR_END: resolve
117
118        // Send the request on the next poll of the ShellRequest future
119        let send_req_context = self.clone();
120        let send_request = move || send_req_context.send_request(request);
121
122        shared_state.lock().unwrap().send_request = Some(Box::new(send_request));
123
124        ShellRequest { shared_state }
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use assert_matches::assert_matches;
131
132    use crate::capability::{channel, executor_and_spawner, CapabilityContext, Operation};
133
134    #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, Eq, Debug)]
135    struct TestOperation;
136
137    impl Operation for TestOperation {
138        type Output = ();
139    }
140
141    #[test]
142    fn test_effect_future() {
143        let (request_sender, requests) = channel();
144        let (event_sender, events) = channel::<()>();
145        let (executor, spawner) = executor_and_spawner();
146        let capability_context =
147            CapabilityContext::new(request_sender, event_sender.clone(), spawner.clone());
148
149        let future = capability_context.request_from_shell(TestOperation);
150
151        // The future hasn't been awaited so we shouldn't have any requests.
152        assert_matches!(requests.receive(), None);
153        assert_matches!(events.receive(), None);
154
155        // It also shouldn't have spawned anything so check that
156        executor.run_all();
157        assert_matches!(requests.receive(), None);
158        assert_matches!(events.receive(), None);
159
160        spawner.spawn(async move {
161            future.await;
162            event_sender.send(());
163        });
164
165        // We still shouldn't have any requests
166        assert_matches!(requests.receive(), None);
167        assert_matches!(events.receive(), None);
168
169        executor.run_all();
170        let mut request = requests.receive().expect("we should have a request here");
171        assert_matches!(requests.receive(), None);
172        assert_matches!(events.receive(), None);
173
174        request.resolve(()).expect("request should resolve");
175
176        assert_matches!(requests.receive(), None);
177        assert_matches!(events.receive(), None);
178
179        executor.run_all();
180        assert_matches!(requests.receive(), None);
181        assert_matches!(events.receive(), Some(()));
182        assert_matches!(events.receive(), None);
183    }
184}