crux_core/capability/
shell_stream.rs

1use std::{
2    sync::{Arc, Mutex},
3    task::{Poll, Waker},
4};
5
6use futures::Stream;
7
8use super::{channel, channel::Receiver};
9use crate::core::Request;
10
11pub struct ShellStream<T> {
12    shared_state: Arc<Mutex<SharedState<T>>>,
13}
14
15struct SharedState<T> {
16    receiver: Receiver<T>,
17    waker: Option<Waker>,
18    send_request: Option<Box<dyn FnOnce() + Send + 'static>>,
19}
20
21impl<T> Stream for ShellStream<T> {
22    type Item = T;
23
24    fn poll_next(
25        self: std::pin::Pin<&mut Self>,
26        cx: &mut std::task::Context<'_>,
27    ) -> Poll<Option<Self::Item>> {
28        let mut shared_state = self.shared_state.lock().unwrap();
29
30        if let Some(send_request) = shared_state.send_request.take() {
31            send_request();
32        }
33
34        match shared_state.receiver.try_receive() {
35            Ok(Some(next)) => Poll::Ready(Some(next)),
36            Ok(None) => {
37                shared_state.waker = Some(cx.waker().clone());
38                Poll::Pending
39            }
40            Err(()) => Poll::Ready(None),
41        }
42    }
43}
44
45impl<Op, Ev> crate::capability::CapabilityContext<Op, Ev>
46where
47    Op: crate::capability::Operation,
48    Ev: 'static,
49{
50    /// Send an effect request to the shell, expecting a stream of responses
51    ///
52    /// # Panics
53    ///
54    /// Panics if we can't acquire the shared state lock.
55    pub fn stream_from_shell(&self, operation: Op) -> ShellStream<Op::Output> {
56        let (sender, receiver) = channel();
57        let shared_state = Arc::new(Mutex::new(SharedState {
58            receiver,
59            waker: None,
60            send_request: None,
61        }));
62
63        // Our callback holds a weak pointer so the channel can be freed
64        // whenever the associated task ends.
65        let callback_shared_state = Arc::downgrade(&shared_state);
66
67        let request = Request::resolves_many_times(operation, move |result| {
68            let Some(shared_state) = callback_shared_state.upgrade() else {
69                // Let the caller know that the associated task has finished.
70                return Err(());
71            };
72
73            let mut shared_state = shared_state.lock().unwrap();
74
75            sender.send(result);
76            if let Some(waker) = shared_state.waker.take() {
77                waker.wake();
78            }
79
80            Ok(())
81        });
82
83        // Put a callback into our shared_state so that we only send
84        // our request to the shell when the stream is first polled.
85        let send_req_context = self.clone();
86        let send_request = move || send_req_context.send_request(request);
87        shared_state.lock().unwrap().send_request = Some(Box::new(send_request));
88
89        ShellStream { shared_state }
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use assert_matches::assert_matches;
96
97    use crate::capability::{channel, executor_and_spawner, CapabilityContext, Operation};
98
99    #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, Eq, Debug)]
100    struct TestOperation;
101
102    impl Operation for TestOperation {
103        type Output = Option<Done>;
104    }
105
106    #[derive(serde::Deserialize, PartialEq, Eq, Debug)]
107    struct Done;
108
109    #[test]
110    fn test_shell_stream() {
111        let (request_sender, requests) = channel();
112        let (event_sender, events) = channel::<()>();
113        let (executor, spawner) = executor_and_spawner();
114        let capability_context =
115            CapabilityContext::new(request_sender, event_sender.clone(), spawner.clone());
116
117        let mut stream = capability_context.stream_from_shell(TestOperation);
118
119        // The stream hasn't been polled so we shouldn't have any requests.
120        assert_matches!(requests.receive(), None);
121        assert_matches!(events.receive(), None);
122
123        // It also shouldn't have spawned anything so check that
124        executor.run_all();
125        assert_matches!(requests.receive(), None);
126        assert_matches!(events.receive(), None);
127
128        spawner.spawn(async move {
129            use futures::StreamExt;
130            while let Some(maybe_done) = stream.next().await {
131                event_sender.send(());
132                if maybe_done.is_some() {
133                    break;
134                }
135            }
136        });
137
138        // We still shouldn't have any requests
139        assert_matches!(requests.receive(), None);
140        assert_matches!(events.receive(), None);
141
142        executor.run_all();
143        let mut request = requests.receive().expect("we should have a request here");
144
145        assert_matches!(requests.receive(), None);
146        assert_matches!(events.receive(), None);
147
148        request.resolve(None).unwrap();
149
150        executor.run_all();
151
152        // We should have one event
153        assert_matches!(requests.receive(), None);
154        assert_matches!(events.receive(), Some(()));
155        assert_matches!(events.receive(), None);
156
157        // Resolve it a few more times and then finish.
158        request.resolve(None).unwrap();
159        request.resolve(None).unwrap();
160        request.resolve(Some(Done)).unwrap();
161        executor.run_all();
162
163        // We should have three events
164        assert_matches!(requests.receive(), None);
165        assert_matches!(events.receive(), Some(()));
166        assert_matches!(events.receive(), Some(()));
167        assert_matches!(events.receive(), Some(()));
168        assert_matches!(events.receive(), None);
169
170        // The next resolve should error as we've terminated the task
171        request
172            .resolve(None)
173            .expect_err("resolving a finished task should error");
174    }
175}