crux_core/capability/
shell_stream.rs

1use std::{
2    sync::{Arc, Mutex},
3    task::{Poll, Waker},
4};
5
6use futures::Stream;
7
8use super::{channel, channel::Receiver};
9use crate::core::Request;
10
11pub struct ShellStream<T> {
12    shared_state: Arc<Mutex<SharedState<T>>>,
13}
14
15struct SharedState<T> {
16    receiver: Receiver<T>,
17    waker: Option<Waker>,
18    send_request: Option<Box<dyn FnOnce() + Send + 'static>>,
19}
20
21impl<T> Stream for ShellStream<T> {
22    type Item = T;
23
24    fn poll_next(
25        self: std::pin::Pin<&mut Self>,
26        cx: &mut std::task::Context<'_>,
27    ) -> Poll<Option<Self::Item>> {
28        let mut shared_state = self.shared_state.lock().unwrap();
29
30        if let Some(send_request) = shared_state.send_request.take() {
31            send_request();
32        }
33
34        match shared_state.receiver.try_receive() {
35            Ok(Some(next)) => Poll::Ready(Some(next)),
36            Ok(None) => {
37                shared_state.waker = Some(cx.waker().clone());
38                Poll::Pending
39            }
40            Err(_) => Poll::Ready(None),
41        }
42    }
43}
44
45impl<Op, Ev> crate::capability::CapabilityContext<Op, Ev>
46where
47    Op: crate::capability::Operation,
48    Ev: 'static,
49{
50    /// Send an effect request to the shell, expecting a stream of responses
51    pub fn stream_from_shell(&self, operation: Op) -> ShellStream<Op::Output> {
52        let (sender, receiver) = channel();
53        let shared_state = Arc::new(Mutex::new(SharedState {
54            receiver,
55            waker: None,
56            send_request: None,
57        }));
58
59        // Our callback holds a weak pointer so the channel can be freed
60        // whenever the associated task ends.
61        let callback_shared_state = Arc::downgrade(&shared_state);
62
63        let request = Request::resolves_many_times(operation, move |result| {
64            let Some(shared_state) = callback_shared_state.upgrade() else {
65                // Let the caller know that the associated task has finished.
66                return Err(());
67            };
68
69            let mut shared_state = shared_state.lock().unwrap();
70
71            sender.send(result);
72            if let Some(waker) = shared_state.waker.take() {
73                waker.wake();
74            }
75
76            Ok(())
77        });
78
79        // Put a callback into our shared_state so that we only send
80        // our request to the shell when the stream is first polled.
81        let send_req_context = self.clone();
82        let send_request = move || send_req_context.send_request(request);
83        shared_state.lock().unwrap().send_request = Some(Box::new(send_request));
84
85        ShellStream { shared_state }
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use assert_matches::assert_matches;
92
93    use crate::capability::{channel, executor_and_spawner, CapabilityContext, Operation};
94
95    #[derive(serde::Serialize, Clone, PartialEq, Eq, Debug)]
96    struct TestOperation;
97
98    impl Operation for TestOperation {
99        type Output = Option<Done>;
100    }
101
102    #[derive(serde::Deserialize, PartialEq, Eq, Debug)]
103    struct Done;
104
105    #[test]
106    fn test_shell_stream() {
107        let (request_sender, requests) = channel();
108        let (event_sender, events) = channel::<()>();
109        let (executor, spawner) = executor_and_spawner();
110        let capability_context =
111            CapabilityContext::new(request_sender, event_sender.clone(), spawner.clone());
112
113        let mut stream = capability_context.stream_from_shell(TestOperation);
114
115        // The stream hasn't been polled so we shouldn't have any requests.
116        assert_matches!(requests.receive(), None);
117        assert_matches!(events.receive(), None);
118
119        // It also shouldn't have spawned anything so check that
120        executor.run_all();
121        assert_matches!(requests.receive(), None);
122        assert_matches!(events.receive(), None);
123
124        spawner.spawn(async move {
125            use futures::StreamExt;
126            while let Some(maybe_done) = stream.next().await {
127                event_sender.send(());
128                if maybe_done.is_some() {
129                    break;
130                }
131            }
132        });
133
134        // We still shouldn't have any requests
135        assert_matches!(requests.receive(), None);
136        assert_matches!(events.receive(), None);
137
138        executor.run_all();
139        let mut request = requests.receive().expect("we should have a request here");
140
141        assert_matches!(requests.receive(), None);
142        assert_matches!(events.receive(), None);
143
144        request.resolve(None).unwrap();
145
146        executor.run_all();
147
148        // We should have one event
149        assert_matches!(requests.receive(), None);
150        assert_matches!(events.receive(), Some(()));
151        assert_matches!(events.receive(), None);
152
153        // Resolve it a few more times and then finish.
154        request.resolve(None).unwrap();
155        request.resolve(None).unwrap();
156        request.resolve(Some(Done)).unwrap();
157        executor.run_all();
158
159        // We should have three events
160        assert_matches!(requests.receive(), None);
161        assert_matches!(events.receive(), Some(()));
162        assert_matches!(events.receive(), Some(()));
163        assert_matches!(events.receive(), Some(()));
164        assert_matches!(events.receive(), None);
165
166        // The next resolve should error as we've terminated the task
167        request
168            .resolve(None)
169            .expect_err("resolving a finished task should error");
170    }
171}