Skip to main content

crux_http/
protocol.rs

1//! The protocol for communicating with the shell
2//!
3//! Crux capabilities don't interface with the outside world themselves, they carry
4//! out all their operations by exchanging messages with the platform specific shell.
5//! This module defines the protocol for `crux_http` to communicate with the shell.
6
7use async_trait::async_trait;
8use derive_builder::Builder;
9use facet_generate_attrs as typegen;
10use serde::{Deserialize, Serialize};
11
12use crate::HttpError;
13
14#[derive(facet::Facet, Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
15pub struct HttpHeader {
16    pub name: String,
17    pub value: String,
18}
19
20#[derive(facet::Facet, Serialize, Deserialize, Default, Clone, PartialEq, Eq, Builder)]
21#[builder(
22    custom_constructor,
23    build_fn(private, name = "fallible_build"),
24    setter(into)
25)]
26pub struct HttpRequest {
27    pub method: String,
28    pub url: String,
29    #[builder(setter(custom))]
30    pub headers: Vec<HttpHeader>,
31    #[serde(with = "serde_bytes")]
32    #[facet(typegen::bytes)]
33    pub body: Vec<u8>,
34}
35
36impl std::fmt::Debug for HttpRequest {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        let body_repr = if let Ok(s) = std::str::from_utf8(&self.body) {
39            if s.len() < 50 {
40                format!("\"{s}\"")
41            } else {
42                format!("\"{}\"...", s.chars().take(50).collect::<String>())
43            }
44        } else {
45            format!("<binary data - {} bytes>", self.body.len())
46        };
47        let mut builder = f.debug_struct("HttpRequest");
48        builder
49            .field("method", &self.method)
50            .field("url", &self.url);
51        if !self.headers.is_empty() {
52            builder.field("headers", &self.headers);
53        }
54        builder.field("body", &format_args!("{body_repr}")).finish()
55    }
56}
57
58macro_rules! http_method {
59    ($name:ident, $method:expr) => {
60        pub fn $name(url: impl Into<String>) -> HttpRequestBuilder {
61            HttpRequestBuilder {
62                method: Some($method.to_string()),
63                url: Some(url.into()),
64                headers: Some(vec![]),
65                body: Some(vec![]),
66            }
67        }
68    };
69}
70
71impl HttpRequest {
72    http_method!(get, "GET");
73    http_method!(put, "PUT");
74    http_method!(delete, "DELETE");
75    http_method!(post, "POST");
76    http_method!(patch, "PATCH");
77    http_method!(head, "HEAD");
78    http_method!(options, "OPTIONS");
79}
80
81impl HttpRequestBuilder {
82    pub fn header(&mut self, name: impl Into<String>, value: impl Into<String>) -> &mut Self {
83        self.headers.get_or_insert_with(Vec::new).push(HttpHeader {
84            name: name.into(),
85            value: value.into(),
86        });
87        self
88    }
89
90    /// Sets the query parameters of the request to the given value.
91    ///
92    /// # Errors
93    /// Returns an [`HttpError`] if the serialization fails.
94    pub fn query(&mut self, query: &impl Serialize) -> crate::Result<&mut Self> {
95        if let Some(url) = &mut self.url {
96            if url.contains('?') {
97                url.push('&');
98            } else {
99                url.push('?');
100            }
101            url.push_str(&serde_qs::to_string(query)?);
102        }
103
104        Ok(self)
105    }
106
107    /// Sets the body of the request to the JSON representation of the given value.
108    ///
109    /// # Panics
110    /// Panics if the serialization fails.
111    pub fn json(&mut self, body: impl serde::Serialize) -> &mut Self {
112        self.body = Some(serde_json::to_vec(&body).unwrap());
113        self
114    }
115
116    /// Builds the request.
117    ///
118    /// # Panics
119    /// Panics if any required fields are missing.
120    #[must_use]
121    pub fn build(&self) -> HttpRequest {
122        self.fallible_build()
123            .expect("All required fields were initialized")
124    }
125}
126
127#[derive(facet::Facet, Serialize, Deserialize, Default, Clone, Debug, PartialEq, Eq, Builder)]
128#[builder(
129    custom_constructor,
130    build_fn(private, name = "fallible_build"),
131    setter(into)
132)]
133pub struct HttpResponse {
134    pub status: u16, // FIXME this probably should be a giant enum instead.
135    #[builder(setter(custom))]
136    pub headers: Vec<HttpHeader>,
137    #[serde(with = "serde_bytes")]
138    #[facet(typegen::bytes)]
139    pub body: Vec<u8>,
140}
141
142impl HttpResponse {
143    #[must_use]
144    pub fn status(status: u16) -> HttpResponseBuilder {
145        HttpResponseBuilder {
146            status: Some(status),
147            headers: Some(vec![]),
148            body: Some(vec![]),
149        }
150    }
151    #[must_use]
152    pub fn ok() -> HttpResponseBuilder {
153        Self::status(200)
154    }
155}
156
157impl HttpResponseBuilder {
158    pub fn header(&mut self, name: impl Into<String>, value: impl Into<String>) -> &mut Self {
159        self.headers.get_or_insert_with(Vec::new).push(HttpHeader {
160            name: name.into(),
161            value: value.into(),
162        });
163        self
164    }
165
166    /// Sets the body of the response to the given JSON.
167    ///
168    /// # Panics
169    /// If the JSON serialization fails.
170    pub fn json(&mut self, body: impl serde::Serialize) -> &mut Self {
171        self.body = Some(serde_json::to_vec(&body).unwrap());
172        self
173    }
174
175    /// Builds the response.
176    ///
177    /// # Panics
178    /// If a required field has not been initialized.
179    #[must_use]
180    pub fn build(&self) -> HttpResponse {
181        self.fallible_build()
182            .expect("All required fields were initialized")
183    }
184}
185
186#[derive(facet::Facet, Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
187#[repr(C)]
188pub enum HttpResult {
189    Ok(HttpResponse),
190    Err(HttpError),
191}
192
193impl From<crate::Result<HttpResponse>> for HttpResult {
194    fn from(result: Result<HttpResponse, HttpError>) -> Self {
195        match result {
196            Ok(response) => HttpResult::Ok(response),
197            Err(err) => HttpResult::Err(err),
198        }
199    }
200}
201
202impl crux_core::capability::Operation for HttpRequest {
203    type Output = HttpResult;
204
205    #[cfg(feature = "typegen")]
206    fn register_types(
207        generator: &mut crux_core::type_generation::serde::TypeGen,
208    ) -> crux_core::type_generation::serde::Result {
209        generator.register_type::<HttpError>()?;
210        generator.register_type::<Self>()?;
211        generator.register_type::<Self::Output>()?;
212        Ok(())
213    }
214}
215
216#[async_trait]
217pub(crate) trait EffectSender {
218    async fn send(&self, effect: HttpRequest) -> HttpResult;
219}
220
221#[async_trait]
222pub(crate) trait ProtocolRequestBuilder {
223    async fn into_protocol_request(mut self) -> crate::Result<HttpRequest>;
224}
225
226#[async_trait]
227impl ProtocolRequestBuilder for crate::Request {
228    async fn into_protocol_request(mut self) -> crate::Result<HttpRequest> {
229        let body = if self.is_empty() == Some(false) {
230            self.take_body().into_bytes().await?
231        } else {
232            vec![]
233        };
234
235        Ok(HttpRequest {
236            method: self.method().to_string(),
237            url: self.url().to_string(),
238            headers: self
239                .iter()
240                .flat_map(|(name, values)| {
241                    values.iter().map(|value| HttpHeader {
242                        name: name.to_string(),
243                        value: value.to_string(),
244                    })
245                })
246                .collect(),
247            body,
248        })
249    }
250}
251
252impl From<HttpResponse> for crate::ResponseAsync {
253    fn from(effect_response: HttpResponse) -> Self {
254        let mut res = http_types::Response::new(effect_response.status);
255        res.set_body(effect_response.body);
256        for header in effect_response.headers {
257            res.append_header(header.name.as_str(), header.value);
258        }
259
260        crate::ResponseAsync::new(res)
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use serde::{Deserialize, Serialize};
268
269    #[test]
270    fn test_http_request_get() {
271        let req = HttpRequest::get("https://example.com").build();
272
273        insta::assert_debug_snapshot!(req, @r#"
274        HttpRequest {
275            method: "GET",
276            url: "https://example.com",
277            body: "",
278        }
279        "#);
280    }
281
282    #[test]
283    fn test_http_request_get_with_fields() {
284        let req = HttpRequest::get("https://example.com")
285            .header("foo", "bar")
286            .body("123")
287            .build();
288
289        insta::assert_debug_snapshot!(req, @r#"
290        HttpRequest {
291            method: "GET",
292            url: "https://example.com",
293            headers: [
294                HttpHeader {
295                    name: "foo",
296                    value: "bar",
297                },
298            ],
299            body: "123",
300        }
301        "#);
302    }
303
304    #[test]
305    fn test_http_response_status() {
306        let req = HttpResponse::status(302).build();
307
308        insta::assert_debug_snapshot!(req, @"
309        HttpResponse {
310            status: 302,
311            headers: [],
312            body: [],
313        }
314        ");
315    }
316
317    #[test]
318    fn test_http_response_status_with_fields() {
319        let req = HttpResponse::status(302)
320            .header("foo", "bar")
321            .body("hey")
322            .build();
323
324        insta::assert_debug_snapshot!(req, @r#"
325        HttpResponse {
326            status: 302,
327            headers: [
328                HttpHeader {
329                    name: "foo",
330                    value: "bar",
331                },
332            ],
333            body: [
334                104,
335                101,
336                121,
337            ],
338        }
339        "#);
340    }
341
342    #[test]
343    fn test_http_request_debug_repr() {
344        {
345            // small
346            let req = HttpRequest::post("http://example.com")
347                .header("foo", "bar")
348                .body("hello world!")
349                .build();
350            let repr = format!("{req:?}");
351            assert_eq!(
352                repr,
353                r#"HttpRequest { method: "POST", url: "http://example.com", headers: [HttpHeader { name: "foo", value: "bar" }], body: "hello world!" }"#
354            );
355        }
356
357        {
358            // big
359            let req = HttpRequest::post("http://example.com")
360                // we check that we handle unicode boundaries correctly
361                .body("abcdefghijklmnopqrstuvwxyz abcdefghijklmnopqrstu😀😀😀😀😀😀")
362                .build();
363            let repr = format!("{req:?}");
364            assert_eq!(
365                repr,
366                r#"HttpRequest { method: "POST", url: "http://example.com", body: "abcdefghijklmnopqrstuvwxyz abcdefghijklmnopqrstu😀😀"... }"#
367            );
368        }
369
370        {
371            // binary
372            let req = HttpRequest::post("http://example.com")
373                .body(vec![255, 254, 253, 252])
374                .build();
375            let repr = format!("{req:?}");
376            assert_eq!(
377                repr,
378                r#"HttpRequest { method: "POST", url: "http://example.com", body: <binary data - 4 bytes> }"#
379            );
380        }
381    }
382
383    #[test]
384    fn test_http_request_query() {
385        #[derive(Serialize, Deserialize)]
386        struct QueryParams {
387            page: u32,
388            limit: u32,
389            search: String,
390        }
391
392        let query = QueryParams {
393            page: 2,
394            limit: 10,
395            search: "test".to_string(),
396        };
397
398        let mut builder = HttpRequestBuilder {
399            method: Some("GET".to_string()),
400            url: Some("https://example.com".to_string()),
401            headers: Some(vec![HttpHeader {
402                name: "foo".to_string(),
403                value: "bar".to_string(),
404            }]),
405            body: Some(vec![]),
406        };
407
408        builder
409            .query(&query)
410            .expect("should serialize query params");
411        let req = builder.build();
412
413        insta::assert_debug_snapshot!(req, @r#"
414        HttpRequest {
415            method: "GET",
416            url: "https://example.com?page=2&limit=10&search=test",
417            headers: [
418                HttpHeader {
419                    name: "foo",
420                    value: "bar",
421                },
422            ],
423            body: "",
424        }
425        "#);
426    }
427
428    #[test]
429    fn test_http_request_query_with_special_chars() {
430        #[derive(Serialize, Deserialize)]
431        struct QueryParams {
432            allowed: String,
433            disallowed: String,
434            delimiters: String,
435            alpha_numeric_and_space: String,
436        }
437
438        let query = QueryParams {
439            // allowed chars (RFC 3986)
440            allowed: ";/?:@$,-.!~*'()".to_string(),
441            // disallowed chars (RFC 3986)
442            disallowed: "#".to_string(),
443            // delimiters in key value pairs, need encoding
444            delimiters: "&=+".to_string(),
445            // not RFC 3986 Compliant (space should be %20 not +)
446            // but "+" is very common so we allow it
447            alpha_numeric_and_space: "ABC abc 123".to_string(),
448        };
449
450        let mut builder = HttpRequestBuilder {
451            method: Some("GET".to_string()),
452            url: Some("https://example.com".to_string()),
453            headers: Some(vec![]),
454            body: Some(vec![]),
455        };
456
457        builder
458            .query(&query)
459            .expect("should serialize query params with special chars");
460        let req = builder.build();
461
462        insta::assert_debug_snapshot!(req, @r#"
463        HttpRequest {
464            method: "GET",
465            url: "https://example.com?allowed=;/?:@$,-.!~*'()&disallowed=%23&delimiters=%26%3D%2B&alpha_numeric_and_space=ABC+abc+123",
466            body: "",
467        }
468        "#);
469    }
470
471    #[test]
472    fn test_http_request_query_with_empty_values() {
473        #[derive(Serialize, Deserialize)]
474        struct QueryParams {
475            empty: String,
476            none: Option<String>,
477        }
478
479        let query = QueryParams {
480            empty: String::new(),
481            none: None,
482        };
483
484        let mut builder = HttpRequestBuilder {
485            method: Some("GET".to_string()),
486            url: Some("https://example.com".to_string()),
487            headers: Some(vec![]),
488            body: Some(vec![]),
489        };
490
491        builder
492            .query(&query)
493            .expect("should serialize query params with empty values");
494        let req = builder.build();
495
496        insta::assert_debug_snapshot!(req, @r#"
497        HttpRequest {
498            method: "GET",
499            url: "https://example.com?empty=&none",
500            body: "",
501        }
502        "#);
503    }
504
505    #[test]
506    fn test_http_request_query_with_url_with_existing_query_params() {
507        #[derive(Serialize, Deserialize)]
508        struct QueryParams {
509            name: String,
510            email: String,
511        }
512
513        let query = QueryParams {
514            name: "John Doe".to_string(),
515            email: "john@example.com".to_string(),
516        };
517
518        let mut builder = HttpRequestBuilder {
519            method: Some("GET".to_string()),
520            url: Some("https://example.com?foo=bar".to_string()),
521            headers: Some(vec![]),
522            body: Some(vec![]),
523        };
524
525        builder
526            .query(&query)
527            .expect("should serialize query params");
528        let req = builder.build();
529
530        insta::assert_debug_snapshot!(req, @r#"
531        HttpRequest {
532            method: "GET",
533            url: "https://example.com?foo=bar&name=John+Doe&email=john@example.com",
534            body: "",
535        }
536        "#);
537    }
538}