Skip to content

Commit

Permalink
Fix gas meter and setNetwork (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lohann authored Aug 9, 2024
1 parent d417c79 commit 690a689
Show file tree
Hide file tree
Showing 7 changed files with 522 additions and 335 deletions.
134 changes: 87 additions & 47 deletions src/Gateway.sol
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,24 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {
bytes32 internal constant FIRST_MESSAGE_PLACEHOLDER = bytes32(uint256(2 ** 256 - 1));

// Shard data, maps the pubkey coordX (which is already collision resistant) to shard info.
mapping(bytes32 => KeyInfo) _shards;
mapping(bytes32 => KeyInfo) private _shards;

// GMP message status
mapping(bytes32 => GmpInfo) _messages;
mapping(bytes32 => GmpInfo) private _messages;

// GAP necessary for migration purposes
mapping(GmpSender => mapping(uint16 => uint256)) private _deprecated_Deposits;
mapping(uint16 => bytes32) private _deprecated_Networks;

// Hash of the previous GMP message submitted.
bytes32 public prevMessageHash;

// Replay protection mechanism, stores the hash of the executed messages
// messageHash => shardId
mapping(bytes32 => bytes32) _executedMessages;
mapping(bytes32 => bytes32) private _executedMessages;

// Network ID => Source network
mapping(uint16 => NetworkInfo) _networkInfo;
mapping(uint16 => NetworkInfo) private _networkInfo;

/**
* @dev Shard info stored in the Gateway Contract
Expand Down Expand Up @@ -177,6 +181,10 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {
return NETWORK_ID;
}

function networkInfo(uint16 id) external view returns (NetworkInfo memory) {
return _networkInfo[id];
}

/**
* @dev Verify if shard exists, if the TSS signature is valid then increment shard's nonce.
*/
Expand All @@ -190,7 +198,7 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {

// Load y parity bit, it must be 27 (even), or 28 (odd)
// ref: https://ethereum.github.io/yellowpaper/paper.pdf
uint8 yParity = uint8(BranchlessMath.select((status & SHARD_Y_PARITY) > 0, 28, 27));
uint8 yParity = BranchlessMath.ternaryU8((status & SHARD_Y_PARITY) > 0, 28, 27);

// Verify Signature
require(
Expand All @@ -206,13 +214,14 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {
return bytes32(tssKey.xCoord);
}

// Converts a `TssKey` into an `KeyInfo` unique identifier
// Initialize networks
function _updateNetworks(Network[] calldata networks) private {
for (uint256 i = 0; i < networks.length; i++) {
Network calldata network = networks[i];
bytes32 domainSeparator = computeDomainSeparator(network.id, network.gateway);
NetworkInfo storage info = _networkInfo[network.id];
info.domainSeparator = domainSeparator;
require(info.domainSeparator == bytes32(0), "network already initialized");
require(network.id != NETWORK_ID || network.gateway == address(this), "wrong gateway address");
info.domainSeparator = computeDomainSeparator(network.id, network.gateway);
info.gasLimit = 15_000_000; // Default to 15M gas
info.relativeGasPrice = UFloatMath.ONE;
info.baseFee = 0;
Expand Down Expand Up @@ -248,10 +257,10 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {
);

// if is a new shard shard, set its initial nonce to 1
shard.nonce = uint32(BranchlessMath.select(nonce == 0, 1, nonce));
shard.nonce = uint32(BranchlessMath.ternaryU32(nonce == 0, 1, nonce));

// enable/disable the y-parity flag
status = uint8(BranchlessMath.select(yParity > 0, status | SHARD_Y_PARITY, status & ~SHARD_Y_PARITY));
status = BranchlessMath.ternaryU8(yParity > 0, status | SHARD_Y_PARITY, status & ~SHARD_Y_PARITY);

// enable SHARD_ACTIVE bitflag
status |= SHARD_ACTIVE;
Expand Down Expand Up @@ -340,7 +349,7 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {
// https://eips.ethereum.org/EIPS/eip-150
uint256 gasNeeded = gasLimit.saturatingMul(64).saturatingDiv(63);
// to guarantee it was provided enough gas to execute the GMP message
gasNeeded = gasNeeded.saturatingAdd(6412);
gasNeeded = gasNeeded.saturatingAdd(10000);
require(gasleft() >= gasNeeded, "insufficient gas to execute GMP message");
}

Expand Down Expand Up @@ -374,7 +383,7 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {
}

// Update GMP status
status = GmpStatus(BranchlessMath.select(success, uint256(GmpStatus.SUCCESS), uint256(GmpStatus.REVERT)));
status = GmpStatus(BranchlessMath.ternary(success, uint256(GmpStatus.SUCCESS), uint256(GmpStatus.REVERT)));

// Persist gmp execution status on storage
gmp.status = status;
Expand Down Expand Up @@ -413,11 +422,9 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {
unchecked {
// Compute GMP gas used
uint256 gasUsed = 7214;
{
gasUsed = gasUsed.saturatingAdd(GasUtils.calldataGasCost());
gasUsed = gasUsed.saturatingAdd(GasUtils.proxyOverheadGasCost(uint16(msg.data.length), 64));
gasUsed = gasUsed.saturatingAdd(initialGas - gasleft());
}
gasUsed = gasUsed.saturatingAdd(GasUtils.calldataBaseCost());
gasUsed = gasUsed.saturatingAdd(GasUtils.proxyOverheadGasCost(uint16(msg.data.length), 64));
gasUsed = gasUsed.saturatingAdd(initialGas - gasleft());

// Compute refund amount
uint256 refund = BranchlessMath.min(gasUsed.saturatingMul(tx.gasprice), address(this).balance);
Expand All @@ -429,34 +436,48 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {
}
}

function _setNetworkInfo(bytes32 executor, bytes32 messageHash, UpdateNetworkInfo calldata data) private {
require(data.mortality >= block.number, "message expired");
function _setNetworkInfo(bytes32 executor, bytes32 messageHash, UpdateNetworkInfo calldata info) private {
require(info.mortality >= block.number, "message expired");
require(executor != bytes32(0), "executor cannot be zero");

// Verify signature and if the message was already executed
require(_executedMessages[messageHash] == bytes32(0), "message already executed");

// Update network info and store the message hash to prevent replay attacks
NetworkInfo storage networkInfo = _networkInfo[data.networkId];

// Verify if the domain separator is not zero
require((networkInfo.domainSeparator | data.domainSeparator) != bytes32(0), "domain separator cannot be zero");
// Update network info
NetworkInfo memory stored = _networkInfo[info.networkId];

// Update domain separator if it's not zero
if (data.domainSeparator != bytes32(0)) {
networkInfo.domainSeparator = messageHash;
}
// Verify and update domain separator if it's not zero
stored.domainSeparator =
BranchlessMath.ternary(info.domainSeparator != bytes32(0), info.domainSeparator, stored.domainSeparator);
require(stored.domainSeparator != bytes32(0), "domain separator cannot be zero");

// Update gas limit if it's not zero
networkInfo.gasLimit = uint64(BranchlessMath.select(data.gasLimit > 0, data.gasLimit, networkInfo.gasLimit));
if (UFloat9x56.unwrap(data.relativeGasPrice) > 0 || data.baseFee > 0) {
networkInfo.relativeGasPrice = networkInfo.relativeGasPrice;
networkInfo.baseFee = networkInfo.baseFee;
stored.gasLimit = BranchlessMath.ternaryU64(info.gasLimit > 0, info.gasLimit, stored.gasLimit);

// Update relative gas price and base fee if any of them are greater than zero
{
bool shouldUpdate = UFloat9x56.unwrap(info.relativeGasPrice) > 0 || info.baseFee > 0;
stored.relativeGasPrice = UFloat9x56.wrap(
BranchlessMath.ternaryU64(
shouldUpdate, UFloat9x56.unwrap(info.relativeGasPrice), UFloat9x56.unwrap(stored.relativeGasPrice)
)
);
stored.baseFee = BranchlessMath.ternaryU128(shouldUpdate, info.baseFee, stored.baseFee);
}

// Save the message hash to prevent replay attacks
_executedMessages[messageHash] = executor;

// Update network info
_networkInfo[info.networkId] = stored;

emit NetworkUpdated(
messageHash, data.networkId, data.domainSeparator, data.relativeGasPrice, data.baseFee, data.gasLimit
messageHash,
info.networkId,
stored.domainSeparator,
stored.relativeGasPrice,
stored.baseFee,
stored.gasLimit
);
}

Expand Down Expand Up @@ -494,10 +515,10 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {
address destinationAddress,
uint16 destinationNetwork,
uint256 executionGasLimit,
bytes memory data
bytes calldata data
) external payable returns (bytes32) {
// Check if the message data is too large
require(data.length < MAX_PAYLOAD_SIZE, "msg data too large");
require(data.length <= MAX_PAYLOAD_SIZE, "msg data too large");

// Check if the destination network is supported
NetworkInfo storage info = _networkInfo[destinationNetwork];
Expand All @@ -506,7 +527,7 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {

// Check if the sender has deposited enougth funds to execute the GMP message
{
uint256 nonZeros = GasUtils.countNonZeros(data);
uint256 nonZeros = GasUtils.countNonZerosCalldata(data);
uint256 zeros = data.length - nonZeros;
uint256 msgPrice = GasUtils.estimateWeiCost(
info.relativeGasPrice, info.baseFee, uint16(nonZeros), uint16(zeros), executionGasLimit
Expand All @@ -521,18 +542,37 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {
bytes32 prevHash = prevMessageHash;

// if the messageHash is the first message, we use a zero salt
uint256 salt = BranchlessMath.select(prevHash == FIRST_MESSAGE_PLACEHOLDER, 0, uint256(prevHash));
uint256 salt = BranchlessMath.ternary(prevHash == FIRST_MESSAGE_PLACEHOLDER, 0, uint256(prevHash));

// Create GMP message and update prevMessageHash
GmpMessage memory message =
GmpMessage(source, NETWORK_ID, destinationAddress, destinationNetwork, executionGasLimit, salt, data);
prevHash = message.eip712TypedHash(domainSeparator);
prevMessageHash = prevHash;
bytes memory payload;
{
GmpMessage memory message =
GmpMessage(source, NETWORK_ID, destinationAddress, destinationNetwork, executionGasLimit, salt, data);
prevHash = message.eip712TypedHash(domainSeparator);
prevMessageHash = prevHash;
payload = message.data;
}

emit GmpCreated(
prevHash, GmpSender.unwrap(source), destinationAddress, destinationNetwork, executionGasLimit, salt, data
);
return prevHash;
// Emit `GmpCreated` event without copy the data, to simplify the gas estimation.
// the assembly code below is equivalent to:
// ```solidity
// emit GmpCreated(prevHash, source, destinationAddress, destinationNetwork, executionGasLimit, salt, data);
// return prevHash;
// ```
bytes32 eventSelector = GmpCreated.selector;
assembly {
let ptr := sub(payload, 0x80)
mstore(ptr, destinationNetwork) // dest network
mstore(add(ptr, 0x20), executionGasLimit) // gas limit
mstore(add(ptr, 0x40), salt) // salt
mstore(add(ptr, 0x60), 0x80) // data offset
let size := and(add(mload(payload), 31), 0xffffffe0)
size := add(size, 160)
log4(ptr, size, eventSelector, prevHash, source, destinationAddress)
mstore(0, prevHash)
return(0, 32)
}
}

/**
Expand All @@ -555,7 +595,7 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {
require(baseFee > 0 || UFloat9x56.unwrap(relativeGasPrice) > 0, "unsupported network");

// if the message data is too large, we use the maximum base fee.
baseFee = BranchlessMath.select(messageSize > MAX_PAYLOAD_SIZE, 2 ** 256 - 1, baseFee);
baseFee = BranchlessMath.ternary(messageSize > MAX_PAYLOAD_SIZE, 2 ** 256 - 1, baseFee);

// Estimate the cost
return GasUtils.estimateWeiCost(relativeGasPrice, baseFee, uint16(messageSize), 0, gasLimit);
Expand All @@ -573,7 +613,7 @@ contract Gateway is IGateway, IExecutor, IUpgradable, GatewayEIP712 {
function _getAdmin() private view returns (address admin) {
admin = ERC1967.getAdmin();
// If the admin slot is empty, then the 0xd4833be6144AF48d4B09E5Ce41f826eEcb7706D6 is the admin
admin = BranchlessMath.select(admin == address(0x0), 0xd4833be6144AF48d4B09E5Ce41f826eEcb7706D6, admin);
admin = BranchlessMath.ternary(admin == address(0x0), 0xd4833be6144AF48d4B09E5Ce41f826eEcb7706D6, admin);
}

function setAdmin(address newAdmin) external payable {
Expand Down
64 changes: 59 additions & 5 deletions src/utils/BranchlessMath.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,20 @@ library BranchlessMath {
* @dev Returns the smallest of two numbers.
*/
function min(uint256 x, uint256 y) internal pure returns (uint256) {
return select(x < y, x, y);
return ternary(x < y, x, y);
}

/**
* @dev Returns the largest of two numbers.
*/
function max(uint256 x, uint256 y) internal pure returns (uint256) {
return select(x > y, x, y);
return ternary(x > y, x, y);
}

/**
* @dev If `condition` is true returns `a`, otherwise returns `b`.
*/
function select(bool condition, uint256 a, uint256 b) internal pure returns (uint256) {
function ternary(bool condition, uint256 a, uint256 b) internal pure returns (uint256) {
unchecked {
// branchless select, works because:
// b ^ (a ^ b) == a
Expand All @@ -40,13 +40,67 @@ library BranchlessMath {

/**
* @dev If `condition` is true returns `a`, otherwise returns `b`.
* see `BranchlessMath.ternary`
*/
function select(bool condition, address a, address b) internal pure returns (address) {
return address(uint160(select(condition, uint256(uint160(a)), uint256(uint160(b)))));
function ternary(bool condition, address a, address b) internal pure returns (address r) {
assembly {
r := xor(b, mul(xor(a, b), condition))
}
}

/**
* @dev If `condition` is true returns `a`, otherwise returns `b`.
* see `BranchlessMath.ternary`
*/
function ternary(bool condition, bytes32 a, bytes32 b) internal pure returns (bytes32 r) {
assembly {
r := xor(b, mul(xor(a, b), condition))
}
}

/**
* @dev If `condition` is true returns `a`, otherwise returns `b`.
* see `BranchlessMath.ternary`
*/
function ternaryU128(bool condition, uint128 a, uint128 b) internal pure returns (uint128 r) {
assembly {
r := xor(b, mul(xor(a, b), condition))
}
}

/**
* @dev If `condition` is true returns `a`, otherwise returns `b`.
* see `BranchlessMath.ternary`
*/
function ternaryU64(bool condition, uint64 a, uint64 b) internal pure returns (uint64 r) {
assembly {
r := xor(b, mul(xor(a, b), condition))
}
}

/**
* @dev If `condition` is true returns `a`, otherwise returns `b`.
* see `BranchlessMath.ternary`
*/
function ternaryU32(bool condition, uint32 a, uint32 b) internal pure returns (uint32 r) {
assembly {
r := xor(b, mul(xor(a, b), condition))
}
}

/**
* @dev If `condition` is true returns `a`, otherwise returns `b`.
* see `BranchlessMath.ternary`
*/
function ternaryU8(bool condition, uint8 a, uint8 b) internal pure returns (uint8 r) {
assembly {
r := xor(b, mul(xor(a, b), condition))
}
}

/**
* @dev If `condition` is true return `value`, otherwise return zero.
* see `BranchlessMath.ternary`
*/
function selectIf(bool condition, uint256 value) internal pure returns (uint256) {
unchecked {
Expand Down
Loading

0 comments on commit 690a689

Please sign in to comment.