Skip to content

Commit

Permalink
reduce some boilerplate, make it easier to access a static reference …
Browse files Browse the repository at this point in the history
…to the class instance (#164)
  • Loading branch information
barnjamin committed Nov 25, 2023
1 parent 5c0ea03 commit 8f8b035
Show file tree
Hide file tree
Showing 12 changed files with 60 additions and 188 deletions.
2 changes: 2 additions & 0 deletions core/definitions/src/attestation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ export type WormholeMessageId = {
chain: Chain;
emitter: UniversalAddress;
sequence: SequenceId;
// TODO
vaa?: VAA;
};
export function isWormholeMessageId(thing: WormholeMessageId | any): thing is WormholeMessageId {
return (
Expand Down
19 changes: 9 additions & 10 deletions core/definitions/src/chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ import {

import { TokenAddress } from "./address";
import { WormholeMessageId } from "./attestation";
import { PlatformContext, PlatformUtils } from "./platform";
import { PlatformContext } from "./platform";
import { protocolIsRegistered } from "./protocol";
import { AutomaticCircleBridge, CircleBridge } from "./protocols/cctp";
import { IbcBridge } from "./protocols/ibc";
import { AutomaticTokenBridge, TokenBridge } from "./protocols/tokenBridge";
import { protocolIsRegistered } from "./protocol";
import { RpcConnection } from "./rpc";
import { ChainConfig, SignedTx } from "./types";

Expand All @@ -24,7 +24,6 @@ export abstract class ChainContext<
readonly network: N;

readonly platform: PlatformContext<N, P>;
readonly platformUtils: PlatformUtils<N, P>;

readonly chain: C;
readonly config: ChainConfig<N, C>;
Expand All @@ -40,7 +39,6 @@ export abstract class ChainContext<
constructor(chain: C, platform: PlatformContext<N, P>) {
this.config = platform.config[chain];
this.platform = platform;
this.platformUtils = platform.constructor as any as PlatformUtils<N, P>;
this.chain = this.config.key;
this.network = this.config.network;
}
Expand All @@ -52,29 +50,30 @@ export abstract class ChainContext<

// Get the number of decimals for a token
async getDecimals(token: TokenAddress<C>): Promise<bigint> {
return this.platformUtils.getDecimals(this.chain, this.getRpc(), token);
return this.platform.utils().getDecimals(this.chain, this.getRpc(), token);
}

// Get the balance of a token for a given address
async getBalance(walletAddr: string, token: TokenAddress<C>): Promise<bigint | null> {
return this.platformUtils.getBalance(this.chain, await this.getRpc(), walletAddr, token);
return this.platform.utils().getBalance(this.chain, await this.getRpc(), walletAddr, token);
}

async getLatestBlock(): Promise<number> {
return this.platformUtils.getLatestBlock(this.getRpc());
return this.platform.utils().getLatestBlock(this.getRpc());
}

async getLatestFinalizedBlock(): Promise<number> {
return this.platformUtils.getLatestFinalizedBlock(this.getRpc());
return this.platform.utils().getLatestFinalizedBlock(this.getRpc());
}

// Get details about the transaction
async parseTransaction(txid: string): Promise<WormholeMessageId[]> {
return this.platform.parseTransaction(this.chain, await this.getRpc(), txid);
return this.platform.parseWormholeMessages(this.chain, await this.getRpc(), txid);
}

// Send a transaction and wait for it to be confirmed
async sendWait(stxns: SignedTx): Promise<string[]> {
return this.platformUtils.sendWait(this.chain, await this.getRpc(), stxns);
return this.platform.utils().sendWait(this.chain, await this.getRpc(), stxns);
}

//
Expand Down
24 changes: 16 additions & 8 deletions core/definitions/src/platform.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ import {
PlatformToChains,
ProtocolName,
} from "@wormhole-foundation/sdk-base";
import { WormholeCore } from ".";
import { TokenAddress } from "./address";
import { WormholeMessageId } from "./attestation";
import { ChainContext } from "./chain";
import { create } from "./protocol";
import { RpcConnection } from "./rpc";
import { TokenAddress } from "./address";
import { Balances, ChainsConfig, SignedTx, TokenId, TxHash } from "./types";
import { ProtocolInitializer } from "./protocol";

// PlatformUtils represents the _static_ attributes available on
// the PlatformContext Class
Expand All @@ -21,9 +22,6 @@ export interface PlatformUtils<N extends Network, P extends Platform> {
// Initialize a new PlatformContext object
new (network: N, config?: ChainsConfig<N, P>): PlatformContext<N, P>;

// Get a protocol name
getProtocolInitializer<PN extends ProtocolName>(protocol: PN): ProtocolInitializer<P, PN>;

// Check if this chain is supported by this platform
// Note: purposely not adding generic parameters
isSupportedChain(chain: Chain): boolean;
Expand Down Expand Up @@ -89,19 +87,29 @@ export abstract class PlatformContext<N extends Network, P extends Platform> {
readonly config: ChainsConfig<N, P>,
) {}

// provides access to the static attributes of the PlatformContext class
utils(): PlatformUtils<N, P> {
return this.constructor as any;
}

// Create a _new_ RPC Connection
abstract getRpc<C extends PlatformToChains<P>>(chain: C): RpcConnection<P>;

// Create a new Chain context object
abstract getChain<C extends PlatformToChains<P>>(chain: C): ChainContext<N, P, C>;

// Create a new Protocol Client instance by protocol name
abstract getProtocol<PN extends ProtocolName, T>(protocol: PN, rpc: RpcConnection<P>): Promise<T>;
getProtocol<PN extends ProtocolName, T>(protocol: PN, rpc: RpcConnection<P>): Promise<T> {
return create(this.utils()._platform, protocol, rpc, this.config);
}

// Look up transaction logs and parse out Wormhole messages
abstract parseTransaction<C extends PlatformToChains<P>>(
async parseWormholeMessages<C extends PlatformToChains<P>>(
chain: C,
rpc: RpcConnection<P>,
txid: TxHash,
): Promise<WormholeMessageId[]>;
): Promise<WormholeMessageId[]> {
const wc: WormholeCore<N, P, C> = await this.getProtocol("WormholeCore", rpc);
return wc.parseTransaction(txid);
}
}
16 changes: 13 additions & 3 deletions core/definitions/src/protocol.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import {
Chain,
isChain,
Network,
Platform,
chainToPlatform,
ProtocolName,
Network,
chainToPlatform,
isChain,
} from "@wormhole-foundation/sdk-base";
import { RpcConnection } from "./rpc";
import { ChainsConfig } from "./types";
Expand Down Expand Up @@ -76,3 +76,13 @@ export function getProtocolInitializer<P extends Platform, PN extends ProtocolNa

return pctr as ProtocolInitializer<P, PN>;
}

export const create = <N extends Network, P extends Platform, PN extends ProtocolName, T>(
platform: P,
protocol: PN,
rpc: RpcConnection<P>,
config: ChainsConfig<N, P>,
): Promise<T> => {
const pctr = getProtocolInitializer(platform, protocol);
return pctr.fromRpc(rpc, config) as Promise<T>;
};
8 changes: 6 additions & 2 deletions core/definitions/src/protocols/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@ import { AccountAddress } from "../address";
import { WormholeMessageId } from "../attestation";
import { TxHash } from "../types";
import { UnsignedTransaction } from "../unsignedTransaction";
export interface WormholeCore<N extends Network, P extends Platform, C extends PlatformToChains<P>> {
export interface WormholeCore<
N extends Network,
P extends Platform,
C extends PlatformToChains<P>,
> {
publishMessage(
sender: AccountAddress<C>,
message: string | Uint8Array
message: string | Uint8Array,
): AsyncGenerator<UnsignedTransaction<N, C>>;
parseTransaction(txid: TxHash): Promise<WormholeMessageId[]>;
// TODO: events?
Expand Down
25 changes: 3 additions & 22 deletions core/definitions/src/testing/mocks/platform.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ import {
RpcConnection,
TokenAddress,
TokenId,
TxHash,
WormholeMessageId,
} from "../..";
import { MockChain } from "./chain";
import { MockRpc } from "./rpc";
Expand All @@ -27,22 +25,17 @@ export function mockPlatformFactory<N extends Network, P extends Platform>(
static _platform: P = platform;
constructor(network: N, _config?: ChainsConfig<N, P>) {
super(network, _config ? _config : config);
this.network = network;
}
}
// @ts-ignore
return ConcreteMockPlatform<N>;
return ConcreteMockPlatform;
}

// Note: don't use this directly, instead create a ConcreteMockPlatform with the
// mockPlatformFactory
export class MockPlatform<N extends Network, P extends Platform> implements PlatformContext<N, P> {
network: N;
config: ChainsConfig<N, P>;

export class MockPlatform<N extends Network, P extends Platform> extends PlatformContext<N, P> {
constructor(network: N, config: ChainsConfig<N, P>) {
this.network = network;
this.config = config;
super(network, config);
}

static getProtocol<PN extends ProtocolName, T extends any>(protocol: PN): T {
Expand Down Expand Up @@ -129,18 +122,6 @@ export class MockPlatform<N extends Network, P extends Platform> implements Plat
return 0n;
}

async parseTransaction<C extends PlatformToChains<P>>(
chain: C,
rpc: RpcConnection<P>,
txid: TxHash,
): Promise<WormholeMessageId[]> {
throw new Error("Method not implemented");
}

getProtocol<PN extends ProtocolName, T extends any>(protocol: PN): T {
throw new Error("Method not implemented.");
}

getDecimals<C extends PlatformToChains<P>>(
chain: C,
rpc: RpcConnection<P>,
Expand Down
6 changes: 3 additions & 3 deletions examples/src/helpers/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ export async function getStuff<
C extends PlatformToChains<P>,
>(chain: ChainContext<N, P, C>): Promise<TransferStuff<N, P, C>> {
let signer: Signer;

switch (chain.platformUtils._platform) {
const platform = chain.platform.utils()._platform;
switch (platform) {
case "Solana":
signer = await getSolanaSigner(await chain.getRpc(), getEnv("SOL_PRIVATE_KEY"));
break;
Expand All @@ -56,7 +56,7 @@ export async function getStuff<
signer = await getEvmSigner(await chain.getRpc(), getEnv("ETH_PRIVATE_KEY"));
break;
default:
throw new Error("Unrecognized platform: " + chain.platformUtils._platform);
throw new Error("Unrecognized platform: " + platform);
}

return { chain, signer, address: nativeChainAddress(chain.chain, signer.address()) };
Expand Down
4 changes: 1 addition & 3 deletions platforms/aptos/src/address.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,4 @@ export class AptosAddress implements Address {
}
}

try {
registerNative("Aptos", AptosAddress);
} catch {}
registerNative("Aptos", AptosAddress);
40 changes: 4 additions & 36 deletions platforms/aptos/src/platform.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,4 @@
import {
Chain,
Network,
PlatformContext,
ProtocolImplementation,
ProtocolInitializer,
ProtocolName,
WormholeCore,
WormholeMessageId,
getProtocolInitializer,
} from "@wormhole-foundation/connect-sdk";
import { Chain, Network, PlatformContext } from "@wormhole-foundation/connect-sdk";
import { AptosClient } from "aptos";
import { AptosChain } from "./chain";
import { AptosChains, AptosPlatformType, _platform } from "./types";
Expand Down Expand Up @@ -43,22 +33,6 @@ export class AptosPlatform<N extends Network> extends PlatformContext<N, AptosPl
throw new Error("No configuration available for chain: " + chain);
}

async getProtocol<PN extends ProtocolName>(
protocol: PN,
rpc: AptosClient,
): Promise<ProtocolImplementation<AptosPlatformType, PN>> {
return AptosPlatform.getProtocolInitializer(protocol).fromRpc(rpc, this.config);
}

async parseTransaction<C extends AptosChains>(
chain: C,
rpc: AptosClient,
tx: string,
): Promise<WormholeMessageId[]> {
const core: WormholeCore<N, AptosPlatformType, C> = await this.getProtocol("WormholeCore", rpc);
return core.parseTransaction(tx);
}

static nativeTokenId<N extends Network, C extends AptosChains>(network: N, chain: C): TokenId<C> {
if (!this.isSupportedChain(chain)) throw new Error(`invalid chain: ${chain}`);
return nativeChainAddress(chain, APTOS_COIN);
Expand All @@ -80,7 +54,7 @@ export class AptosPlatform<N extends Network> extends PlatformContext<N, AptosPl
return platform === AptosPlatform._platform;
}

async getDecimals(
static async getDecimals(
chain: Chain,
rpc: AptosClient,
token: AnyAptosAddress | "native",
Expand All @@ -96,7 +70,7 @@ export class AptosPlatform<N extends Network> extends PlatformContext<N, AptosPl
return decimals;
}

async getBalance(
static async getBalance(
chain: Chain,
rpc: AptosClient,
walletAddress: string,
Expand All @@ -120,7 +94,7 @@ export class AptosPlatform<N extends Network> extends PlatformContext<N, AptosPl
}
}

async getBalances(
static async getBalances(
chain: Chain,
rpc: AptosClient,
walletAddress: string,
Expand Down Expand Up @@ -193,10 +167,4 @@ export class AptosPlatform<N extends Network> extends PlatformContext<N, AptosPl
const ci = await conn.getChainId();
return this.chainFromChainId(ci.toString());
}

static getProtocolInitializer<PN extends ProtocolName>(
protocol: PN,
): ProtocolInitializer<AptosPlatformType, PN> {
return getProtocolInitializer(this._platform, protocol);
}
}
31 changes: 0 additions & 31 deletions platforms/cosmwasm/src/platform.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,9 @@ import {
ChainsConfig,
Network,
PlatformContext,
ProtocolImplementation,
ProtocolInitializer,
ProtocolName,
SignedTx,
TxHash,
WormholeCore,
WormholeMessageId,
decimals,
getProtocolInitializer,
nativeChainIds,
networkPlatformConfigs,
} from "@wormhole-foundation/connect-sdk";
Expand Down Expand Up @@ -67,25 +61,6 @@ export class CosmwasmPlatform<N extends Network> extends PlatformContext<N, Cosm
throw new Error("No configuration available for chain: " + chain);
}

async getProtocol<PN extends ProtocolName>(
protocol: PN,
rpc: CosmWasmClient,
): Promise<ProtocolImplementation<CosmwasmPlatformType, PN>> {
return CosmwasmPlatform.getProtocolInitializer(protocol).fromRpc(rpc, this.config);
}

async parseTransaction<C extends CosmwasmChains>(
chain: C,
rpc: CosmWasmClient,
txid: TxHash,
): Promise<WormholeMessageId[]> {
const core: WormholeCore<N, CosmwasmPlatformType, C> = await this.getProtocol(
"WormholeCore",
rpc,
);
return core.parseTransaction(txid);
}

static getQueryClient = (rpc: CosmWasmClient): QueryClient & BankExtension & IbcExtension => {
// @ts-ignore -- access private attribute
const tmClient: TendermintClient = rpc.getTmClient()!;
Expand Down Expand Up @@ -225,10 +200,4 @@ export class CosmwasmPlatform<N extends Network> extends PlatformContext<N, Cosm
const conn = await queryClient.ibc.channel.channel(IBC_TRANSFER_PORT, sourceChannel);
return conn.channel?.counterparty?.channelId ?? null;
}

static getProtocolInitializer<PN extends ProtocolName>(
protocol: PN,
): ProtocolInitializer<CosmwasmPlatformType, PN> {
return getProtocolInitializer(this._platform, protocol);
}
}
Loading

0 comments on commit 8f8b035

Please sign in to comment.