Skip to content

Commit

Permalink
combine circle decoders to a single payload (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
barnjamin authored Oct 13, 2023
1 parent 5c96c09 commit 3d76ba0
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 96 deletions.
3 changes: 2 additions & 1 deletion connect/src/protocols/cctpTransfer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,13 @@ export class CCTPTransfer implements WormholeTransfer {
wh: Wormhole,
messageId: CircleMessageId,
): Promise<CCTPTransfer> {
const [message, burnMessage, hash] = deserializeCircleMessage(
const [message, hash] = deserializeCircleMessage(
hexByteStringToUint8Array(messageId.message),
);
// If no hash is passed, set to the one we just computed
if (messageId.hash === "") messageId.hash = hash;

const { payload: burnMessage } = message;
const xferSender = burnMessage.messageSender;
const xferReceiver = burnMessage.mintRecipient;

Expand Down
30 changes: 4 additions & 26 deletions core/definitions/__tests__/circleMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import {
import { circleMessageLayout } from "../src/protocols/cctp";
import { UniversalAddress } from "../src";
import { circleContracts } from "@wormhole-foundation/sdk-base/src/constants/contracts";
import { circleBurnMessageLayout } from "../dist/cjs";

const ethAddressToUniversal = (address: string) => {
return new UniversalAddress("00".repeat(12) + address.slice(2));
Expand Down Expand Up @@ -40,39 +39,18 @@ describe("Circle Message tests", function () {
);

const decoded = deserializeLayout(circleMessageLayout, orig);

expect(decoded.version).toEqual(0);
expect(decoded.sourceDomain).toEqual(circleChainId(fromChain));
expect(decoded.destinationDomain).toEqual(circleChainId(toChain));
expect(decoded.nonce).toEqual(235558n);
expect(decoded.sender.equals(actualSender)).toBeTruthy();
expect(decoded.recipient.equals(actualReceiver)).toBeTruthy();
expect(decoded.messageBody.length).toEqual(132);

const decodedPayload = deserializeLayout(
circleBurnMessageLayout as Layout,
decoded.messageBody,
);

const burnToken = new UniversalAddress(
decodedPayload.burnToken.toUint8Array(),
);
const mintRecipient = new UniversalAddress(
decodedPayload.mintRecipient.toUint8Array(),
);
const messageSender = new UniversalAddress(
decodedPayload.messageSender.toUint8Array(),
);

// TODO: why does this fail? not passing instanceof check??
console.log(decodedPayload.burnToken.equals(tokenAddress));

const decodedPayload = decoded.payload;
expect(decodedPayload.version).toEqual(0);
expect(decodedPayload.amount).toEqual(1000000n);
expect(burnToken.equals(tokenAddress)).toBeTruthy();
expect(mintRecipient.equals(accountSender)).toBeTruthy();
expect(messageSender.equals(accountSender)).toBeTruthy();

console.log(decodedPayload);
expect(decodedPayload.burnToken.equals(tokenAddress)).toBeTruthy();
expect(decodedPayload.mintRecipient.equals(accountSender)).toBeTruthy();
expect(decodedPayload.messageSender.equals(accountSender)).toBeTruthy();
});
});
44 changes: 19 additions & 25 deletions core/definitions/src/protocols/cctp.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
import {
PlatformName,
CircleChainId,
Layout,
LayoutToType,
PlatformName,
deserializeLayout,
uint8ArrayToHexByteString,
LayoutToType,
} from "@wormhole-foundation/sdk-base";
import { UniversalOrNative, ChainAddress } from "../address";
import { ChainAddress, UniversalOrNative } from "../address";
import { CircleMessageId } from "../attestation";
import { UnsignedTransaction } from "../unsignedTransaction";
import { TokenId, TxHash } from "../types";
import { universalAddressItem } from "../layout-items";
import "../payloads/connect";
import { RpcConnection } from "../rpc";
import { universalAddressItem } from "../layout-items";
import { TokenId } from "../types";
import { UnsignedTransaction } from "../unsignedTransaction";
import { keccak256 } from "../utils";

// https://developers.circle.com/stablecoin/docs/cctp-technical-reference#message
const circleBurnMessageLayout: Layout = [
{ name: "version", binary: "uint", size: 4 },
{ name: "burnToken", ...universalAddressItem },
{ name: "mintRecipient", ...universalAddressItem },
{ name: "amount", binary: "uint", size: 32 },
{ name: "messageSender", ...universalAddressItem },
];

// TODO: convert domain to chain name?
export const circleMessageLayout: Layout = [
{ name: "version", binary: "uint", size: 4 },
{ name: "sourceDomain", binary: "uint", size: 4 },
Expand All @@ -24,30 +32,16 @@ export const circleMessageLayout: Layout = [
{ name: "sender", ...universalAddressItem },
{ name: "recipient", ...universalAddressItem },
{ name: "destinationCaller", ...universalAddressItem },
{ name: "messageBody", binary: "bytes" },
];

export const circleBurnMessageLayout: Layout = [
{ name: "version", binary: "uint", size: 4 },
{ name: "burnToken", ...universalAddressItem },
{ name: "mintRecipient", ...universalAddressItem },
{ name: "amount", binary: "uint", size: 32 },
{ name: "messageSender", ...universalAddressItem },
// TODO: is this the only message body we'll get?
{ name: "payload", binary: "object", layout: circleBurnMessageLayout },
];

export const deserializeCircleMessage = (
data: Uint8Array,
): [
LayoutToType<typeof circleMessageLayout>,
LayoutToType<typeof circleBurnMessageLayout>,
string,
] => {
): [LayoutToType<typeof circleMessageLayout>, string] => {
const msg = deserializeLayout(circleMessageLayout, data);
// Expect a body message, only Burn atm
if (msg.messageBody.length === 0) throw new Error("empty message body");
const burnMsg = deserializeLayout(circleBurnMessageLayout, msg.messageBody);
const messsageHash = uint8ArrayToHexByteString(keccak256(data));
return [msg, burnMsg, messsageHash];
return [msg, messsageHash];
};

export type CircleTransferMessage = {
Expand Down
114 changes: 74 additions & 40 deletions core/definitions/src/vaa.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ type PayloadLiteral = keyof Wormhole.PayloadLiteralToLayoutMapping & string;
type LayoutOf<PL extends PayloadLiteral> =
//TODO check if this lazy instantiation hack is actually necessary
PL extends infer V extends PayloadLiteral
? Wormhole.PayloadLiteralToLayoutMapping[V]
: never;
? Wormhole.PayloadLiteralToLayoutMapping[V]
: never;

type PayloadLiteralToPayloadType<PL extends PayloadLiteral> =
LayoutToType<LayoutOf<PL>>;
type PayloadLiteralToPayloadType<PL extends PayloadLiteral> = LayoutToType<
LayoutOf<PL>
>;

const guardianSignatureLayout = [
{ name: "guardianIndex", binary: "uint", size: 1 },
Expand Down Expand Up @@ -75,9 +76,7 @@ type BaseLayout = LayoutToType<typeof baseLayout>;

type ExtendedLiteral = PayloadLiteral | "Uint8Array";
type ExtendedLiteralToPayloadType<EL extends ExtendedLiteral> =
EL extends PayloadLiteral
? PayloadLiteralToPayloadType<EL>
: Uint8Array;
EL extends PayloadLiteral ? PayloadLiteralToPayloadType<EL> : Uint8Array;

export interface VAA<EL extends ExtendedLiteral = ExtendedLiteral>
extends BaseLayout {
Expand All @@ -102,26 +101,38 @@ function getPayloadLayout<PL extends PayloadLiteral>(payloadLiteral: PL) {

const extendedLiteralToPayloadItem = <EL extends ExtendedLiteral>(
extendedLiteral: EL,
) => (
extendedLiteral === "Uint8Array"
? { name: "payload", binary: "bytes" } as const
: { name: "payload", binary: "object", layout: getPayloadLayout(extendedLiteral) } as const
) satisfies LayoutItem;
) =>
(extendedLiteral === "Uint8Array"
? ({ name: "payload", binary: "bytes" } as const)
: ({
name: "payload",
binary: "object",
layout: getPayloadLayout(extendedLiteral),
} as const)) satisfies LayoutItem;

//annoyingly we can't use the return value of extendedLiteralToPayloadItem
type ExtendedLiteralToDynamicItems<EL extends ExtendedLiteral> =
DynamicItemsOfLayout<[
...typeof baseLayout,
EL extends PayloadLiteral
? { name: "payload", binary: "object", layout: DynamicItemsOfLayout<LayoutOf<EL>> }
: { name: "payload", binary: "bytes" }
]>;
DynamicItemsOfLayout<
[
...typeof baseLayout,
EL extends PayloadLiteral
? {
name: "payload";
binary: "object";
layout: DynamicItemsOfLayout<LayoutOf<EL>>;
}
: { name: "payload"; binary: "bytes" },
]
>;

export const create = <EL extends ExtendedLiteral = "Uint8Array">(
extendedLiteral: EL,
vaaData: LayoutToType<ExtendedLiteralToDynamicItems<EL>>,
): VAA<EL> => {
const bodyLayout = [...envelopeLayout, extendedLiteralToPayloadItem(extendedLiteral)] as const;
const bodyLayout = [
...envelopeLayout,
extendedLiteralToPayloadItem(extendedLiteral),
] as const;
const bodyWithFixed = addFixedValues(
bodyLayout,
//not sure why the unknown cast here is required and why the type isn't inferred correctly
Expand All @@ -148,7 +159,9 @@ export function registerPayloadType<PL extends PayloadLiteral>(
payloadFactory.set(payloadLiteral, payloadLayout);
}

export const serialize = <EL extends ExtendedLiteral>(vaa: VAA<EL>): Uint8Array => {
export const serialize = <EL extends ExtendedLiteral>(
vaa: VAA<EL>,
): Uint8Array => {
const layout = [
...baseLayout,
extendedLiteralToPayloadItem(vaa.payloadLiteral),
Expand All @@ -160,23 +173,23 @@ export const serializePayload = <EL extends ExtendedLiteral>(
extendedLiteral: EL,
payload: ExtendedLiteralToPayloadType<EL>,
) => {
if (extendedLiteral === "Uint8Array")
return payload;
if (extendedLiteral === "Uint8Array") return payload;

const layout = getPayloadLayout(extendedLiteral);
return serializeLayout(layout, payload as LayoutToType<typeof layout>);
}
};

export type NamedPayloads = readonly (readonly [string, Layout])[];

export const payloadDiscriminator = <NP extends NamedPayloads>(namedPayloads: NP) => {
export const payloadDiscriminator = <NP extends NamedPayloads>(
namedPayloads: NP,
) => {
const literals = column(namedPayloads, 0);
const layouts = column(namedPayloads, 1);
const discriminator = layoutDiscriminator(layouts);

return (data: Uint8Array | string): (typeof literals)[number] | null => {
if (typeof data === "string")
data = hexByteStringToUint8Array(data);
if (typeof data === "string") data = hexByteStringToUint8Array(data);

const index = discriminator(data);
return index !== null ? literals[index] : null;
Expand All @@ -194,12 +207,19 @@ export function deserialize<PL extends PayloadLiteral>(
): VAA<PL>;

export function deserialize<EL extends ExtendedLiteral>(
payloadDet: EL | ((data: Uint8Array | string) => (EL & PayloadLiteral) | null),
payloadDet:
| EL
| ((data: Uint8Array | string) => (EL & PayloadLiteral) | null),
data: Uint8Array | string,
): VAA<EL> {
if (typeof data === "string") data = hexByteStringToUint8Array(data);

const [header, envelopeOffset] = deserializeLayout(headerLayout, data, 0, false);
const [header, envelopeOffset] = deserializeLayout(
headerLayout,
data,
0,
false,
);

//ensure that guardian signature indicies are unique and in ascending order - see:
//https://github.com/wormhole-foundation/wormhole/blob/8e0cf4c31f39b5ba06b0f6cdb6e690d3adf3d6a3/ethereum/contracts/Messages.sol#L121
Expand All @@ -212,22 +232,30 @@ export function deserialize<EL extends ExtendedLiteral>(
"Guardian signatures must be in ascending order of guardian set index",
);

const [envelope, payloadOffset] = deserializeLayout(envelopeLayout, data, envelopeOffset, false);
const [envelope, payloadOffset] = deserializeLayout(
envelopeLayout,
data,
envelopeOffset,
false,
);

const [payloadLiteral, payload] =
payloadDet === "Uint8Array"
? ["Uint8Array", data.slice(payloadOffset)]
: typeof payloadDet === "string"
? [payloadDet, deserializeLayout(getPayloadLayout(payloadDet), data, payloadOffset)]
: deserializePayload(payloadDet, data, payloadOffset);
? ["Uint8Array", data.slice(payloadOffset)]
: typeof payloadDet === "string"
? [
payloadDet,
deserializeLayout(getPayloadLayout(payloadDet), data, payloadOffset),
]
: deserializePayload(payloadDet, data, payloadOffset);
const hash = keccak256(data.slice(envelopeOffset));

return { payloadLiteral, ...header, ...envelope, payload, hash } as VAA<EL>;
}

type DeserializePayloadReturn<EL extends ExtendedLiteral> =
ExtendedLiteralToPayloadType<EL> |
[EL & PayloadLiteral, PayloadLiteralToPayloadType<EL & PayloadLiteral>];
| ExtendedLiteralToPayloadType<EL>
| [EL & PayloadLiteral, PayloadLiteralToPayloadType<EL & PayloadLiteral>];

export function deserializePayload<EL extends ExtendedLiteral>(
payloadLiteral: EL,
Expand All @@ -242,24 +270,30 @@ export function deserializePayload<PL extends PayloadLiteral>(
): [PL, PayloadLiteralToPayloadType<PL>];

export function deserializePayload<EL extends ExtendedLiteral>(
payloadDet: EL | ((data: Uint8Array | string) => (EL & PayloadLiteral) | null),
payloadDet:
| EL
| ((data: Uint8Array | string) => (EL & PayloadLiteral) | null),
data: Uint8Array | string,
offset = 0,
): DeserializePayloadReturn<EL> {
//grouped together to have just a single cast on the return type
return (() => {
if (typeof data === "string") data = hexByteStringToUint8Array(data);

if (payloadDet === "Uint8Array")
return data.slice(offset);
if (payloadDet === "Uint8Array") return data.slice(offset);

if (typeof payloadDet === "string")
return deserializeLayout(getPayloadLayout(payloadDet), data, offset);

const candidate = payloadDet(data);
if (candidate === null)
throw new Error(`Encoded data does not match any of the given payload types - ${data}`);
throw new Error(
`Encoded data does not match any of the given payload types - ${data}`,
);

return [candidate, deserializeLayout(getPayloadLayout(candidate), data, offset)];
return [
candidate,
deserializeLayout(getPayloadLayout(candidate), data, offset),
];
})() as DeserializePayloadReturn<EL>;
}
7 changes: 4 additions & 3 deletions platforms/evm/src/protocols/circleBridge.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,16 @@ export class EvmCircleBridge implements CircleBridge<'Evm'> {

const [messageLog] = messageLogs;
const { message } = messageLog.args;
const [header, body, hash] = deserializeCircleMessage(
const [circleMsg, hash] = deserializeCircleMessage(
hexByteStringToUint8Array(message),
);
const { payload: body } = circleMsg;

const xferSender = body.messageSender;
const xferReceiver = body.mintRecipient;

const sendChain = toCircleChainName(header.sourceDomain);
const rcvChain = toCircleChainName(header.destinationDomain);
const sendChain = toCircleChainName(circleMsg.sourceDomain);
const rcvChain = toCircleChainName(circleMsg.destinationDomain);

const token = nativeChainAddress([sendChain, body.burnToken]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ export function getUpgradeContractAccounts(
vaa: VAA<'CoreBridgeUpgradeContract'>,
spill?: PublicKeyInitData,
): UpgradeContractAccounts {
const {newContract} = vaa.payload;
const { newContract } = vaa.payload;

return {
payer: new PublicKey(payer),
Expand Down

0 comments on commit 3d76ba0

Please sign in to comment.