crux_core/capability/
executor.rs1use std::{
2 sync::{Arc, Mutex},
3 task::{Context, Wake},
4};
5
6use crossbeam_channel::{Receiver, Sender};
7use futures::{future, Future, FutureExt};
8use slab::Slab;
9
10type BoxFuture = future::BoxFuture<'static, ()>;
11
12pub(crate) struct QueuingExecutor {
15 spawn_queue: Receiver<BoxFuture>,
16 ready_queue: Receiver<TaskId>,
17 ready_sender: Sender<TaskId>,
18 tasks: Mutex<Slab<Option<BoxFuture>>>,
19}
20#[derive(Clone)]
25pub struct Spawner {
26 future_sender: Sender<BoxFuture>,
27}
28#[derive(Clone, Copy, Debug)]
31struct TaskId(u32);
32
33impl std::ops::Deref for TaskId {
34 type Target = u32;
35
36 fn deref(&self) -> &Self::Target {
37 &self.0
38 }
39}
40
41pub(crate) fn executor_and_spawner() -> (QueuingExecutor, Spawner) {
42 let (future_sender, spawn_queue) = crossbeam_channel::unbounded();
43 let (ready_sender, ready_queue) = crossbeam_channel::unbounded();
44
45 (
46 QueuingExecutor {
47 ready_queue,
48 spawn_queue,
49 ready_sender,
50 tasks: Mutex::new(Slab::new()),
51 },
52 Spawner { future_sender },
53 )
54}
55
56impl Spawner {
59 pub fn spawn(&self, future: impl Future<Output = ()> + 'static + Send) {
60 let future = future.boxed();
61 self.future_sender
62 .send(future)
63 .expect("unable to spawn an async task, task sender channel is disconnected.")
64 }
65}
66#[derive(Clone)]
69struct TaskWaker {
70 task_id: TaskId,
71 sender: Sender<TaskId>,
72}
73
74impl Wake for TaskWaker {
77 fn wake(self: Arc<Self>) {
78 self.wake_by_ref();
79 }
80
81 fn wake_by_ref(self: &Arc<Self>) {
82 let _ = self.sender.send(self.task_id);
85 }
86}
87impl QueuingExecutor {
92 pub fn run_all(&self) {
93 let mut did_some_work = true;
98
99 while did_some_work {
100 did_some_work = false;
101 while let Ok(task) = self.spawn_queue.try_recv() {
102 let task_id = self
103 .tasks
104 .lock()
105 .expect("Task slab poisoned")
106 .insert(Some(task));
107 self.run_task(TaskId(task_id.try_into().expect("TaskId overflow")));
108 did_some_work = true;
109 }
110 while let Ok(task_id) = self.ready_queue.try_recv() {
111 match self.run_task(task_id) {
112 RunTask::Unavailable => {
113 self.ready_sender.send(task_id).expect("could not requeue");
122 }
123 RunTask::Missing => {
124 }
128 RunTask::Suspended | RunTask::Completed => did_some_work = true,
129 }
130 }
131 }
132 }
133
134 fn run_task(&self, task_id: TaskId) -> RunTask {
135 let mut lock = self.tasks.lock().expect("Task slab poisoned");
136 let Some(task) = lock.get_mut(*task_id as usize) else {
137 return RunTask::Missing;
138 };
139 let Some(mut task) = task.take() else {
140 return RunTask::Unavailable;
143 };
144
145 drop(lock);
147
148 let waker = Arc::new(TaskWaker {
149 task_id,
150 sender: self.ready_sender.clone(),
151 })
152 .into();
153 let context = &mut Context::from_waker(&waker);
154
155 if task.as_mut().poll(context).is_pending() {
157 self.tasks
159 .lock()
160 .expect("Task slab poisoned")
161 .get_mut(*task_id as usize)
162 .expect("Task slot is missing")
163 .replace(task);
164 RunTask::Suspended
165 } else {
166 self.tasks.lock().unwrap().remove(*task_id as usize);
168 RunTask::Completed
169 }
170 }
171}
172
173enum RunTask {
174 Missing,
175 Unavailable,
176 Suspended,
177 Completed,
178}
179
180#[cfg(test)]
183mod tests {
184
185 use rand::Rng;
186 use std::{
187 sync::atomic::{AtomicI32, Ordering},
188 task::Poll,
189 };
190
191 use super::*;
192 use crate::capability::shell_request::ShellRequest;
193
194 #[test]
195 fn test_task_does_not_leak() {
196 let counter = Arc::new(());
198 assert_eq!(Arc::strong_count(&counter), 1);
199
200 let (executor, spawner) = executor_and_spawner();
201
202 let future = {
203 let counter = counter.clone();
204 async move {
205 assert_eq!(Arc::strong_count(&counter), 2);
206 ShellRequest::<()>::new().await;
207 }
208 };
209
210 spawner.spawn(future);
211 executor.run_all();
212 drop(executor);
213 drop(spawner);
214 assert_eq!(Arc::strong_count(&counter), 1);
215 }
216
217 #[test]
218 fn test_multithreaded_executor() {
219 struct Chaotic {
224 ready_once: bool,
225 children: Vec<Chaotic>,
226 }
227
228 static CHAOS_COUNT: AtomicI32 = AtomicI32::new(0);
229
230 impl Chaotic {
231 fn new_with_children(num_children: usize) -> Self {
232 CHAOS_COUNT.fetch_add(1, Ordering::SeqCst);
233 Self {
234 ready_once: false,
235 children: (0..num_children)
236 .map(|_| Chaotic::new_with_children(num_children - 1))
237 .collect(),
238 }
239 }
240 }
241
242 impl Future for Chaotic {
243 type Output = ();
244
245 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
246 if self.ready_once {
248 return Poll::Ready(());
249 }
250 if rand::rng().random_bool(0.1) {
251 cx.waker().wake_by_ref();
252
253 Poll::Pending
254 } else {
255 let mut ready = true;
256 let this = self.get_mut();
257 for child in &mut this.children {
258 if child.poll_unpin(cx).is_pending() {
259 ready = false;
260 }
261 }
262 if ready {
263 this.ready_once = true;
264 cx.waker().wake_by_ref();
266 CHAOS_COUNT.fetch_sub(1, Ordering::SeqCst);
267 Poll::Ready(())
268 } else {
269 Poll::Pending
270 }
271 }
272 }
273 }
274
275 let (executor, spawner) = executor_and_spawner();
276 for _ in 0..100 {
278 let future = Chaotic::new_with_children(6);
279 spawner.spawn(future);
280 }
281 assert_eq!(CHAOS_COUNT.load(Ordering::SeqCst), 195700);
282 let executor = Arc::new(executor);
283 assert_eq!(executor.spawn_queue.len(), 100);
284
285 let handles = (0..10)
287 .map(|_| {
288 let executor = executor.clone();
289 std::thread::spawn(move || {
290 executor.run_all();
291 })
292 })
293 .collect::<Vec<_>>();
294 for handle in handles {
295 handle.join().unwrap();
296 }
297 assert_eq!(executor.spawn_queue.len(), 0);
299 assert_eq!(executor.ready_queue.len(), 0);
300 assert_eq!(CHAOS_COUNT.load(Ordering::SeqCst), 0);
301 }
302}