vest/regular/
choice.rs

1use super::disjoint::DisjointFrom;
2use crate::properties::*;
3use vstd::prelude::*;
4
5verus! {
6
7#[allow(missing_docs)]
8#[derive(Debug)]
9pub enum Either<A, B> {
10    Left(A),
11    Right(B),
12}
13
14impl<A: View, B: View> View for Either<A, B> {
15    type V = Either<A::V, B::V>;
16
17    open spec fn view(&self) -> Either<A::V, B::V> {
18        match self {
19            Either::Left(v) => Either::Left(v@),
20            Either::Right(v) => Either::Right(v@),
21        }
22    }
23}
24
25/// Combinator that tries the `Fst` combinator and if it fails, tries the `Snd` combinator.
26pub struct OrdChoice<Fst, Snd>(pub Fst, pub Snd);
27
28impl<Fst: View, Snd: View> View for OrdChoice<Fst, Snd> where  {
29    type V = OrdChoice<Fst::V, Snd::V>;
30
31    open spec fn view(&self) -> Self::V {
32        OrdChoice(self.0@, self.1@)
33    }
34}
35
36impl<Fst, Snd> SpecCombinator for OrdChoice<Fst, Snd> where
37    Fst: SpecCombinator,
38    Snd: SpecCombinator + DisjointFrom<Fst>,
39 {
40    type SpecResult = Either<Fst::SpecResult, Snd::SpecResult>;
41
42    open spec fn spec_parse(&self, s: Seq<u8>) -> Result<(usize, Self::SpecResult), ()> {
43        if self.1.disjoint_from(&self.0) {
44            if let Ok((n, v)) = self.0.spec_parse(s) {
45                Ok((n, Either::Left(v)))
46            } else {
47                if let Ok((n, v)) = self.1.spec_parse(s) {
48                    Ok((n, Either::Right(v)))
49                } else {
50                    Err(())
51                }
52            }
53        } else {
54            Err(())
55        }
56    }
57
58    proof fn spec_parse_wf(&self, s: Seq<u8>) {
59        if let Ok((n, v)) = self.0.spec_parse(s) {
60            self.0.spec_parse_wf(s);
61        } else {
62            if let Ok((n, v)) = self.1.spec_parse(s) {
63                self.1.spec_parse_wf(s);
64            }
65        }
66    }
67
68    open spec fn spec_serialize(&self, v: Self::SpecResult) -> Result<Seq<u8>, ()> {
69        if self.1.disjoint_from(&self.0) {
70            match v {
71                Either::Left(v) => self.0.spec_serialize(v),
72                Either::Right(v) => self.1.spec_serialize(v),
73            }
74        } else {
75            Err(())
76        }
77    }
78}
79
80impl<Fst, Snd> SecureSpecCombinator for OrdChoice<Fst, Snd> where
81    Fst: SecureSpecCombinator,
82    Snd: SecureSpecCombinator + DisjointFrom<Fst>,
83 {
84    open spec fn is_prefix_secure() -> bool {
85        Fst::is_prefix_secure() && Snd::is_prefix_secure()
86    }
87
88    proof fn lemma_prefix_secure(&self, s1: Seq<u8>, s2: Seq<u8>) {
89        if self.1.disjoint_from(&self.0) {
90            // must also explicitly state that parser1 will fail on anything that parser2 will succeed on
91            self.1.parse_disjoint_on(&self.0, s1.add(s2));
92            if Self::is_prefix_secure() {
93                self.0.lemma_prefix_secure(s1, s2);
94                self.1.lemma_prefix_secure(s1, s2);
95            }
96        }
97    }
98
99    proof fn theorem_serialize_parse_roundtrip(&self, v: Self::SpecResult) {
100        match v {
101            Either::Left(v) => {
102                self.0.theorem_serialize_parse_roundtrip(v);
103            },
104            Either::Right(v) => {
105                self.1.theorem_serialize_parse_roundtrip(v);
106                let buf = self.1.spec_serialize(v).unwrap();
107                if self.1.disjoint_from(&self.0) {
108                    self.1.parse_disjoint_on(&self.0, buf);
109                }
110            },
111        }
112    }
113
114    proof fn theorem_parse_serialize_roundtrip(&self, buf: Seq<u8>) {
115        if let Ok((n, v)) = self.0.spec_parse(buf) {
116            self.0.theorem_parse_serialize_roundtrip(buf);
117        } else {
118            if let Ok((n, v)) = self.1.spec_parse(buf) {
119                self.1.theorem_parse_serialize_roundtrip(buf);
120            }
121        }
122    }
123}
124
125impl<Fst, Snd> Combinator for OrdChoice<Fst, Snd> where
126    Fst: Combinator,
127    Snd: Combinator,
128    Fst::V: SecureSpecCombinator<SpecResult = <Fst::Owned as View>::V>,
129    Snd::V: SecureSpecCombinator<SpecResult = <Snd::Owned as View>::V>,
130    Snd::V: DisjointFrom<Fst::V>,
131 {
132    type Result<'a> = Either<Fst::Result<'a>, Snd::Result<'a>>;
133
134    type Owned = Either<Fst::Owned, Snd::Owned>;
135
136    open spec fn spec_length(&self) -> Option<usize> {
137        None
138    }
139
140    fn length(&self) -> Option<usize> {
141        None
142    }
143
144    open spec fn parse_requires(&self) -> bool {
145        self.0.parse_requires() && self.1.parse_requires() && self@.1.disjoint_from(&self@.0)
146    }
147
148    fn parse<'a>(&self, s: &'a [u8]) -> (res: Result<(usize, Self::Result<'a>), ParseError>) {
149        if let Ok((n, v)) = self.0.parse(s) {
150            Ok((n, Either::Left(v)))
151        } else {
152            if let Ok((n, v)) = self.1.parse(s) {
153                Ok((n, Either::Right(v)))
154            } else {
155                Err(ParseError::OrdChoiceNoMatch)
156            }
157        }
158    }
159
160    open spec fn serialize_requires(&self) -> bool {
161        self.0.serialize_requires() && self.1.serialize_requires() && self@.1.disjoint_from(
162            &self@.0,
163        )
164    }
165
166    fn serialize(&self, v: Self::Result<'_>, data: &mut Vec<u8>, pos: usize) -> (res: Result<
167        usize,
168        SerializeError,
169    >) {
170        match v {
171            Either::Left(v) => {
172                let n = self.0.serialize(v, data, pos)?;
173                if n <= usize::MAX - pos && n + pos <= data.len() {
174                    Ok(n)
175                } else {
176                    Err(SerializeError::InsufficientBuffer)
177                }
178            },
179            Either::Right(v) => {
180                let n = self.1.serialize(v, data, pos)?;
181                if n <= usize::MAX - pos && n + pos <= data.len() {
182                    Ok(n)
183                } else {
184                    Err(SerializeError::InsufficientBuffer)
185                }
186            },
187        }
188    }
189}
190
191/// This macro constructs a nested OrdChoice combinator
192/// in the form of OrdChoice(..., OrdChoice(..., OrdChoice(..., ...)))
193#[allow(unused_macros)]
194#[macro_export]
195macro_rules! ord_choice {
196    ($c:expr $(,)?) => {
197        $c
198    };
199
200    ($c:expr, $($rest:expr),* $(,)?) => {
201        OrdChoice($c, ord_choice!($($rest),*))
202    };
203}
204
205pub use ord_choice;
206
207/// Build a type for the `ord_choice!` macro
208#[allow(unused_macros)]
209#[macro_export]
210macro_rules! ord_choice_type {
211    ($c:ty $(,)?) => {
212        $c
213    };
214
215    ($c:ty, $($rest:ty),* $(,)?) => {
216        OrdChoice<$c, ord_choice_type!($($rest),*)>
217    };
218}
219
220pub use ord_choice_type;
221
222/// Build a type for the result of `ord_choice!`
223#[allow(unused_macros)]
224#[macro_export]
225macro_rules! ord_choice_result {
226    ($c:ty $(,)?) => {
227        $c
228    };
229
230    ($c:ty, $($rest:ty),* $(,)?) => {
231        Either<$c, ord_choice_result!($($rest),*)>
232    };
233}
234
235pub use ord_choice_result;
236
237/// Maps x:Ti to ord_choice_result!(T1, ..., Tn)
238#[allow(unused_macros)]
239#[macro_export]
240macro_rules! inj_ord_choice_result {
241    (*, $($rest:tt),* $(,)?) => {
242        Either::Right(inj_ord_choice_result!($($rest),*))
243    };
244
245    ($x:expr $(,)?) => {
246        $x
247    };
248
249    ($x:expr, $(*),* $(,)?) => {
250        Either::Left($x)
251    };
252}
253
254pub use inj_ord_choice_result;
255
256/// Same as above but for patterns
257#[allow(unused_macros)]
258#[macro_export]
259macro_rules! inj_ord_choice_pat {
260    (*, $($rest:tt),* $(,)?) => {
261        Either::Right(inj_ord_choice_pat!($($rest),*))
262    };
263
264    ($x:pat $(,)?) => {
265        $x
266    };
267
268    ($x:pat, $(*),* $(,)?) => {
269        Either::Left($x)
270    };
271}
272
273pub use inj_ord_choice_pat;
274
275// what would it look like if we manually implemented the match combinator?
276//
277// use super::uints::*;
278// use super::tail::*;
279//
280// pub struct MatchU8With123 {
281//     pub val: u8,
282//     pub arm1: U8,
283//     pub arm2: U16,
284//     pub arm3: U32,
285//     pub default: Tail,
286// }
287//
288// impl View for MatchU8With123 {
289//     type V = Self;
290//
291//     open spec fn view(&self) -> Self::V {
292//         MatchU8With123 {
293//             val: self.val,
294//             arm1: self.arm1@,
295//             arm2: self.arm2@,
296//             arm3: self.arm3@,
297//             default: self.default@,
298//         }
299//     }
300// }
301//
302// pub enum SpecMsgMatchU8With123 {
303//     Arm1(u8),
304//     Arm2(u16),
305//     Arm3(u32),
306//     Default(Seq<u8>),
307// }
308//
309// pub enum MsgMatchU8With123<'a> {
310//     Arm1(u8),
311//     Arm2(u16),
312//     Arm3(u32),
313//     Default(&'a [u8]),
314// }
315//
316// pub enum MsgOwnedMatchU8With123 {
317//     Arm1(u8),
318//     Arm2(u16),
319//     Arm3(u32),
320//     Default(Vec<u8>),
321// }
322//
323// impl View for MsgMatchU8With123<'_> {
324//     type V = SpecMsgMatchU8With123;
325//
326//     open spec fn view(&self) -> Self::V {
327//         match self {
328//             MsgMatchU8With123::Arm1(v) => SpecMsgMatchU8With123::Arm1(v@),
329//             MsgMatchU8With123::Arm2(v) => SpecMsgMatchU8With123::Arm2(v@),
330//             MsgMatchU8With123::Arm3(v) => SpecMsgMatchU8With123::Arm3(v@),
331//             MsgMatchU8With123::Default(v) => SpecMsgMatchU8With123::Default(v@),
332//         }
333//     }
334// }
335//
336// impl View for MsgOwnedMatchU8With123 {
337//     type V = SpecMsgMatchU8With123;
338//
339//     open spec fn view(&self) -> Self::V {
340//         match self {
341//             MsgOwnedMatchU8With123::Arm1(v) => SpecMsgMatchU8With123::Arm1(v@),
342//             MsgOwnedMatchU8With123::Arm2(v) => SpecMsgMatchU8With123::Arm2(v@),
343//             MsgOwnedMatchU8With123::Arm3(v) => SpecMsgMatchU8With123::Arm3(v@),
344//             MsgOwnedMatchU8With123::Default(v) => SpecMsgMatchU8With123::Default(v@),
345//         }
346//     }
347// }
348//
349// impl SpecCombinator for MatchU8With123 {
350//     type SpecResult = SpecMsgMatchU8With123;
351//
352//     open spec fn spec_parse(&self, s: Seq<u8>) -> Result<(usize, Self::SpecResult), ()> {
353//         match self.val {
354//             1u8 => {
355//                 if let Ok((n, v)) = self.arm1.spec_parse(s) {
356//                     Ok((n, SpecMsgMatchU8With123::Arm1(v)))
357//                 } else {
358//                     Err(())
359//                 }
360//             },
361//             2u8 => {
362//                 if let Ok((n, v)) = self.arm2.spec_parse(s) {
363//                     Ok((n, SpecMsgMatchU8With123::Arm2(v)))
364//                 } else {
365//                     Err(())
366//                 }
367//             },
368//             3u8 => {
369//                 if let Ok((n, v)) = self.arm3.spec_parse(s) {
370//                     Ok((n, SpecMsgMatchU8With123::Arm3(v)))
371//                 } else {
372//                     Err(())
373//                 }
374//             },
375//             _ => {
376//                 if let Ok((n, v)) = self.default.spec_parse(s) {
377//                     Ok((n, SpecMsgMatchU8With123::Default(v)))
378//                 } else {
379//                     Err(())
380//                 }
381//             },
382//         }
383//     }
384//
385//     proof fn spec_parse_wf(&self, s: Seq<u8>) {
386//         match self.val {
387//             1u8 => {
388//                 if let Ok((n, v)) = self.arm1.spec_parse(s) {
389//                     self.arm1.spec_parse_wf(s);
390//                 }
391//             },
392//             2u8 => {
393//                 if let Ok((n, v)) = self.arm2.spec_parse(s) {
394//                     self.arm2.spec_parse_wf(s);
395//                 }
396//             },
397//             3u8 => {
398//                 if let Ok((n, v)) = self.arm3.spec_parse(s) {
399//                     self.arm3.spec_parse_wf(s);
400//                 }
401//             },
402//             _ => {
403//                 if let Ok((n, v)) = self.default.spec_parse(s) {
404//                     self.default.spec_parse_wf(s);
405//                 }
406//             },
407//         }
408//     }
409//
410//     open spec fn spec_serialize(&self, v: Self::SpecResult) -> Result<Seq<u8>, ()> {
411//         match self.val {
412//             1u8 => {
413//                 if let SpecMsgMatchU8With123::Arm1(v) = v {
414//                     self.arm1.spec_serialize(v)
415//                 } else {
416//                     Err(())
417//                 }
418//             },
419//             2u8 => {
420//                 if let SpecMsgMatchU8With123::Arm2(v) = v {
421//                     self.arm2.spec_serialize(v)
422//                 } else {
423//                     Err(())
424//                 }
425//             },
426//             3u8 => {
427//                 if let SpecMsgMatchU8With123::Arm3(v) = v {
428//                     self.arm3.spec_serialize(v)
429//                 } else {
430//                     Err(())
431//                 }
432//             },
433//             _ => {
434//                 if let SpecMsgMatchU8With123::Default(v) = v {
435//                     self.default.spec_serialize(v)
436//                 } else {
437//                     Err(())
438//                 }
439//             },
440//         }
441//     }
442// }
443//
444// impl SecureSpecCombinator for MatchU8With123 {
445//     open spec fn spec_is_prefix_secure() -> bool {
446//         U8::spec_is_prefix_secure() && U16::spec_is_prefix_secure() && U32::spec_is_prefix_secure()
447//             && Tail::spec_is_prefix_secure()
448//     }
449//
450//     proof fn lemma_prefix_secure(&self, s1: Seq<u8>, s2: Seq<u8>) {
451//         match self.val {
452//             1u8 => {
453//                 self.arm1.lemma_prefix_secure(s1, s2);
454//             },
455//             2u8 => {
456//                 self.arm2.lemma_prefix_secure(s1, s2);
457//             },
458//             3u8 => {
459//                 self.arm3.lemma_prefix_secure(s1, s2);
460//             },
461//             _ => {
462//                 self.default.lemma_prefix_secure(s1, s2);
463//             },
464//         }
465//     }
466//
467//     proof fn theorem_serialize_parse_roundtrip(&self, v: Self::SpecResult) {
468//         match self.val {
469//             1u8 => {
470//                 if let SpecMsgMatchU8With123::Arm1(v) = v {
471//                     self.arm1.theorem_serialize_parse_roundtrip(v);
472//                 }
473//             },
474//             2u8 => {
475//                 if let SpecMsgMatchU8With123::Arm2(v) = v {
476//                     self.arm2.theorem_serialize_parse_roundtrip(v);
477//                 }
478//             },
479//             3u8 => {
480//                 if let SpecMsgMatchU8With123::Arm3(v) = v {
481//                     self.arm3.theorem_serialize_parse_roundtrip(v);
482//                 }
483//             },
484//             _ => {
485//                 if let SpecMsgMatchU8With123::Default(v) = v {
486//                     self.default.theorem_serialize_parse_roundtrip(v);
487//                 }
488//             },
489//         }
490//     }
491//
492//     proof fn theorem_parse_serialize_roundtrip(&self, buf: Seq<u8>) {
493//         match self.val {
494//             1u8 => {
495//                 self.arm1.theorem_parse_serialize_roundtrip(buf);
496//             },
497//             2u8 => {
498//                 self.arm2.theorem_parse_serialize_roundtrip(buf);
499//             },
500//             3u8 => {
501//                 self.arm3.theorem_parse_serialize_roundtrip(buf);
502//             },
503//             _ => {
504//                 self.default.theorem_parse_serialize_roundtrip(buf);
505//             },
506//         }
507//     }
508// }
509//
510// impl Combinator for MatchU8With123 {
511//     type Result<'a> = MsgMatchU8With123<'a>;
512//
513//     type Owned = MsgOwnedMatchU8With123;
514//
515//     open spec fn spec_length(&self) -> Option<usize> {
516//         None
517//     }
518//
519//     fn length(&self) -> Option<usize> {
520//         None
521//     }
522//
523//     fn exec_is_prefix_secure() -> bool {
524//         U8::exec_is_prefix_secure() && U16::exec_is_prefix_secure() && U32::exec_is_prefix_secure()
525//             && Tail::exec_is_prefix_secure()
526//     }
527//
528//     open spec fn parse_requires(&self) -> bool {
529//         self.arm1.parse_requires() && self.arm2.parse_requires() && self.arm3.parse_requires()
530//             && self.default.parse_requires()
531//     }
532//
533//     fn parse<'a>(&self, s: &'a [u8]) -> (res: Result<(usize, Self::Result<'a>), ()>) {
534//         match self.val {
535//             1u8 => {
536//                 if let Ok((n, v)) = self.arm1.parse(s) {
537//                     Ok((n, MsgMatchU8With123::Arm1(v)))
538//                 } else {
539//                     Err(())
540//                 }
541//             },
542//             2u8 => {
543//                 if let Ok((n, v)) = self.arm2.parse(s) {
544//                     Ok((n, MsgMatchU8With123::Arm2(v)))
545//                 } else {
546//                     Err(())
547//                 }
548//             },
549//             3u8 => {
550//                 if let Ok((n, v)) = self.arm3.parse(s) {
551//                     Ok((n, MsgMatchU8With123::Arm3(v)))
552//                 } else {
553//                     Err(())
554//                 }
555//             },
556//             _ => {
557//                 if let Ok((n, v)) = self.default.parse(s) {
558//                     Ok((n, MsgMatchU8With123::Default(v)))
559//                 } else {
560//                     Err(())
561//                 }
562//             },
563//         }
564//     }
565//
566//     open spec fn serialize_requires(&self) -> bool {
567//         self.arm1.serialize_requires() && self.arm2.serialize_requires()
568//             && self.arm3.serialize_requires() && self.default.serialize_requires()
569//     }
570//
571//     fn serialize(&self, v: Self::Result<'_>, data: &mut Vec<u8>, pos: usize) -> (res: Result<
572//         usize,
573//         (),
574//     >) {
575//         match self.val {
576//             1u8 => {
577//                 if let MsgMatchU8With123::Arm1(v) = v {
578//                     self.arm1.serialize(v, data, pos)
579//                 } else {
580//                     Err(())
581//                 }
582//             },
583//             2u8 => {
584//                 if let MsgMatchU8With123::Arm2(v) = v {
585//                     self.arm2.serialize(v, data, pos)
586//                 } else {
587//                     Err(())
588//                 }
589//             },
590//             3u8 => {
591//                 if let MsgMatchU8With123::Arm3(v) = v {
592//                     self.arm3.serialize(v, data, pos)
593//                 } else {
594//                     Err(())
595//                 }
596//             },
597//             _ => {
598//                 if let MsgMatchU8With123::Default(v) = v {
599//                     self.default.serialize(v, data, pos)
600//                 } else {
601//                     Err(())
602//                 }
603//             },
604//         }
605//     }
606// }
607} // verus!