vest/regular/
and_then.rs

1use crate::properties::*;
2use vstd::prelude::*;
3
4use super::bytes::Bytes;
5
6verus! {
7
8/// Combinator that monadically chains two combinators.
9pub struct AndThen<Prev, Next>(pub Prev, pub Next);
10
11impl<Prev: View, Next: View> View for AndThen<Prev, Next> {
12    type V = AndThen<Prev::V, Next::V>;
13
14    open spec fn view(&self) -> Self::V {
15        AndThen(self.0@, self.1@)
16    }
17}
18
19impl<Next: SpecCombinator> SpecCombinator for AndThen<Bytes, Next> {
20    type SpecResult = Next::SpecResult;
21
22    open spec fn spec_parse(&self, s: Seq<u8>) -> Result<(usize, Self::SpecResult), ()> {
23        if let Ok((n, v1)) = self.0.spec_parse(s) {
24            if let Ok((m, v2)) = self.1.spec_parse(v1) {
25                // !! for security, can only proceed if the `Next` parser consumed the entire
26                // !! output from the `Prev` parser
27                if m == n {
28                    Ok((n, v2))
29                } else {
30                    Err(())
31                }
32            } else {
33                Err(())
34            }
35        } else {
36            Err(())
37        }
38    }
39
40    proof fn spec_parse_wf(&self, s: Seq<u8>) {
41        if let Ok((n, v1)) = self.0.spec_parse(s) {
42            self.0.spec_parse_wf(s);
43            self.1.spec_parse_wf(v1);
44        }
45    }
46
47    open spec fn spec_serialize(&self, v: Self::SpecResult) -> Result<Seq<u8>, ()> {
48        if let Ok(buf1) = self.1.spec_serialize(v) {
49            self.0.spec_serialize(buf1)
50        } else {
51            Err(())
52        }
53    }
54}
55
56impl<Next: SecureSpecCombinator> SecureSpecCombinator for AndThen<Bytes, Next> {
57    proof fn theorem_serialize_parse_roundtrip(&self, v: Self::SpecResult) {
58        if let Ok(buf1) = self.1.spec_serialize(v) {
59            self.1.theorem_serialize_parse_roundtrip(v);
60            self.0.theorem_serialize_parse_roundtrip(buf1);
61        }
62    }
63
64    proof fn theorem_parse_serialize_roundtrip(&self, buf: Seq<u8>) {
65        if let Ok((n, v1)) = self.0.spec_parse(buf) {
66            if let Ok((m, v2)) = self.1.spec_parse(v1) {
67                self.0.theorem_parse_serialize_roundtrip(buf);
68                self.1.theorem_parse_serialize_roundtrip(v1);
69                if m == n {
70                    if let Ok(buf2) = self.1.spec_serialize(v2) {
71                        if let Ok(buf1) = self.0.spec_serialize(buf2) {
72                            assert(buf1 == buf.subrange(0, n as int));
73                        }
74                    }
75                }
76            }
77        }
78    }
79
80    open spec fn is_prefix_secure() -> bool {
81        Bytes::is_prefix_secure()
82    }
83
84    proof fn lemma_prefix_secure(&self, buf: Seq<u8>, s2: Seq<u8>) {
85        self.0.lemma_prefix_secure(buf, s2);
86    }
87}
88
89impl<Next: Combinator> Combinator for AndThen<Bytes, Next> where
90    Next::V: SecureSpecCombinator<SpecResult = <Next::Owned as View>::V>,
91 {
92    type Result<'a> = Next::Result<'a>;
93
94    type Owned = Next::Owned;
95
96    open spec fn spec_length(&self) -> Option<usize> {
97        self.0.spec_length()
98    }
99
100    fn length(&self) -> Option<usize> {
101        self.0.length()
102    }
103
104    open spec fn parse_requires(&self) -> bool {
105        self.1.parse_requires()
106    }
107
108    fn parse<'a>(&self, s: &'a [u8]) -> Result<(usize, Self::Result<'a>), ParseError> {
109        let (n, v1) = self.0.parse(s)?;
110        let (m, v2) = self.1.parse(v1)?;
111        if m == n {
112            Ok((n, v2))
113        } else {
114            Err(ParseError::AndThenUnusedBytes)
115        }
116    }
117
118    open spec fn serialize_requires(&self) -> bool {
119        self.1.serialize_requires()
120    }
121
122    fn serialize(&self, v: Self::Result<'_>, data: &mut Vec<u8>, pos: usize) -> Result<
123        usize,
124        SerializeError,
125    > {
126        let n = self.1.serialize(v, data, pos)?;
127        if n == self.0.0 {
128            Ok(n)
129        } else {
130            Err(SerializeError::AndThenUnusedBytes)
131        }
132    }
133}
134
135} // verus!