crux_core/capability/
shell_stream.rs1use std::{
2 sync::{Arc, Mutex},
3 task::{Poll, Waker},
4};
5
6use futures::Stream;
7
8use super::{channel, channel::Receiver};
9use crate::core::Request;
10
11pub struct ShellStream<T> {
12 shared_state: Arc<Mutex<SharedState<T>>>,
13}
14
15struct SharedState<T> {
16 receiver: Receiver<T>,
17 waker: Option<Waker>,
18 send_request: Option<Box<dyn FnOnce() + Send + 'static>>,
19}
20
21impl<T> Stream for ShellStream<T> {
22 type Item = T;
23
24 fn poll_next(
25 self: std::pin::Pin<&mut Self>,
26 cx: &mut std::task::Context<'_>,
27 ) -> Poll<Option<Self::Item>> {
28 let mut shared_state = self.shared_state.lock().unwrap();
29
30 if let Some(send_request) = shared_state.send_request.take() {
31 send_request();
32 }
33
34 match shared_state.receiver.try_receive() {
35 Ok(Some(next)) => Poll::Ready(Some(next)),
36 Ok(None) => {
37 shared_state.waker = Some(cx.waker().clone());
38 Poll::Pending
39 }
40 Err(()) => Poll::Ready(None),
41 }
42 }
43}
44
45impl<Op, Ev> crate::capability::CapabilityContext<Op, Ev>
46where
47 Op: crate::capability::Operation,
48 Ev: 'static,
49{
50 pub fn stream_from_shell(&self, operation: Op) -> ShellStream<Op::Output> {
56 let (sender, receiver) = channel();
57 let shared_state = Arc::new(Mutex::new(SharedState {
58 receiver,
59 waker: None,
60 send_request: None,
61 }));
62
63 let callback_shared_state = Arc::downgrade(&shared_state);
66
67 let request = Request::resolves_many_times(operation, move |result| {
68 let Some(shared_state) = callback_shared_state.upgrade() else {
69 return Err(());
71 };
72
73 let mut shared_state = shared_state.lock().unwrap();
74
75 sender.send(result);
76 if let Some(waker) = shared_state.waker.take() {
77 waker.wake();
78 }
79
80 Ok(())
81 });
82
83 let send_req_context = self.clone();
86 let send_request = move || send_req_context.send_request(request);
87 shared_state.lock().unwrap().send_request = Some(Box::new(send_request));
88
89 ShellStream { shared_state }
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use assert_matches::assert_matches;
96
97 use crate::capability::{channel, executor_and_spawner, CapabilityContext, Operation};
98
99 #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, Eq, Debug)]
100 struct TestOperation;
101
102 impl Operation for TestOperation {
103 type Output = Option<Done>;
104 }
105
106 #[derive(serde::Deserialize, PartialEq, Eq, Debug)]
107 struct Done;
108
109 #[test]
110 fn test_shell_stream() {
111 let (request_sender, requests) = channel();
112 let (event_sender, events) = channel::<()>();
113 let (executor, spawner) = executor_and_spawner();
114 let capability_context =
115 CapabilityContext::new(request_sender, event_sender.clone(), spawner.clone());
116
117 let mut stream = capability_context.stream_from_shell(TestOperation);
118
119 assert_matches!(requests.receive(), None);
121 assert_matches!(events.receive(), None);
122
123 executor.run_all();
125 assert_matches!(requests.receive(), None);
126 assert_matches!(events.receive(), None);
127
128 spawner.spawn(async move {
129 use futures::StreamExt;
130 while let Some(maybe_done) = stream.next().await {
131 event_sender.send(());
132 if maybe_done.is_some() {
133 break;
134 }
135 }
136 });
137
138 assert_matches!(requests.receive(), None);
140 assert_matches!(events.receive(), None);
141
142 executor.run_all();
143 let mut request = requests.receive().expect("we should have a request here");
144
145 assert_matches!(requests.receive(), None);
146 assert_matches!(events.receive(), None);
147
148 request.resolve(None).unwrap();
149
150 executor.run_all();
151
152 assert_matches!(requests.receive(), None);
154 assert_matches!(events.receive(), Some(()));
155 assert_matches!(events.receive(), None);
156
157 request.resolve(None).unwrap();
159 request.resolve(None).unwrap();
160 request.resolve(Some(Done)).unwrap();
161 executor.run_all();
162
163 assert_matches!(requests.receive(), None);
165 assert_matches!(events.receive(), Some(()));
166 assert_matches!(events.receive(), Some(()));
167 assert_matches!(events.receive(), Some(()));
168 assert_matches!(events.receive(), None);
169
170 request
172 .resolve(None)
173 .expect_err("resolving a finished task should error");
174 }
175}