use std::{
sync::{Arc, Mutex},
task::{Poll, Waker},
};
use futures::Future;
use crate::Request;
pub struct ShellRequest<T> {
shared_state: Arc<Mutex<SharedState<T>>>,
}
#[cfg(test)]
impl ShellRequest<()> {
pub(crate) fn new() -> Self {
Self {
shared_state: Arc::new(Mutex::new(SharedState {
result: None,
waker: None,
send_request: None,
})),
}
}
}
struct SharedState<T> {
result: Option<T>,
waker: Option<Waker>,
send_request: Option<Box<dyn FnOnce() + Send + 'static>>,
}
impl<T> Future for ShellRequest<T> {
type Output = T;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let mut shared_state = self.shared_state.lock().unwrap();
if let Some(send_request) = shared_state.send_request.take() {
send_request();
}
match shared_state.result.take() {
Some(result) => Poll::Ready(result),
None => {
let cloned_waker = cx.waker().clone();
shared_state.waker = Some(cloned_waker);
Poll::Pending
}
}
}
}
impl<Op, Ev> crate::capability::CapabilityContext<Op, Ev>
where
Op: crate::capability::Operation,
Ev: 'static,
{
pub fn request_from_shell(&self, operation: Op) -> ShellRequest<Op::Output> {
let shared_state = Arc::new(Mutex::new(SharedState {
result: None,
waker: None,
send_request: None,
}));
let callback_shared_state = Arc::downgrade(&shared_state);
let request = Request::resolves_once(operation, move |result| {
let Some(shared_state) = callback_shared_state.upgrade() else {
return;
};
let mut shared_state = shared_state.lock().unwrap();
shared_state.result = Some(result);
if let Some(waker) = shared_state.waker.take() {
waker.wake()
}
});
let send_req_context = self.clone();
let send_request = move || send_req_context.send_request(request);
shared_state.lock().unwrap().send_request = Some(Box::new(send_request));
ShellRequest { shared_state }
}
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use crate::capability::{channel, executor_and_spawner, CapabilityContext, Operation};
#[derive(serde::Serialize, Clone, PartialEq, Eq, Debug)]
struct TestOperation;
impl Operation for TestOperation {
type Output = ();
}
#[test]
fn test_effect_future() {
let (request_sender, requests) = channel();
let (event_sender, events) = channel::<()>();
let (executor, spawner) = executor_and_spawner();
let capability_context =
CapabilityContext::new(request_sender, event_sender.clone(), spawner.clone());
let future = capability_context.request_from_shell(TestOperation);
assert_matches!(requests.receive(), None);
assert_matches!(events.receive(), None);
executor.run_all();
assert_matches!(requests.receive(), None);
assert_matches!(events.receive(), None);
spawner.spawn(async move {
future.await;
event_sender.send(());
});
assert_matches!(requests.receive(), None);
assert_matches!(events.receive(), None);
executor.run_all();
let mut request = requests.receive().expect("we should have a request here");
assert_matches!(requests.receive(), None);
assert_matches!(events.receive(), None);
request.resolve(()).expect("request should resolve");
assert_matches!(requests.receive(), None);
assert_matches!(events.receive(), None);
executor.run_all();
assert_matches!(requests.receive(), None);
assert_matches!(events.receive(), Some(()));
assert_matches!(events.receive(), None);
}
}