verdict_parser/common/
depend.rs

1// Copied from Vest to avoid a Verus export/import issue with some tweaks on Depend
2
3use super::*;
4use vstd::prelude::*;
5use vstd::slice::slice_subrange;
6
7verus! {
8
9/// Spec version of [`Depend`].
10pub struct SpecDepend<Fst, Snd> where Fst: SecureSpecCombinator, Snd: SpecCombinator {
11    /// combinators that contain dependencies
12    pub fst: Fst,
13    /// closure that captures dependencies and maps them to the dependent combinators
14    pub snd: spec_fn(Fst::SpecResult) -> Snd,
15}
16
17impl<Fst, Snd> SpecCombinator for SpecDepend<Fst, Snd> where
18    Fst: SecureSpecCombinator,
19    Snd: SpecCombinator,
20 {
21    type SpecResult = (Fst::SpecResult, Snd::SpecResult);
22
23    open spec fn spec_parse(&self, s: Seq<u8>) -> Result<(usize, Self::SpecResult), ()> {
24        if Fst::is_prefix_secure() {
25            if let Ok((n, v1)) = self.fst.spec_parse(s) {
26                let snd = (self.snd)(v1);
27                if let Ok((m, v2)) = snd.spec_parse(s.subrange(n as int, s.len() as int)) {
28                    if n <= usize::MAX - m {
29                        Ok(((n + m) as usize, (v1, v2)))
30                    } else {
31                        Err(())
32                    }
33                } else {
34                    Err(())
35                }
36            } else {
37                Err(())
38            }
39        } else {
40            Err(())
41        }
42    }
43
44    proof fn spec_parse_wf(&self, s: Seq<u8>) {
45        if let Ok((n, v1)) = self.fst.spec_parse(s) {
46            let snd = (self.snd)(v1);
47            if let Ok((m, v2)) = snd.spec_parse(s.subrange(n as int, s.len() as int)) {
48                self.fst.spec_parse_wf(s);
49                snd.spec_parse_wf(s.subrange(n as int, s.len() as int));
50            }
51        }
52    }
53
54    open spec fn spec_serialize(&self, v: Self::SpecResult) -> Result<Seq<u8>, ()> {
55        if Fst::is_prefix_secure() {
56            if let Ok(buf1) = self.fst.spec_serialize(v.0) {
57                let snd = (self.snd)(v.0);
58                if let Ok(buf2) = snd.spec_serialize(v.1) {
59                    if buf1.len() + buf2.len() <= usize::MAX {
60                        Ok(buf1.add(buf2))
61                    } else {
62                        Err(())
63                    }
64                } else {
65                    Err(())
66                }
67            } else {
68                Err(())
69            }
70        } else {
71            Err(())
72        }
73    }
74}
75
76impl<Fst, Snd> SecureSpecCombinator for SpecDepend<Fst, Snd> where
77    Fst: SecureSpecCombinator,
78    Snd: SecureSpecCombinator,
79 {
80    proof fn theorem_serialize_parse_roundtrip(&self, v: Self::SpecResult) {
81        if let Ok((buf)) = self.spec_serialize(v) {
82            let buf0 = self.fst.spec_serialize(v.0).unwrap();
83            let buf1 = (self.snd)(v.0).spec_serialize(v.1).unwrap();
84            self.fst.theorem_serialize_parse_roundtrip(v.0);
85            self.fst.lemma_prefix_secure(buf0, buf1);
86            (self.snd)(v.0).theorem_serialize_parse_roundtrip(v.1);
87            assert(buf0.add(buf1).subrange(buf0.len() as int, buf.len() as int) == buf1);
88        }
89    }
90
91    proof fn theorem_parse_serialize_roundtrip(&self, buf: Seq<u8>) {
92        if let Ok((nm, (v0, v1))) = self.spec_parse(buf) {
93            let (n, v0_) = self.fst.spec_parse(buf).unwrap();
94            self.fst.spec_parse_wf(buf);
95            let buf0 = buf.subrange(0, n as int);
96            let buf1 = buf.subrange(n as int, buf.len() as int);
97            assert(buf == buf0.add(buf1));
98            self.fst.theorem_parse_serialize_roundtrip(buf);
99            let (m, v1_) = (self.snd)(v0).spec_parse(buf1).unwrap();
100            (self.snd)(v0).theorem_parse_serialize_roundtrip(buf1);
101            (self.snd)(v0).spec_parse_wf(buf1);
102            let buf2 = self.spec_serialize((v0, v1)).unwrap();
103            assert(buf2 == buf.subrange(0, nm as int));
104        } else {
105        }
106    }
107
108    open spec fn is_prefix_secure() -> bool {
109        Fst::is_prefix_secure() && Snd::is_prefix_secure()
110    }
111
112    proof fn lemma_prefix_secure(&self, buf: Seq<u8>, s2: Seq<u8>) {
113        if Fst::is_prefix_secure() && Snd::is_prefix_secure() {
114            if let Ok((nm, (v0, v1))) = self.spec_parse(buf) {
115                let (n, _) = self.fst.spec_parse(buf).unwrap();
116                self.fst.spec_parse_wf(buf);
117                let buf0 = buf.subrange(0, n as int);
118                let buf1 = buf.subrange(n as int, buf.len() as int);
119                self.fst.lemma_prefix_secure(buf0, buf1);
120                self.fst.lemma_prefix_secure(buf0, buf1.add(s2));
121                self.fst.lemma_prefix_secure(buf, s2);
122                let snd = (self.snd)(v0);
123                let (m, v1_) = snd.spec_parse(buf1).unwrap();
124                assert(buf.add(s2).subrange(0, n as int) == buf0);
125                assert(buf.add(s2).subrange(n as int, buf.add(s2).len() as int) == buf1.add(s2));
126                snd.lemma_prefix_secure(buf1, s2);
127            } else {
128            }
129        } else {
130        }
131    }
132}
133
134/// Use this Continuation trait instead of Fn(Input) -> Output
135/// to avoid unsupported Verus features
136pub trait Continuation {
137    type Input<'a>;
138    type Output;
139
140    fn apply<'a>(&self, i: Self::Input<'a>) -> (o: Self::Output)
141        requires self.requires(i)
142        ensures self.ensures(i, o);
143
144    spec fn requires<'a>(&self, i: Self::Input<'a>) -> bool;
145    spec fn ensures<'a>(&self, i: Self::Input<'a>, o: Self::Output) -> bool;
146}
147
148/// Combinator that sequentially applies two combinators, where the second combinator depends on
149/// the result of the first one.
150#[verifier::reject_recursive_types(Snd)]
151pub struct Depend<Fst, Snd, C> where
152    Fst: Combinator,
153    Snd: Combinator,
154    Fst::V: SecureSpecCombinator<SpecResult = <Fst::Owned as View>::V>,
155    Snd::V: SecureSpecCombinator<SpecResult = <Snd::Owned as View>::V>,
156    C: for <'a>Continuation<Input<'a> = Fst::Result<'a>, Output = Snd>,
157 {
158    /// combinators that contain dependencies
159    pub fst: Fst,
160    /// closure that captures dependencies and maps them to the dependent combinators
161    // pub snd: for <'a>fn(Fst::Result<'a>) -> Snd,
162    pub snd: C,
163    /// spec closure for well-formedness
164    pub spec_snd: Ghost<spec_fn(<Fst::Owned as View>::V) -> Snd::V>,
165}
166
167impl<Fst, Snd, C> Depend<Fst, Snd, C> where
168    Fst: Combinator,
169    Snd: Combinator,
170    Fst::V: SecureSpecCombinator<SpecResult = <Fst::Owned as View>::V>,
171    Snd::V: SecureSpecCombinator<SpecResult = <Snd::Owned as View>::V>,
172    C: for <'a>Continuation<Input<'a> = Fst::Result<'a>, Output = Snd>,
173 {
174    /// well-formed [`DepPair`] should have its clousre [`snd`] well-formed w.r.t. [`spec_snd`]
175    pub open spec fn wf(&self) -> bool {
176        let Ghost(spec_snd_dep) = self.spec_snd;
177        &&& forall|i| #[trigger] self.snd.requires(i)
178        &&& forall|i, snd| self.snd.ensures(i, snd) ==> spec_snd_dep(i@) == snd@
179    }
180}
181
182impl<Fst, Snd, C> View for Depend<Fst, Snd, C> where
183    Fst: Combinator,
184    Snd: Combinator,
185    Fst::V: SecureSpecCombinator<SpecResult = <Fst::Owned as View>::V>,
186    Snd::V: SecureSpecCombinator<SpecResult = <Snd::Owned as View>::V>,
187    C: for <'a>Continuation<Input<'a> = Fst::Result<'a>, Output = Snd>,
188 {
189    type V = SpecDepend<Fst::V, Snd::V>;
190
191    open spec fn view(&self) -> Self::V {
192        let Ghost(spec_snd) = self.spec_snd;
193        SpecDepend { fst: self.fst@, snd: spec_snd }
194    }
195}
196
197impl<Fst, Snd, C> Combinator for Depend<Fst, Snd, C> where
198    Fst: Combinator,
199    Snd: Combinator,
200    Fst::V: SecureSpecCombinator<SpecResult = <Fst::Owned as View>::V>,
201    Snd::V: SecureSpecCombinator<SpecResult = <Snd::Owned as View>::V>,
202    C: for <'a>Continuation<Input<'a> = Fst::Result<'a>, Output = Snd>,
203    for <'a>Fst::Result<'a>: PolyfillClone,
204 {
205    type Result<'a> = (Fst::Result<'a>, Snd::Result<'a>);
206
207    type Owned = (Fst::Owned, Snd::Owned);
208
209    open spec fn spec_length(&self) -> Option<usize> {
210        None
211    }
212
213    fn length(&self) -> Option<usize> {
214        None
215    }
216
217    open spec fn parse_requires(&self) -> bool {
218        &&& self.wf()
219        &&& self.fst.parse_requires()
220        &&& forall |i, snd| self.snd.ensures(i, snd) ==> snd.parse_requires()
221        &&& Fst::V::is_prefix_secure()
222    }
223
224    #[inline]
225    fn parse<'a>(&self, s: &'a [u8]) -> (res: Result<(usize, Self::Result<'a>), ParseError>) {
226        let (n, v1) = self.fst.parse(s)?;
227        let s_ = slice_subrange(s, n, s.len());
228        let snd = self.snd.apply(v1.clone());
229        let (m, v2) = snd.parse(s_)?;
230        if n <= usize::MAX - m {
231            Ok(((n + m), (v1, v2)))
232        } else {
233            Err(ParseError::SizeOverflow)
234        }
235    }
236
237    open spec fn serialize_requires(&self) -> bool {
238        &&& self.wf()
239        &&& self.fst.serialize_requires()
240        &&& forall |i, snd| self.snd.ensures(i, snd) ==> snd.serialize_requires()
241        &&& Fst::V::is_prefix_secure()
242    }
243
244    #[inline]
245    fn serialize(&self, v: Self::Result<'_>, data: &mut Vec<u8>, pos: usize) -> (res: Result<
246        usize,
247        SerializeError,
248    >) {
249        let n = self.fst.serialize(v.0.clone(), data, pos)?;
250        if n <= usize::MAX - pos && n + pos <= data.len() {
251            let snd = self.snd.apply(v.0);
252            let m = snd.serialize(v.1, data, pos + n)?;
253            if m <= usize::MAX - n {
254                assert(data@.subrange(pos as int, pos + n + m as int) == self@.spec_serialize(
255                    v@,
256                ).unwrap());
257                assert(data@ == seq_splice(old(data)@, pos, self@.spec_serialize(v@).unwrap()));
258                Ok(n + m)
259            } else {
260                Err(SerializeError::SizeOverflow)
261            }
262        } else {
263            Err(SerializeError::InsufficientBuffer)
264        }
265    }
266}
267
268}