From 1c5074b16546843275115bafcdc7e164fd9de0f4 Mon Sep 17 00:00:00 2001 From: Ben Guidarelli Date: Tue, 9 Jan 2024 07:38:07 -0500 Subject: [PATCH] add chain context to constructor args so we always use the same chain context obj --- connect/src/protocols/cctpTransfer.ts | 55 ++++++++++++----- connect/src/protocols/tokenTransfer.ts | 81 ++++++++++++++++---------- 2 files changed, 89 insertions(+), 47 deletions(-) diff --git a/connect/src/protocols/cctpTransfer.ts b/connect/src/protocols/cctpTransfer.ts index 825bb80534..08070f1493 100644 --- a/connect/src/protocols/cctpTransfer.ts +++ b/connect/src/protocols/cctpTransfer.ts @@ -51,6 +51,9 @@ export class CircleTransfer { private readonly wh: Wormhole; + fromChain: ChainContext; + toChain: ChainContext; + // state machine tracker private _state: TransferState; @@ -62,10 +65,18 @@ export class CircleTransfer attestations?: AttestationReceipt[]; - private constructor(wh: Wormhole, transfer: CircleTransferDetails) { + private constructor( + wh: Wormhole, + transfer: CircleTransferDetails, + fromChain?: ChainContext, + toChain?: ChainContext, + ) { this._state = TransferState.Created; this.wh = wh; this.transfer = transfer; + + this.fromChain = fromChain ?? wh.getChain(transfer.from.chain); + this.toChain = toChain ?? wh.getChain(transfer.to.chain); } getTransferState(): TransferState { @@ -76,30 +87,41 @@ export class CircleTransfer static async from( wh: Wormhole, from: CircleTransferDetails, + timeout?: number, + fromChain?: ChainContext, + toChain?: ChainContext, ): Promise>; static async from( wh: Wormhole, from: WormholeMessageId, timeout?: number, + fromChain?: ChainContext, + toChain?: ChainContext, ): Promise>; static async from( wh: Wormhole, from: string, // CircleMessage hex encoded timeout?: number, + fromChain?: ChainContext, + toChain?: ChainContext, ): Promise>; static async from( wh: Wormhole, from: TransactionId, timeout?: number, + fromChain?: ChainContext, + toChain?: ChainContext, ): Promise>; static async from( wh: Wormhole, from: CircleTransferDetails | WormholeMessageId | string | TransactionId, timeout: number = DEFAULT_TASK_TIMEOUT, + fromChain?: ChainContext, + toChain?: ChainContext, ): Promise> { // This is a new transfer, just return the object if (isCircleTransferDetails(from)) { - return new CircleTransfer(wh, from); + return new CircleTransfer(wh, from, fromChain, toChain); } // This is an existing transfer, fetch the details @@ -107,12 +129,16 @@ export class CircleTransfer if (isWormholeMessageId(from)) { tt = await CircleTransfer.fromWormholeMessageId(wh, from, timeout); } else if (isTransactionIdentifier(from)) { - tt = await CircleTransfer.fromTransaction(wh, from, timeout); + tt = await CircleTransfer.fromTransaction(wh, from, timeout, fromChain); } else if (isCircleMessageId(from)) { - tt = await CircleTransfer.fromCircleMessage(wh, from, timeout); + tt = await CircleTransfer.fromCircleMessage(wh, from); } else { throw new Error("Invalid `from` parameter for CircleTransfer"); } + + tt.fromChain = fromChain ?? wh.getChain(tt.transfer.from.chain); + tt.toChain = toChain ?? wh.getChain(tt.transfer.to.chain); + await tt.fetchAttestation(timeout); return tt; @@ -158,7 +184,6 @@ export class CircleTransfer private static async fromCircleMessage( wh: Wormhole, message: string, - timeout: number, ): Promise> { const [msg, hash] = CircleBridge.deserialize(encoding.hex.decode(message)); @@ -188,14 +213,15 @@ export class CircleTransfer wh: Wormhole, from: TransactionId, timeout: number, + fromChain?: ChainContext, ): Promise> { const { chain, txid } = from; - const originChain = wh.getChain(chain); + fromChain = fromChain ?? wh.getChain(chain); // First try to parse out a WormholeMessage // If we get one or more, we assume its a Wormhole attested // transfer - const msgIds: WormholeMessageId[] = await originChain.parseTransaction(txid); + const msgIds: WormholeMessageId[] = await fromChain.parseTransaction(txid); // If we found a VAA message, use it let ct: CircleTransfer; @@ -203,7 +229,7 @@ export class CircleTransfer ct = await CircleTransfer.fromWormholeMessageId(wh, msgIds[0]!, timeout); } else { // Otherwise try to parse out a circle message - const cb = await originChain.getCircleBridge(); + const cb = await fromChain.getCircleBridge(); const circleMessage = await cb.parseTransactionDetails(txid); const details: CircleTransferDetails = { ...circleMessage, @@ -234,11 +260,9 @@ export class CircleTransfer if (this._state !== TransferState.Created) throw new Error("Invalid state transition in `start`"); - const fromChain = this.wh.getChain(this.transfer.from.chain); - let xfer: AsyncGenerator>; if (this.transfer.automatic) { - const cr = await fromChain.getAutomaticCircleBridge(); + const cr = await this.fromChain.getAutomaticCircleBridge(); xfer = cr.transfer( this.transfer.from.address, { chain: this.transfer.to.chain, address: this.transfer.to.address }, @@ -246,7 +270,7 @@ export class CircleTransfer this.transfer.nativeGas, ); } else { - const cb = await fromChain.getCircleBridge(); + const cb = await this.fromChain.getCircleBridge(); xfer = cb.transfer( this.transfer.from.address, { chain: this.transfer.to.chain, address: this.transfer.to.address }, @@ -254,7 +278,7 @@ export class CircleTransfer ); } - this.txids = await signSendWait(fromChain, xfer, signer); + this.txids = await signSendWait(this.fromChain, xfer, signer); this._state = TransferState.SourceInitiated; return this.txids.map(({ txid }) => txid); @@ -371,12 +395,11 @@ export class CircleTransfer const { message, attestation: signatures } = attestation; if (!signatures) throw new Error(`No Circle Attestation for ${id.hash}`); - const toChain = this.wh.getChain(this.transfer.to.chain); - const tb = await toChain.getCircleBridge(); + const tb = await this.toChain.getCircleBridge(); const xfer = tb.redeem(this.transfer.to.address, message, signatures!); - const txids = await signSendWait(toChain, xfer, signer); + const txids = await signSendWait(this.toChain, xfer, signer); this.txids?.push(...txids); return txids.map(({ txid }) => txid); } diff --git a/connect/src/protocols/tokenTransfer.ts b/connect/src/protocols/tokenTransfer.ts index c512955e31..9516a0f6a5 100644 --- a/connect/src/protocols/tokenTransfer.ts +++ b/connect/src/protocols/tokenTransfer.ts @@ -5,7 +5,7 @@ import { Platform, PlatformToChains, encoding, - toChain, + toChain as toChainName, } from "@wormhole-foundation/sdk-base"; import { AttestationId, @@ -54,6 +54,9 @@ export class TokenTransfer { private readonly wh: Wormhole; + fromChain: ChainContext; + toChain: ChainContext; + // state machine tracker private _state: TransferState; @@ -67,10 +70,18 @@ export class TokenTransfer // on the source chain (if its been completed and finalized) attestations?: AttestationReceipt[]; - private constructor(wh: Wormhole, transfer: TokenTransferDetails) { + private constructor( + wh: Wormhole, + transfer: TokenTransferDetails, + fromChain?: ChainContext, + toChain?: ChainContext, + ) { this._state = TransferState.Created; this.wh = wh; this.transfer = transfer; + + this.fromChain = fromChain ?? wh.getChain(transfer.from.chain); + this.toChain = toChain ?? wh.getChain(transfer.to.chain); } getTransferState(): TransferState { @@ -81,26 +92,37 @@ export class TokenTransfer static async from( wh: Wormhole, from: TokenTransferDetails, + timeout?: number, + fromChain?: ChainContext, + toChain?: ChainContext, ): Promise>; static async from( wh: Wormhole, from: WormholeMessageId, timeout?: number, + fromChain?: ChainContext, + toChain?: ChainContext, ): Promise>; static async from( wh: Wormhole, from: TransactionId, timeout?: number, + fromChain?: ChainContext, + toChain?: ChainContext, ): Promise>; static async from( wh: Wormhole, from: TokenTransferDetails | WormholeMessageId | TransactionId, timeout: number = 6000, + fromChain?: ChainContext, + toChain?: ChainContext, ): Promise> { if (isTokenTransferDetails(from)) { - await TokenTransfer.validateTransferDetails(wh, from); - const fromChain = wh.getChain(from.from.chain); - const toChain = wh.getChain(from.to.chain); + fromChain = fromChain ?? wh.getChain(from.from.chain); + toChain = toChain ?? wh.getChain(from.to.chain); + + // throws if invalid + await TokenTransfer.validateTransferDetails(wh, from, fromChain, toChain); // Apply hackery from = { @@ -108,17 +130,21 @@ export class TokenTransfer ...(await TokenTransfer.destinationOverrides(fromChain, toChain, from)), }; - return new TokenTransfer(wh, from); + return new TokenTransfer(wh, from, fromChain, toChain); } let tt: TokenTransfer; if (isWormholeMessageId(from)) { tt = await TokenTransfer.fromIdentifier(wh, from, timeout); } else if (isTransactionIdentifier(from)) { - tt = await TokenTransfer.fromTransaction(wh, from, timeout); + tt = await TokenTransfer.fromTransaction(wh, from, timeout, fromChain); } else { throw new Error("Invalid `from` parameter for TokenTransfer"); } + + tt.fromChain = fromChain ?? wh.getChain(tt.transfer.from.chain); + tt.toChain = toChain ?? wh.getChain(tt.transfer.to.chain); + await tt.fetchAttestation(timeout); return tt; } @@ -168,8 +194,10 @@ export class TokenTransfer wh: Wormhole, from: TransactionId, timeout: number, + fromChain?: ChainContext, ): Promise> { - const msg = await TokenTransfer.getTransferMessage(wh.getChain(from.chain), from.txid, timeout); + fromChain = fromChain ?? wh.getChain(from.chain); + const msg = await TokenTransfer.getTransferMessage(fromChain, from.txid, timeout); const tt = await TokenTransfer.fromIdentifier(wh, msg, timeout); tt.txids = [from]; return tt; @@ -189,8 +217,7 @@ export class TokenTransfer if (this._state !== TransferState.Created) throw new Error("Invalid state transition in `start`"); - const fromChain = this.wh.getChain(this.transfer.from.chain); - this.txids = await TokenTransfer.transfer(fromChain, this.transfer, signer); + this.txids = await TokenTransfer.transfer(this.fromChain, this.transfer, signer); this._state = TransferState.SourceInitiated; return this.txids.map(({ txid }) => txid); } @@ -213,11 +240,7 @@ export class TokenTransfer // TODO: assuming the _last_ transaction in the list will contain the msg id const txid = this.txids[this.txids.length - 1]!; - const msgId = await TokenTransfer.getTransferMessage( - this.wh.getChain(txid.chain), - txid.txid, - timeout, - ); + const msgId = await TokenTransfer.getTransferMessage(this.fromChain, txid.txid, timeout); this.attestations = [{ id: msgId }]; } @@ -253,9 +276,8 @@ export class TokenTransfer const { attestation } = this.attestations[0]!; if (!attestation) throw new Error(`No VAA found for ${this.attestations[0]!.id.sequence}`); - const toChain = this.wh.getChain(this.transfer.to.chain); const redeemTxids = await TokenTransfer.redeem( - toChain, + this.toChain, attestation as TokenTransferVAA, signer, ); @@ -394,6 +416,8 @@ export class TokenTransfer static async validateTransferDetails( wh: Wormhole, transfer: TokenTransferDetails, + fromChain?: ChainContext, + toChain?: ChainContext, ): Promise { if (transfer.from.chain === transfer.to.chain) throw new Error("Cannot transfer to the same chain"); @@ -404,8 +428,8 @@ export class TokenTransfer if (transfer.nativeGas && !transfer.automatic) throw new Error("Gas Dropoff is only supported for automatic transfers"); - const fromChain = wh.getChain(transfer.from.chain); - const toChain = wh.getChain(transfer.to.chain); + fromChain = fromChain ?? wh.getChain(transfer.from.chain); + toChain = toChain ?? wh.getChain(transfer.to.chain); if (!fromChain.supportsTokenBridge()) throw new Error(`Token Bridge not supported on ${transfer.from.chain}`); @@ -594,26 +618,21 @@ export class TokenTransfer wh: Wormhole, receipt: TransferReceipt, timeout: number = DEFAULT_TASK_TIMEOUT, - // Optional parameters to override chain context (typically for custom rpc) - _fromChain?: ChainContext, SC>, - _toChain?: ChainContext, DC>, + fromChain?: ChainContext, SC>, + toChain?: ChainContext, DC>, ) { const start = Date.now(); const leftover = (start: number, max: number) => Math.max(max - (Date.now() - start), 0); - _fromChain = _fromChain ?? wh.getChain(receipt.from); - _toChain = _toChain ?? wh.getChain(receipt.to); + fromChain = fromChain ?? wh.getChain(receipt.from); + toChain = toChain ?? wh.getChain(receipt.to); // Check the source chain for initiation transaction // and capture the message id if (isSourceInitiated(receipt)) { if (receipt.originTxs.length === 0) throw "Origin transactions required to fetch message id"; const { txid } = receipt.originTxs[receipt.originTxs.length - 1]!; - const msg = await TokenTransfer.getTransferMessage( - _fromChain, - txid, - leftover(start, timeout), - ); + const msg = await TokenTransfer.getTransferMessage(fromChain, txid, leftover(start, timeout)); receipt = { ...receipt, state: TransferState.SourceFinalized, @@ -648,7 +667,7 @@ export class TokenTransfer const { chainId, txHash } = txStatus.globalTx.destinationTx; receipt = { ...receipt, - destinationTxs: [{ chain: toChain(chainId) as DC, txid: txHash }], + destinationTxs: [{ chain: toChainName(chainId) as DC, txid: txHash }], state: TransferState.DestinationFinalized, } satisfies CompletedTransferReceipt; } @@ -661,7 +680,7 @@ export class TokenTransfer if (!receipt.attestation.attestation) throw "Signed Attestation required to check for redeem"; let isComplete = await TokenTransfer.isTransferComplete( - _toChain, + toChain, receipt.attestation.attestation as TokenTransferVAA, );