import os import csv import io from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional from collections import OrderedDict class MalformedMsg(Exception): pass class UnknownMsgFieldType(MalformedMsg): pass class UnexpectedEndOfStream(MalformedMsg): pass class FieldEncodingNotMinimal(MalformedMsg): pass class UnknownMandatoryTLVRecordType(MalformedMsg): pass class MsgTrailingGarbage(MalformedMsg): pass class MsgInvalidFieldOrder(MalformedMsg): pass class UnexpectedFieldSizeForEncoder(MalformedMsg): pass def _num_remaining_bytes_to_read(fd: io.BytesIO) -> int: cur_pos = fd.tell() end_pos = fd.seek(0, io.SEEK_END) fd.seek(cur_pos) return end_pos - cur_pos def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None: nremaining = _num_remaining_bytes_to_read(fd) if nremaining < n: raise UnexpectedEndOfStream(f"wants to read {n} bytes but only {nremaining} bytes left") def write_bigsize_int(i: int) -> bytes: assert i >= 0, i if i < 0xfd: return int.to_bytes(i, length=1, byteorder="big", signed=False) elif i < 0x1_0000: return b"\xfd" + int.to_bytes(i, length=2, byteorder="big", signed=False) elif i < 0x1_0000_0000: return b"\xfe" + int.to_bytes(i, length=4, byteorder="big", signed=False) else: return b"\xff" + int.to_bytes(i, length=8, byteorder="big", signed=False) def read_bigsize_int(fd: io.BytesIO) -> Optional[int]: try: first = fd.read(1)[0] except IndexError: return None # end of file if first < 0xfd: return first elif first == 0xfd: _assert_can_read_at_least_n_bytes(fd, 2) val = int.from_bytes(fd.read(2), byteorder="big", signed=False) if not (0xfd <= val < 0x1_0000): raise FieldEncodingNotMinimal() return val elif first == 0xfe: _assert_can_read_at_least_n_bytes(fd, 4) val = int.from_bytes(fd.read(4), byteorder="big", signed=False) if not (0x1_0000 <= val < 0x1_0000_0000): raise FieldEncodingNotMinimal() return val elif first == 0xff: _assert_can_read_at_least_n_bytes(fd, 8) val = int.from_bytes(fd.read(8), byteorder="big", signed=False) if not (0x1_0000_0000 <= val): raise FieldEncodingNotMinimal() return val raise Exception() # TODO: maybe if field_type is not "byte", we could return a list of type_len sized chunks? # if field_type is a numeric, we could return a list of ints? def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> Union[bytes, int]: if not fd: raise Exception() if isinstance(count, int): assert count >= 0, f"{count!r} must be non-neg int" elif count == "...": pass else: raise Exception(f"unexpected field count: {count!r}") if count == 0: return b"" type_len = None if field_type == 'byte': type_len = 1 elif field_type in ('u8', 'u16', 'u32', 'u64'): if field_type == 'u8': type_len = 1 elif field_type == 'u16': type_len = 2 elif field_type == 'u32': type_len = 4 else: assert field_type == 'u64' type_len = 8 assert count == 1, count _assert_can_read_at_least_n_bytes(fd, type_len) return int.from_bytes(fd.read(type_len), byteorder="big", signed=False) elif field_type in ('tu16', 'tu32', 'tu64'): if field_type == 'tu16': type_len = 2 elif field_type == 'tu32': type_len = 4 else: assert field_type == 'tu64' type_len = 8 assert count == 1, count raw = fd.read(type_len) if len(raw) > 0 and raw[0] == 0x00: raise FieldEncodingNotMinimal() return int.from_bytes(raw, byteorder="big", signed=False) elif field_type == 'varint': assert count == 1, count val = read_bigsize_int(fd) if val is None: raise UnexpectedEndOfStream() return val elif field_type == 'chain_hash': type_len = 32 elif field_type == 'channel_id': type_len = 32 elif field_type == 'sha256': type_len = 32 elif field_type == 'signature': type_len = 64 elif field_type == 'point': type_len = 33 elif field_type == 'short_channel_id': type_len = 8 if count == "...": total_len = -1 # read all else: if type_len is None: raise UnknownMsgFieldType(f"unknown field type: {field_type!r}") total_len = count * type_len _assert_can_read_at_least_n_bytes(fd, total_len) return fd.read(total_len) # TODO: maybe for "value" we could accept a list with len "count" of appropriate items def _write_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str], value: Union[bytes, int]) -> None: if not fd: raise Exception() if isinstance(count, int): assert count >= 0, f"{count!r} must be non-neg int" elif count == "...": pass else: raise Exception(f"unexpected field count: {count!r}") if count == 0: return type_len = None if field_type == 'byte': type_len = 1 elif field_type == 'u8': type_len = 1 elif field_type == 'u16': type_len = 2 elif field_type == 'u32': type_len = 4 elif field_type == 'u64': type_len = 8 elif field_type in ('tu16', 'tu32', 'tu64'): if field_type == 'tu16': type_len = 2 elif field_type == 'tu32': type_len = 4 else: assert field_type == 'tu64' type_len = 8 assert count == 1, count if isinstance(value, int): value = int.to_bytes(value, length=type_len, byteorder="big", signed=False) if not isinstance(value, (bytes, bytearray)): raise Exception(f"can only write bytes into fd. got: {value!r}") while len(value) > 0 and value[0] == 0x00: value = value[1:] nbytes_written = fd.write(value) if nbytes_written != len(value): raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?") return elif field_type == 'varint': assert count == 1, count if isinstance(value, int): value = write_bigsize_int(value) if not isinstance(value, (bytes, bytearray)): raise Exception(f"can only write bytes into fd. got: {value!r}") nbytes_written = fd.write(value) if nbytes_written != len(value): raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?") return elif field_type == 'chain_hash': type_len = 32 elif field_type == 'channel_id': type_len = 32 elif field_type == 'sha256': type_len = 32 elif field_type == 'signature': type_len = 64 elif field_type == 'point': type_len = 33 elif field_type == 'short_channel_id': type_len = 8 total_len = -1 if count != "...": if type_len is None: raise UnknownMsgFieldType(f"unknown field type: {field_type!r}") total_len = count * type_len if isinstance(value, int) and (count == 1 or field_type == 'byte'): value = int.to_bytes(value, length=total_len, byteorder="big", signed=False) if not isinstance(value, (bytes, bytearray)): raise Exception(f"can only write bytes into fd. got: {value!r}") if count != "..." and total_len != len(value): raise UnexpectedFieldSizeForEncoder(f"expected: {total_len}, got {len(value)}") nbytes_written = fd.write(value) if nbytes_written != len(value): raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?") def _read_tlv_record(*, fd: io.BytesIO) -> Tuple[int, bytes]: if not fd: raise Exception() tlv_type = _read_field(fd=fd, field_type="varint", count=1) tlv_len = _read_field(fd=fd, field_type="varint", count=1) tlv_val = _read_field(fd=fd, field_type="byte", count=tlv_len) return tlv_type, tlv_val def _write_tlv_record(*, fd: io.BytesIO, tlv_type: int, tlv_val: bytes) -> None: if not fd: raise Exception() tlv_len = len(tlv_val) _write_field(fd=fd, field_type="varint", count=1, value=tlv_type) _write_field(fd=fd, field_type="varint", count=1, value=tlv_len) _write_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val) def _resolve_field_count(field_count_str: str, *, vars_dict: dict, allow_any=False) -> Union[int, str]: """Returns an evaluated field count, typically an int. If allow_any is True, the return value can be a str with value=="...". """ if field_count_str == "": field_count = 1 elif field_count_str == "...": if not allow_any: raise Exception("field count is '...' but allow_any is False") return field_count_str else: try: field_count = int(field_count_str) except ValueError: field_count = vars_dict[field_count_str] if isinstance(field_count, (bytes, bytearray)): field_count = int.from_bytes(field_count, byteorder="big") assert isinstance(field_count, int) return field_count class LNSerializer: def __init__(self): # TODO msg_type could be 'int' everywhere... self.msg_scheme_from_type = {} # type: Dict[bytes, List[Sequence[str]]] self.msg_type_from_name = {} # type: Dict[str, bytes] self.in_tlv_stream_get_tlv_record_scheme_from_type = {} # type: Dict[str, Dict[int, List[Sequence[str]]]] self.in_tlv_stream_get_record_type_from_name = {} # type: Dict[str, Dict[str, int]] self.in_tlv_stream_get_record_name_from_type = {} # type: Dict[str, Dict[int, str]] path = os.path.join(os.path.dirname(__file__), "lnwire", "peer_wire.csv") with open(path, newline='') as f: csvreader = csv.reader(f) for row in csvreader: #print(f">>> {row!r}") if row[0] == "msgtype": # msgtype,,[,