1#![allow(clippy::unsafe_derive_deserialize)]
2pub mod command;
9pub mod protocol;
10
11use std::{
12 collections::HashSet,
13 future::Future,
14 marker::PhantomData,
15 pin::Pin,
16 sync::{
17 LazyLock, Mutex,
18 atomic::{AtomicUsize, Ordering},
19 },
20 task::Poll,
21 time,
22};
23
24use crux_core::{Command, Request, command::RequestBuilder};
25use futures::{
26 FutureExt,
27 channel::oneshot::{self, Sender},
28 select_biased,
29};
30
31pub use protocol::*;
32
33#[derive(Debug, PartialEq, Eq, Clone)]
35#[repr(C)]
36pub enum TimerOutcome {
37 Completed(CompletedTimerHandle),
39 Cleared,
41}
42
43pub struct Time<Effect, Event> {
51 effect: PhantomData<Effect>,
53 event: PhantomData<Event>,
54}
55
56impl<Effect, Event> Time<Effect, Event>
57where
58 Effect: Send + From<Request<TimeRequest>> + 'static,
59 Event: Send + 'static,
60{
61 #[must_use]
66 pub fn now() -> RequestBuilder<Effect, Event, impl Future<Output = time::SystemTime>> {
67 Command::request_from_shell(TimeRequest::Now).map(|r| {
68 let TimeResponse::Now { instant } = r else {
69 panic!("Incorrect response received for TimeRequest::Now")
70 };
71
72 instant.into()
73 })
74 }
75
76 #[must_use]
84 pub fn notify_at(
85 system_time: time::SystemTime,
86 ) -> (
87 RequestBuilder<Effect, Event, impl Future<Output = TimerOutcome>>,
88 TimerHandle,
89 ) {
90 let timer_id = get_timer_id();
91 let (sender, mut receiver) = oneshot::channel();
92
93 let handle = TimerHandle {
94 timer_id,
95 abort: sender,
96 };
97
98 let completed_handle = CompletedTimerHandle { timer_id };
99
100 let builder = RequestBuilder::new(move |ctx| {
105 async move {
106 if let Ok(Some(cleared_id)) = receiver.try_recv()
107 && cleared_id == timer_id
108 {
109 return TimerOutcome::Cleared;
110 }
111
112 select_biased! {
113 response = ctx.request_from_shell(
114 TimeRequest::NotifyAt {
115 id: timer_id,
116 instant: system_time.into()
117 }
118 ).fuse() => {
119 let TimeResponse::InstantArrived { id } = response else {
120 panic!("Unexpected response to TimeRequest::NotifyAt");
121 };
122
123 assert!(id == timer_id, "InstantArrived with unexpected timer ID");
124
125 TimerOutcome::Completed(completed_handle)
126 },
127 cleared = receiver => {
128 let cleared_id = cleared.unwrap();
134
135 let TimeResponse::Cleared { id } = ctx.request_from_shell(TimeRequest::Clear { id: cleared_id }).await else {
137 panic!("Unexpected response to TimeRequest::Clear");
138 };
139
140 assert!(id == cleared_id, "Cleared with unexpected timer ID");
141
142 TimerOutcome::Cleared
143 }
144 }
145 }
146 });
147
148 (builder, handle)
149 }
150
151 #[must_use]
159 pub fn notify_after(
160 duration: time::Duration,
161 ) -> (
162 RequestBuilder<Effect, Event, impl Future<Output = TimerOutcome>>,
163 TimerHandle,
164 ) {
165 let timer_id = get_timer_id();
166 let (sender, mut receiver) = oneshot::channel();
167
168 let handle = TimerHandle {
169 timer_id,
170 abort: sender,
171 };
172
173 let completed_handle = CompletedTimerHandle { timer_id };
174
175 let builder = RequestBuilder::new(move |ctx| async move {
176 if let Ok(Some(cleared_id)) = receiver.try_recv()
177 && cleared_id == timer_id
178 {
179 return TimerOutcome::Cleared;
180 }
181
182 select_biased! {
183 response = ctx.request_from_shell(
184 TimeRequest::NotifyAfter {
185 id: timer_id,
186 duration: duration.into()
187 }
188 ).fuse() => {
189 let TimeResponse::DurationElapsed { id } = response else {
190 panic!("Unexpected response to TimeRequest::NotifyAt");
191 };
192
193 assert!(id == timer_id, "InstantArrived with unexpected timer ID");
194
195 TimerOutcome::Completed(completed_handle)
196 }
197 cleared = receiver => {
198 let cleared_id = cleared.unwrap();
204 if cleared_id != timer_id {
205 unreachable!("Cleared with the wrong ID");
206 }
207
208 let TimeResponse::Cleared { id } = ctx.request_from_shell(TimeRequest::Clear { id: cleared_id }).await else {
210 panic!("Unexpected response to TimeRequest::Clear");
211 };
212
213 assert!(id == cleared_id, "Cleared resolved with unexpected timer ID");
214
215 TimerOutcome::Cleared
216 }
217 }
218 });
219
220 (builder, handle)
221 }
222}
223
224#[derive(Debug)]
227pub struct TimerHandle {
228 timer_id: TimerId,
229 abort: Sender<TimerId>,
230}
231
232impl TimerHandle {
233 pub fn clear(self) {
240 let _ = self.abort.send(self.timer_id);
241 }
242}
243
244#[derive(Debug, PartialEq, Eq, Clone)]
249pub struct CompletedTimerHandle {
250 timer_id: TimerId,
251}
252
253impl Eq for TimerHandle {}
254
255impl PartialEq for TimerHandle {
256 fn eq(&self, other: &Self) -> bool {
257 self.timer_id == other.timer_id
258 }
259}
260
261impl PartialEq<CompletedTimerHandle> for TimerHandle {
262 fn eq(&self, other: &CompletedTimerHandle) -> bool {
263 self.timer_id == other.timer_id
264 }
265}
266
267impl PartialEq<TimerHandle> for CompletedTimerHandle {
268 fn eq(&self, other: &TimerHandle) -> bool {
269 self.timer_id == other.timer_id
270 }
271}
272
273impl From<TimerHandle> for CompletedTimerHandle {
274 fn from(value: TimerHandle) -> Self {
275 Self {
276 timer_id: value.timer_id,
277 }
278 }
279}
280
281fn get_timer_id() -> TimerId {
282 static COUNTER: AtomicUsize = AtomicUsize::new(1);
283 TimerId(COUNTER.fetch_add(1, Ordering::Relaxed))
284}
285
286pub struct TimerFuture<F>
287where
288 F: Future<Output = TimeResponse> + Unpin,
289{
290 timer_id: TimerId,
291 is_cleared: bool,
292 future: F,
293}
294
295impl<F> Future for TimerFuture<F>
296where
297 F: Future<Output = TimeResponse> + Unpin,
298{
299 type Output = TimeResponse;
300
301 fn poll(
302 self: Pin<&mut Self>,
303 cx: &mut std::task::Context<'_>,
304 ) -> std::task::Poll<Self::Output> {
305 if self.is_cleared {
306 return Poll::Ready(TimeResponse::Cleared { id: self.timer_id });
308 }
309 let timer_is_cleared = {
311 let mut lock = CLEARED_TIMER_IDS.lock().unwrap();
312 lock.remove(&self.timer_id)
313 };
314 let this = self.get_mut();
315 this.is_cleared = timer_is_cleared;
316 if timer_is_cleared {
317 Poll::Ready(TimeResponse::Cleared { id: this.timer_id })
320 } else {
321 Pin::new(&mut this.future).poll(cx)
323 }
324 }
325}
326
327static CLEARED_TIMER_IDS: LazyLock<Mutex<HashSet<TimerId>>> =
332 LazyLock::new(|| Mutex::new(HashSet::new()));
333
334#[cfg(test)]
335mod test {
336 use super::*;
337
338 use crux_core::Request;
339
340 use super::{Time, TimerOutcome};
341 use crate::Instant;
342 use crate::protocol::duration::Duration;
343 use crate::{TimeRequest, TimeResponse};
344
345 #[test]
346 fn test_serializing_the_request_types_as_json() {
347 let now = TimeRequest::Now;
348
349 let serialized = serde_json::to_string(&now).unwrap();
350 assert_eq!(&serialized, "\"now\"");
351
352 let deserialized: TimeRequest = serde_json::from_str(&serialized).unwrap();
353 assert_eq!(now, deserialized);
354
355 let now = TimeRequest::NotifyAt {
356 id: TimerId(1),
357 instant: Instant::new(1, 2),
358 };
359
360 let serialized = serde_json::to_string(&now).unwrap();
361 assert_eq!(
362 &serialized,
363 r#"{"notifyAt":{"id":1,"instant":{"seconds":1,"nanos":2}}}"#
364 );
365
366 let deserialized: TimeRequest = serde_json::from_str(&serialized).unwrap();
367 assert_eq!(now, deserialized);
368
369 let now = TimeRequest::NotifyAfter {
370 id: TimerId(2),
371 duration: Duration::from_secs(1),
372 };
373
374 let serialized = serde_json::to_string(&now).unwrap();
375 assert_eq!(
376 &serialized,
377 r#"{"notifyAfter":{"id":2,"duration":{"nanos":1000000000}}}"#
378 );
379
380 let deserialized: TimeRequest = serde_json::from_str(&serialized).unwrap();
381 assert_eq!(now, deserialized);
382 }
383
384 #[test]
385 fn test_serializing_the_response_types_as_json() {
386 let now = TimeResponse::Now {
387 instant: Instant::new(1, 2),
388 };
389
390 let serialized = serde_json::to_string(&now).unwrap();
391 assert_eq!(
392 &serialized,
393 r#"{"now":{"instant":{"seconds":1,"nanos":2}}}"#
394 );
395
396 let deserialized: TimeResponse = serde_json::from_str(&serialized).unwrap();
397 assert_eq!(now, deserialized);
398
399 let now = TimeResponse::DurationElapsed { id: TimerId(1) };
400
401 let serialized = serde_json::to_string(&now).unwrap();
402 assert_eq!(&serialized, r#"{"durationElapsed":{"id":1}}"#);
403
404 let deserialized: TimeResponse = serde_json::from_str(&serialized).unwrap();
405 assert_eq!(now, deserialized);
406
407 let now = TimeResponse::InstantArrived { id: TimerId(2) };
408
409 let serialized = serde_json::to_string(&now).unwrap();
410 assert_eq!(&serialized, r#"{"instantArrived":{"id":2}}"#);
411
412 let deserialized: TimeResponse = serde_json::from_str(&serialized).unwrap();
413 assert_eq!(now, deserialized);
414 }
415
416 enum Effect {
417 Time(Request<TimeRequest>),
418 }
419
420 impl From<Request<TimeRequest>> for Effect {
421 fn from(value: Request<TimeRequest>) -> Self {
422 Self::Time(value)
423 }
424 }
425
426 #[derive(Debug, PartialEq)]
427 enum Event {
428 Elapsed(TimerOutcome),
429 }
430
431 #[test]
432 fn timer_can_be_cleared() {
433 let (cmd, handle) = Time::notify_after(core::time::Duration::from_secs(2));
434 let mut cmd = cmd.then_send(Event::Elapsed);
435
436 let effect = cmd.effects().next();
437
438 assert!(cmd.events().next().is_none());
439
440 let Some(Effect::Time(_request)) = effect else {
441 panic!("should get an effect");
442 };
443
444 handle.clear();
445
446 let effect = cmd.effects().next();
447 assert!(cmd.events().next().is_none());
448
449 let Some(Effect::Time(mut request)) = effect else {
450 panic!("should get an effect");
451 };
452
453 let TimeRequest::Clear { id } = request.operation else {
454 panic!("expected a Clear request");
455 };
456
457 request
458 .resolve(TimeResponse::Cleared { id })
459 .expect("should resolve");
460
461 let event = cmd.events().next();
462
463 assert!(matches!(event, Some(Event::Elapsed(TimerOutcome::Cleared))));
464 }
465
466 #[test]
467 fn dropping_a_timer_handle_does_not_clear_the_request() {
468 let (cmd, handle) = Time::notify_after(core::time::Duration::from_secs(2));
469 drop(handle);
470
471 let mut cmd = cmd.then_send(Event::Elapsed);
472 let effect = cmd.effects().next();
473
474 assert!(cmd.events().next().is_none());
475
476 let Some(Effect::Time(mut request)) = effect else {
477 panic!("should get an effect");
478 };
479
480 let TimeRequest::NotifyAfter { id, .. } = request.operation else {
481 panic!("Expected a NotifyAfter");
482 };
483
484 request
485 .resolve(TimeResponse::DurationElapsed { id })
486 .expect("should resolve");
487
488 let event = cmd.events().next();
489
490 assert!(matches!(
491 event,
492 Some(Event::Elapsed(TimerOutcome::Completed(_)))
493 ));
494 }
495
496 #[test]
497 fn dropping_a_timer_request_while_holding_a_handle_and_polling() {
498 let (cmd, handle) = Time::notify_after(core::time::Duration::from_secs(2));
499 let mut cmd = cmd.then_send(Event::Elapsed);
500
501 let effect: Effect = cmd.effects().next().expect("Expected an effect!");
502
503 drop(effect);
504 assert!(!cmd.is_done());
505
506 drop(handle);
507 assert!(cmd.is_done());
508 }
509}