Skip to content

Commit

Permalink
add chain context to constructor args so we always use the same chain…
Browse files Browse the repository at this point in the history
… context obj
  • Loading branch information
barnjamin committed Jan 9, 2024
1 parent dfb32e3 commit ae5376c
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 47 deletions.
55 changes: 39 additions & 16 deletions connect/src/protocols/cctpTransfer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ export class CircleTransfer<N extends Network = Network>
{
private readonly wh: Wormhole<N>;

protected fromChain: ChainContext<N, Platform, Chain>;
protected toChain: ChainContext<N, Platform, Chain>;

// state machine tracker
private _state: TransferState;

Expand All @@ -62,10 +65,18 @@ export class CircleTransfer<N extends Network = Network>

attestations?: AttestationReceipt<CircleTransferProtocol>[];

private constructor(wh: Wormhole<N>, transfer: CircleTransferDetails) {
private constructor(
wh: Wormhole<N>,
transfer: CircleTransferDetails,
fromChain?: ChainContext<N, Platform, Chain>,
toChain?: ChainContext<N, Platform, Chain>,
) {
this._state = TransferState.Created;
this.wh = wh;
this.transfer = transfer;

this.fromChain = fromChain ?? wh.getChain(transfer.from.chain);
this.toChain = toChain ?? wh.getChain(transfer.to.chain);
}

getTransferState(): TransferState {
Expand All @@ -76,43 +87,58 @@ export class CircleTransfer<N extends Network = Network>
static async from<N extends Network>(
wh: Wormhole<N>,
from: CircleTransferDetails,
timeout?: number,
fromChain?: ChainContext<N, Platform, Chain>,
toChain?: ChainContext<N, Platform, Chain>,
): Promise<CircleTransfer<N>>;
static async from<N extends Network>(
wh: Wormhole<N>,
from: WormholeMessageId,
timeout?: number,
fromChain?: ChainContext<N, Platform, Chain>,
toChain?: ChainContext<N, Platform, Chain>,
): Promise<CircleTransfer<N>>;
static async from<N extends Network>(
wh: Wormhole<N>,
from: string, // CircleMessage hex encoded
timeout?: number,
fromChain?: ChainContext<N, Platform, Chain>,
toChain?: ChainContext<N, Platform, Chain>,
): Promise<CircleTransfer<N>>;
static async from<N extends Network>(
wh: Wormhole<N>,
from: TransactionId,
timeout?: number,
fromChain?: ChainContext<N, Platform, Chain>,
toChain?: ChainContext<N, Platform, Chain>,
): Promise<CircleTransfer<N>>;
static async from<N extends Network>(
wh: Wormhole<N>,
from: CircleTransferDetails | WormholeMessageId | string | TransactionId,
timeout: number = DEFAULT_TASK_TIMEOUT,
fromChain?: ChainContext<N, Platform, Chain>,
toChain?: ChainContext<N, Platform, Chain>,
): Promise<CircleTransfer<N>> {
// This is a new transfer, just return the object
if (isCircleTransferDetails(from)) {
return new CircleTransfer(wh, from);
return new CircleTransfer(wh, from, fromChain, toChain);
}

// This is an existing transfer, fetch the details
let tt: CircleTransfer<N> | undefined;
if (isWormholeMessageId(from)) {
tt = await CircleTransfer.fromWormholeMessageId(wh, from, timeout);
} else if (isTransactionIdentifier(from)) {
tt = await CircleTransfer.fromTransaction(wh, from, timeout);
tt = await CircleTransfer.fromTransaction(wh, from, timeout, fromChain);
} else if (isCircleMessageId(from)) {
tt = await CircleTransfer.fromCircleMessage(wh, from, timeout);
tt = await CircleTransfer.fromCircleMessage(wh, from);
} else {
throw new Error("Invalid `from` parameter for CircleTransfer");
}

tt.fromChain = fromChain ?? wh.getChain(tt.transfer.from.chain);
tt.toChain = toChain ?? wh.getChain(tt.transfer.to.chain);

await tt.fetchAttestation(timeout);

return tt;
Expand Down Expand Up @@ -158,7 +184,6 @@ export class CircleTransfer<N extends Network = Network>
private static async fromCircleMessage<N extends Network>(
wh: Wormhole<N>,
message: string,
timeout: number,
): Promise<CircleTransfer<N>> {
const [msg, hash] = CircleBridge.deserialize(encoding.hex.decode(message));

Expand Down Expand Up @@ -188,22 +213,23 @@ export class CircleTransfer<N extends Network = Network>
wh: Wormhole<N>,
from: TransactionId,
timeout: number,
fromChain?: ChainContext<N, Platform, Chain>,
): Promise<CircleTransfer<N>> {
const { chain, txid } = from;
const originChain = wh.getChain(chain);
fromChain = fromChain ?? wh.getChain(chain);

// First try to parse out a WormholeMessage
// If we get one or more, we assume its a Wormhole attested
// transfer
const msgIds: WormholeMessageId[] = await originChain.parseTransaction(txid);
const msgIds: WormholeMessageId[] = await fromChain.parseTransaction(txid);

// If we found a VAA message, use it
let ct: CircleTransfer<N>;
if (msgIds.length > 0) {
ct = await CircleTransfer.fromWormholeMessageId(wh, msgIds[0]!, timeout);
} else {
// Otherwise try to parse out a circle message
const cb = await originChain.getCircleBridge();
const cb = await fromChain.getCircleBridge();
const circleMessage = await cb.parseTransactionDetails(txid);
const details: CircleTransferDetails = {
...circleMessage,
Expand Down Expand Up @@ -234,27 +260,25 @@ export class CircleTransfer<N extends Network = Network>
if (this._state !== TransferState.Created)
throw new Error("Invalid state transition in `start`");

const fromChain = this.wh.getChain(this.transfer.from.chain);

let xfer: AsyncGenerator<UnsignedTransaction<N>>;
if (this.transfer.automatic) {
const cr = await fromChain.getAutomaticCircleBridge();
const cr = await this.fromChain.getAutomaticCircleBridge();
xfer = cr.transfer(
this.transfer.from.address,
{ chain: this.transfer.to.chain, address: this.transfer.to.address },
this.transfer.amount,
this.transfer.nativeGas,
);
} else {
const cb = await fromChain.getCircleBridge();
const cb = await this.fromChain.getCircleBridge();
xfer = cb.transfer(
this.transfer.from.address,
{ chain: this.transfer.to.chain, address: this.transfer.to.address },
this.transfer.amount,
);
}

this.txids = await signSendWait<N, typeof fromChain.chain>(fromChain, xfer, signer);
this.txids = await signSendWait<N, Chain>(this.fromChain, xfer, signer);
this._state = TransferState.SourceInitiated;

return this.txids.map(({ txid }) => txid);
Expand Down Expand Up @@ -371,12 +395,11 @@ export class CircleTransfer<N extends Network = Network>
const { message, attestation: signatures } = attestation;
if (!signatures) throw new Error(`No Circle Attestation for ${id.hash}`);

const toChain = this.wh.getChain(this.transfer.to.chain);
const tb = await toChain.getCircleBridge();
const tb = await this.toChain.getCircleBridge();

const xfer = tb.redeem(this.transfer.to.address, message, signatures!);

const txids = await signSendWait<N, typeof toChain.chain>(toChain, xfer, signer);
const txids = await signSendWait<N, Chain>(this.toChain, xfer, signer);
this.txids?.push(...txids);
return txids.map(({ txid }) => txid);
}
Expand Down
81 changes: 50 additions & 31 deletions connect/src/protocols/tokenTransfer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
Platform,
PlatformToChains,
encoding,
toChain,
toChain as toChainName,
} from "@wormhole-foundation/sdk-base";
import {
AttestationId,
Expand Down Expand Up @@ -54,6 +54,9 @@ export class TokenTransfer<N extends Network = Network>
{
private readonly wh: Wormhole<N>;

protected fromChain: ChainContext<N, Platform, Chain>;
protected toChain: ChainContext<N, Platform, Chain>;

// state machine tracker
private _state: TransferState;

Expand All @@ -67,10 +70,18 @@ export class TokenTransfer<N extends Network = Network>
// on the source chain (if its been completed and finalized)
attestations?: AttestationReceipt<TokenTransferProtocol>[];

private constructor(wh: Wormhole<N>, transfer: TokenTransferDetails) {
private constructor(
wh: Wormhole<N>,
transfer: TokenTransferDetails,
fromChain?: ChainContext<N, Platform, Chain>,
toChain?: ChainContext<N, Platform, Chain>,
) {
this._state = TransferState.Created;
this.wh = wh;
this.transfer = transfer;

this.fromChain = fromChain ?? wh.getChain(transfer.from.chain);
this.toChain = toChain ?? wh.getChain(transfer.to.chain);
}

getTransferState(): TransferState {
Expand All @@ -81,44 +92,59 @@ export class TokenTransfer<N extends Network = Network>
static async from<N extends Network>(
wh: Wormhole<N>,
from: TokenTransferDetails,
timeout?: number,
fromChain?: ChainContext<N, Platform, Chain>,
toChain?: ChainContext<N, Platform, Chain>,
): Promise<TokenTransfer<N>>;
static async from<N extends Network>(
wh: Wormhole<N>,
from: WormholeMessageId,
timeout?: number,
fromChain?: ChainContext<N, Platform, Chain>,
toChain?: ChainContext<N, Platform, Chain>,
): Promise<TokenTransfer<N>>;
static async from<N extends Network>(
wh: Wormhole<N>,
from: TransactionId,
timeout?: number,
fromChain?: ChainContext<N, Platform, Chain>,
toChain?: ChainContext<N, Platform, Chain>,
): Promise<TokenTransfer<N>>;
static async from<N extends Network>(
wh: Wormhole<N>,
from: TokenTransferDetails | WormholeMessageId | TransactionId,
timeout: number = 6000,
fromChain?: ChainContext<N, Platform, Chain>,
toChain?: ChainContext<N, Platform, Chain>,
): Promise<TokenTransfer<N>> {
if (isTokenTransferDetails(from)) {
await TokenTransfer.validateTransferDetails(wh, from);
const fromChain = wh.getChain(from.from.chain);
const toChain = wh.getChain(from.to.chain);
fromChain = fromChain ?? wh.getChain(from.from.chain);
toChain = toChain ?? wh.getChain(from.to.chain);

// throws if invalid
await TokenTransfer.validateTransferDetails(wh, from, fromChain, toChain);

// Apply hackery
from = {
...from,
...(await TokenTransfer.destinationOverrides(fromChain, toChain, from)),
};

return new TokenTransfer(wh, from);
return new TokenTransfer(wh, from, fromChain, toChain);
}

let tt: TokenTransfer<N>;
if (isWormholeMessageId(from)) {
tt = await TokenTransfer.fromIdentifier(wh, from, timeout);
} else if (isTransactionIdentifier(from)) {
tt = await TokenTransfer.fromTransaction(wh, from, timeout);
tt = await TokenTransfer.fromTransaction(wh, from, timeout, fromChain);
} else {
throw new Error("Invalid `from` parameter for TokenTransfer");
}

tt.fromChain = fromChain ?? wh.getChain(tt.transfer.from.chain);
tt.toChain = toChain ?? wh.getChain(tt.transfer.to.chain);

await tt.fetchAttestation(timeout);
return tt;
}
Expand Down Expand Up @@ -168,8 +194,10 @@ export class TokenTransfer<N extends Network = Network>
wh: Wormhole<N>,
from: TransactionId,
timeout: number,
fromChain?: ChainContext<N, Platform, Chain>,
): Promise<TokenTransfer<N>> {
const msg = await TokenTransfer.getTransferMessage(wh.getChain(from.chain), from.txid, timeout);
fromChain = fromChain ?? wh.getChain(from.chain);
const msg = await TokenTransfer.getTransferMessage(fromChain, from.txid, timeout);
const tt = await TokenTransfer.fromIdentifier(wh, msg, timeout);
tt.txids = [from];
return tt;
Expand All @@ -189,8 +217,7 @@ export class TokenTransfer<N extends Network = Network>
if (this._state !== TransferState.Created)
throw new Error("Invalid state transition in `start`");

const fromChain = this.wh.getChain(this.transfer.from.chain);
this.txids = await TokenTransfer.transfer<N>(fromChain, this.transfer, signer);
this.txids = await TokenTransfer.transfer<N>(this.fromChain, this.transfer, signer);
this._state = TransferState.SourceInitiated;
return this.txids.map(({ txid }) => txid);
}
Expand All @@ -213,11 +240,7 @@ export class TokenTransfer<N extends Network = Network>

// TODO: assuming the _last_ transaction in the list will contain the msg id
const txid = this.txids[this.txids.length - 1]!;
const msgId = await TokenTransfer.getTransferMessage(
this.wh.getChain(txid.chain),
txid.txid,
timeout,
);
const msgId = await TokenTransfer.getTransferMessage(this.fromChain, txid.txid, timeout);
this.attestations = [{ id: msgId }];
}

Expand Down Expand Up @@ -253,9 +276,8 @@ export class TokenTransfer<N extends Network = Network>
const { attestation } = this.attestations[0]!;
if (!attestation) throw new Error(`No VAA found for ${this.attestations[0]!.id.sequence}`);

const toChain = this.wh.getChain(this.transfer.to.chain);
const redeemTxids = await TokenTransfer.redeem<N>(
toChain,
this.toChain,
attestation as TokenTransferVAA,
signer,
);
Expand Down Expand Up @@ -394,6 +416,8 @@ export class TokenTransfer<N extends Network = Network>
static async validateTransferDetails<N extends Network>(
wh: Wormhole<N>,
transfer: TokenTransferDetails,
fromChain?: ChainContext<N, Platform, Chain>,
toChain?: ChainContext<N, Platform, Chain>,
): Promise<void> {
if (transfer.from.chain === transfer.to.chain)
throw new Error("Cannot transfer to the same chain");
Expand All @@ -404,8 +428,8 @@ export class TokenTransfer<N extends Network = Network>
if (transfer.nativeGas && !transfer.automatic)
throw new Error("Gas Dropoff is only supported for automatic transfers");

const fromChain = wh.getChain(transfer.from.chain);
const toChain = wh.getChain(transfer.to.chain);
fromChain = fromChain ?? wh.getChain(transfer.from.chain);
toChain = toChain ?? wh.getChain(transfer.to.chain);

if (!fromChain.supportsTokenBridge())
throw new Error(`Token Bridge not supported on ${transfer.from.chain}`);
Expand Down Expand Up @@ -594,26 +618,21 @@ export class TokenTransfer<N extends Network = Network>
wh: Wormhole<N>,
receipt: TransferReceipt<TokenTransferProtocol, SC, DC>,
timeout: number = DEFAULT_TASK_TIMEOUT,
// Optional parameters to override chain context (typically for custom rpc)
_fromChain?: ChainContext<N, ChainToPlatform<SC>, SC>,
_toChain?: ChainContext<N, ChainToPlatform<DC>, DC>,
fromChain?: ChainContext<N, ChainToPlatform<SC>, SC>,
toChain?: ChainContext<N, ChainToPlatform<DC>, DC>,
) {
const start = Date.now();
const leftover = (start: number, max: number) => Math.max(max - (Date.now() - start), 0);

_fromChain = _fromChain ?? wh.getChain(receipt.from);
_toChain = _toChain ?? wh.getChain(receipt.to);
fromChain = fromChain ?? wh.getChain(receipt.from);
toChain = toChain ?? wh.getChain(receipt.to);

// Check the source chain for initiation transaction
// and capture the message id
if (isSourceInitiated(receipt)) {
if (receipt.originTxs.length === 0) throw "Origin transactions required to fetch message id";
const { txid } = receipt.originTxs[receipt.originTxs.length - 1]!;
const msg = await TokenTransfer.getTransferMessage(
_fromChain,
txid,
leftover(start, timeout),
);
const msg = await TokenTransfer.getTransferMessage(fromChain, txid, leftover(start, timeout));
receipt = {
...receipt,
state: TransferState.SourceFinalized,
Expand Down Expand Up @@ -648,7 +667,7 @@ export class TokenTransfer<N extends Network = Network>
const { chainId, txHash } = txStatus.globalTx.destinationTx;
receipt = {
...receipt,
destinationTxs: [{ chain: toChain(chainId) as DC, txid: txHash }],
destinationTxs: [{ chain: toChainName(chainId) as DC, txid: txHash }],
state: TransferState.DestinationFinalized,
} satisfies CompletedTransferReceipt<TokenTransferProtocol>;
}
Expand All @@ -661,7 +680,7 @@ export class TokenTransfer<N extends Network = Network>
if (!receipt.attestation.attestation) throw "Signed Attestation required to check for redeem";

let isComplete = await TokenTransfer.isTransferComplete(
_toChain,
toChain,
receipt.attestation.attestation as TokenTransferVAA,
);

Expand Down

0 comments on commit ae5376c

Please sign in to comment.