Skip to content

Commit

Permalink
Add get missing metadata contracts task (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
moisses89 authored Jan 14, 2025
1 parent 7b42103 commit 2e7b883
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 17 deletions.
3 changes: 3 additions & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class Settings(BaseSettings):
ETHERSCAN_MAX_REQUESTS: int = 1
BLOCKSCOUT_MAX_REQUESTS: int = 1
SOURCIFY_MAX_REQUESTS: int = 100
CONTRACT_MAX_DOWNLOAD_RETRIES: int = (
90 # Task running once per day, means 3 months trying.
)


settings = Settings()
Expand Down
21 changes: 21 additions & 0 deletions app/datasources/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,24 @@ async def get_abi_by_contract_address(
if result := results.first():
return cast(ABI, result)
return None

@classmethod
async def get_contracts_without_abi(
cls, session: AsyncSession, max_retries: int = 0
):
"""
Fetches contracts without an ABI and fewer retries than max_retries, streaming results in batches to reduce memory usage for large datasets.
More information about streaming results can be found here: https://docs.sqlalchemy.org/en/20/core/connections.html#streaming-with-a-dynamically-growing-buffer-using-stream-results
:param session:
:param max_retries:
:return:
"""
query = (
select(cls)
.where(cls.abi_id == None) # noqa: E711
.where(cls.fetch_retries <= max_retries)
)
result = await session.stream(query)
async for contract in result:
yield contract
10 changes: 6 additions & 4 deletions app/services/contract_metadata_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ async def should_attempt_download(
session: AsyncSession,
contract_address: ChecksumAddress,
chain_id: int,
retries: int,
max_retries: int,
) -> bool:
"""
Return True if fetch retries is less than the number of retries and there is not ABI, False otherwise.
Expand All @@ -226,11 +226,13 @@ async def should_attempt_download(
:param session:
:param contract_address:
:param chain_id:
:param retries:
:param max_retries:
:return:
"""
redis = get_redis()
cache_key = f"should_attempt_download:{contract_address}:{chain_id}:{retries}"
cache_key = (
f"should_attempt_download:{contract_address}:{chain_id}:{max_retries}"
)
# Try from cache first
cached_retries = cast(str, redis.get(cache_key))
if cached_retries:
Expand All @@ -240,7 +242,7 @@ async def should_attempt_download(
session, address=HexBytes(contract_address), chain_id=chain_id
)

if contract and (contract.fetch_retries > retries or contract.abi_id):
if contract and (contract.fetch_retries > max_retries or contract.abi_id):
redis.set(cache_key, 0)
return False

Expand Down
4 changes: 3 additions & 1 deletion app/services/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def process_event(self, message: str) -> None:
if self._is_processable_event(tx_service_event):
chain_id = int(tx_service_event["chainId"])
contract_address = tx_service_event["to"]
get_contract_metadata_task.send(contract_address, chain_id)
get_contract_metadata_task.send(
address=contract_address, chain_id=chain_id
)
except json.JSONDecodeError:
logging.error(f"Unsupported message. Cannot parse as JSON: {message}")

Expand Down
29 changes: 29 additions & 0 deletions app/tests/datasources/db/test_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from eth_account import Account
from hexbytes import HexBytes
from sqlmodel.ext.asyncio.session import AsyncSession

from app.datasources.db.database import database_session
Expand Down Expand Up @@ -134,3 +136,30 @@ async def test_timestamped_model(self, session: AsyncSession):
self.assertEqual(result_updated[0].created, contract_created_date)
self.assertNotEqual(result_updated[0].modified, contract_modified_date)
self.assertTrue(contract_modified_date < result_updated[0].modified)

@database_session
async def test_get_contracts_without_abi(self, session: AsyncSession):
random_address = HexBytes(Account.create().address)
abi_json = {"name": "A Test ABI"}
source = AbiSource(name="local", url="")
await source.create(session)
abi = Abi(abi_json=abi_json, source_id=source.id)
await abi.create(session)
# Should return the contract
expected_contract = await Contract(
address=random_address, name="A test contract", chain_id=1
).create(session)
async for contract in Contract.get_contracts_without_abi(session, 0):
self.assertEqual(expected_contract, contract[0])

# Contracts with more retries shouldn't be returned
expected_contract.fetch_retries = 1
await expected_contract.update(session)
async for contract in Contract.get_contracts_without_abi(session, 0):
self.fail("Expected no contracts, but found one.")

# Contracts with abi shouldn't be returned
expected_contract.abi_id = abi.id
await expected_contract.update(session)
async for contract in Contract.get_contracts_without_abi(session, 10):
self.fail("Expected no contracts, but found one.")
2 changes: 1 addition & 1 deletion app/tests/services/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,5 @@ def test_process_event_calls_send(self, mock_get_contract_metadata_task):
EventsService().process_event(valid_message)

mock_get_contract_metadata_task.assert_called_once_with(
"0x6ED857dc1da2c41470A95589bB482152000773e9", 1
address="0x6ED857dc1da2c41470A95589bB482152000773e9", chain_id=1
)
48 changes: 44 additions & 4 deletions app/tests/workers/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
from dramatiq.worker import Worker
from eth_account import Account
from hexbytes import HexBytes
from safe_eth.eth import EthereumNetwork
from safe_eth.eth.clients import AsyncEtherscanClientV2
from sqlmodel.ext.asyncio.session import AsyncSession

from app.datasources.db.database import database_session
from app.datasources.db.models import Contract
from app.datasources.db.models import AbiSource, Contract
from app.workers.tasks import get_contract_metadata_task, redis_broker, test_task

from ...datasources.cache.redis import get_redis
from ...services.contract_metadata_service import ContractMetadataService
from ..datasources.db.db_async_conn import DbAsyncConn
from ..mocks.contract_metadata_mocks import (
etherscan_metadata_mock,
Expand Down Expand Up @@ -75,23 +78,60 @@ def _wait_tasks_execution(self):
while len(redis_tasks) > 0:
redis_tasks = self.worker.broker.client.lrange("dramatiq:default", 0, -1)

@mock.patch.object(ContractMetadataService, "enabled_clients")
@mock.patch.object(
AsyncEtherscanClientV2, "async_get_contract_metadata", autospec=True
)
@database_session
async def test_get_contract_metadata_task(
self, etherscan_get_contract_metadata_mock: MagicMock, session: AsyncSession
self,
etherscan_get_contract_metadata_mock: MagicMock,
mock_enabled_clients: MagicMock,
session: AsyncSession,
):
etherscan_get_contract_metadata_mock.return_value = etherscan_metadata_mock
contract_address = "0xd9Db270c1B5E3Bd161E8c8503c55cEABeE709552"
chain_id = 100
get_contract_metadata_task.fn(contract_address, chain_id)
cache_key = f"should_attempt_download:{contract_address}:{chain_id}:0"
redis = get_redis()
redis.delete(cache_key)
await AbiSource(name="Etherscan", url="").create(session)
etherscan_get_contract_metadata_mock.return_value = None
mock_enabled_clients.return_value = [
AsyncEtherscanClientV2(EthereumNetwork(chain_id))
]
# Should try one time
get_contract_metadata_task.fn(address=contract_address, chain_id=chain_id)
contract = await Contract.get_contract(
session, HexBytes(contract_address), chain_id
)
self.assertIsNotNone(contract)
self.assertIsNone(contract.abi_id)
self.assertEqual(etherscan_get_contract_metadata_mock.call_count, 1)

# Shouldn't try second time
etherscan_get_contract_metadata_mock.return_value = etherscan_metadata_mock
chain_id = 100
get_contract_metadata_task.fn(address=contract_address, chain_id=chain_id)
contract = await Contract.get_contract(
session, HexBytes(contract_address), chain_id
)
self.assertIsNotNone(contract)
self.assertIsNone(contract.abi_id)
self.assertEqual(etherscan_get_contract_metadata_mock.call_count, 1)

# After reset cache and database retries should download the contract
contract.fetch_retries = 0
redis.delete(cache_key)
await contract.update(session)
get_contract_metadata_task.fn(address=contract_address, chain_id=chain_id)
await session.refresh(contract)
contract = await Contract.get_contract(
session, HexBytes(contract_address), chain_id
)
self.assertIsNotNone(contract)
self.assertIsNotNone(contract.abi_id)
self.assertEqual(etherscan_get_contract_metadata_mock.call_count, 2)

@mock.patch.object(
AsyncEtherscanClientV2, "async_get_contract_metadata", autospec=True
)
Expand Down
35 changes: 28 additions & 7 deletions app/workers/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
import dramatiq
from dramatiq.brokers.redis import RedisBroker
from dramatiq.middleware import AsyncIO
from periodiq import PeriodiqMiddleware
from hexbytes import HexBytes
from periodiq import PeriodiqMiddleware, cron
from safe_eth.eth.utils import fast_to_checksum_address
from sqlmodel.ext.asyncio.session import AsyncSession

from ..config import settings
from ..datasources.db.database import database_session
from ..services.contract_metadata_service import get_contract_metadata_service
from app.config import settings
from app.datasources.db.database import database_session
from app.datasources.db.models import Contract
from app.services.contract_metadata_service import get_contract_metadata_service

logger = logging.getLogger(__name__)


redis_broker = RedisBroker(url=settings.REDIS_URL)
redis_broker.add_middleware(PeriodiqMiddleware(skip_delay=60))
redis_broker.add_middleware(AsyncIO())
Expand All @@ -37,11 +40,14 @@ async def test_task(message: str) -> None:
@dramatiq.actor
@database_session
async def get_contract_metadata_task(
address: str, chain_id: int, session: AsyncSession
session: AsyncSession,
address: str,
chain_id: int,
skip_attemp_download: bool = False,
) -> None:
contract_metadata_service = get_contract_metadata_service()
# Just try the first time, following retries should be scheduled
if await contract_metadata_service.should_attempt_download(
if skip_attemp_download or await contract_metadata_service.should_attempt_download(
session, address, chain_id, 0
):
logger.info(
Expand Down Expand Up @@ -77,6 +83,21 @@ async def get_contract_metadata_task(
address,
chain_id,
)
get_contract_metadata_task.send(proxy_implementation_address, chain_id)
get_contract_metadata_task.send(
address=proxy_implementation_address, chain_id=chain_id
)
else:
logger.debug("Skipping contract=%s and chain=%s", address, chain_id)


@dramatiq.actor(periodic=cron("0 0 * * *")) # Every midnight
@database_session
async def get_missing_contract_metadata_task(session: AsyncSession) -> None:
async for contract in Contract.get_contracts_without_abi(
session, settings.CONTRACT_MAX_DOWNLOAD_RETRIES
):
get_contract_metadata_task.send(
address=HexBytes(contract[0].address).hex(),
chain_id=contract[0].chain_id,
skip_attemp_download=True,
)

0 comments on commit 2e7b883

Please sign in to comment.