diff --git a/core/definitions/src/platform.ts b/core/definitions/src/platform.ts index 93d7ffcb0..40ae8bfbf 100644 --- a/core/definitions/src/platform.ts +++ b/core/definitions/src/platform.ts @@ -7,7 +7,8 @@ import { NativeAddress } from "./address"; import { WormholeMessageId } from "./attestation"; import { ChainContext } from "./chain"; import { RpcConnection } from "./rpc"; -import { ChainsConfig, SignedTx, TokenId, TxHash } from "./types"; +import { AnyAddress, Balances, ChainsConfig, TokenId, TxHash } from "./types"; +import { SignedTx } from "./types"; import { UniversalAddress } from "./universalAddress"; export interface PlatformUtils

{ @@ -21,14 +22,20 @@ export interface PlatformUtils

{ getDecimals( chain: ChainName, rpc: RpcConnection

, - token: NativeAddress

| UniversalAddress | "native", + token: AnyAddress, ): Promise; getBalance( chain: ChainName, rpc: RpcConnection

, walletAddr: string, - token: NativeAddress

| UniversalAddress | "native", + token: AnyAddress, ): Promise; + getBalances( + chain: ChainName, + rpc: RpcConnection

, + walletAddress: string, + tokens: AnyAddress[], + ): Promise; getCurrentBlock(rpc: RpcConnection

): Promise; // Platform interaction utils diff --git a/core/definitions/src/protocols/cctp.ts b/core/definitions/src/protocols/cctp.ts index 5705077ec..f1e265cb2 100644 --- a/core/definitions/src/protocols/cctp.ts +++ b/core/definitions/src/protocols/cctp.ts @@ -5,12 +5,12 @@ import { deserializeLayout, uint8ArrayToHexByteString, } from "@wormhole-foundation/sdk-base"; -import { ChainAddress, UniversalOrNative } from "../address"; +import { ChainAddress } from "../address"; import { CircleMessageId } from "../attestation"; import { universalAddressItem } from "../layout-items"; import "../payloads/connect"; import { RpcConnection } from "../rpc"; -import { TokenId } from "../types"; +import { AnyAddress, TokenId } from "../types"; import { UnsignedTransaction } from "../unsignedTransaction"; import { keccak256 } from "../utils"; @@ -55,7 +55,7 @@ export type CircleTransferMessage = { export interface AutomaticCircleBridge

{ transfer( token: ChainAddress, - sender: UniversalOrNative

, + sender: AnyAddress, recipient: ChainAddress, amount: bigint, nativeGas?: bigint, @@ -66,13 +66,13 @@ export interface AutomaticCircleBridge

{ // https://github.com/circlefin/evm-cctp-contracts export interface CircleBridge

{ redeem( - sender: UniversalOrNative

, + sender: AnyAddress, message: string, attestation: string, ): AsyncGenerator; transfer( token: ChainAddress, - sender: UniversalOrNative

, + sender: AnyAddress, recipient: ChainAddress, amount: bigint, ): AsyncGenerator; diff --git a/core/definitions/src/protocols/core.ts b/core/definitions/src/protocols/core.ts index 0ca8a74e4..31f5479da 100644 --- a/core/definitions/src/protocols/core.ts +++ b/core/definitions/src/protocols/core.ts @@ -1,5 +1,5 @@ import { PlatformName } from "@wormhole-foundation/sdk-base"; -import { UniversalOrNative } from "../address"; +import { AnyAddress } from "../types"; import { UnsignedTransaction } from "../unsignedTransaction"; import { RpcConnection } from "../rpc"; @@ -15,7 +15,7 @@ export function supportsWormholeCore

( export interface WormholeCore

{ publishMessage( - sender: UniversalOrNative

, + sender: AnyAddress, message: string | Uint8Array ): AsyncGenerator; // TODO: parseTransactionDetails diff --git a/core/definitions/src/protocols/ibc.ts b/core/definitions/src/protocols/ibc.ts index 7f5fa7666..c7e93120d 100644 --- a/core/definitions/src/protocols/ibc.ts +++ b/core/definitions/src/protocols/ibc.ts @@ -4,10 +4,10 @@ import { PlatformName, toChainId, } from "@wormhole-foundation/sdk-base"; -import { ChainAddress, NativeAddress, UniversalOrNative } from "../address"; +import { ChainAddress, NativeAddress } from "../address"; import { IbcMessageId, WormholeMessageId } from "../attestation"; import { RpcConnection } from "../rpc"; -import { TokenId, TxHash } from "../types"; +import { AnyAddress, TokenId, TxHash } from "../types"; import { UnsignedTransaction } from "../unsignedTransaction"; // Configuration for a transfer through the Gateway @@ -209,9 +209,9 @@ export function supportsIbcBridge

( export interface IbcBridge

{ //alternative naming: initiateTransfer transfer( - sender: UniversalOrNative

, + sender: AnyAddress, recipient: ChainAddress, - token: UniversalOrNative

| "native", + token: AnyAddress, amount: bigint, payload?: Uint8Array, ): AsyncGenerator; diff --git a/core/definitions/src/protocols/tokenBridge.ts b/core/definitions/src/protocols/tokenBridge.ts index aaba7866a..31bc9480a 100644 --- a/core/definitions/src/protocols/tokenBridge.ts +++ b/core/definitions/src/protocols/tokenBridge.ts @@ -1,6 +1,6 @@ import { PlatformName } from "@wormhole-foundation/sdk-base"; -import { UniversalOrNative, NativeAddress, ChainAddress } from "../address"; -import { TokenId } from "../types"; +import { NativeAddress, ChainAddress } from "../address"; +import { AnyAddress, TokenId } from "../types"; import { VAA } from "../vaa"; import { UnsignedTransaction } from "../unsignedTransaction"; import "../payloads/tokenBridge"; @@ -36,9 +36,9 @@ export function supportsAutomaticTokenBridge

( export interface TokenBridge

{ // checks a native address to see if its a wrapped version - isWrappedAsset(nativeAddress: UniversalOrNative

): Promise; + isWrappedAsset(nativeAddress: AnyAddress): Promise; // returns the original asset with its foreign chain - getOriginalAsset(nativeAddress: UniversalOrNative

): Promise; + getOriginalAsset(nativeAddress: AnyAddress): Promise; // returns the wrapped version of the native asset getWrappedNative(): Promise>; @@ -53,24 +53,24 @@ export interface TokenBridge

{ ): Promise; //signer required: createAttestation( - token_to_attest: UniversalOrNative

, - payer?: UniversalOrNative

+ token_to_attest: AnyAddress, + payer?: AnyAddress ): AsyncGenerator; submitAttestation( vaa: VAA<"AttestMeta">, - payer?: UniversalOrNative

+ payer?: AnyAddress ): AsyncGenerator; //alternative naming: initiateTransfer transfer( - sender: UniversalOrNative

, + sender: AnyAddress, recipient: ChainAddress, - token: UniversalOrNative

| "native", + token: AnyAddress, amount: bigint, payload?: Uint8Array ): AsyncGenerator; //alternative naming: completeTransfer redeem( - sender: UniversalOrNative

, + sender: AnyAddress, vaa: VAA<"Transfer"> | VAA<"TransferWithPayload">, unwrapNative?: boolean //default: true ): AsyncGenerator; @@ -79,15 +79,15 @@ export interface TokenBridge

{ export interface AutomaticTokenBridge

{ transfer( - sender: UniversalOrNative

, + sender: AnyAddress, recipient: ChainAddress, - token: UniversalOrNative

| "native", + token: AnyAddress, amount: bigint, relayerFee: bigint, nativeGas?: bigint ): AsyncGenerator; redeem( - sender: UniversalOrNative

, + sender: AnyAddress, vaa: VAA<"TransferWithPayload"> ): AsyncGenerator; getRelayerFee( diff --git a/core/definitions/src/relayer.ts b/core/definitions/src/relayer.ts index 626733801..e34498346 100644 --- a/core/definitions/src/relayer.ts +++ b/core/definitions/src/relayer.ts @@ -1,19 +1,19 @@ import { Chain, PlatformName } from "@wormhole-foundation/sdk-base"; -import { UniversalOrNative } from "./address"; +import { AnyAddress } from "./types"; export interface Relayer

{ relaySupported(chain: Chain): boolean; getRelayerFee( sourceChain: Chain, destChain: Chain, - tokenId: UniversalOrNative

, + tokenId: AnyAddress, ): Promise; // TODO: What should this be named? // I don't think it should return an UnisgnedTransaction // rather it should take some signing callbacks and // a ref to track the progress startTransferWithRelay( - token: UniversalOrNative

| "native", + token: AnyAddress, amount: bigint, toNativeToken: string, sendingChain: Chain, @@ -24,13 +24,13 @@ export interface Relayer

{ ): Promise; calculateNativeTokenAmt( destChain: Chain, - tokenId: UniversalOrNative

, + tokenId: AnyAddress, amount: bigint, walletAddress: string, ): Promise; calculateMaxSwapAmount( destChain: Chain, - tokenId: UniversalOrNative

, + tokenId: AnyAddress, walletAddress: string, ): Promise; } diff --git a/core/definitions/src/testing/mocks/platform.ts b/core/definitions/src/testing/mocks/platform.ts index 51d8ce5f8..5ceb10cab 100644 --- a/core/definitions/src/testing/mocks/platform.ts +++ b/core/definitions/src/testing/mocks/platform.ts @@ -19,6 +19,8 @@ import { nativeIsRegistered, NativeAddress, UniversalAddress, + AnyAddress, + Balances, } from "../.."; import { MockRpc } from "./rpc"; import { MockChain } from "./chain"; @@ -84,6 +86,15 @@ export class MockPlatform

implements Platform

{ throw new Error("Method not implemented."); } + getBalances( + chain: ChainName, + rpc: RpcConnection, + walletAddress: string, + tokens: AnyAddress[], + ): Promise { + throw new Error("method not implemented"); + } + getChain(chain: ChainName): ChainContext

{ if (chain in this.conf) return new MockChain

(this.conf[chain]!); throw new Error("No configuration available for chain: " + chain); diff --git a/core/definitions/src/testing/mocks/tokenBridge.ts b/core/definitions/src/testing/mocks/tokenBridge.ts index 7bc27d385..bdc3b0909 100644 --- a/core/definitions/src/testing/mocks/tokenBridge.ts +++ b/core/definitions/src/testing/mocks/tokenBridge.ts @@ -1,11 +1,10 @@ import { PlatformName } from "@wormhole-foundation/sdk-base"; import { + AnyAddress, ChainAddress, NativeAddress, RpcConnection, TokenBridge, - TokenId, - UniversalOrNative, UnsignedTransaction, VAA, } from "../.."; @@ -19,10 +18,10 @@ import { export class MockTokenBridge

implements TokenBridge

{ constructor(readonly rpc: RpcConnection

) {} - isWrappedAsset(token: UniversalOrNative

): Promise { + isWrappedAsset(token: AnyAddress): Promise { throw new Error("Method not implemented."); } - getOriginalAsset(token: UniversalOrNative

): Promise { + getOriginalAsset(token: AnyAddress): Promise { throw new Error("Method not implemented."); } hasWrappedAsset(original: ChainAddress): Promise { @@ -36,9 +35,7 @@ export class MockTokenBridge

implements TokenBridge

{ ): Promise { throw new Error("Method not implemented."); } - createAttestation( - address: UniversalOrNative

, - ): AsyncGenerator { + createAttestation(address: AnyAddress): AsyncGenerator { throw new Error("Method not implemented."); } submitAttestation( @@ -47,16 +44,16 @@ export class MockTokenBridge

implements TokenBridge

{ throw new Error("Method not implemented."); } transfer( - sender: UniversalOrNative

, + sender: AnyAddress, recipient: ChainAddress, - token: "native" | UniversalOrNative

, + token: "native" | AnyAddress, amount: bigint, payload?: Uint8Array | undefined, ): AsyncGenerator { throw new Error("Method not implemented."); } redeem( - sender: UniversalOrNative

, + sender: AnyAddress, vaa: VAA<"Transfer"> | VAA<"TransferWithPayload">, unwrapNative?: boolean | undefined, ): AsyncGenerator { diff --git a/core/definitions/src/types.ts b/core/definitions/src/types.ts index 673b57bbd..acc47ad70 100644 --- a/core/definitions/src/types.ts +++ b/core/definitions/src/types.ts @@ -15,6 +15,14 @@ export type SequenceId = bigint; export type SignedTx = any; +export type AnyAddress = + | NativeAddress + | UniversalAddress + | string + | number + | Uint8Array + | number[]; + export type TokenId = ChainAddress; export function isTokenId(thing: TokenId | any): thing is TokenId { return ( @@ -23,6 +31,10 @@ export function isTokenId(thing: TokenId | any): thing is TokenId { ); } +export type Balances = { + [key: string]: BigInt | null; +}; + export interface Signer { chain(): ChainName; address(): string; diff --git a/core/definitions/src/universalAddress.ts b/core/definitions/src/universalAddress.ts index bea3bbd25..c4188ab3a 100644 --- a/core/definitions/src/universalAddress.ts +++ b/core/definitions/src/universalAddress.ts @@ -8,6 +8,7 @@ import { Address, NativeAddress, toNative } from "./address"; export class UniversalAddress implements Address { static readonly byteSize = 32; + private readonly type = "Universal"; private readonly address: Uint8Array; @@ -45,7 +46,7 @@ export class UniversalAddress implements Address { } equals(other: UniversalAddress): boolean { - if (other instanceof UniversalAddress) { + if (UniversalAddress.instanceof(other)) { return other.toString() === this.toString(); } return false; @@ -54,4 +55,8 @@ export class UniversalAddress implements Address { static isValidAddress(address: string) { return isHexByteString(address, UniversalAddress.byteSize); } + + static instanceof(address: any) { + return address.type === "Universal"; + } } diff --git a/platforms/cosmwasm/__tests__/unit/platform.test.ts b/platforms/cosmwasm/__tests__/unit/platform.test.ts index 7bd5d9715..f0f78257c 100644 --- a/platforms/cosmwasm/__tests__/unit/platform.test.ts +++ b/platforms/cosmwasm/__tests__/unit/platform.test.ts @@ -2,8 +2,6 @@ import { expect, test } from "@jest/globals"; import { chains, chainConfigs, - DEFAULT_NETWORK, - chainToPlatform, } from "@wormhole-foundation/connect-sdk"; import { CosmwasmPlatform } from "../../src/platform"; diff --git a/platforms/cosmwasm/src/address.ts b/platforms/cosmwasm/src/address.ts index d33b91979..20717f95c 100644 --- a/platforms/cosmwasm/src/address.ts +++ b/platforms/cosmwasm/src/address.ts @@ -6,6 +6,7 @@ import { } from "@wormhole-foundation/connect-sdk"; import { CosmwasmPlatform } from "./platform"; import { nativeDenomToChain } from "./constants"; +import { AnyCosmwasmAddress } from "./types"; declare global { namespace Wormhole { @@ -102,6 +103,7 @@ function tryDecode(data: string): { data: Uint8Array; prefix?: string } { export class CosmwasmAddress implements Address { static readonly contractAddressByteSize = 32; static readonly accountAddressByteSize = 20; + public readonly platform = CosmwasmPlatform.platform; // the actual bytes of the address private readonly address: Uint8Array; @@ -113,7 +115,15 @@ export class CosmwasmAddress implements Address { // The denomType is "native", "ibc", or "factory" private readonly denomType?: string; - constructor(address: string | Uint8Array | UniversalAddress) { + constructor(address: AnyCosmwasmAddress) { + if (CosmwasmAddress.instanceof(address)) { + const a = address as unknown as CosmwasmAddress; + this.address = a.address; + this.domain = a.domain; + this.denom = a.denom; + this.denomType = a.denomType; + return; + } if (typeof address === "string") { // A native denom like "uatom" if (nativeDenomToChain.has(CosmwasmPlatform.network, address)) { @@ -152,7 +162,7 @@ export class CosmwasmAddress implements Address { } } else if (address instanceof Uint8Array) { this.address = address; - } else if (address instanceof UniversalAddress) { + } else if (UniversalAddress.instanceof(address)) { this.address = address.toUint8Array(); } else throw new Error(`Invalid Cosmwasm address ${address}`); @@ -224,6 +234,10 @@ export class CosmwasmAddress implements Address { return true; } + static instanceof(address: any) { + return address.platform === CosmwasmPlatform.platform; + } + equals(other: UniversalAddress): boolean { return other.equals(this.toUniversalAddress()); } diff --git a/platforms/cosmwasm/src/platform.ts b/platforms/cosmwasm/src/platform.ts index 9b3d7b37f..44e10285a 100644 --- a/platforms/cosmwasm/src/platform.ts +++ b/platforms/cosmwasm/src/platform.ts @@ -1,5 +1,11 @@ import { CosmWasmClient } from "@cosmjs/cosmwasm-stargate"; -import { IbcExtension, QueryClient, setupIbcExtension } from "@cosmjs/stargate"; +import { + BankExtension, + IbcExtension, + QueryClient, + setupBankExtension, + setupIbcExtension, +} from "@cosmjs/stargate"; import { TendermintClient } from "@cosmjs/tendermint-rpc"; import { @@ -49,6 +55,7 @@ export module CosmwasmPlatform { isSupportedChain, getDecimals, getBalance, + getBalances, sendWait, getCurrentBlock, chainFromRpc, @@ -112,10 +119,14 @@ export module CosmwasmPlatform { export const getQueryClient = ( rpc: CosmWasmClient, - ): QueryClient & IbcExtension => { + ): QueryClient & BankExtension & IbcExtension => { // @ts-ignore const tmClient: TendermintClient = rpc.getTmClient()!; - return QueryClient.withExtensions(tmClient, setupIbcExtension); + return QueryClient.withExtensions( + tmClient, + setupBankExtension, + setupIbcExtension, + ); }; // cached channels from config if available diff --git a/platforms/cosmwasm/src/platformUtils.ts b/platforms/cosmwasm/src/platformUtils.ts index 33a28cb29..6d5ae8634 100644 --- a/platforms/cosmwasm/src/platformUtils.ts +++ b/platforms/cosmwasm/src/platformUtils.ts @@ -5,11 +5,10 @@ import { SignedTx, Network, PlatformToChains, - WormholeMessageId, nativeDecimals, chainToPlatform, PlatformUtils, - UniversalOrNative, + Balances, } from "@wormhole-foundation/connect-sdk"; import { IBC_TRANSFER_PORT, @@ -19,7 +18,7 @@ import { import { CosmWasmClient } from "@cosmjs/cosmwasm-stargate"; import { CosmwasmPlatform } from "./platform"; import { CosmwasmAddress } from "./address"; -import { IbcExtension, QueryClient, setupIbcExtension } from "@cosmjs/stargate"; +import { AnyCosmwasmAddress } from "./types"; // forces CosmwasmUtils to implement PlatformUtils var _: PlatformUtils<"Cosmwasm"> = CosmwasmUtils; @@ -55,11 +54,12 @@ export module CosmwasmUtils { export async function getDecimals( chain: ChainName, rpc: CosmWasmClient, - token: UniversalOrNative<"Cosmwasm"> | "native", + token: AnyCosmwasmAddress | "native", ): Promise { if (token === "native") return nativeDecimals(CosmwasmPlatform.platform); - const { decimals } = await rpc.queryContractSmart(token.toString(), { + const addrStr = new CosmwasmAddress(token).toString(); + const { decimals } = await rpc.queryContractSmart(addrStr, { token_info: {}, }); return decimals; @@ -69,7 +69,7 @@ export module CosmwasmUtils { chain: ChainName, rpc: CosmWasmClient, walletAddress: string, - token: UniversalOrNative<"Cosmwasm"> | "native", + token: AnyCosmwasmAddress | "native", ): Promise { if (token === "native") { const { amount } = await rpc.getBalance( @@ -79,10 +79,32 @@ export module CosmwasmUtils { return BigInt(amount); } - const { amount } = await rpc.getBalance(walletAddress, token.toString()); + const addrStr = new CosmwasmAddress(token).toString(); + const { amount } = await rpc.getBalance(walletAddress, addrStr); return BigInt(amount); } + export async function getBalances( + chain: ChainName, + rpc: CosmWasmClient, + walletAddress: string, + tokens: (AnyCosmwasmAddress | "native")[], + ): Promise { + const client = CosmwasmPlatform.getQueryClient(rpc); + const allBalances = await client.bank.allBalances(walletAddress); + const balancesArr = tokens.map((token) => { + const address = + token === "native" + ? getNativeDenom(chain) + : new CosmwasmAddress(token).toString(); + const balance = allBalances.find((balance) => balance.denom === address); + const balanceBigInt = balance ? BigInt(balance.amount) : null; + return { [address]: balanceBigInt }; + }); + + return balancesArr.reduce((obj, item) => Object.assign(obj, item), {}); + } + function getNativeDenom(chain: ChainName): string { // TODO: required because of const map if (CosmwasmPlatform.network === "Devnet") @@ -135,19 +157,11 @@ export module CosmwasmUtils { sourceChannel: string, rpc: CosmWasmClient, ): Promise { - const queryClient = asQueryClient(rpc); + const queryClient = CosmwasmPlatform.getQueryClient(rpc); const conn = await queryClient.ibc.channel.channel( IBC_TRANSFER_PORT, sourceChannel, ); return conn.channel?.counterparty?.channelId ?? null; } - - export const asQueryClient = ( - rpc: CosmWasmClient, - ): QueryClient & IbcExtension => { - // @ts-ignore - const tmClient: TendermintClient = rpc.getTmClient()!; - return QueryClient.withExtensions(tmClient, setupIbcExtension); - }; } diff --git a/platforms/cosmwasm/src/protocols/ibc.ts b/platforms/cosmwasm/src/protocols/ibc.ts index 77a417904..a45cfada7 100644 --- a/platforms/cosmwasm/src/protocols/ibc.ts +++ b/platforms/cosmwasm/src/protocols/ibc.ts @@ -37,13 +37,13 @@ import { import { CosmwasmContracts } from "../contracts"; import { Gateway } from "../gateway"; import { CosmwasmPlatform } from "../platform"; -import { CosmwasmUtils } from "../platformUtils"; -import { CosmwasmChainName, UniversalOrCosmwasm } from "../types"; +import { CosmwasmChainName, AnyCosmwasmAddress } from "../types"; import { CosmwasmTransaction, CosmwasmUnsignedTransaction, computeFee, } from "../unsignedTransaction"; +import { CosmwasmAddress } from "../address"; const millisToNano = (seconds: number) => seconds * 1_000_000; @@ -87,12 +87,12 @@ export class CosmwasmIbcBridge implements IbcBridge<"Cosmwasm"> { } async *transfer( - sender: UniversalOrCosmwasm, + sender: AnyCosmwasmAddress, recipient: ChainAddress, - token: UniversalOrCosmwasm | "native", + token: AnyCosmwasmAddress | "native", amount: bigint, ): AsyncGenerator { - const senderAddress = sender.toString(); + const senderAddress = new CosmwasmAddress(sender).toString(); const nonce = Math.round(Math.random() * 10000); // TODO: needs heavy testing @@ -124,7 +124,13 @@ export class CosmwasmIbcBridge implements IbcBridge<"Cosmwasm"> { const timeout = millisToNano(Date.now() + IBC_TIMEOUT_MILLIS); const memo = JSON.stringify(payload); - const ibcDenom = Gateway.deriveIbcDenom(this.chain, token.toString()); + const ibcDenom = + token === "native" + ? CosmwasmPlatform.getNativeDenom(this.chain) + : Gateway.deriveIbcDenom( + this.chain, + new CosmwasmAddress(token).toString(), + ); const ibcToken = coin(amount.toString(), ibcDenom.toString()); const ibcMessage: MsgTransferEncodeObject = { @@ -332,7 +338,7 @@ export class CosmwasmIbcBridge implements IbcBridge<"Cosmwasm"> { for (const xfer of xfers) { // If its present in the commitment results, its interpreted as in-flight // the client throws an error and we report any error as not in-flight - const qc = CosmwasmUtils.asQueryClient(this.rpc); + const qc = CosmwasmPlatform.getQueryClient(this.rpc); try { await qc.ibc.channel.packetCommitment( IBC_TRANSFER_PORT, diff --git a/platforms/cosmwasm/src/protocols/tokenBridge.ts b/platforms/cosmwasm/src/protocols/tokenBridge.ts index 1dce6897e..0f7f3b603 100644 --- a/platforms/cosmwasm/src/protocols/tokenBridge.ts +++ b/platforms/cosmwasm/src/protocols/tokenBridge.ts @@ -24,12 +24,12 @@ import { } from "../unsignedTransaction"; import { CosmwasmContracts } from "../contracts"; import { + AnyCosmwasmAddress, CosmwasmChainName, - UniversalOrCosmwasm, WrappedRegistryResponse, - toCosmwasmAddrString, } from "../types"; import { CosmwasmPlatform } from "../platform"; +import { CosmwasmAddress } from "../address"; export class CosmwasmTokenBridge implements TokenBridge<"Cosmwasm"> { private tokenBridge: string; @@ -54,7 +54,7 @@ export class CosmwasmTokenBridge implements TokenBridge<"Cosmwasm"> { return new CosmwasmTokenBridge(network, chain, rpc, contracts); } - async isWrappedAsset(token: UniversalOrCosmwasm): Promise { + async isWrappedAsset(token: AnyCosmwasmAddress): Promise { try { await this.getOriginalAsset(token); return true; @@ -92,8 +92,8 @@ export class CosmwasmTokenBridge implements TokenBridge<"Cosmwasm"> { return toNative(this.chain, address); } - async getOriginalAsset(token: UniversalOrCosmwasm): Promise { - const wrappedAddress = token.toString(); + async getOriginalAsset(token: AnyCosmwasmAddress): Promise { + const wrappedAddress = new CosmwasmAddress(token).toString(); const response = await this.rpc.queryContractSmart(wrappedAddress, { wrapped_asset_info: {}, @@ -119,11 +119,14 @@ export class CosmwasmTokenBridge implements TokenBridge<"Cosmwasm"> { } async *createAttestation( - token: UniversalOrCosmwasm | "native", - payer?: UniversalOrCosmwasm, + token: AnyCosmwasmAddress | "native", + payer?: AnyCosmwasmAddress, ): AsyncGenerator { if (!payer) throw new Error("Payer required to create attestation"); + const tokenStr = new CosmwasmAddress(token).toString(); + const payerStr = new CosmwasmAddress(payer).toString(); + // TODO nonce? const nonce = 0; const assetInfo = @@ -134,13 +137,13 @@ export class CosmwasmTokenBridge implements TokenBridge<"Cosmwasm"> { }, } : { - token: { contract_addr: token.toString() }, + token: { contract_addr: tokenStr }, }; yield this.createUnsignedTx( { msgs: [ - buildExecuteMsg(payer.toString(), this.tokenBridge, { + buildExecuteMsg(payerStr, this.tokenBridge, { create_asset_meta: { asset_info: assetInfo, nonce }, }), ], @@ -153,14 +156,16 @@ export class CosmwasmTokenBridge implements TokenBridge<"Cosmwasm"> { async *submitAttestation( vaa: VAA<"AttestMeta">, - payer?: UniversalOrCosmwasm, + payer?: AnyCosmwasmAddress, ): AsyncGenerator { if (!payer) throw new Error("Payer required to submit attestation"); + const payerStr = new CosmwasmAddress(payer).toString(); + yield this.createUnsignedTx( { msgs: [ - buildExecuteMsg(payer.toString(), this.tokenBridge, { + buildExecuteMsg(payerStr, this.tokenBridge, { submit_vaa: { data: serialize(vaa) }, }), ], @@ -172,9 +177,9 @@ export class CosmwasmTokenBridge implements TokenBridge<"Cosmwasm"> { } async *transfer( - sender: UniversalOrCosmwasm, + sender: AnyCosmwasmAddress, recipient: ChainAddress, - token: UniversalOrCosmwasm | "native", + token: AnyCosmwasmAddress | "native", amount: bigint, payload?: Uint8Array, ): AsyncGenerator { @@ -194,7 +199,7 @@ export class CosmwasmTokenBridge implements TokenBridge<"Cosmwasm"> { const tokenAddress = isNative ? denom : token.toString(); - const senderAddress = sender.toString(); + const senderAddress = new CosmwasmAddress(sender).toString(); const mk_initiate_transfer = (info: object) => { const common = { @@ -273,14 +278,14 @@ export class CosmwasmTokenBridge implements TokenBridge<"Cosmwasm"> { } async *redeem( - sender: UniversalOrCosmwasm, + sender: AnyCosmwasmAddress, vaa: VAA<"Transfer"> | VAA<"TransferWithPayload">, unwrapNative: boolean = true, ): AsyncGenerator { // TODO: unwrapNative const data = Buffer.from(serialize(vaa)).toString("base64"); - const senderAddress = toCosmwasmAddrString(sender); + const senderAddress = new CosmwasmAddress(sender).toString(); const toTranslator = this.translator && diff --git a/platforms/cosmwasm/src/types.ts b/platforms/cosmwasm/src/types.ts index 7cd074be9..ada3e54c0 100644 --- a/platforms/cosmwasm/src/types.ts +++ b/platforms/cosmwasm/src/types.ts @@ -1,26 +1,17 @@ import { - UniversalAddress, UniversalOrNative, PlatformToChains, - GatewayTransferMsg, } from "@wormhole-foundation/connect-sdk"; import { logs as cosmosLogs } from "@cosmjs/stargate"; export type CosmwasmChainName = PlatformToChains<"Cosmwasm">; -export type UniversalOrCosmwasm = UniversalOrNative<"Cosmwasm"> | string; +export type UniversalOrCosmwasm = UniversalOrNative<"Cosmwasm">; +export type AnyCosmwasmAddress = UniversalOrCosmwasm | string | Uint8Array; export interface WrappedRegistryResponse { address: string; } -export const toCosmwasmAddrString = (addr: UniversalOrCosmwasm) => - typeof addr === "string" - ? addr - : (addr instanceof UniversalAddress - ? addr.toNative("Cosmwasm") - : addr - ).unwrap(); - // TODO: do >1 key at a time export const searchCosmosLogs = ( key: string, diff --git a/platforms/evm/src/address.ts b/platforms/evm/src/address.ts index 5fb15ec05..f5bac5dae 100644 --- a/platforms/evm/src/address.ts +++ b/platforms/evm/src/address.ts @@ -1,6 +1,8 @@ -import { Address, UniversalAddress } from '@wormhole-foundation/connect-sdk'; +import { Address, UniversalAddress, registerNative } from '@wormhole-foundation/connect-sdk'; import { ethers } from 'ethers'; +import { AnyEvmAddress } from './types'; +import { EvmPlatform } from './platform'; declare global { namespace Wormhole { @@ -15,11 +17,17 @@ export const EvmZeroAddress = ethers.ZeroAddress; export class EvmAddress implements Address { static readonly byteSize = 20; + public readonly platform = EvmPlatform.platform; - //stored as checksum address + // stored as checksum address private readonly address: string; - constructor(address: string | Uint8Array | UniversalAddress) { + constructor(address: AnyEvmAddress) { + if (EvmAddress.instanceof(address)) { + const a = address as unknown as EvmAddress; + this.address = a.address; + return; + } if (typeof address === 'string') { if (!EvmAddress.isValidAddress(address)) throw new Error( @@ -34,7 +42,7 @@ export class EvmAddress implements Address { ); this.address = ethers.getAddress(ethers.hexlify(address)); - } else if (address instanceof UniversalAddress) { + } else if (UniversalAddress.instanceof(address)) { // If its a universal address and we want it to be an ethereum address, // we need to chop off the first 12 bytes of padding const addressBytes = address.toUint8Array(); @@ -71,7 +79,12 @@ export class EvmAddress implements Address { static isValidAddress(address: string) { return ethers.isAddress(address); } + static instanceof(address: any) { + return address.platform === EvmPlatform.platform; + } equals(other: UniversalAddress): boolean { return other.equals(this.toUniversalAddress()); } } + +registerNative('Evm', EvmAddress); \ No newline at end of file diff --git a/platforms/evm/src/index.ts b/platforms/evm/src/index.ts index 616021eec..bc96e9373 100644 --- a/platforms/evm/src/index.ts +++ b/platforms/evm/src/index.ts @@ -3,6 +3,7 @@ export * from './address'; export * from './contracts'; export * from './unsignedTransaction'; export * from './platform'; +export * from './types'; export * from './chain'; export * from './protocols/tokenBridge'; @@ -10,4 +11,3 @@ export * from './protocols/automaticTokenBridge'; export * from './protocols/circleBridge'; export * from './protocols/automaticCircleBridge'; -export * from './types'; diff --git a/platforms/evm/src/platform.ts b/platforms/evm/src/platform.ts index 99c9c8889..750aee34e 100644 --- a/platforms/evm/src/platform.ts +++ b/platforms/evm/src/platform.ts @@ -44,6 +44,7 @@ export module EvmPlatform { isSupportedChain, getDecimals, getBalance, + getBalances, sendWait, getCurrentBlock, chainFromRpc, diff --git a/platforms/evm/src/platformUtils.ts b/platforms/evm/src/platformUtils.ts index ec9018d4b..5553fb722 100644 --- a/platforms/evm/src/platformUtils.ts +++ b/platforms/evm/src/platformUtils.ts @@ -8,7 +8,7 @@ import { nativeDecimals, chainToPlatform, PlatformUtils, - UniversalOrNative, + Balances, } from '@wormhole-foundation/connect-sdk'; import { Provider } from 'ethers'; @@ -16,6 +16,7 @@ import { evmChainIdToNetworkChainPair } from './constants'; import { EvmAddress, EvmZeroAddress } from './address'; import { EvmContracts } from './contracts'; import { EvmPlatform } from './platform'; +import { AnyEvmAddress } from './types'; // forces EvmUtils to implement PlatformUtils var _: PlatformUtils<'Evm'> = EvmUtils; @@ -48,32 +49,47 @@ export module EvmUtils { export async function getDecimals( chain: ChainName, rpc: Provider, - token: UniversalOrNative<'Evm'> | 'native', + token: AnyEvmAddress | 'native', ): Promise { if (token === 'native') return nativeDecimals(EvmPlatform.platform); const tokenContract = EvmContracts.getTokenImplementation( rpc, - token.toString(), + new EvmAddress(token).toString(), ); - const decimals = await tokenContract.decimals(); - return decimals; + return tokenContract.decimals(); } export async function getBalance( chain: ChainName, rpc: Provider, walletAddr: string, - token: UniversalOrNative<'Evm'> | 'native', + token: AnyEvmAddress | 'native', ): Promise { - if (token === 'native') return await rpc.getBalance(walletAddr); + if (token === 'native') return rpc.getBalance(walletAddr); const tokenImpl = EvmContracts.getTokenImplementation( rpc, - token.toString(), + new EvmAddress(token).toString(), ); - const balance = await tokenImpl.balanceOf(walletAddr); - return balance; + return tokenImpl.balanceOf(walletAddr); + } + + export async function getBalances( + chain: ChainName, + rpc: Provider, + walletAddr: string, + tokens: (AnyEvmAddress | 'native')[], + ): Promise { + const balancesArr = await Promise.all( + tokens.map(async (token) => { + const balance = await getBalance(chain, rpc, walletAddr, token); + const address = + token === 'native' ? 'native' : new EvmAddress(token).toString(); + return { [address]: balance }; + }), + ); + return balancesArr.reduce((obj, item) => Object.assign(obj, item), {}); } export async function sendWait( diff --git a/platforms/evm/src/protocols/automaticCircleBridge.ts b/platforms/evm/src/protocols/automaticCircleBridge.ts index 17482221a..c4e43ae50 100644 --- a/platforms/evm/src/protocols/automaticCircleBridge.ts +++ b/platforms/evm/src/protocols/automaticCircleBridge.ts @@ -8,18 +8,13 @@ import { import { evmNetworkChainToEvmChainId } from '../constants'; -import { - EvmChainName, - UniversalOrEvm, - addChainId, - addFrom, - toEvmAddrString, -} from '../types'; +import { AnyEvmAddress, EvmChainName, addChainId, addFrom } from '../types'; import { EvmUnsignedTransaction } from '../unsignedTransaction'; import { CircleRelayer } from '../ethers-contracts'; import { Provider, TransactionRequest } from 'ethers'; import { EvmContracts } from '../contracts'; import { EvmPlatform } from '../platform'; +import { EvmAddress } from '../address'; export class EvmAutomaticCircleBridge implements AutomaticCircleBridge<'Evm'> { readonly circleRelayer: CircleRelayer; @@ -53,19 +48,21 @@ export class EvmAutomaticCircleBridge implements AutomaticCircleBridge<'Evm'> { async *transfer( token: TokenId, - sender: UniversalOrEvm, + sender: AnyEvmAddress, recipient: ChainAddress, amount: bigint, nativeGas?: bigint, ): AsyncGenerator { - const senderAddr = toEvmAddrString(sender); + const senderAddr = new EvmAddress(sender).toString(); const recipientChainId = chainToChainId(recipient.chain); const recipientAddress = recipient.address .toUniversalAddress() .toUint8Array(); const nativeTokenGas = nativeGas ? nativeGas : 0n; - const tokenAddr = toEvmAddrString(token.address.toUniversalAddress()); + const tokenAddr = new EvmAddress( + token.address.toUniversalAddress(), + ).toString(); const tokenContract = EvmContracts.getTokenImplementation( this.provider, diff --git a/platforms/evm/src/protocols/automaticTokenBridge.ts b/platforms/evm/src/protocols/automaticTokenBridge.ts index 481509d5f..966cd5ca3 100644 --- a/platforms/evm/src/protocols/automaticTokenBridge.ts +++ b/platforms/evm/src/protocols/automaticTokenBridge.ts @@ -1,6 +1,5 @@ import { ChainAddress, - UniversalOrNative, AutomaticTokenBridge, VAA, serialize, @@ -13,17 +12,12 @@ import { import { Provider, TransactionRequest } from 'ethers'; import { evmNetworkChainToEvmChainId } from '../constants'; -import { - EvmChainName, - UniversalOrEvm, - addChainId, - addFrom, - toEvmAddrString, -} from '../types'; +import { AnyEvmAddress, EvmChainName, addChainId, addFrom } from '../types'; import { EvmUnsignedTransaction } from '../unsignedTransaction'; import { TokenBridgeRelayer } from '../ethers-contracts'; import { EvmContracts } from '../contracts'; import { EvmPlatform } from '../platform'; +import { EvmAddress } from '../address'; export class EvmAutomaticTokenBridge implements AutomaticTokenBridge<'Evm'> { readonly tokenBridgeRelayer: TokenBridgeRelayer; @@ -47,10 +41,10 @@ export class EvmAutomaticTokenBridge implements AutomaticTokenBridge<'Evm'> { ); } async *redeem( - sender: UniversalOrNative<'Evm'>, + sender: AnyEvmAddress, vaa: VAA<'TransferWithPayload'>, ): AsyncGenerator { - const senderAddr = toEvmAddrString(sender); + const senderAddr = new EvmAddress(sender).toString(); const txReq = await this.tokenBridgeRelayer.completeTransferWithRelay.populateTransaction( serialize(vaa), @@ -72,14 +66,14 @@ export class EvmAutomaticTokenBridge implements AutomaticTokenBridge<'Evm'> { //alternative naming: initiateTransfer async *transfer( - sender: UniversalOrEvm, + sender: AnyEvmAddress, recipient: ChainAddress, - token: UniversalOrEvm | 'native', + token: AnyEvmAddress | 'native', amount: bigint, relayerFee: bigint, nativeGas?: bigint, ): AsyncGenerator { - const senderAddr = toEvmAddrString(sender); + const senderAddr = new EvmAddress(sender).toString(); const recipientChainId = chainToChainId(recipient.chain); const recipientAddress = recipient.address .toUniversalAddress() @@ -101,7 +95,7 @@ export class EvmAutomaticTokenBridge implements AutomaticTokenBridge<'Evm'> { ); } else { //TODO check for ERC-2612 (permit) support on token? - const tokenAddr = toEvmAddrString(token); + const tokenAddr = new EvmAddress(token).toString(); // TODO: allowance? const txReq = @@ -132,7 +126,9 @@ export class EvmAutomaticTokenBridge implements AutomaticTokenBridge<'Evm'> { : token; const destChainId = toChainId(recipient.chain); - const destTokenAddress = toEvmAddrString(tokenId.address.toString()); + const destTokenAddress = new EvmAddress( + tokenId.address.toString(), + ).toString(); const tokenContract = EvmContracts.getTokenImplementation( this.provider, diff --git a/platforms/evm/src/protocols/circleBridge.ts b/platforms/evm/src/protocols/circleBridge.ts index 7ffefb94f..9548d4ad5 100644 --- a/platforms/evm/src/protocols/circleBridge.ts +++ b/platforms/evm/src/protocols/circleBridge.ts @@ -17,8 +17,8 @@ import { evmNetworkChainToEvmChainId } from '../constants'; import { addFrom, addChainId, - toEvmAddrString, EvmChainName, + AnyEvmAddress, UniversalOrEvm, } from '../types'; import { EvmUnsignedTransaction } from '../unsignedTransaction'; @@ -26,6 +26,7 @@ import { MessageTransmitter, TokenMessenger } from '../ethers-contracts'; import { LogDescription, Provider, TransactionRequest } from 'ethers'; import { EvmContracts } from '../contracts'; import { EvmPlatform } from '../platform'; +import { EvmAddress } from '../address'; //https://github.com/circlefin/evm-cctp-contracts @@ -77,11 +78,11 @@ export class EvmCircleBridge implements CircleBridge<'Evm'> { } async *redeem( - sender: UniversalOrEvm, + sender: AnyEvmAddress, message: string, attestation: string, ): AsyncGenerator { - const senderAddr = toEvmAddrString(sender); + const senderAddr = new EvmAddress(sender).toString(); const txReq = await this.msgTransmitter.receiveMessage.populateTransaction( hexByteStringToUint8Array(message), @@ -96,15 +97,17 @@ export class EvmCircleBridge implements CircleBridge<'Evm'> { //alternative naming: initiateTransfer async *transfer( token: TokenId, - sender: UniversalOrEvm, + sender: AnyEvmAddress, recipient: ChainAddress, amount: bigint, ): AsyncGenerator { - const senderAddr = toEvmAddrString(sender); + const senderAddr = new EvmAddress(sender).toString(); const recipientAddress = recipient.address .toUniversalAddress() .toUint8Array(); - const tokenAddr = toEvmAddrString(token.address as UniversalOrEvm); + const tokenAddr = new EvmAddress( + token.address as UniversalOrEvm, + ).toString(); const tokenContract = EvmContracts.getTokenImplementation( this.provider, diff --git a/platforms/evm/src/protocols/tokenBridge.ts b/platforms/evm/src/protocols/tokenBridge.ts index da8bd27ca..672a7a2f2 100644 --- a/platforms/evm/src/protocols/tokenBridge.ts +++ b/platforms/evm/src/protocols/tokenBridge.ts @@ -29,14 +29,13 @@ import { EvmUnsignedTransaction } from '../unsignedTransaction'; import { EvmContracts } from '../contracts'; import { EvmChainName, - UniversalOrEvm, addFrom, addChainId, - toEvmAddrString, unusedArbiterFee, unusedNonce, + AnyEvmAddress, } from '../types'; -import { EvmZeroAddress } from '../address'; +import { EvmAddress, EvmZeroAddress } from '../address'; import { EvmPlatform } from '../platform'; //Currently the code does not consider Wormhole msg fee (because it is and always has been 0). @@ -67,16 +66,18 @@ export class EvmTokenBridge implements TokenBridge<'Evm'> { return new EvmTokenBridge(network, chain, provider, contracts); } - async isWrappedAsset(token: UniversalOrEvm): Promise { - return await this.tokenBridge.isWrappedAsset(toEvmAddrString(token)); + async isWrappedAsset(token: AnyEvmAddress): Promise { + return await this.tokenBridge.isWrappedAsset( + new EvmAddress(token).toString(), + ); } - async getOriginalAsset(token: UniversalOrEvm): Promise { + async getOriginalAsset(token: AnyEvmAddress): Promise { if (!(await this.isWrappedAsset(token))) throw ErrNotWrapped(token.toString()); const tokenContract = TokenContractFactory.connect( - toEvmAddrString(token), + new EvmAddress(token).toString(), this.provider, ); const [chain, address] = await Promise.all([ @@ -127,12 +128,12 @@ export class EvmTokenBridge implements TokenBridge<'Evm'> { } async *createAttestation( - token: UniversalOrEvm, + token: AnyEvmAddress, ): AsyncGenerator { const ignoredNonce = 0; yield this.createUnsignedTx( await this.tokenBridge.attestToken.populateTransaction( - toEvmAddrString(token), + new EvmAddress(token).toString(), ignoredNonce, ), 'TokenBridge.createAttestation', @@ -155,13 +156,13 @@ export class EvmTokenBridge implements TokenBridge<'Evm'> { //alternative naming: initiateTransfer async *transfer( - sender: UniversalOrEvm, + sender: AnyEvmAddress, recipient: ChainAddress, - token: UniversalOrEvm | 'native', + token: AnyEvmAddress | 'native', amount: bigint, payload?: Uint8Array, ): AsyncGenerator { - const senderAddr = toEvmAddrString(sender); + const senderAddr = new EvmAddress(sender).toString(); const recipientChainId = toChainId(recipient.chain); const recipientAddress = recipient.address .toUniversalAddress() @@ -189,7 +190,7 @@ export class EvmTokenBridge implements TokenBridge<'Evm'> { ); } else { //TODO check for ERC-2612 (permit) support on token? - const tokenAddr = toEvmAddrString(token); + const tokenAddr = new EvmAddress(token).toString(); const tokenContract = TokenContractFactory.connect( tokenAddr, this.provider, @@ -235,11 +236,11 @@ export class EvmTokenBridge implements TokenBridge<'Evm'> { //alternative naming: completeTransfer async *redeem( - sender: UniversalOrEvm, + sender: AnyEvmAddress, vaa: VAA<'Transfer'> | VAA<'TransferWithPayload'>, unwrapNative: boolean = true, ): AsyncGenerator { - const senderAddr = toEvmAddrString(sender); + const senderAddr = new EvmAddress(sender).toString(); if (vaa.payload.token.chain !== this.chain) if (vaa.payloadLiteral === 'TransferWithPayload') { const fromAddr = toNative(this.chain, vaa.payload.from).unwrap(); diff --git a/platforms/evm/src/protocols/wormholeCore.ts b/platforms/evm/src/protocols/wormholeCore.ts index 147911289..ec54314f6 100644 --- a/platforms/evm/src/protocols/wormholeCore.ts +++ b/platforms/evm/src/protocols/wormholeCore.ts @@ -8,14 +8,9 @@ import { } from '../constants'; import { EvmUnsignedTransaction } from '../unsignedTransaction'; import { EvmContracts } from '../contracts'; -import { - EvmChainName, - UniversalOrEvm, - addChainId, - addFrom, - toEvmAddrString, -} from '../types'; +import { AnyEvmAddress, EvmChainName, addChainId, addFrom } from '../types'; import { EvmPlatform } from '../platform'; +import { EvmAddress } from '../address'; export class EvmWormholeCore implements WormholeCore<'Evm'> { readonly chainId: bigint; @@ -44,10 +39,10 @@ export class EvmWormholeCore implements WormholeCore<'Evm'> { } async *publishMessage( - sender: UniversalOrEvm, + sender: AnyEvmAddress, message: Uint8Array | string, ): AsyncGenerator { - const senderAddr = toEvmAddrString(sender); + const senderAddr = new EvmAddress(sender).toString(); const txReq = await this.core.publishMessage.populateTransaction( 0, diff --git a/platforms/evm/src/types.ts b/platforms/evm/src/types.ts index 26d641e0a..ebb04bfff 100644 --- a/platforms/evm/src/types.ts +++ b/platforms/evm/src/types.ts @@ -1,25 +1,15 @@ import { - UniversalAddress, UniversalOrNative, - registerNative, PlatformToChains, } from '@wormhole-foundation/connect-sdk'; import { TransactionRequest } from 'ethers'; -import { EvmAddress } from './address'; - -registerNative('Evm', EvmAddress); - export const unusedNonce = 0; export const unusedArbiterFee = 0n; export type EvmChainName = PlatformToChains<'Evm'>; -export type UniversalOrEvm = UniversalOrNative<'Evm'> | string; - -export const toEvmAddrString = (addr: UniversalOrEvm) => - typeof addr === 'string' - ? addr - : (addr instanceof UniversalAddress ? addr.toNative('Evm') : addr).unwrap(); +export type UniversalOrEvm = UniversalOrNative<'Evm'>; +export type AnyEvmAddress = UniversalOrEvm | string | Uint8Array; export const addFrom = (txReq: TransactionRequest, from: string) => ({ ...txReq, diff --git a/platforms/solana/__tests__/unit/platform.test.ts b/platforms/solana/__tests__/unit/platform.test.ts index 662a6f5fc..3f3faf83d 100644 --- a/platforms/solana/__tests__/unit/platform.test.ts +++ b/platforms/solana/__tests__/unit/platform.test.ts @@ -12,8 +12,6 @@ import { import { SolanaPlatform } from '../../src'; -import { PublicKey } from '@solana/web3.js'; - // @ts-ignore -- this is the mock we import above import { getDefaultProvider } from '@solana/web3.js'; diff --git a/platforms/solana/src/address.ts b/platforms/solana/src/address.ts index 930e77532..9ac418b96 100644 --- a/platforms/solana/src/address.ts +++ b/platforms/solana/src/address.ts @@ -3,9 +3,13 @@ import { hexByteStringToUint8Array, Address, UniversalAddress, + PlatformName, + registerNative, } from '@wormhole-foundation/connect-sdk'; -import { PublicKey, PublicKeyInitData } from '@solana/web3.js'; +import { PublicKey } from '@solana/web3.js'; +import { AnySolanaAddress } from './types'; +import { SolanaPlatform } from './platform'; declare global { namespace Wormhole { @@ -21,12 +25,20 @@ export const SolanaZeroAddress = '11111111111111111111111111111111'; export class SolanaAddress implements Address { static readonly byteSize = 32; + public readonly platform: PlatformName = SolanaPlatform.platform; private readonly address: PublicKey; - constructor(address: PublicKeyInitData | UniversalAddress) { - if (address instanceof UniversalAddress) - this.address = new PublicKey(address.toUint8Array()); + constructor(address: AnySolanaAddress) { + if (SolanaAddress.instanceof(address)) { + const a = address as unknown as SolanaAddress; + this.address = a.address; + return; + } + if (UniversalAddress.instanceof(address)) + this.address = new PublicKey( + (address as UniversalAddress).toUint8Array(), + ); if (typeof address === 'string' && isHexByteString(address)) this.address = new PublicKey(hexByteStringToUint8Array(address)); else this.address = new PublicKey(address); @@ -48,7 +60,13 @@ export class SolanaAddress implements Address { return new UniversalAddress(this.address.toBytes()); } + static instanceof(address: any) { + return address.platform === SolanaPlatform.platform; + } + equals(other: UniversalAddress): boolean { return this.toUniversalAddress().equals(other); } } + +registerNative('Solana', SolanaAddress); \ No newline at end of file diff --git a/platforms/solana/src/chain.ts b/platforms/solana/src/chain.ts index 1c7666737..432b4b7cd 100644 --- a/platforms/solana/src/chain.ts +++ b/platforms/solana/src/chain.ts @@ -4,7 +4,6 @@ import { NativeAddress, UniversalAddress, UniversalOrNative, - Platform, toNative, } from '@wormhole-foundation/connect-sdk'; import { getAssociatedTokenAddress } from '@solana/spl-token'; diff --git a/platforms/solana/src/platform.ts b/platforms/solana/src/platform.ts index 417ceb499..6555d8218 100644 --- a/platforms/solana/src/platform.ts +++ b/platforms/solana/src/platform.ts @@ -37,6 +37,7 @@ export module SolanaPlatform { isSupportedChain, getDecimals, getBalance, + getBalances, sendWait, getCurrentBlock, chainFromRpc, diff --git a/platforms/solana/src/platformUtils.ts b/platforms/solana/src/platformUtils.ts index 53a9c3547..80d443113 100644 --- a/platforms/solana/src/platformUtils.ts +++ b/platforms/solana/src/platformUtils.ts @@ -9,12 +9,14 @@ import { nativeDecimals, PlatformUtils, chainToPlatform, - UniversalOrNative, + Balances, } from '@wormhole-foundation/connect-sdk'; import { Connection, ParsedAccountData, PublicKey } from '@solana/web3.js'; import { solGenesisHashToNetworkChainPair } from './constants'; import { SolanaPlatform } from './platform'; import { SolanaAddress, SolanaZeroAddress } from './address'; +import { AnySolanaAddress } from './types'; +import { TOKEN_PROGRAM_ID } from '@solana/spl-token'; // forces SolanaUtils to implement PlatformUtils var _: PlatformUtils<'Solana'> = SolanaUtils; @@ -46,12 +48,12 @@ export module SolanaUtils { export async function getDecimals( chain: ChainName, rpc: Connection, - token: UniversalOrNative<'Solana'> | 'native', + token: AnySolanaAddress | 'native', ): Promise { if (token === 'native') return nativeDecimals(SolanaPlatform.platform); let mint = await rpc.getParsedAccountInfo( - new PublicKey(token.toUint8Array()), + new SolanaAddress(token).unwrap(), ); if (!mint || !mint.value) throw new Error('could not fetch token details'); @@ -64,7 +66,7 @@ export module SolanaUtils { chain: ChainName, rpc: Connection, walletAddress: string, - token: UniversalOrNative<'Solana'> | 'native', + token: AnySolanaAddress | 'native', ): Promise { if (token === 'native') return BigInt(await rpc.getBalance(new PublicKey(walletAddress))); @@ -76,7 +78,7 @@ export module SolanaUtils { const splToken = await rpc.getTokenAccountsByOwner( new PublicKey(walletAddress), - { mint: new PublicKey(token.toUint8Array()) }, + { mint: new SolanaAddress(token).unwrap() }, ); if (!splToken.value[0]) return null; @@ -84,6 +86,39 @@ export module SolanaUtils { return BigInt(balance.value.amount); } + export async function getBalances( + chain: ChainName, + rpc: Connection, + walletAddress: string, + tokens: (AnySolanaAddress | 'native')[], + ): Promise { + let native: bigint; + if (tokens.includes('native')) { + native = BigInt(await rpc.getBalance(new PublicKey(walletAddress))); + } + + const splParsedTokenAccounts = await rpc.getParsedTokenAccountsByOwner( + new PublicKey(walletAddress), + { + programId: new PublicKey(TOKEN_PROGRAM_ID), + }, + ); + + const balancesArr = tokens.map((token) => { + if (token === 'native') { + return { ['native']: native }; + } + const addrString = new SolanaAddress(token).toString(); + const amount = splParsedTokenAccounts.value.find( + (v) => v?.account.data.parsed?.info?.mint === token, + )?.account.data.parsed?.info?.tokenAmount?.amount; + if (!amount) return { [addrString]: null }; + return { [addrString]: BigInt(amount) }; + }); + + return balancesArr.reduce((obj, item) => Object.assign(obj, item), {}); + } + export async function sendWait( chain: ChainName, rpc: Connection, diff --git a/platforms/solana/src/protocols/tokenBridge.ts b/platforms/solana/src/protocols/tokenBridge.ts index 884e89f27..7e51cfca9 100644 --- a/platforms/solana/src/protocols/tokenBridge.ts +++ b/platforms/solana/src/protocols/tokenBridge.ts @@ -58,8 +58,9 @@ import { import { SolanaContracts } from '../contracts'; import { SolanaUnsignedTransaction } from '../unsignedTransaction'; -import { SolanaChainName, UniversalOrSolana } from '../types'; +import { AnySolanaAddress, SolanaChainName } from '../types'; import { SolanaPlatform } from '../platform'; +import { SolanaAddress } from '../address'; export class SolanaTokenBridge implements TokenBridge<'Solana'> { readonly chainId: ChainId; @@ -86,27 +87,28 @@ export class SolanaTokenBridge implements TokenBridge<'Solana'> { return new SolanaTokenBridge(network, chain, connection, contracts); } - async isWrappedAsset(token: UniversalOrSolana): Promise { + async isWrappedAsset(token: AnySolanaAddress): Promise { return getWrappedMeta( this.connection, this.tokenBridge.programId, - token.toUint8Array(), + new SolanaAddress(token).toUint8Array(), ) .catch((_) => null) .then((meta) => meta != null); } - async getOriginalAsset(token: UniversalOrSolana): Promise { + async getOriginalAsset(token: AnySolanaAddress): Promise { if (!(await this.isWrappedAsset(token))) throw ErrNotWrapped(token.toString()); - const mint = new PublicKey(token.toUint8Array()); + const tokenAddr = new SolanaAddress(token).toUint8Array(); + const mint = new PublicKey(tokenAddr); try { const meta = await getWrappedMeta( this.connection, this.tokenBridge.programId, - token.toUint8Array(), + tokenAddr, ); if (meta === null) @@ -165,10 +167,11 @@ export class SolanaTokenBridge implements TokenBridge<'Solana'> { } async *createAttestation( - token: UniversalOrSolana, - sender: UniversalOrSolana, + token: AnySolanaAddress, + payer?: AnySolanaAddress, ): AsyncGenerator { - const senderAddress = new PublicKey(sender.toUint8Array()); + if (!payer) throw new Error('Payer required to create attestation'); + const senderAddress = new SolanaAddress(payer).unwrap(); // TODO: const nonce = 0; // createNonce().readUInt32LE(0); @@ -183,7 +186,7 @@ export class SolanaTokenBridge implements TokenBridge<'Solana'> { this.tokenBridge.programId, this.coreBridge.programId, senderAddress, - token.toUint8Array(), + new SolanaAddress(token).toUint8Array(), messageKey.publicKey, nonce, ); @@ -199,9 +202,10 @@ export class SolanaTokenBridge implements TokenBridge<'Solana'> { async *submitAttestation( vaa: VAA<'AttestMeta'>, - sender: UniversalOrSolana, + payer?: AnySolanaAddress, ): AsyncGenerator { - const senderAddress = new PublicKey(sender.toUint8Array()); + if (!payer) throw new Error('Payer required to create attestation'); + const senderAddress = new SolanaAddress(payer).unwrap(); const transaction = new Transaction().add( createCreateWrappedInstruction( @@ -220,15 +224,15 @@ export class SolanaTokenBridge implements TokenBridge<'Solana'> { } private async transferSol( - sender: UniversalOrSolana, + sender: AnySolanaAddress, recipient: ChainAddress, amount: bigint, payload?: Uint8Array, ): Promise { // https://github.com/wormhole-foundation/wormhole-connect/blob/development/sdk/src/contexts/solana/context.ts#L245 - const senderAddress = new PublicKey(sender.toUint8Array()); - // TODO: why? + const senderAddress = new SolanaAddress(sender).unwrap(); + // TODO: the payer can actually be different from the sender. We need to allow the user to pass in an optional payer const payerPublicKey = senderAddress; const recipientAddress = recipient.address @@ -332,9 +336,9 @@ export class SolanaTokenBridge implements TokenBridge<'Solana'> { } async *transfer( - sender: UniversalOrSolana, + sender: AnySolanaAddress, recipient: ChainAddress, - token: UniversalOrSolana | 'native', + token: AnySolanaAddress | 'native', amount: bigint, payload?: Uint8Array, ): AsyncGenerator { @@ -345,9 +349,9 @@ export class SolanaTokenBridge implements TokenBridge<'Solana'> { return; } - const tokenAddress = new PublicKey(token.toUint8Array()); + const tokenAddress = new SolanaAddress(token).unwrap(); - const senderAddress = new PublicKey(sender.toUint8Array()); + const senderAddress = new SolanaAddress(sender).unwrap(); const senderTokenAddress = await getAssociatedTokenAddress( tokenAddress, senderAddress, @@ -454,10 +458,10 @@ export class SolanaTokenBridge implements TokenBridge<'Solana'> { } private async postVaa( - sender: UniversalOrSolana, + sender: AnySolanaAddress, vaa: VAA<'Transfer'> | VAA<'TransferWithPayload'>, ) { - const senderAddr = new PublicKey(sender.toUint8Array()); + const senderAddr = new SolanaAddress(sender).unwrap(); const signatureSet = Keypair.generate(); const verifySignaturesInstructions = @@ -495,7 +499,7 @@ export class SolanaTokenBridge implements TokenBridge<'Solana'> { } async *redeem( - sender: UniversalOrSolana, + sender: AnySolanaAddress, vaa: VAA<'Transfer'> | VAA<'TransferWithPayload'>, unwrapNative: boolean = true, ): AsyncGenerator { @@ -503,8 +507,8 @@ export class SolanaTokenBridge implements TokenBridge<'Solana'> { // TODO: check if vaa.payload.token.address is native Sol const { blockhash } = await this.connection.getLatestBlockhash(); - const senderAddress = new PublicKey(sender.toUint8Array()); - const ataAddress = new PublicKey(vaa.payload.to.address.toUint8Array()); + const senderAddress = new SolanaAddress(sender).unwrap(); + const ataAddress = new SolanaAddress(vaa.payload.to.address).unwrap(); const wrappedToken = await this.getWrappedAsset(vaa.payload.token); // If the ata doesn't exist yet, create it diff --git a/platforms/solana/src/types.ts b/platforms/solana/src/types.ts index addb870f3..e955cb250 100644 --- a/platforms/solana/src/types.ts +++ b/platforms/solana/src/types.ts @@ -1,23 +1,12 @@ import { PlatformToChains, - UniversalAddress, UniversalOrNative, - registerNative, } from '@wormhole-foundation/connect-sdk'; -import { SolanaAddress } from './address'; +import { PublicKeyInitData } from '@solana/web3.js'; export const unusedNonce = 0; export const unusedArbiterFee = 0n; -registerNative('Solana', SolanaAddress); - export type SolanaChainName = PlatformToChains<'Solana'>; export type UniversalOrSolana = UniversalOrNative<'Solana'>; - -export const toSolanaAddrString = (addr: UniversalOrSolana) => - typeof addr === 'string' - ? addr - : (addr instanceof UniversalAddress - ? addr.toNative('Solana') - : addr - ).unwrap(); +export type AnySolanaAddress = UniversalOrSolana | PublicKeyInitData;