diff --git a/contrib/import-checker.py b/contrib/import-checker.py --- a/contrib/import-checker.py +++ b/contrib/import-checker.py @@ -36,6 +36,8 @@ 'mercurial.pure.parsers', # third-party imports should be directly imported 'mercurial.thirdparty', + 'mercurial.thirdparty.cbor', + 'mercurial.thirdparty.cbor.cbor2', 'mercurial.thirdparty.zope', 'mercurial.thirdparty.zope.interface', ) diff --git a/mercurial/utils/cborutil.py b/mercurial/utils/cborutil.py new file mode 100644 --- /dev/null +++ b/mercurial/utils/cborutil.py @@ -0,0 +1,258 @@ +# cborutil.py - CBOR extensions +# +# Copyright 2018 Gregory Szorc +# +# This software may be used and distributed according to the terms of the +# GNU General Public License version 2 or any later version. + +from __future__ import absolute_import + +import struct + +from ..thirdparty.cbor.cbor2 import ( + decoder as decodermod, +) + +# Very short very of RFC 7049... +# +# Each item begins with a byte. The 3 high bits of that byte denote the +# "major type." The lower 5 bits denote the "subtype." Each major type +# has its own encoding mechanism. +# +# Most types have lengths. However, bytestring, string, array, and map +# can be indefinite length. These are denotes by a subtype with value 31. +# Sub-components of those types then come afterwards and are terminated +# by a "break" byte. + +MAJOR_TYPE_UINT = 0 +MAJOR_TYPE_NEGINT = 1 +MAJOR_TYPE_BYTESTRING = 2 +MAJOR_TYPE_STRING = 3 +MAJOR_TYPE_ARRAY = 4 +MAJOR_TYPE_MAP = 5 +MAJOR_TYPE_SEMANTIC = 6 +MAJOR_TYPE_SPECIAL = 7 + +SUBTYPE_MASK = 0b00011111 + +SUBTYPE_HALF_FLOAT = 25 +SUBTYPE_SINGLE_FLOAT = 26 +SUBTYPE_DOUBLE_FLOAT = 27 +SUBTYPE_INDEFINITE = 31 + +# Indefinite types begin with their major type ORd with information value 31. +BEGIN_INDEFINITE_BYTESTRING = struct.pack( + r'>B', MAJOR_TYPE_BYTESTRING << 5 | SUBTYPE_INDEFINITE) +BEGIN_INDEFINITE_ARRAY = struct.pack( + r'>B', MAJOR_TYPE_ARRAY << 5 | SUBTYPE_INDEFINITE) +BEGIN_INDEFINITE_MAP = struct.pack( + r'>B', MAJOR_TYPE_MAP << 5 | SUBTYPE_INDEFINITE) + +ENCODED_LENGTH_1 = struct.Struct(r'>B') +ENCODED_LENGTH_2 = struct.Struct(r'>BB') +ENCODED_LENGTH_3 = struct.Struct(r'>BH') +ENCODED_LENGTH_4 = struct.Struct(r'>BL') +ENCODED_LENGTH_5 = struct.Struct(r'>BQ') + +# The break ends an indefinite length item. +BREAK = b'\xff' +BREAK_INT = 255 + +def encodelength(majortype, length): + """Obtain a value encoding the major type and its length.""" + if length < 24: + return ENCODED_LENGTH_1.pack(majortype << 5 | length) + elif length < 256: + return ENCODED_LENGTH_2.pack(majortype << 5 | 24, length) + elif length < 65536: + return ENCODED_LENGTH_3.pack(majortype << 5 | 25, length) + elif length < 4294967296: + return ENCODED_LENGTH_4.pack(majortype << 5 | 26, length) + else: + return ENCODED_LENGTH_5.pack(majortype << 5 | 27, length) + +def streamencodebytestring(v): + yield encodelength(MAJOR_TYPE_BYTESTRING, len(v)) + yield v + +def streamencodebytestringfromiter(it): + """Convert an iterator of chunks to an indefinite bytestring. + + Given an input that is iterable and each element in the iterator is + representable as bytes, emit an indefinite length bytestring. + """ + yield BEGIN_INDEFINITE_BYTESTRING + + for chunk in it: + yield encodelength(MAJOR_TYPE_BYTESTRING, len(chunk)) + yield chunk + + yield BREAK + +def streamencodeindefinitebytestring(source, chunksize=65536): + """Given a large source buffer, emit as an indefinite length bytestring. + + This is a generator of chunks constituting the encoded CBOR data. + """ + yield BEGIN_INDEFINITE_BYTESTRING + + i = 0 + l = len(source) + + while True: + chunk = source[i:i + chunksize] + i += len(chunk) + + yield encodelength(MAJOR_TYPE_BYTESTRING, len(chunk)) + yield chunk + + if i >= l: + break + + yield BREAK + +def streamencodeint(v): + if v >= 18446744073709551616 or v < -18446744073709551616: + raise ValueError('big integers not supported') + + if v >= 0: + yield encodelength(MAJOR_TYPE_UINT, v) + else: + yield encodelength(MAJOR_TYPE_NEGINT, abs(v) - 1) + +def streamencodearray(l): + """Encode a known size iterable to an array.""" + + yield encodelength(MAJOR_TYPE_ARRAY, len(l)) + + for i in l: + for chunk in streamencode(i): + yield chunk + +def streamencodearrayfromiter(it): + """Encode an iterator of items to an indefinite length array.""" + + yield BEGIN_INDEFINITE_ARRAY + + for i in it: + for chunk in streamencode(i): + yield chunk + + yield BREAK + +def streamencodeset(s): + # https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml defines + # semantic tag 258 for finite sets. + yield encodelength(MAJOR_TYPE_SEMANTIC, 258) + + for chunk in streamencodearray(sorted(s)): + yield chunk + +def streamencodemap(d): + """Encode dictionary to a generator. + + Does not supporting indefinite length dictionaries. + """ + yield encodelength(MAJOR_TYPE_MAP, len(d)) + + for key, value in sorted(d.iteritems()): + for chunk in streamencode(key): + yield chunk + for chunk in streamencode(value): + yield chunk + +def streamencodemapfromiter(it): + """Given an iterable of (key, value), encode to an indefinite length map.""" + yield BEGIN_INDEFINITE_MAP + + for key, value in it: + for chunk in streamencode(key): + yield chunk + for chunk in streamencode(value): + yield chunk + + yield BREAK + +def streamencodebool(b): + # major type 7, simple value 20 and 21. + yield b'\xf5' if b else b'\xf4' + +def streamencodenone(v): + # major type 7, simple value 22. + yield b'\xf6' + +STREAM_ENCODERS = { + bytes: streamencodebytestring, + int: streamencodeint, + list: streamencodearray, + tuple: streamencodearray, + dict: streamencodemap, + set: streamencodeset, + bool: streamencodebool, + type(None): streamencodenone, +} + +def streamencode(v): + """Encode a value in a streaming manner. + + Given an input object, encode it to CBOR recursively. + + Returns a generator of CBOR encoded bytes. There is no guarantee + that each emitted chunk fully decodes to a value or sub-value. + + Encoding is deterministic - unordered collections are sorted. + """ + fn = STREAM_ENCODERS.get(v.__class__) + + if not fn: + raise ValueError('do not know how to encode %s' % type(v)) + + return fn(v) + +def readindefinitebytestringtoiter(fh, expectheader=True): + """Read an indefinite bytestring to a generator. + + Receives an object with a ``read(X)`` method to read N bytes. + + If ``expectheader`` is True, it is expected that the first byte read + will represent an indefinite length bytestring. Otherwise, we + expect the first byte to be part of the first bytestring chunk. + """ + read = fh.read + decodeuint = decodermod.decode_uint + byteasinteger = decodermod.byte_as_integer + + if expectheader: + initial = decodermod.byte_as_integer(read(1)) + + majortype = initial >> 5 + subtype = initial & SUBTYPE_MASK + + if majortype != MAJOR_TYPE_BYTESTRING: + raise decodermod.CBORDecodeError( + 'expected major type %d; got %d' % (MAJOR_TYPE_BYTESTRING, + majortype)) + + if subtype != SUBTYPE_INDEFINITE: + raise decodermod.CBORDecodeError( + 'expected indefinite subtype; got %d' % subtype) + + # The indefinite bytestring is composed of chunks of normal bytestrings. + # Read chunks until we hit a BREAK byte. + + while True: + # We need to sniff for the BREAK byte. + initial = byteasinteger(read(1)) + + if initial == BREAK_INT: + break + + length = decodeuint(fh, initial & SUBTYPE_MASK) + chunk = read(length) + + if len(chunk) != length: + raise decodermod.CBORDecodeError( + 'failed to read bytestring chunk: got %d bytes; expected %d' % ( + len(chunk), length)) + + yield chunk diff --git a/tests/test-cbor.py b/tests/test-cbor.py new file mode 100644 --- /dev/null +++ b/tests/test-cbor.py @@ -0,0 +1,210 @@ +from __future__ import absolute_import + +import io +import unittest + +from mercurial.thirdparty import ( + cbor, +) +from mercurial.utils import ( + cborutil, +) + +def loadit(it): + return cbor.loads(b''.join(it)) + +class BytestringTests(unittest.TestCase): + def testsimple(self): + self.assertEqual( + list(cborutil.streamencode(b'foobar')), + [b'\x46', b'foobar']) + + self.assertEqual( + loadit(cborutil.streamencode(b'foobar')), + b'foobar') + + def testlong(self): + source = b'x' * 1048576 + + self.assertEqual(loadit(cborutil.streamencode(source)), source) + + def testfromiter(self): + # This is the example from RFC 7049 Section 2.2.2. + source = [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99'] + + self.assertEqual( + list(cborutil.streamencodebytestringfromiter(source)), + [ + b'\x5f', + b'\x44', + b'\xaa\xbb\xcc\xdd', + b'\x43', + b'\xee\xff\x99', + b'\xff', + ]) + + self.assertEqual( + loadit(cborutil.streamencodebytestringfromiter(source)), + b''.join(source)) + + def testfromiterlarge(self): + source = [b'a' * 16, b'b' * 128, b'c' * 1024, b'd' * 1048576] + + self.assertEqual( + loadit(cborutil.streamencodebytestringfromiter(source)), + b''.join(source)) + + def testindefinite(self): + source = b'\x00\x01\x02\x03' + b'\xff' * 16384 + + it = cborutil.streamencodeindefinitebytestring(source, chunksize=2) + + self.assertEqual(next(it), b'\x5f') + self.assertEqual(next(it), b'\x42') + self.assertEqual(next(it), b'\x00\x01') + self.assertEqual(next(it), b'\x42') + self.assertEqual(next(it), b'\x02\x03') + self.assertEqual(next(it), b'\x42') + self.assertEqual(next(it), b'\xff\xff') + + dest = b''.join(cborutil.streamencodeindefinitebytestring( + source, chunksize=42)) + self.assertEqual(cbor.loads(dest), b''.join(source)) + + def testreadtoiter(self): + source = io.BytesIO(b'\x5f\x44\xaa\xbb\xcc\xdd\x43\xee\xff\x99\xff') + + it = cborutil.readindefinitebytestringtoiter(source) + self.assertEqual(next(it), b'\xaa\xbb\xcc\xdd') + self.assertEqual(next(it), b'\xee\xff\x99') + + with self.assertRaises(StopIteration): + next(it) + +class IntTests(unittest.TestCase): + def testsmall(self): + self.assertEqual(list(cborutil.streamencode(0)), [b'\x00']) + self.assertEqual(list(cborutil.streamencode(1)), [b'\x01']) + self.assertEqual(list(cborutil.streamencode(2)), [b'\x02']) + self.assertEqual(list(cborutil.streamencode(3)), [b'\x03']) + self.assertEqual(list(cborutil.streamencode(4)), [b'\x04']) + + def testnegativesmall(self): + self.assertEqual(list(cborutil.streamencode(-1)), [b'\x20']) + self.assertEqual(list(cborutil.streamencode(-2)), [b'\x21']) + self.assertEqual(list(cborutil.streamencode(-3)), [b'\x22']) + self.assertEqual(list(cborutil.streamencode(-4)), [b'\x23']) + self.assertEqual(list(cborutil.streamencode(-5)), [b'\x24']) + + def testrange(self): + for i in range(-70000, 70000, 10): + self.assertEqual( + b''.join(cborutil.streamencode(i)), + cbor.dumps(i)) + +class ArrayTests(unittest.TestCase): + def testempty(self): + self.assertEqual(list(cborutil.streamencode([])), [b'\x80']) + self.assertEqual(loadit(cborutil.streamencode([])), []) + + def testbasic(self): + source = [b'foo', b'bar', 1, -10] + + self.assertEqual(list(cborutil.streamencode(source)), [ + b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29']) + + def testemptyfromiter(self): + self.assertEqual(b''.join(cborutil.streamencodearrayfromiter([])), + b'\x9f\xff') + + def testfromiter1(self): + source = [b'foo'] + + self.assertEqual(list(cborutil.streamencodearrayfromiter(source)), [ + b'\x9f', + b'\x43', b'foo', + b'\xff', + ]) + + dest = b''.join(cborutil.streamencodearrayfromiter(source)) + self.assertEqual(cbor.loads(dest), source) + + def testtuple(self): + source = (b'foo', None, 42) + + self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))), + list(source)) + +class SetTests(unittest.TestCase): + def testempty(self): + self.assertEqual(list(cborutil.streamencode(set())), [ + b'\xd9\x01\x02', + b'\x80', + ]) + + def testset(self): + source = {b'foo', None, 42} + + self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))), + source) + +class BoolTests(unittest.TestCase): + def testbasic(self): + self.assertEqual(list(cborutil.streamencode(True)), [b'\xf5']) + self.assertEqual(list(cborutil.streamencode(False)), [b'\xf4']) + + self.assertIs(loadit(cborutil.streamencode(True)), True) + self.assertIs(loadit(cborutil.streamencode(False)), False) + +class NoneTests(unittest.TestCase): + def testbasic(self): + self.assertEqual(list(cborutil.streamencode(None)), [b'\xf6']) + + self.assertIs(loadit(cborutil.streamencode(None)), None) + +class MapTests(unittest.TestCase): + def testempty(self): + self.assertEqual(list(cborutil.streamencode({})), [b'\xa0']) + self.assertEqual(loadit(cborutil.streamencode({})), {}) + + def testemptyindefinite(self): + self.assertEqual(list(cborutil.streamencodemapfromiter([])), [ + b'\xbf', b'\xff']) + + self.assertEqual(loadit(cborutil.streamencodemapfromiter([])), {}) + + def testone(self): + source = {b'foo': b'bar'} + self.assertEqual(list(cborutil.streamencode(source)), [ + b'\xa1', b'\x43', b'foo', b'\x43', b'bar']) + + self.assertEqual(loadit(cborutil.streamencode(source)), source) + + def testmultiple(self): + source = { + b'foo': b'bar', + b'baz': b'value1', + } + + self.assertEqual(loadit(cborutil.streamencode(source)), source) + + self.assertEqual( + loadit(cborutil.streamencodemapfromiter(source.items())), + source) + + def testcomplex(self): + source = { + b'key': 1, + 2: -10, + } + + self.assertEqual(loadit(cborutil.streamencode(source)), + source) + + self.assertEqual( + loadit(cborutil.streamencodemapfromiter(source.items())), + source) + +if __name__ == '__main__': + import silenttestrunner + silenttestrunner.main(__name__)