crux_core/capability/
shell_stream.rs1use std::{
2 sync::{Arc, Mutex},
3 task::{Poll, Waker},
4};
5
6use futures::Stream;
7
8use super::{channel, channel::Receiver};
9use crate::core::Request;
10
11pub struct ShellStream<T> {
12 shared_state: Arc<Mutex<SharedState<T>>>,
13}
14
15struct SharedState<T> {
16 receiver: Receiver<T>,
17 waker: Option<Waker>,
18 send_request: Option<Box<dyn FnOnce() + Send + 'static>>,
19}
20
21impl<T> Stream for ShellStream<T> {
22 type Item = T;
23
24 fn poll_next(
25 self: std::pin::Pin<&mut Self>,
26 cx: &mut std::task::Context<'_>,
27 ) -> Poll<Option<Self::Item>> {
28 let mut shared_state = self.shared_state.lock().unwrap();
29
30 if let Some(send_request) = shared_state.send_request.take() {
31 send_request();
32 }
33
34 match shared_state.receiver.try_receive() {
35 Ok(Some(next)) => Poll::Ready(Some(next)),
36 Ok(None) => {
37 shared_state.waker = Some(cx.waker().clone());
38 Poll::Pending
39 }
40 Err(_) => Poll::Ready(None),
41 }
42 }
43}
44
45impl<Op, Ev> crate::capability::CapabilityContext<Op, Ev>
46where
47 Op: crate::capability::Operation,
48 Ev: 'static,
49{
50 pub fn stream_from_shell(&self, operation: Op) -> ShellStream<Op::Output> {
52 let (sender, receiver) = channel();
53 let shared_state = Arc::new(Mutex::new(SharedState {
54 receiver,
55 waker: None,
56 send_request: None,
57 }));
58
59 let callback_shared_state = Arc::downgrade(&shared_state);
62
63 let request = Request::resolves_many_times(operation, move |result| {
64 let Some(shared_state) = callback_shared_state.upgrade() else {
65 return Err(());
67 };
68
69 let mut shared_state = shared_state.lock().unwrap();
70
71 sender.send(result);
72 if let Some(waker) = shared_state.waker.take() {
73 waker.wake();
74 }
75
76 Ok(())
77 });
78
79 let send_req_context = self.clone();
82 let send_request = move || send_req_context.send_request(request);
83 shared_state.lock().unwrap().send_request = Some(Box::new(send_request));
84
85 ShellStream { shared_state }
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use assert_matches::assert_matches;
92
93 use crate::capability::{channel, executor_and_spawner, CapabilityContext, Operation};
94
95 #[derive(serde::Serialize, Clone, PartialEq, Eq, Debug)]
96 struct TestOperation;
97
98 impl Operation for TestOperation {
99 type Output = Option<Done>;
100 }
101
102 #[derive(serde::Deserialize, PartialEq, Eq, Debug)]
103 struct Done;
104
105 #[test]
106 fn test_shell_stream() {
107 let (request_sender, requests) = channel();
108 let (event_sender, events) = channel::<()>();
109 let (executor, spawner) = executor_and_spawner();
110 let capability_context =
111 CapabilityContext::new(request_sender, event_sender.clone(), spawner.clone());
112
113 let mut stream = capability_context.stream_from_shell(TestOperation);
114
115 assert_matches!(requests.receive(), None);
117 assert_matches!(events.receive(), None);
118
119 executor.run_all();
121 assert_matches!(requests.receive(), None);
122 assert_matches!(events.receive(), None);
123
124 spawner.spawn(async move {
125 use futures::StreamExt;
126 while let Some(maybe_done) = stream.next().await {
127 event_sender.send(());
128 if maybe_done.is_some() {
129 break;
130 }
131 }
132 });
133
134 assert_matches!(requests.receive(), None);
136 assert_matches!(events.receive(), None);
137
138 executor.run_all();
139 let mut request = requests.receive().expect("we should have a request here");
140
141 assert_matches!(requests.receive(), None);
142 assert_matches!(events.receive(), None);
143
144 request.resolve(None).unwrap();
145
146 executor.run_all();
147
148 assert_matches!(requests.receive(), None);
150 assert_matches!(events.receive(), Some(()));
151 assert_matches!(events.receive(), None);
152
153 request.resolve(None).unwrap();
155 request.resolve(None).unwrap();
156 request.resolve(Some(Done)).unwrap();
157 executor.run_all();
158
159 assert_matches!(requests.receive(), None);
161 assert_matches!(events.receive(), Some(()));
162 assert_matches!(events.receive(), Some(()));
163 assert_matches!(events.receive(), Some(()));
164 assert_matches!(events.receive(), None);
165
166 request
168 .resolve(None)
169 .expect_err("resolving a finished task should error");
170 }
171}