# -*- coding: utf-8 -*-
"""
A basic Ethereum wallet library
"""
from typing import (
Any,
Dict, Union
)
import ecdsa
from cryptnoxpy import Derivation
from eth_account._utils.legacy_transactions import (
encode_transaction,
serializable_unsigned_transaction_from_dict
)
from eth_utils.curried import keccak
from hexbytes import HexBytes
from web3 import Web3
from . import endpoint as ep
from .. import validators
try:
from lib import cryptos
import enums
except ImportError:
from ...lib import cryptos
from ... import enums
[docs]
def address(public_key: str) -> str:
return keccak(hexstr=("0x" + public_key[2:]))[-20:].hex()
[docs]
def checksum_address(public_key: str) -> str:
return Web3.to_checksum_address(address(public_key))
[docs]
class Api:
PATH = "m/44'/60'/0'/0/0"
SYMBOL = "eth"
[docs]
def __init__(self, endpoint: str, network: Union[enums.EthNetwork, str],
api_key: str):
if isinstance(network, str):
try:
network = enums.EthNetwork[network.upper()]
except KeyError:
raise LookupError("Network is invalid")
if endpoint.startswith("http"):
self.endpoint = ep.DirectEndpoint(endpoint, network)
else:
self.endpoint = ep.factory(endpoint, network, api_key)
@property
def block_number(self):
return self._web3.eth.block_number
[docs]
def contract(self, address="", abi=""):
return self._web3.eth.contract(address=address, abi=abi)
[docs]
def get_transaction_count(self, address: str, blocks: str = None) -> int:
return self._web3.eth.get_transaction_count(Web3.to_checksum_address(address), blocks)
[docs]
def get_balance(self, address: str) -> int:
return self._web3.eth.get_balance(Web3.to_checksum_address(address))
@property
def gas_price(self):
return self._web3.eth.gas_price
@property
def network(self):
return self.endpoint.network
[docs]
def transaction_hash(self, transaction: Dict[str, Any], vrs: bool = False):
try:
del transaction["maxFeePerGas"]
del transaction["maxPriorityFeePerGas"]
except KeyError:
pass
unsigned_transaction = serializable_unsigned_transaction_from_dict(transaction)
encoded_transaction = encode_transaction(unsigned_transaction, (self._chain_id, 0, 0))
return keccak(encoded_transaction)
[docs]
def push(self, transaction, signature, public_key):
unsigned_transaction = serializable_unsigned_transaction_from_dict(transaction)
var_v, var_r, var_s = Api._decode_vrs(signature, self._chain_id,
self.transaction_hash(transaction),
cryptos.decode_pubkey(public_key))
rlp_encoded = encode_transaction(unsigned_transaction, (var_v, var_r, var_s))
return self._web3.eth.send_raw_transaction(HexBytes(rlp_encoded))
@property
def _chain_id(self) -> int:
return self.endpoint.network.value
@staticmethod
def _decode_vrs(signature_der: bytes, chain_id: int, transaction: bytes, q_pub: tuple) -> tuple:
"""
Method used for getting v, r and s values
:param signature_der: Signature generated by the Cryptnox card
:param chain_id: Networks chain ID
:param transaction: Encoded transaction
:param q_pub: Wallets q_pub
:return: Tuple containing v, r, s values
"""
curve = ecdsa.curves.SECP256k1
signature_decode = ecdsa.util.sigdecode_der
generator = curve.generator
var_r, var_s = signature_decode(signature_der, generator.order())
# Parity recovery
var_q = ecdsa.keys.VerifyingKey.from_public_key_recovery_with_digest(
signature_der, transaction, curve, sigdecode=signature_decode)[1]
i = 35
if var_q.to_string("uncompressed") == cryptos.encode_pubkey(q_pub, "bin"):
i += 1
var_v = 2 * chain_id + i
return var_v, var_r, var_s
@property
def _provider(self) -> str:
return self.endpoint.provider
@property
def _web3(self) -> Web3:
return Web3(Web3.HTTPProvider(self._provider))
[docs]
class EthValidator:
"""
Class defining Ethereum validators
"""
network = validators.EnumValidator(enums.EthNetwork)
price = validators.IntValidator(min_value=0)
limit = validators.IntValidator(min_value=0)
derivation = validators.EnumValidator(Derivation)
api_key = validators.AnyValidator()
endpoint = ep.EndpointValidator()
[docs]
def __init__(self, endpoint: str = "publicnode", network: str = "sepolia", price: int = 8,
limit: int = 2500, derivation: str = "DERIVE", api_key=""):
self.endpoint = endpoint
self.network = network
self.price = price
self.limit = limit
self.derivation = derivation
self.api_key = api_key
[docs]
def validate(self):
for cls in ep.Endpoint.__subclasses__():
if cls.name == self.endpoint and self.network not in cls.available_networks:
self.__class__.__dict__["network"].valid_values = \
"\n".join(x.lower() for x in cls.available_networks)
raise validators.ValidationError("Invalid value for network")