Skip to main content

crux_http/
protocol.rs

1//! The protocol for communicating with the shell
2//!
3//! Crux capabilities don't interface with the outside world themselves, they carry
4//! out all their operations by exchanging messages with the platform specific shell.
5//! This module defines the protocol for `crux_http` to communicate with the shell.
6
7use async_trait::async_trait;
8use derive_builder::Builder;
9use serde::{Deserialize, Serialize};
10
11use crate::HttpError;
12
13#[derive(facet::Facet, Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
14pub struct HttpHeader {
15    pub name: String,
16    pub value: String,
17}
18
19#[derive(facet::Facet, Serialize, Deserialize, Default, Clone, PartialEq, Eq, Builder)]
20#[builder(
21    custom_constructor,
22    build_fn(private, name = "fallible_build"),
23    setter(into)
24)]
25pub struct HttpRequest {
26    pub method: String,
27    pub url: String,
28    #[builder(setter(custom))]
29    pub headers: Vec<HttpHeader>,
30    #[serde(with = "serde_bytes")]
31    #[facet(bytes)]
32    pub body: Vec<u8>,
33}
34
35impl std::fmt::Debug for HttpRequest {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        let body_repr = if let Ok(s) = std::str::from_utf8(&self.body) {
38            if s.len() < 50 {
39                format!("\"{s}\"")
40            } else {
41                format!("\"{}\"...", s.chars().take(50).collect::<String>())
42            }
43        } else {
44            format!("<binary data - {} bytes>", self.body.len())
45        };
46        let mut builder = f.debug_struct("HttpRequest");
47        builder
48            .field("method", &self.method)
49            .field("url", &self.url);
50        if !self.headers.is_empty() {
51            builder.field("headers", &self.headers);
52        }
53        builder.field("body", &format_args!("{body_repr}")).finish()
54    }
55}
56
57macro_rules! http_method {
58    ($name:ident, $method:expr) => {
59        pub fn $name(url: impl Into<String>) -> HttpRequestBuilder {
60            HttpRequestBuilder {
61                method: Some($method.to_string()),
62                url: Some(url.into()),
63                headers: Some(vec![]),
64                body: Some(vec![]),
65            }
66        }
67    };
68}
69
70impl HttpRequest {
71    http_method!(get, "GET");
72    http_method!(put, "PUT");
73    http_method!(delete, "DELETE");
74    http_method!(post, "POST");
75    http_method!(patch, "PATCH");
76    http_method!(head, "HEAD");
77    http_method!(options, "OPTIONS");
78}
79
80impl HttpRequestBuilder {
81    pub fn header(&mut self, name: impl Into<String>, value: impl Into<String>) -> &mut Self {
82        self.headers.get_or_insert_with(Vec::new).push(HttpHeader {
83            name: name.into(),
84            value: value.into(),
85        });
86        self
87    }
88
89    /// Sets the query parameters of the request to the given value.
90    ///
91    /// # Errors
92    /// Returns an [`HttpError`] if the serialization fails.
93    pub fn query(&mut self, query: &impl Serialize) -> crate::Result<&mut Self> {
94        if let Some(url) = &mut self.url {
95            if url.contains('?') {
96                url.push('&');
97            } else {
98                url.push('?');
99            }
100            url.push_str(&serde_qs::to_string(query)?);
101        }
102
103        Ok(self)
104    }
105
106    /// Sets the body of the request to the JSON representation of the given value.
107    ///
108    /// # Panics
109    /// Panics if the serialization fails.
110    pub fn json(&mut self, body: impl serde::Serialize) -> &mut Self {
111        self.body = Some(serde_json::to_vec(&body).unwrap());
112        self
113    }
114
115    /// Builds the request.
116    ///
117    /// # Panics
118    /// Panics if any required fields are missing.
119    #[must_use]
120    pub fn build(&self) -> HttpRequest {
121        self.fallible_build()
122            .expect("All required fields were initialized")
123    }
124}
125
126#[derive(facet::Facet, Serialize, Deserialize, Default, Clone, Debug, PartialEq, Eq, Builder)]
127#[builder(
128    custom_constructor,
129    build_fn(private, name = "fallible_build"),
130    setter(into)
131)]
132pub struct HttpResponse {
133    pub status: u16, // FIXME this probably should be a giant enum instead.
134    #[builder(setter(custom))]
135    pub headers: Vec<HttpHeader>,
136    #[serde(with = "serde_bytes")]
137    #[facet(bytes)]
138    pub body: Vec<u8>,
139}
140
141impl HttpResponse {
142    #[must_use]
143    pub fn status(status: u16) -> HttpResponseBuilder {
144        HttpResponseBuilder {
145            status: Some(status),
146            headers: Some(vec![]),
147            body: Some(vec![]),
148        }
149    }
150    #[must_use]
151    pub fn ok() -> HttpResponseBuilder {
152        Self::status(200)
153    }
154}
155
156impl HttpResponseBuilder {
157    pub fn header(&mut self, name: impl Into<String>, value: impl Into<String>) -> &mut Self {
158        self.headers.get_or_insert_with(Vec::new).push(HttpHeader {
159            name: name.into(),
160            value: value.into(),
161        });
162        self
163    }
164
165    /// Sets the body of the response to the given JSON.
166    ///
167    /// # Panics
168    /// If the JSON serialization fails.
169    pub fn json(&mut self, body: impl serde::Serialize) -> &mut Self {
170        self.body = Some(serde_json::to_vec(&body).unwrap());
171        self
172    }
173
174    /// Builds the response.
175    ///
176    /// # Panics
177    /// If a required field has not been initialized.
178    #[must_use]
179    pub fn build(&self) -> HttpResponse {
180        self.fallible_build()
181            .expect("All required fields were initialized")
182    }
183}
184
185#[derive(facet::Facet, Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
186#[repr(C)]
187pub enum HttpResult {
188    Ok(HttpResponse),
189    Err(HttpError),
190}
191
192impl From<crate::Result<HttpResponse>> for HttpResult {
193    fn from(result: Result<HttpResponse, HttpError>) -> Self {
194        match result {
195            Ok(response) => HttpResult::Ok(response),
196            Err(err) => HttpResult::Err(err),
197        }
198    }
199}
200
201impl crux_core::capability::Operation for HttpRequest {
202    type Output = HttpResult;
203
204    #[cfg(feature = "typegen")]
205    fn register_types(
206        generator: &mut crux_core::type_generation::serde::TypeGen,
207    ) -> crux_core::type_generation::serde::Result {
208        generator.register_type::<HttpError>()?;
209        generator.register_type::<Self>()?;
210        generator.register_type::<Self::Output>()?;
211        Ok(())
212    }
213}
214
215#[async_trait]
216pub(crate) trait EffectSender {
217    async fn send(&self, effect: HttpRequest) -> HttpResult;
218}
219
220#[async_trait]
221pub(crate) trait ProtocolRequestBuilder {
222    async fn into_protocol_request(mut self) -> crate::Result<HttpRequest>;
223}
224
225#[async_trait]
226impl ProtocolRequestBuilder for crate::Request {
227    async fn into_protocol_request(mut self) -> crate::Result<HttpRequest> {
228        let body = if self.is_empty() == Some(false) {
229            self.take_body().into_bytes().await?
230        } else {
231            vec![]
232        };
233
234        Ok(HttpRequest {
235            method: self.method().to_string(),
236            url: self.url().to_string(),
237            headers: self
238                .iter()
239                .flat_map(|(name, values)| {
240                    values.iter().map(|value| HttpHeader {
241                        name: name.to_string(),
242                        value: value.to_string(),
243                    })
244                })
245                .collect(),
246            body,
247        })
248    }
249}
250
251impl From<HttpResponse> for crate::ResponseAsync {
252    fn from(effect_response: HttpResponse) -> Self {
253        let mut res = http_types::Response::new(effect_response.status);
254        res.set_body(effect_response.body);
255        for header in effect_response.headers {
256            res.append_header(header.name.as_str(), header.value);
257        }
258
259        crate::ResponseAsync::new(res)
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use serde::{Deserialize, Serialize};
267
268    #[test]
269    fn test_http_request_get() {
270        let req = HttpRequest::get("https://example.com").build();
271
272        insta::assert_debug_snapshot!(req, @r#"
273        HttpRequest {
274            method: "GET",
275            url: "https://example.com",
276            body: "",
277        }
278        "#);
279    }
280
281    #[test]
282    fn test_http_request_get_with_fields() {
283        let req = HttpRequest::get("https://example.com")
284            .header("foo", "bar")
285            .body("123")
286            .build();
287
288        insta::assert_debug_snapshot!(req, @r#"
289        HttpRequest {
290            method: "GET",
291            url: "https://example.com",
292            headers: [
293                HttpHeader {
294                    name: "foo",
295                    value: "bar",
296                },
297            ],
298            body: "123",
299        }
300        "#);
301    }
302
303    #[test]
304    fn test_http_response_status() {
305        let req = HttpResponse::status(302).build();
306
307        insta::assert_debug_snapshot!(req, @"
308        HttpResponse {
309            status: 302,
310            headers: [],
311            body: [],
312        }
313        ");
314    }
315
316    #[test]
317    fn test_http_response_status_with_fields() {
318        let req = HttpResponse::status(302)
319            .header("foo", "bar")
320            .body("hey")
321            .build();
322
323        insta::assert_debug_snapshot!(req, @r#"
324        HttpResponse {
325            status: 302,
326            headers: [
327                HttpHeader {
328                    name: "foo",
329                    value: "bar",
330                },
331            ],
332            body: [
333                104,
334                101,
335                121,
336            ],
337        }
338        "#);
339    }
340
341    #[test]
342    fn test_http_request_debug_repr() {
343        {
344            // small
345            let req = HttpRequest::post("http://example.com")
346                .header("foo", "bar")
347                .body("hello world!")
348                .build();
349            let repr = format!("{req:?}");
350            assert_eq!(
351                repr,
352                r#"HttpRequest { method: "POST", url: "http://example.com", headers: [HttpHeader { name: "foo", value: "bar" }], body: "hello world!" }"#
353            );
354        }
355
356        {
357            // big
358            let req = HttpRequest::post("http://example.com")
359                // we check that we handle unicode boundaries correctly
360                .body("abcdefghijklmnopqrstuvwxyz abcdefghijklmnopqrstu😀😀😀😀😀😀")
361                .build();
362            let repr = format!("{req:?}");
363            assert_eq!(
364                repr,
365                r#"HttpRequest { method: "POST", url: "http://example.com", body: "abcdefghijklmnopqrstuvwxyz abcdefghijklmnopqrstu😀😀"... }"#
366            );
367        }
368
369        {
370            // binary
371            let req = HttpRequest::post("http://example.com")
372                .body(vec![255, 254, 253, 252])
373                .build();
374            let repr = format!("{req:?}");
375            assert_eq!(
376                repr,
377                r#"HttpRequest { method: "POST", url: "http://example.com", body: <binary data - 4 bytes> }"#
378            );
379        }
380    }
381
382    #[test]
383    fn test_http_request_query() {
384        #[derive(Serialize, Deserialize)]
385        struct QueryParams {
386            page: u32,
387            limit: u32,
388            search: String,
389        }
390
391        let query = QueryParams {
392            page: 2,
393            limit: 10,
394            search: "test".to_string(),
395        };
396
397        let mut builder = HttpRequestBuilder {
398            method: Some("GET".to_string()),
399            url: Some("https://example.com".to_string()),
400            headers: Some(vec![HttpHeader {
401                name: "foo".to_string(),
402                value: "bar".to_string(),
403            }]),
404            body: Some(vec![]),
405        };
406
407        builder
408            .query(&query)
409            .expect("should serialize query params");
410        let req = builder.build();
411
412        insta::assert_debug_snapshot!(req, @r#"
413        HttpRequest {
414            method: "GET",
415            url: "https://example.com?page=2&limit=10&search=test",
416            headers: [
417                HttpHeader {
418                    name: "foo",
419                    value: "bar",
420                },
421            ],
422            body: "",
423        }
424        "#);
425    }
426
427    #[test]
428    fn test_http_request_query_with_special_chars() {
429        #[derive(Serialize, Deserialize)]
430        struct QueryParams {
431            allowed: String,
432            disallowed: String,
433            delimiters: String,
434            alpha_numeric_and_space: String,
435        }
436
437        let query = QueryParams {
438            // allowed chars (RFC 3986)
439            allowed: ";/?:@$,-.!~*'()".to_string(),
440            // disallowed chars (RFC 3986)
441            disallowed: "#".to_string(),
442            // delimiters in key value pairs, need encoding
443            delimiters: "&=+".to_string(),
444            // not RFC 3986 Compliant (space should be %20 not +)
445            // but "+" is very common so we allow it
446            alpha_numeric_and_space: "ABC abc 123".to_string(),
447        };
448
449        let mut builder = HttpRequestBuilder {
450            method: Some("GET".to_string()),
451            url: Some("https://example.com".to_string()),
452            headers: Some(vec![]),
453            body: Some(vec![]),
454        };
455
456        builder
457            .query(&query)
458            .expect("should serialize query params with special chars");
459        let req = builder.build();
460
461        insta::assert_debug_snapshot!(req, @r#"
462        HttpRequest {
463            method: "GET",
464            url: "https://example.com?allowed=;/?:@$,-.!~*'()&disallowed=%23&delimiters=%26%3D%2B&alpha_numeric_and_space=ABC+abc+123",
465            body: "",
466        }
467        "#);
468    }
469
470    #[test]
471    fn test_http_request_query_with_empty_values() {
472        #[derive(Serialize, Deserialize)]
473        struct QueryParams {
474            empty: String,
475            none: Option<String>,
476        }
477
478        let query = QueryParams {
479            empty: String::new(),
480            none: None,
481        };
482
483        let mut builder = HttpRequestBuilder {
484            method: Some("GET".to_string()),
485            url: Some("https://example.com".to_string()),
486            headers: Some(vec![]),
487            body: Some(vec![]),
488        };
489
490        builder
491            .query(&query)
492            .expect("should serialize query params with empty values");
493        let req = builder.build();
494
495        insta::assert_debug_snapshot!(req, @r#"
496        HttpRequest {
497            method: "GET",
498            url: "https://example.com?empty=&none",
499            body: "",
500        }
501        "#);
502    }
503
504    #[test]
505    fn test_http_request_query_with_url_with_existing_query_params() {
506        #[derive(Serialize, Deserialize)]
507        struct QueryParams {
508            name: String,
509            email: String,
510        }
511
512        let query = QueryParams {
513            name: "John Doe".to_string(),
514            email: "john@example.com".to_string(),
515        };
516
517        let mut builder = HttpRequestBuilder {
518            method: Some("GET".to_string()),
519            url: Some("https://example.com?foo=bar".to_string()),
520            headers: Some(vec![]),
521            body: Some(vec![]),
522        };
523
524        builder
525            .query(&query)
526            .expect("should serialize query params");
527        let req = builder.build();
528
529        insta::assert_debug_snapshot!(req, @r#"
530        HttpRequest {
531            method: "GET",
532            url: "https://example.com?foo=bar&name=John+Doe&email=john@example.com",
533            body: "",
534        }
535        "#);
536    }
537}