Skip to main content

crux_http/
protocol.rs

1//! The protocol for communicating with the shell
2//!
3//! Crux capabilities don't interface with the outside world themselves, they carry
4//! out all their operations by exchanging messages with the platform specific shell.
5//! This module defines the protocol for `crux_http` to communicate with the shell.
6
7use async_trait::async_trait;
8use derive_builder::Builder;
9use facet_generate_attrs as typegen;
10use serde::{Deserialize, Serialize};
11
12use crate::HttpError;
13
14#[derive(facet::Facet, Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
15pub struct HttpHeader {
16    pub name: String,
17    pub value: String,
18}
19
20#[derive(facet::Facet, Serialize, Deserialize, Default, Clone, PartialEq, Eq, Builder)]
21#[builder(
22    custom_constructor,
23    build_fn(private, name = "fallible_build"),
24    setter(into)
25)]
26pub struct HttpRequest {
27    pub method: String,
28    pub url: String,
29    #[builder(setter(custom))]
30    pub headers: Vec<HttpHeader>,
31    #[serde(with = "serde_bytes")]
32    #[facet(typegen::bytes)]
33    pub body: Vec<u8>,
34}
35
36impl std::fmt::Debug for HttpRequest {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        let body_repr = std::str::from_utf8(&self.body).map_or_else(
39            |_| format!("<binary data - {} bytes>", self.body.len()),
40            |s| {
41                if s.len() < 50 {
42                    format!("\"{s}\"")
43                } else {
44                    format!("\"{}\"...", s.chars().take(50).collect::<String>())
45                }
46            },
47        );
48        let mut builder = f.debug_struct("HttpRequest");
49        builder
50            .field("method", &self.method)
51            .field("url", &self.url);
52        if !self.headers.is_empty() {
53            builder.field("headers", &self.headers);
54        }
55        builder.field("body", &format_args!("{body_repr}")).finish()
56    }
57}
58
59macro_rules! http_method {
60    ($name:ident, $method:expr) => {
61        pub fn $name(url: impl Into<String>) -> HttpRequestBuilder {
62            HttpRequestBuilder {
63                method: Some($method.to_string()),
64                url: Some(url.into()),
65                headers: Some(vec![]),
66                body: Some(vec![]),
67            }
68        }
69    };
70}
71
72impl HttpRequest {
73    http_method!(get, "GET");
74    http_method!(put, "PUT");
75    http_method!(delete, "DELETE");
76    http_method!(post, "POST");
77    http_method!(patch, "PATCH");
78    http_method!(head, "HEAD");
79    http_method!(options, "OPTIONS");
80}
81
82impl HttpRequestBuilder {
83    pub fn header(&mut self, name: impl Into<String>, value: impl Into<String>) -> &mut Self {
84        self.headers.get_or_insert_with(Vec::new).push(HttpHeader {
85            name: name.into(),
86            value: value.into(),
87        });
88        self
89    }
90
91    /// Sets the query parameters of the request to the given value.
92    ///
93    /// # Errors
94    /// Returns an [`HttpError`] if the serialization fails.
95    pub fn query(&mut self, query: &impl Serialize) -> crate::Result<&mut Self> {
96        if let Some(url) = &mut self.url {
97            if url.contains('?') {
98                url.push('&');
99            } else {
100                url.push('?');
101            }
102            url.push_str(&serde_qs::to_string(query)?);
103        }
104
105        Ok(self)
106    }
107
108    /// Sets the body of the request to the JSON representation of the given value.
109    ///
110    /// # Panics
111    /// Panics if the serialization fails.
112    pub fn json(&mut self, body: impl serde::Serialize) -> &mut Self {
113        self.body = Some(serde_json::to_vec(&body).unwrap());
114        self
115    }
116
117    /// Builds the request.
118    ///
119    /// # Panics
120    /// Panics if any required fields are missing.
121    #[must_use]
122    pub fn build(&self) -> HttpRequest {
123        self.fallible_build()
124            .expect("All required fields were initialized")
125    }
126}
127
128#[derive(facet::Facet, Serialize, Deserialize, Default, Clone, Debug, PartialEq, Eq, Builder)]
129#[builder(
130    custom_constructor,
131    build_fn(private, name = "fallible_build"),
132    setter(into)
133)]
134pub struct HttpResponse {
135    pub status: u16, // FIXME this probably should be a giant enum instead.
136    #[builder(setter(custom))]
137    pub headers: Vec<HttpHeader>,
138    #[serde(with = "serde_bytes")]
139    #[facet(typegen::bytes)]
140    pub body: Vec<u8>,
141}
142
143impl HttpResponse {
144    #[must_use]
145    #[allow(clippy::missing_const_for_fn)]
146    pub fn status(status: u16) -> HttpResponseBuilder {
147        HttpResponseBuilder {
148            status: Some(status),
149            headers: Some(vec![]),
150            body: Some(vec![]),
151        }
152    }
153    #[must_use]
154    pub fn ok() -> HttpResponseBuilder {
155        Self::status(200)
156    }
157}
158
159impl HttpResponseBuilder {
160    pub fn header(&mut self, name: impl Into<String>, value: impl Into<String>) -> &mut Self {
161        self.headers.get_or_insert_with(Vec::new).push(HttpHeader {
162            name: name.into(),
163            value: value.into(),
164        });
165        self
166    }
167
168    /// Sets the body of the response to the given JSON.
169    ///
170    /// # Panics
171    /// If the JSON serialization fails.
172    pub fn json(&mut self, body: impl serde::Serialize) -> &mut Self {
173        self.body = Some(serde_json::to_vec(&body).unwrap());
174        self
175    }
176
177    /// Builds the response.
178    ///
179    /// # Panics
180    /// If a required field has not been initialized.
181    #[must_use]
182    pub fn build(&self) -> HttpResponse {
183        self.fallible_build()
184            .expect("All required fields were initialized")
185    }
186}
187
188#[derive(facet::Facet, Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
189#[repr(C)]
190pub enum HttpResult {
191    Ok(HttpResponse),
192    Err(HttpError),
193}
194
195impl From<crate::Result<HttpResponse>> for HttpResult {
196    fn from(result: Result<HttpResponse, HttpError>) -> Self {
197        match result {
198            Ok(response) => Self::Ok(response),
199            Err(err) => Self::Err(err),
200        }
201    }
202}
203
204impl crux_core::capability::Operation for HttpRequest {
205    type Output = HttpResult;
206
207    #[cfg(feature = "typegen")]
208    fn register_types(
209        generator: &mut crux_core::type_generation::serde::TypeGen,
210    ) -> crux_core::type_generation::serde::Result {
211        generator.register_type::<HttpError>()?;
212        generator.register_type::<Self>()?;
213        generator.register_type::<Self::Output>()?;
214        Ok(())
215    }
216}
217
218#[async_trait]
219pub(crate) trait EffectSender {
220    async fn send(&self, effect: HttpRequest) -> HttpResult;
221}
222
223#[async_trait]
224pub(crate) trait ProtocolRequestBuilder {
225    async fn into_protocol_request(mut self) -> crate::Result<HttpRequest>;
226}
227
228#[async_trait]
229impl ProtocolRequestBuilder for crate::Request {
230    async fn into_protocol_request(mut self) -> crate::Result<HttpRequest> {
231        let body = if self.is_empty() == Some(false) {
232            self.take_body().into_bytes().await?
233        } else {
234            vec![]
235        };
236
237        Ok(HttpRequest {
238            method: self.method().to_string(),
239            url: self.url().to_string(),
240            headers: self
241                .iter()
242                .flat_map(|(name, values)| {
243                    values.iter().map(|value| HttpHeader {
244                        name: name.to_string(),
245                        value: value.to_string(),
246                    })
247                })
248                .collect(),
249            body,
250        })
251    }
252}
253
254impl From<HttpResponse> for crate::ResponseAsync {
255    fn from(effect_response: HttpResponse) -> Self {
256        let mut res = http_types::Response::new(effect_response.status);
257        res.set_body(effect_response.body);
258        for header in effect_response.headers {
259            res.append_header(header.name.as_str(), header.value);
260        }
261
262        Self::new(res)
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use serde::{Deserialize, Serialize};
270
271    #[test]
272    fn test_http_request_get() {
273        let req = HttpRequest::get("https://example.com").build();
274
275        insta::assert_debug_snapshot!(req, @r#"
276        HttpRequest {
277            method: "GET",
278            url: "https://example.com",
279            body: "",
280        }
281        "#);
282    }
283
284    #[test]
285    fn test_http_request_get_with_fields() {
286        let req = HttpRequest::get("https://example.com")
287            .header("foo", "bar")
288            .body("123")
289            .build();
290
291        insta::assert_debug_snapshot!(req, @r#"
292        HttpRequest {
293            method: "GET",
294            url: "https://example.com",
295            headers: [
296                HttpHeader {
297                    name: "foo",
298                    value: "bar",
299                },
300            ],
301            body: "123",
302        }
303        "#);
304    }
305
306    #[test]
307    fn test_http_response_status() {
308        let req = HttpResponse::status(302).build();
309
310        insta::assert_debug_snapshot!(req, @"
311        HttpResponse {
312            status: 302,
313            headers: [],
314            body: [],
315        }
316        ");
317    }
318
319    #[test]
320    fn test_http_response_status_with_fields() {
321        let req = HttpResponse::status(302)
322            .header("foo", "bar")
323            .body("hey")
324            .build();
325
326        insta::assert_debug_snapshot!(req, @r#"
327        HttpResponse {
328            status: 302,
329            headers: [
330                HttpHeader {
331                    name: "foo",
332                    value: "bar",
333                },
334            ],
335            body: [
336                104,
337                101,
338                121,
339            ],
340        }
341        "#);
342    }
343
344    #[test]
345    fn test_http_request_debug_repr() {
346        {
347            // small
348            let req = HttpRequest::post("http://example.com")
349                .header("foo", "bar")
350                .body("hello world!")
351                .build();
352            let repr = format!("{req:?}");
353            assert_eq!(
354                repr,
355                r#"HttpRequest { method: "POST", url: "http://example.com", headers: [HttpHeader { name: "foo", value: "bar" }], body: "hello world!" }"#
356            );
357        }
358
359        {
360            // big
361            let req = HttpRequest::post("http://example.com")
362                // we check that we handle unicode boundaries correctly
363                .body("abcdefghijklmnopqrstuvwxyz abcdefghijklmnopqrstu😀😀😀😀😀😀")
364                .build();
365            let repr = format!("{req:?}");
366            assert_eq!(
367                repr,
368                r#"HttpRequest { method: "POST", url: "http://example.com", body: "abcdefghijklmnopqrstuvwxyz abcdefghijklmnopqrstu😀😀"... }"#
369            );
370        }
371
372        {
373            // binary
374            let req = HttpRequest::post("http://example.com")
375                .body(vec![255, 254, 253, 252])
376                .build();
377            let repr = format!("{req:?}");
378            assert_eq!(
379                repr,
380                r#"HttpRequest { method: "POST", url: "http://example.com", body: <binary data - 4 bytes> }"#
381            );
382        }
383    }
384
385    #[test]
386    fn test_http_request_query() {
387        #[derive(Serialize, Deserialize)]
388        struct QueryParams {
389            page: u32,
390            limit: u32,
391            search: String,
392        }
393
394        let query = QueryParams {
395            page: 2,
396            limit: 10,
397            search: "test".to_string(),
398        };
399
400        let mut builder = HttpRequestBuilder {
401            method: Some("GET".to_string()),
402            url: Some("https://example.com".to_string()),
403            headers: Some(vec![HttpHeader {
404                name: "foo".to_string(),
405                value: "bar".to_string(),
406            }]),
407            body: Some(vec![]),
408        };
409
410        builder
411            .query(&query)
412            .expect("should serialize query params");
413        let req = builder.build();
414
415        insta::assert_debug_snapshot!(req, @r#"
416        HttpRequest {
417            method: "GET",
418            url: "https://example.com?page=2&limit=10&search=test",
419            headers: [
420                HttpHeader {
421                    name: "foo",
422                    value: "bar",
423                },
424            ],
425            body: "",
426        }
427        "#);
428    }
429
430    #[test]
431    fn test_http_request_query_with_special_chars() {
432        #[derive(Serialize, Deserialize)]
433        struct QueryParams {
434            allowed: String,
435            disallowed: String,
436            delimiters: String,
437            alpha_numeric_and_space: String,
438        }
439
440        let query = QueryParams {
441            // allowed chars (RFC 3986)
442            allowed: ";/?:@$,-.!~*'()".to_string(),
443            // disallowed chars (RFC 3986)
444            disallowed: "#".to_string(),
445            // delimiters in key value pairs, need encoding
446            delimiters: "&=+".to_string(),
447            // not RFC 3986 Compliant (space should be %20 not +)
448            // but "+" is very common so we allow it
449            alpha_numeric_and_space: "ABC abc 123".to_string(),
450        };
451
452        let mut builder = HttpRequestBuilder {
453            method: Some("GET".to_string()),
454            url: Some("https://example.com".to_string()),
455            headers: Some(vec![]),
456            body: Some(vec![]),
457        };
458
459        builder
460            .query(&query)
461            .expect("should serialize query params with special chars");
462        let req = builder.build();
463
464        insta::assert_debug_snapshot!(req, @r#"
465        HttpRequest {
466            method: "GET",
467            url: "https://example.com?allowed=;/?:@$,-.!~*'()&disallowed=%23&delimiters=%26%3D%2B&alpha_numeric_and_space=ABC+abc+123",
468            body: "",
469        }
470        "#);
471    }
472
473    #[test]
474    fn test_http_request_query_with_empty_values() {
475        #[derive(Serialize, Deserialize)]
476        struct QueryParams {
477            empty: String,
478            none: Option<String>,
479        }
480
481        let query = QueryParams {
482            empty: String::new(),
483            none: None,
484        };
485
486        let mut builder = HttpRequestBuilder {
487            method: Some("GET".to_string()),
488            url: Some("https://example.com".to_string()),
489            headers: Some(vec![]),
490            body: Some(vec![]),
491        };
492
493        builder
494            .query(&query)
495            .expect("should serialize query params with empty values");
496        let req = builder.build();
497
498        insta::assert_debug_snapshot!(req, @r#"
499        HttpRequest {
500            method: "GET",
501            url: "https://example.com?empty=&none",
502            body: "",
503        }
504        "#);
505    }
506
507    #[test]
508    fn test_http_request_query_with_url_with_existing_query_params() {
509        #[derive(Serialize, Deserialize)]
510        struct QueryParams {
511            name: String,
512            email: String,
513        }
514
515        let query = QueryParams {
516            name: "John Doe".to_string(),
517            email: "john@example.com".to_string(),
518        };
519
520        let mut builder = HttpRequestBuilder {
521            method: Some("GET".to_string()),
522            url: Some("https://example.com?foo=bar".to_string()),
523            headers: Some(vec![]),
524            body: Some(vec![]),
525        };
526
527        builder
528            .query(&query)
529            .expect("should serialize query params");
530        let req = builder.build();
531
532        insta::assert_debug_snapshot!(req, @r#"
533        HttpRequest {
534            method: "GET",
535            url: "https://example.com?foo=bar&name=John+Doe&email=john@example.com",
536            body: "",
537        }
538        "#);
539    }
540}