use ecdsa for signing/veryfing instead of coincurve due to compatibility issues

This commit is contained in:
Lex Berezhny 2021-12-13 00:22:18 -05:00
parent 8216f4a873
commit d815a6f02c
7 changed files with 48 additions and 42 deletions

View file

@ -2859,7 +2859,9 @@ class Daemon(metaclass=JSONRPCServerType):
new_txo = tx.outputs[0] new_txo = tx.outputs[0]
if new_signing_key: if new_signing_key:
await new_txo.generate_channel_private_key() new_txo.set_channel_private_key(
await funding_accounts[0].generate_channel_private_key()
)
else: else:
new_txo.private_key = old_txo.private_key new_txo.private_key = old_txo.private_key
@ -2868,7 +2870,6 @@ class Daemon(metaclass=JSONRPCServerType):
await tx.sign(funding_accounts) await tx.sign(funding_accounts)
if not preview: if not preview:
account.add_channel_private_key(new_txo.private_key)
wallet.save() wallet.save()
await self.broadcast_or_release(tx, blocking) await self.broadcast_or_release(tx, blocking)
self.component_manager.loop.create_task(self.storage.save_claims([self._old_get_temp_claim_info( self.component_manager.loop.create_task(self.storage.save_claims([self._old_get_temp_claim_info(

View file

@ -44,25 +44,36 @@ class DeterministicChannelKeyManager:
self.cache = {} self.cache = {}
def maybe_generate_deterministic_key_for_channel(self, txo): def maybe_generate_deterministic_key_for_channel(self, txo):
if self.private_key is None:
return
next_key = self.private_key.child(self.last_known) next_key = self.private_key.child(self.last_known)
if txo.claim.channel.public_key_bytes == next_key.public_key.pubkey_bytes: signing_key = ecdsa.SigningKey.from_secret_exponent(
self.cache[next_key.address()] = next_key next_key.secret_exponent(), ecdsa.SECP256k1
)
public_key_bytes = signing_key.get_verifying_key().to_der()
if txo.claim.channel.public_key_bytes == public_key_bytes:
self.cache[self.account.ledger.public_key_to_address(public_key_bytes)] = signing_key
self.last_known += 1 self.last_known += 1
async def ensure_cache_primed(self): async def ensure_cache_primed(self):
if self.private_key is not None:
await self.generate_next_key() await self.generate_next_key()
async def generate_next_key(self): async def generate_next_key(self) -> ecdsa.SigningKey:
db = self.account.ledger.db db = self.account.ledger.db
while True: while True:
next_key = self.private_key.child(self.last_known) next_key = self.private_key.child(self.last_known)
key_address = next_key.address() signing_key = ecdsa.SigningKey.from_secret_exponent(
self.cache[key_address] = next_key next_key.secret_exponent(), ecdsa.SECP256k1
if not await db.is_channel_key_used(self.account.wallet, key_address): )
return next_key public_key_bytes = signing_key.get_verifying_key().to_der()
key_address = self.account.ledger.public_key_to_address(public_key_bytes)
self.cache[key_address] = signing_key
if not await db.is_channel_key_used(self.account.wallet, signing_key):
return signing_key
self.last_known += 1 self.last_known += 1
def get_private_key_from_pubkey_hash(self, pubkey_hash): def get_private_key_from_pubkey_hash(self, pubkey_hash) -> ecdsa.SigningKey:
return self.cache.get(pubkey_hash) return self.cache.get(pubkey_hash)
@ -561,7 +572,7 @@ class Account:
channel_pubkey_hash = self.ledger.public_key_to_address(public_key_bytes) channel_pubkey_hash = self.ledger.public_key_to_address(public_key_bytes)
self.channel_keys[channel_pubkey_hash] = private_key.to_pem().decode() self.channel_keys[channel_pubkey_hash] = private_key.to_pem().decode()
async def get_channel_private_key(self, public_key_bytes): async def get_channel_private_key(self, public_key_bytes) -> ecdsa.SigningKey:
channel_pubkey_hash = self.ledger.public_key_to_address(public_key_bytes) channel_pubkey_hash = self.ledger.public_key_to_address(public_key_bytes)
private_key_pem = self.channel_keys.get(channel_pubkey_hash) private_key_pem = self.channel_keys.get(channel_pubkey_hash)
if private_key_pem: if private_key_pem:

View file

@ -3,6 +3,7 @@ import logging
import asyncio import asyncio
import sqlite3 import sqlite3
import platform import platform
import ecdsa
from binascii import hexlify from binascii import hexlify
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
@ -1241,10 +1242,11 @@ class Database(SQLiteMixin):
async def set_address_history(self, address, history): async def set_address_history(self, address, history):
await self._set_address_history(address, history) await self._set_address_history(address, history)
async def is_channel_key_used(self, wallet, address): async def is_channel_key_used(self, wallet, key: ecdsa.SigningKey):
channels = await self.get_txos(wallet, txo_type=TXO_TYPES['channel']) channels = await self.get_txos(wallet, txo_type=TXO_TYPES['channel'])
other_key_string = key.to_string()
for channel in channels: for channel in channels:
if channel.private_key is not None and channel.private_key.address() == address: if channel.private_key is not None and channel.private_key.to_string() == other_key_string:
return True return True
return False return False

View file

@ -2,7 +2,6 @@ import struct
import hashlib import hashlib
import logging import logging
import typing import typing
import asyncio
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from typing import List, Iterable, Optional, Tuple from typing import List, Iterable, Optional, Tuple
@ -28,7 +27,6 @@ from .constants import COIN, NULL_HASH32
from .bcd_data_stream import BCDataStream from .bcd_data_stream import BCDataStream
from .hash import TXRef, TXRefImmutable from .hash import TXRef, TXRefImmutable
from .util import ReadOnlyList from .util import ReadOnlyList
from .bip32 import PubKey
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from lbry.wallet.account import Account from lbry.wallet.account import Account
@ -470,14 +468,14 @@ class Output(InputOutput):
self.channel = None self.channel = None
self.signable.clear_signature() self.signable.clear_signature()
def set_channel_private_key(self, private_key): def set_channel_private_key(self, private_key: ecdsa.SigningKey):
self.private_key = private_key self.private_key = private_key
self.claim.channel.public_key_bytes = private_key.public_key.pubkey_bytes self.claim.channel.public_key_bytes = self.private_key.get_verifying_key().to_der()
self.script.generate() self.script.generate()
return private_key return self.private_key
def is_channel_private_key(self, private_key): def is_channel_private_key(self, private_key: ecdsa.SigningKey):
return self.claim.channel.public_key_bytes == private_key.signing_key.to_der() return self.claim.channel.public_key_bytes == private_key.get_verifying_key().to_der()
@classmethod @classmethod
def pay_claim_name_pubkey_hash( def pay_claim_name_pubkey_hash(

View file

@ -63,17 +63,6 @@ class AccountManagement(CommandTestCase):
accounts = await self.daemon.jsonrpc_account_list(account_id, include_claims=True) accounts = await self.daemon.jsonrpc_account_list(account_id, include_claims=True)
self.assertEqual(accounts['items'][0]['name'], 'recreated account') self.assertEqual(accounts['items'][0]['name'], 'recreated account')
async def test_wallet_migration(self):
# null certificates should get deleted
await self.channel_create('@foo1')
await self.channel_create('@foo2')
await self.channel_create('@foo3')
keys = list(self.account.channel_keys.keys())
self.account.channel_keys[keys[0]] = None
self.account.channel_keys[keys[1]] = "some invalid junk"
await self.account.maybe_migrate_certificates()
self.assertEqual(list(self.account.channel_keys.keys()), [keys[2]])
async def assertFindsClaims(self, claim_names, awaitable): async def assertFindsClaims(self, claim_names, awaitable):
self.assertEqual(claim_names, [txo.claim_name for txo in (await awaitable)['items']]) self.assertEqual(claim_names, [txo.claim_name for txo in (await awaitable)['items']])
@ -207,33 +196,33 @@ class AccountManagement(CommandTestCase):
self.assertTrue(channel1b.has_private_key) self.assertTrue(channel1b.has_private_key)
self.assertEqual( self.assertEqual(
channel1a['outputs'][0]['value']['public_key_id'], channel1a['outputs'][0]['value']['public_key_id'],
channel1b.private_key.public_key.address self.ledger.public_key_to_address(channel1b.private_key.verifying_key.to_der())
) )
self.assertTrue(channel2b.has_private_key) self.assertTrue(channel2b.has_private_key)
self.assertEqual( self.assertEqual(
channel2a['outputs'][0]['value']['public_key_id'], channel2a['outputs'][0]['value']['public_key_id'],
channel2b.private_key.public_key.address self.ledger.public_key_to_address(channel2b.private_key.verifying_key.to_der())
) )
# repeatedly calling next channel key returns the same key when not used # repeatedly calling next channel key returns the same key when not used
current_known = keys.last_known current_known = keys.last_known
next_key = await keys.generate_next_key() next_key = await keys.generate_next_key()
self.assertEqual(current_known, keys.last_known) self.assertEqual(current_known, keys.last_known)
self.assertEqual(next_key.address(), (await keys.generate_next_key()).address()) self.assertEqual(next_key.to_string(), (await keys.generate_next_key()).to_string())
# again, should be idempotent # again, should be idempotent
next_key = await keys.generate_next_key() next_key = await keys.generate_next_key()
self.assertEqual(current_known, keys.last_known) self.assertEqual(current_known, keys.last_known)
self.assertEqual(next_key.address(), (await keys.generate_next_key()).address()) self.assertEqual(next_key.to_string(), (await keys.generate_next_key()).to_string())
# create third channel while both daemons running, second daemon should pick it up # create third channel while both daemons running, second daemon should pick it up
channel3a = await self.channel_create('@foo3') channel3a = await self.channel_create('@foo3')
self.assertEqual(current_known+1, keys.last_known) self.assertEqual(current_known+1, keys.last_known)
self.assertNotEqual(next_key.address(), (await keys.generate_next_key()).address()) self.assertNotEqual(next_key.to_string(), (await keys.generate_next_key()).to_string())
channel3b, = (await self.daemon2.jsonrpc_channel_list(name='@foo3'))['items'] channel3b, = (await self.daemon2.jsonrpc_channel_list(name='@foo3'))['items']
self.assertTrue(channel3b.has_private_key) self.assertTrue(channel3b.has_private_key)
self.assertEqual( self.assertEqual(
channel3a['outputs'][0]['value']['public_key_id'], channel3a['outputs'][0]['value']['public_key_id'],
channel3b.private_key.public_key.address self.ledger.public_key_to_address(channel3b.private_key.verifying_key.to_der())
) )
# channel key cache re-populated after simulated restart # channel key cache re-populated after simulated restart

View file

@ -32,7 +32,9 @@ class BasicTransactionTest(IntegrationTestCase):
channel_txo = Output.pay_claim_name_pubkey_hash( channel_txo = Output.pay_claim_name_pubkey_hash(
l2d('1.0'), '@bar', channel, self.account.ledger.address_to_hash160(address1) l2d('1.0'), '@bar', channel, self.account.ledger.address_to_hash160(address1)
) )
await channel_txo.generate_channel_private_key() channel_txo.set_channel_private_key(
await self.account.generate_channel_private_key()
)
channel_txo.script.generate() channel_txo.script.generate()
channel_tx = await Transaction.create([], [channel_txo], [self.account], self.account) channel_tx = await Transaction.create([], [channel_txo], [self.account], self.account)

View file

@ -1,5 +1,7 @@
from binascii import unhexlify from binascii import unhexlify
import ecdsa
from lbry.testcase import AsyncioTestCase from lbry.testcase import AsyncioTestCase
from lbry.wallet.constants import CENT, NULL_HASH32 from lbry.wallet.constants import CENT, NULL_HASH32
from lbry.wallet.bip32 import PrivateKey from lbry.wallet.bip32 import PrivateKey
@ -24,10 +26,11 @@ def get_tx():
async def get_channel(claim_name='@foo'): async def get_channel(claim_name='@foo'):
seed = Mnemonic.mnemonic_to_seed(Mnemonic().make_seed(), '')
bip32_key = PrivateKey.from_seed(Ledger, seed)
signing_key = ecdsa.SigningKey.from_secret_exponent(bip32_key.secret_exponent(), ecdsa.SECP256k1)
channel_txo = Output.pay_claim_name_pubkey_hash(CENT, claim_name, Claim(), b'abc') channel_txo = Output.pay_claim_name_pubkey_hash(CENT, claim_name, Claim(), b'abc')
channel_txo.set_channel_private_key(PrivateKey.from_seed( channel_txo.set_channel_private_key(signing_key)
Ledger, Mnemonic.mnemonic_to_seed(Mnemonic().make_seed(), '')
))
get_tx().add_outputs([channel_txo]) get_tx().add_outputs([channel_txo])
return channel_txo return channel_txo