diff --git a/electrum/lnmsg.py b/electrum/lnmsg.py index 49afb7319..824339857 100644 --- a/electrum/lnmsg.py +++ b/electrum/lnmsg.py @@ -69,16 +69,25 @@ def read_int_from_bigsize(fd: io.BytesIO) -> Optional[int]: raise Exception() -def _read_field(*, fd: io.BytesIO, field_type: str, count: int) -> Union[bytes, int]: +# TODO: maybe if field_type is not "byte", we could return a list of type_len sized chunks? +# if field_type is a numeric, we could return a list of ints? +def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> Union[bytes, int]: if not fd: raise Exception() - assert isinstance(count, int) and count >= 0, f"{count!r} must be non-neg int" + if isinstance(count, int): + assert count >= 0, f"{count!r} must be non-neg int" + elif count == "...": + pass + else: + raise Exception(f"unexpected field count: {count!r}") if count == 0: return b"" type_len = None if field_type == 'byte': type_len = 1 - elif field_type in ('u16', 'u32', 'u64'): - if field_type == 'u16': + elif field_type in ('u8', 'u16', 'u32', 'u64'): + if field_type == 'u8': + type_len = 1 + elif field_type == 'u16': type_len = 2 elif field_type == 'u32': type_len = 4 @@ -119,22 +128,33 @@ def _read_field(*, fd: io.BytesIO, field_type: str, count: int) -> Union[bytes, type_len = 33 elif field_type == 'short_channel_id': type_len = 8 - if type_len is None: - raise UnknownMsgFieldType(f"unexpected field type: {field_type!r}") - total_len = count * type_len - _assert_can_read_at_least_n_bytes(fd, total_len) + if count == "...": + total_len = -1 # read all + else: + if type_len is None: + raise UnknownMsgFieldType(f"unknown field type: {field_type!r}") + total_len = count * type_len + _assert_can_read_at_least_n_bytes(fd, total_len) return fd.read(total_len) -def _write_field(*, fd: io.BytesIO, field_type: str, count: int, +# TODO: maybe for "value" we could accept a list with len "count" of appropriate items +def _write_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str], value: Union[bytes, int]) -> None: if not fd: raise Exception() - assert isinstance(count, int) and count >= 0, f"{count!r} must be non-neg int" + if isinstance(count, int): + assert count >= 0, f"{count!r} must be non-neg int" + elif count == "...": + pass + else: + raise Exception(f"unexpected field count: {count!r}") if count == 0: return type_len = None if field_type == 'byte': type_len = 1 + elif field_type == 'u8': + type_len = 1 elif field_type == 'u16': type_len = 2 elif field_type == 'u32': @@ -182,14 +202,16 @@ def _write_field(*, fd: io.BytesIO, field_type: str, count: int, type_len = 33 elif field_type == 'short_channel_id': type_len = 8 - if type_len is None: - raise UnknownMsgFieldType(f"unexpected fundamental type: {field_type!r}") - total_len = count * type_len - if isinstance(value, int) and (count == 1 or field_type == 'byte'): - value = int.to_bytes(value, length=total_len, byteorder="big", signed=False) + total_len = -1 + if count != "...": + if type_len is None: + raise UnknownMsgFieldType(f"unknown field type: {field_type!r}") + total_len = count * type_len + if isinstance(value, int) and (count == 1 or field_type == 'byte'): + value = int.to_bytes(value, length=total_len, byteorder="big", signed=False) if not isinstance(value, (bytes, bytearray)): raise Exception(f"can only write bytes into fd. got: {value!r}") - if total_len != len(value): + if count != "..." and total_len != len(value): raise Exception(f"unexpected field size. expected: {total_len}, got {len(value)}") nbytes_written = fd.write(value) if nbytes_written != len(value): @@ -212,11 +234,16 @@ def _write_tlv_record(*, fd: io.BytesIO, tlv_type: int, tlv_val: bytes) -> None: _write_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val) -def _resolve_field_count(field_count_str: str, *, vars_dict: dict) -> int: +def _resolve_field_count(field_count_str: str, *, vars_dict: dict, allow_any=False) -> Union[int, str]: + """Returns an evaluated field count, typically an int. + If allow_any is True, the return value can be a str with value=="...". + """ if field_count_str == "": field_count = 1 elif field_count_str == "...": - raise NotImplementedError() # TODO... + if not allow_any: + raise Exception("field count is '...' but allow_any is False") + return field_count_str else: try: field_count = int(field_count_str) @@ -301,7 +328,9 @@ class LNSerializer: field_name = row[3] field_type = row[4] field_count_str = row[5] - field_count = _resolve_field_count(field_count_str, vars_dict=kwargs[tlv_record_name]) + field_count = _resolve_field_count(field_count_str, + vars_dict=kwargs[tlv_record_name], + allow_any=True) field_value = kwargs[tlv_record_name][field_name] _write_field(fd=tlv_record_fd, field_type=field_type, @@ -343,7 +372,9 @@ class LNSerializer: field_name = row[3] field_type = row[4] field_count_str = row[5] - field_count = _resolve_field_count(field_count_str, vars_dict=parsed[tlv_record_name]) + field_count = _resolve_field_count(field_count_str, + vars_dict=parsed[tlv_record_name], + allow_any=True) #print(f">> count={field_count}. parsed={parsed}") parsed[tlv_record_name][field_name] = _read_field(fd=tlv_record_fd, field_type=field_type,