verdict_parser/asn1/
utf8_string.rs

1use super::*;
2use vstd::prelude::*;
3use vstd::vstd::slice::slice_subrange;
4
5verus! {
6
7/// Combainator for UTF8String in ASN.1
8#[derive(Debug, View)]
9pub struct UTF8String;
10
11asn1_tagged!(UTF8String, tag_of!(UTF8_STRING));
12
13pub type SpecUTF8StringValue = Seq<char>;
14pub type UTF8StringValue<'a> = &'a str;
15pub type UTF8StringValueOwned = String;
16
17impl SpecCombinator for UTF8String {
18    type SpecResult = SpecUTF8StringValue;
19
20    closed spec fn spec_parse(&self, s: Seq<u8>) -> Result<(usize, Self::SpecResult), ()> {
21        match Length.spec_parse(s) {
22            Ok((n, l)) => {
23                if n + l <= usize::MAX && n + l <= s.len() {
24                    match spec_parse_utf8(s.skip(n as int).take(l as int)) {
25                        Some(parsed) => Ok(((n + l) as usize, parsed)),
26                        None => Err(()),
27                    }
28                } else {
29                    Err(())
30                }
31            }
32            Err(()) => Err(()),
33        }
34    }
35
36    proof fn spec_parse_wf(&self, s: Seq<u8>) {}
37
38    closed spec fn spec_serialize(&self, v: Self::SpecResult) -> Result<Seq<u8>, ()> {
39        let s = spec_serialize_utf8(v);
40        match Length.spec_serialize(s.len() as LengthValue) {
41            Ok(buf) =>
42                if buf.len() + s.len() <= usize::MAX {
43                    Ok(buf + s)
44                } else {
45                    Err(())
46                },
47            Err(()) => Err(()),
48        }
49    }
50}
51
52impl SecureSpecCombinator for UTF8String {
53    open spec fn is_prefix_secure() -> bool {
54        true
55    }
56
57    proof fn theorem_serialize_parse_roundtrip(&self, v: Self::SpecResult) {
58        let s = spec_serialize_utf8(v);
59
60        Length.theorem_serialize_parse_roundtrip(s.len() as LengthValue);
61        spec_utf8_serialize_parse_roundtrip(v);
62
63        if let Ok(buf) = Length.spec_serialize(s.len() as LengthValue)  {
64            if buf.len() + s.len() <= usize::MAX {
65                Length.lemma_prefix_secure(buf, s);
66                assert((buf + s).skip(buf.len() as int).take(s.len() as int) == s);
67            }
68        }
69    }
70
71    proof fn theorem_parse_serialize_roundtrip(&self, buf: Seq<u8>) {
72        if let Ok((n, l)) = Length.spec_parse(buf) {
73            if n + l <= buf.len() {
74                let s = buf.skip(n as int).take(l as int);
75
76                Length.theorem_parse_serialize_roundtrip(buf);
77                spec_utf8_parse_serialize_roundtrip(s);
78                assert(buf.subrange(0, (n + l) as int) == buf.subrange(0, n as int) + buf.skip(n as int).take(l as int));
79            }
80        }
81    }
82
83    proof fn lemma_prefix_secure(&self, s1: Seq<u8>, s2: Seq<u8>) {
84        Length.lemma_prefix_secure(s1, s2);
85
86        if let Ok((n, l)) = Length.spec_parse(s1) {
87            if n + l <= s1.len() {
88                assert(s1.skip(n as int).take(l as int) =~= (s1 + s2).skip(n as int).take(l as int));
89            }
90        }
91    }
92}
93
94impl Combinator for UTF8String {
95    type Result<'a> = UTF8StringValue<'a>;
96    type Owned = UTF8StringValueOwned;
97
98    closed spec fn spec_length(&self) -> Option<usize> {
99        None
100    }
101
102    fn length(&self) -> Option<usize> {
103        None
104    }
105
106    #[inline(always)]
107    fn parse<'a>(&self, s: &'a [u8]) -> (res: Result<(usize, Self::Result<'a>), ParseError>) {
108        let (n, l) = Length.parse(s)?;
109
110        if let Some(total_len) = n.checked_add(l as usize) {
111            if total_len <= s.len() {
112                match utf8_to_str(slice_take(slice_subrange(s, n, s.len()), l as usize)) {
113                    Some(parsed) => Ok((total_len, parsed)),
114                    _ => Err(ParseError::Other("Invalid UTF-8".to_string()))
115                }
116            } else {
117                Err(ParseError::UnexpectedEndOfInput)
118            }
119        } else {
120            Err(ParseError::SizeOverflow)
121        }
122    }
123
124    #[inline(always)]
125    fn serialize(&self, v: Self::Result<'_>, data: &mut Vec<u8>, pos: usize) -> (res: Result<usize, SerializeError>) {
126        let s = str_to_utf8(v);
127        let n = Length.serialize(s.len() as LengthValue, data, pos)?;
128
129        if pos.checked_add(n).is_none() {
130            return Err(SerializeError::SizeOverflow);
131        }
132
133        if (pos + n).checked_add(s.len()).is_none() {
134            return Err(SerializeError::SizeOverflow);
135        }
136
137        if pos + n + s.len() >= data.len() {
138            return Err(SerializeError::InsufficientBuffer);
139        }
140
141        let ghost data_after_len = data@;
142
143        // No Vec::splice yet in Verus
144        for i in 0..s.len()
145            invariant
146                pos + n + s.len() <= usize::MAX,
147                pos + n + s.len() < data.len() == data_after_len.len(),
148
149                data@ =~= seq_splice(data_after_len, (pos + n) as usize, s@.take(i as int)),
150        {
151            data.set(pos + n + i, s[i]);
152        }
153
154        assert(data@ =~= seq_splice(old(data)@, pos, Length.spec_serialize(s@.len() as LengthValue).unwrap() + s@));
155
156        Ok(n + s.len())
157    }
158}
159
160}
161
162#[cfg(test)]
163mod test {
164    use super::*;
165    use der::Encode;
166
167    fn serialize_utf8_string(v: &str) -> Result<Vec<u8>, SerializeError> {
168        let mut data = vec![0; v.len() + 10];
169        data[0] = 0x0c; // Prepend the tag byte
170        let len = UTF8String.serialize(v, &mut data, 1)?;
171        data.truncate(len + 1);
172        Ok(data)
173    }
174
175    #[test]
176    fn diff_with_der() {
177        let diff = |s: &str| {
178            let res1 = serialize_utf8_string(s).map_err(|_| ());
179            let res2 = s.to_string().to_der().map_err(|_| ());
180            assert_eq!(res1, res2);
181        };
182
183        diff("");
184        diff("asdsad");
185        diff("้ป‘้ฃŽ้›ท");
186        diff("๐Ÿ‘จโ€๐Ÿ‘ฉโ€๐Ÿ‘งโ€๐Ÿ‘ฆ");
187        diff("้ป‘้ฃŽ้›ท".repeat(256).as_str());
188    }
189}