diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecGenerator.java index 5e0fb10d..1dbf71d4 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecGenerator.java @@ -77,12 +77,12 @@ public void generate(Protobuf3Parser.MessageDefContext msgDef, final File destin * Protobuf Codec for $modelClass model object. Generated based on protobuf schema. */ public final class $codecClass implements Codec<$modelClass> { - $unsetOneOfConstants - $parseMethod - $writeMethod - $measureDataMethod - $measureRecordMethod - $fastEqualsMethod + $unsetOneOfConstants + $parseMethod + $writeMethod + $measureDataMethod + $measureRecordMethod + $fastEqualsMethod } """ .replace("$package", codecPackage) diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecMeasureRecordMethodGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecMeasureRecordMethodGenerator.java index 622aaaa6..7326df0d 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecMeasureRecordMethodGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecMeasureRecordMethodGenerator.java @@ -102,7 +102,7 @@ private static String generateFieldSizeOfLines(final Field field, final String m return prefix + switch (field.type()) { case ENUM -> "size += sizeOfEnumList(%s, %s);" .formatted(fieldDef, getValueCode); - case MESSAGE -> "size += sizeOfMessageList($fieldDef, $valueCode, $codec::measureRecord);" + case MESSAGE -> "size += sizeOfMessageList($fieldDef, $valueCode, $codec);" .replace("$fieldDef", fieldDef) .replace("$valueCode", getValueCode) .replace("$codec", ((SingleField) field).messageTypeModelPackage() + "." + @@ -147,7 +147,7 @@ private static String generateFieldSizeOfLines(final Field field, final String m .formatted(fieldDef, getValueCode); case STRING -> "size += sizeOfString(%s, %s, %s);" .formatted(fieldDef, getValueCode, skipDefault); - case MESSAGE -> "size += sizeOfMessage($fieldDef, $valueCode, $codec::measureRecord);" + case MESSAGE -> "size += sizeOfMessage($fieldDef, $valueCode, $codec);" .replace("$fieldDef", fieldDef) .replace("$valueCode", getValueCode) .replace("$codec", ((SingleField)field).messageTypeModelPackage() + "." + diff --git a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java index 415e0a94..1ba5aff1 100644 --- a/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java +++ b/pbj-core/pbj-compiler/src/main/java/com/hedera/pbj/compiler/impl/generators/protobuf/CodecWriteMethodGenerator.java @@ -102,7 +102,7 @@ private static String generateFieldWriteLines(final Field field, final String mo return prefix + switch(field.type()) { case ENUM -> "writeEnumList(out, %s, %s);" .formatted(fieldDef, getValueCode); - case MESSAGE -> "writeMessageList(out, $fieldDef, $valueCode, $codec::write, $codec::measureRecord);" + case MESSAGE -> "writeMessageList(out, $fieldDef, $valueCode, $codec);" .replace("$fieldDef", fieldDef) .replace("$valueCode", getValueCode) .replace("$codec", ((SingleField)field).messageTypeModelPackage() + "." + @@ -165,7 +165,7 @@ private static String generateFieldWriteLines(final Field field, final String mo .formatted(fieldDef, getValueCode); case STRING -> "writeString(out, %s, %s, %s);" .formatted(fieldDef, getValueCode, skipDefault); - case MESSAGE -> "writeMessage(out, $fieldDef, $valueCode, $codec::write, $codec::measureRecord);" + case MESSAGE -> "writeMessage(out, $fieldDef, $valueCode, $codec);" .replace("$fieldDef", fieldDef) .replace("$valueCode", getValueCode) .replace("$codec", ((SingleField)field).messageTypeModelPackage() + "." + diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/Codec.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/Codec.java index f7996b3f..54e5471e 100644 --- a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/Codec.java +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/Codec.java @@ -1,14 +1,13 @@ package com.hedera.pbj.runtime; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - import com.hedera.pbj.runtime.io.ReadableSequentialData; import com.hedera.pbj.runtime.io.WritableSequentialData; +import com.hedera.pbj.runtime.io.buffer.BufferedData; import com.hedera.pbj.runtime.io.buffer.Bytes; import com.hedera.pbj.runtime.io.stream.WritableStreamingData; import edu.umd.cs.findbugs.annotations.NonNull; +import java.io.IOException; +import java.io.UncheckedIOException; /** * Encapsulates Serialization, Deserialization and other IO operations. @@ -166,11 +165,11 @@ public interface Codec { * to write to the {@link WritableStreamingData} */ default Bytes toBytes(@NonNull T item) { - byte[] bytes; - try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); - WritableStreamingData writableStreamingData = new WritableStreamingData(byteArrayOutputStream)) { - write(item, writableStreamingData); - bytes = byteArrayOutputStream.toByteArray(); + // it is cheaper performance wise to measure the size of the object first than grow a buffer as needed + final byte[] bytes = new byte[measureRecord(item)]; + final BufferedData bufferedData = BufferedData.wrap(bytes); + try { + write(item, bufferedData); } catch (IOException e) { throw new UncheckedIOException(e); } diff --git a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoWriterTools.java b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoWriterTools.java index b7ffa004..c432069f 100644 --- a/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoWriterTools.java +++ b/pbj-core/pbj-runtime/src/main/java/com/hedera/pbj/runtime/ProtoWriterTools.java @@ -1,10 +1,5 @@ package com.hedera.pbj.runtime; -import static com.hedera.pbj.runtime.ProtoConstants.WIRE_TYPE_DELIMITED; -import static com.hedera.pbj.runtime.ProtoConstants.WIRE_TYPE_FIXED_32_BIT; -import static com.hedera.pbj.runtime.ProtoConstants.WIRE_TYPE_FIXED_64_BIT; -import static com.hedera.pbj.runtime.ProtoConstants.WIRE_TYPE_VARINT_OR_ZIGZAG; - import com.hedera.pbj.runtime.io.WritableSequentialData; import com.hedera.pbj.runtime.io.buffer.Bytes; import com.hedera.pbj.runtime.io.buffer.RandomAccessData; @@ -18,6 +13,8 @@ import java.util.function.Consumer; import java.util.function.ToIntFunction; +import static com.hedera.pbj.runtime.ProtoConstants.*; + /** * Static helper methods for Writers */ @@ -423,37 +420,35 @@ private static void writeBytesNoChecks(final WritableSequentialData out, final F * @param out The data output to write to * @param field the descriptor for the field we are writing, the field must not be repeated * @param message the message to write - * @param writer method reference to writer for the given message type - * @param sizeOf method reference to sizeOf measure method for the given message type + * @param codec the codec for the given message type * @throws IOException If a I/O error occurs * @param type of message */ public static void writeMessage(final WritableSequentialData out, final FieldDefinition field, - final T message, final ProtoWriter writer, final ToIntFunction sizeOf) throws IOException { + final T message, final Codec codec) throws IOException { assert field.type() == FieldType.MESSAGE : "Not a message type " + field; assert !field.repeated() : "Use writeMessageList with repeated types"; - writeMessageNoChecks(out, field, message, writer, sizeOf); + writeMessageNoChecks(out, field, message, codec); } /** * Write a message to data output, assuming the corresponding field is repeated. Usually this method is * called multiple times, one for every repeated value. If all values are available immediately, {@link - * #writeMessageList(WritableSequentialData, FieldDefinition, List, ProtoWriter, ToIntFunction)} should + * #writeMessageList(WritableSequentialData, FieldDefinition, List, Codec)} should * be used instead. * * @param out The data output to write to * @param field the descriptor for the field we are writing, the field must be repeated * @param message the message to write - * @param writer method reference to writer for the given message type - * @param sizeOf method reference to sizeOf measure method for the given message type + * @param codec the codec for the given message type * @throws IOException If a I/O error occurs * @param type of message */ public static void writeOneRepeatedMessage(final WritableSequentialData out, final FieldDefinition field, - final T message, final ProtoWriter writer, final ToIntFunction sizeOf) throws IOException { + final T message, final Codec codec) throws IOException { assert field.type() == FieldType.MESSAGE : "Not a message type " + field; assert field.repeated() : "writeOneRepeatedMessage can only be used with repeated fields"; - writeMessageNoChecks(out, field, message, writer, sizeOf); + writeMessageNoChecks(out, field, message, codec); } /** @@ -462,23 +457,22 @@ public static void writeOneRepeatedMessage(final WritableSequentialData out, * @param out The data output to write to * @param field the descriptor for the field we are writing * @param message the message to write - * @param writer method reference to writer for the given message type - * @param sizeOf method reference to sizeOf measure method for the given message type + * @param codec the codec for the given message type * @throws IOException If a I/O error occurs * @param type of message */ private static void writeMessageNoChecks(final WritableSequentialData out, final FieldDefinition field, - final T message, final ProtoWriter writer, final ToIntFunction sizeOf) throws IOException { + final T message, final Codec codec) throws IOException { // When not a oneOf don't write default value if (field.oneOf() && message == null) { writeTag(out, field, WIRE_TYPE_DELIMITED); out.writeVarInt(0, false); } else if (message != null) { writeTag(out, field, WIRE_TYPE_DELIMITED); - final int size = sizeOf.applyAsInt(message); + final int size = codec.measureRecord(message); out.writeVarInt(size, false); if (size > 0) { - writer.write(message, out); + codec.write(message, out); } } } @@ -900,12 +894,11 @@ public static void writeStringList(WritableSequentialData out, FieldDefinition f * @param out The data output to write to * @param field the descriptor for the field we are writing * @param list the list of messages value to write - * @param writer method reference to writer method for message type - * @param sizeOf method reference to size of method for message type + * @param codec the codec for the message type * @throws IOException If a I/O error occurs * @param type of message */ - public static void writeMessageList(WritableSequentialData out, FieldDefinition field, List list, ProtoWriter writer, ToIntFunction sizeOf) throws IOException { + public static void writeMessageList(WritableSequentialData out, FieldDefinition field, List list, Codec codec) throws IOException { assert field.type() == FieldType.MESSAGE : "Not a message type " + field; assert field.repeated() : "Use writeMessage with non-repeated types"; // When not a oneOf don't write default value @@ -914,7 +907,7 @@ public static void writeMessageList(WritableSequentialData out, FieldDefinit } final int listSize = list.size(); for (int i = 0; i < listSize; i++) { - writeMessageNoChecks(out, field, list.get(i), writer, sizeOf); + writeMessageNoChecks(out, field, list.get(i), codec); } } @@ -1367,16 +1360,16 @@ public static int sizeOfBytes(FieldDefinition field, RandomAccessData value, boo * * @param field descriptor of field * @param message message value to get encoded size for - * @param sizeOf method reference to sizeOf measure function for message type + * @param codec the protobuf codec for message type * @return the number of bytes for encoded value * @param The type of the message */ - public static int sizeOfMessage(FieldDefinition field, T message, ToIntFunction sizeOf) { + public static int sizeOfMessage(FieldDefinition field, T message, Codec codec) { // When not a oneOf don't write default value if (field.oneOf() && message == null) { return sizeOfTag(field, WIRE_TYPE_DELIMITED) + 1; } else if (message != null) { - final int size = sizeOf.applyAsInt(message); + final int size = codec.measureRecord(message); return sizeOfTag(field, WIRE_TYPE_DELIMITED) + sizeOfVarInt32(size) + size; } else { return 0; @@ -1396,20 +1389,22 @@ public static int sizeOfIntegerList(FieldDefinition field, List list) { return 0; } int size = 0; + final int listSize = list.size(); switch (field.type()) { case INT32 -> { - for (final int i : list) { - size += sizeOfVarInt32(i); + for (int i = 0; i < listSize; i++) { + size += sizeOfVarInt32(list.get(i)); } } case UINT32 -> { - for (final int i : list) { - size += sizeOfUnsignedVarInt32(i); + for (int i = 0; i < listSize; i++) { + size += sizeOfUnsignedVarInt32(list.get(i)); } } case SINT32 -> { - for (final int i : list) { - size += sizeOfUnsignedVarInt64(((long)i << 1) ^ ((long)i >> 63)); + for (int i = 0; i < listSize; i++) { + final long val = list.get(i); + size += sizeOfUnsignedVarInt64((val << 1) ^ (val >> 63)); } } case SFIXED32, FIXED32 -> size += FIXED32_SIZE * list.size(); @@ -1431,15 +1426,17 @@ public static int sizeOfLongList(FieldDefinition field, List list) { return 0; } int size = 0; + final int listSize = list.size(); switch (field.type()) { case INT64, UINT64 -> { - for (final long i : list) { - size += sizeOfUnsignedVarInt64(i); + for (int i = 0; i < listSize; i++) { + size += sizeOfUnsignedVarInt64(list.get(i)); } } case SINT64 -> { - for (final long i : list) { - size += sizeOfUnsignedVarInt64((i << 1) ^ (i >> 63)); + for (int i = 0; i < listSize; i++) { + final long val = list.get(i); + size += sizeOfUnsignedVarInt64((val << 1) ^ (val >> 63)); } } case SFIXED64, FIXED64 -> size += FIXED64_SIZE * list.size(); @@ -1509,8 +1506,9 @@ public static int sizeOfEnumList(FieldDefinition field, List list) { int size = 0; - for (final String value : list) { - size += sizeOfDelimited(field, sizeOfStringNoTag(value)); + final int listSize = list.size(); + for (int i = 0; i < listSize; i++) { + size += sizeOfDelimited(field, sizeOfStringNoTag(list.get(i))); } return size; } @@ -1535,14 +1534,15 @@ public static int sizeOfStringList(FieldDefinition field, List list) { * * @param field descriptor of field * @param list message list value to get encoded size for - * @param sizeOf method reference to sizeOf measure function for message type + * @param codec the protobuf codec for message type * @return the number of bytes for encoded value * @param type for message */ - public static int sizeOfMessageList(FieldDefinition field, List list, ToIntFunction sizeOf) { + public static int sizeOfMessageList(FieldDefinition field, List list, Codec codec) { int size = 0; - for (final T value : list) { - size += sizeOfMessage(field, value, sizeOf); + final int listSize = list.size(); + for (int i = 0; i < listSize; i++) { + size += sizeOfMessage(field, list.get(i), codec); } return size; } @@ -1556,12 +1556,21 @@ public static int sizeOfMessageList(FieldDefinition field, List list, ToI */ public static int sizeOfBytesList(FieldDefinition field, List list) { int size = 0; - for (final RandomAccessData value : list) { - size += Math.toIntExact(sizeOfTag(field, WIRE_TYPE_DELIMITED) + sizeOfVarInt32(Math.toIntExact(value.length())) + value.length()); + final int listSize = list.size(); + for (int i = 0; i < listSize; i++) { + final long valueLength = list.get(i).length(); + size += sizeOfDelimited(field, Math.toIntExact(valueLength)); } return size; } + /** + * Get number of bytes that would be needed to encode a field of wire type delimited + * + * @param field The field definition of the field to be measured + * @param length The length of the delimited field data in bytes + * @return the number of bytes for encoded value + */ public static int sizeOfDelimited(final FieldDefinition field, final int length) { return Math.toIntExact(sizeOfTag(field, WIRE_TYPE_DELIMITED) + sizeOfVarInt32(length) + length); } diff --git a/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/CodecWrapper.java b/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/CodecWrapper.java new file mode 100644 index 00000000..45692e7a --- /dev/null +++ b/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/CodecWrapper.java @@ -0,0 +1,50 @@ +package com.hedera.pbj.runtime; + +import com.hedera.pbj.runtime.io.ReadableSequentialData; +import com.hedera.pbj.runtime.io.WritableSequentialData; +import edu.umd.cs.findbugs.annotations.NonNull; +import java.io.IOException; +import java.util.function.ToIntFunction; + +/** + * Now that we use Codecs in ProtoWriterTools and ProtoReaderTools, we need to wrap the old test lambdas ProtoWriter and + * ToIntFunction into a codec class. Only the two methods are implemented and the rest are left as unsupported. + * + * @param The type of the object to be encoded/decoded + */ +class CodecWrapper implements Codec { + private final ProtoWriter writer; + private final ToIntFunction sizeOf; + + CodecWrapper(ProtoWriter writer, ToIntFunction sizeOf) { + this.writer = writer; + this.sizeOf = sizeOf; + } + + @NonNull + @Override + public T parse(@NonNull ReadableSequentialData input, boolean strictMode, int maxDepth) + throws ParseException { + throw new UnsupportedOperationException(); + } + + @Override + public void write(@NonNull T item, @NonNull WritableSequentialData output) throws IOException { + writer.write(item, output); + } + + @Override + public int measure(@NonNull ReadableSequentialData input) throws ParseException { + throw new UnsupportedOperationException(); + } + + @Override + public int measureRecord(T item) { + return sizeOf.applyAsInt(item); + } + + @Override + public boolean fastEquals(@NonNull T item, @NonNull ReadableSequentialData input) throws ParseException { + throw new UnsupportedOperationException(); + } +} diff --git a/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/ProtoParserToolsTest.java b/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/ProtoParserToolsTest.java index 218247a8..f1ced789 100644 --- a/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/ProtoParserToolsTest.java +++ b/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/ProtoParserToolsTest.java @@ -264,7 +264,8 @@ void testReadNextFieldNumber() throws IOException { final String appleStr = randomVarSizeString(); final Apple apple = Apple.newBuilder().setVariety(appleStr).build(); - writeMessage(bufferedData, definition, apple, (data, out) -> out.writeBytes(data.toByteArray()), Apple::getSerializedSize); + writeMessage(bufferedData, definition, apple, + new CodecWrapper<>((data, out) -> out.writeBytes(data.toByteArray()), Apple::getSerializedSize)); bufferedData.flip(); assertEquals(definition.number(), readNextFieldNumber(bufferedData)); diff --git a/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/ProtoWriterToolsTest.java b/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/ProtoWriterToolsTest.java index 7efa375c..cfd61d5e 100644 --- a/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/ProtoWriterToolsTest.java +++ b/pbj-core/pbj-runtime/src/test/java/com/hedera/pbj/runtime/ProtoWriterToolsTest.java @@ -448,7 +448,7 @@ void testWriteMessage() throws IOException { FieldDefinition definition = createFieldDefinition(MESSAGE); String appleStr = RANDOM_STRING.nextString(); Apple apple = Apple.newBuilder().setVariety(appleStr).build(); - writeMessage(bufferedData, definition, apple, (data, out) -> out.writeBytes(data.toByteArray()), Apple::getSerializedSize); + writeMessage(bufferedData, definition, apple, new CodecWrapper<>((data, out) -> out.writeBytes(data.toByteArray()), Apple::getSerializedSize)); bufferedData.flip(); assertEquals((definition.number() << TAG_TYPE_BITS) | WIRE_TYPE_DELIMITED.ordinal(), bufferedData.readVarInt(false)); int length = bufferedData.readVarInt(false); @@ -464,11 +464,11 @@ void testWriteOneRepeatedMessage() throws IOException { final Apple apple2 = Apple.newBuilder().setVariety(appleStr2).build(); final BufferedData buf1 = BufferedData.allocate(256); final ProtoWriter writer = (data, out) -> out.writeBytes(data.toByteArray()); - ProtoWriterTools.writeMessageList(buf1, definition, List.of(apple1, apple2), writer, Apple::getSerializedSize); + ProtoWriterTools.writeMessageList(buf1, definition, List.of(apple1, apple2), new CodecWrapper<>(writer, Apple::getSerializedSize)); final Bytes writtenBytes1 = buf1.getBytes(0, buf1.position()); final BufferedData buf2 = BufferedData.allocate(256); - ProtoWriterTools.writeOneRepeatedMessage(buf2, definition, apple1, writer, Apple::getSerializedSize); - ProtoWriterTools.writeOneRepeatedMessage(buf2, definition, apple2, writer, Apple::getSerializedSize); + ProtoWriterTools.writeOneRepeatedMessage(buf2, definition, apple1, new CodecWrapper<>(writer, Apple::getSerializedSize)); + ProtoWriterTools.writeOneRepeatedMessage(buf2, definition, apple2, new CodecWrapper<>(writer, Apple::getSerializedSize)); final Bytes writtenBytes2 = buf2.getBytes(0, buf2.position()); assertEquals(writtenBytes1, writtenBytes2); } @@ -476,7 +476,7 @@ void testWriteOneRepeatedMessage() throws IOException { @Test void testWriteOneOfMessage() throws IOException { FieldDefinition definition = createOneOfFieldDefinition(MESSAGE); - writeMessage(bufferedData, definition, null, (data, out) -> out.writeBytes(data.toByteArray()), Apple::getSerializedSize); + writeMessage(bufferedData, definition, null, new CodecWrapper<>((data, out) -> out.writeBytes(data.toByteArray()), Apple::getSerializedSize)); bufferedData.flip(); assertEquals((definition.number() << TAG_TYPE_BITS) | WIRE_TYPE_DELIMITED.ordinal(), bufferedData.readVarInt(false)); int length = bufferedData.readVarInt(false); @@ -991,14 +991,14 @@ void testSizeOfMessageList() { assertEquals( MIN_LENGTH_VAR_SIZE * 2 + TAG_SIZE * 2 + appleStr1.length() + appleStr2.length(), - sizeOfMessageList(definition, Arrays.asList(apple1, apple2), v -> v.getVariety().length())); + sizeOfMessageList(definition, Arrays.asList(apple1, apple2), new CodecWrapper<>(null, v -> v.getVariety().length()))); } @Test void testSizeOfMessageList_empty() { assertEquals( 0, - sizeOfMessageList(createFieldDefinition(MESSAGE), emptyList(), v -> RNG.nextInt())); + sizeOfMessageList(createFieldDefinition(MESSAGE), emptyList(), new CodecWrapper<>(null, v -> RNG.nextInt()))); } @Test @@ -1108,19 +1108,20 @@ void testSizeOfMessage(){ final String appleStr = randomVarSizeString(); final Apple apple = Apple.newBuilder().setVariety(appleStr).build(); - assertEquals(MIN_LENGTH_VAR_SIZE + TAG_SIZE + appleStr.length(), sizeOfMessage(definition, apple, v -> v.getVariety().length())); + assertEquals(MIN_LENGTH_VAR_SIZE + TAG_SIZE + appleStr.length(), sizeOfMessage(definition, apple, + new CodecWrapper<>(null, v -> v.getVariety().length()))); } @Test void testSizeOfMessage_oneOf_null() { final FieldDefinition definition = createOneOfFieldDefinition(MESSAGE); - assertEquals(MIN_LENGTH_VAR_SIZE + TAG_SIZE, sizeOfMessage(definition, null, v -> RNG.nextInt())); + assertEquals(MIN_LENGTH_VAR_SIZE + TAG_SIZE, sizeOfMessage(definition, null, new CodecWrapper<>(null, v -> RNG.nextInt()))); } @Test void testSizeOfMessage_null() { final FieldDefinition definition = createFieldDefinition(MESSAGE); - assertEquals(0, sizeOfMessage(definition, null, v -> RNG.nextInt())); + assertEquals(0, sizeOfMessage(definition, null, new CodecWrapper<>(null, v -> RNG.nextInt()))); } @Test diff --git a/pbj-integration-tests/build.gradle.kts b/pbj-integration-tests/build.gradle.kts index 14e40115..76b4b93e 100644 --- a/pbj-integration-tests/build.gradle.kts +++ b/pbj-integration-tests/build.gradle.kts @@ -86,11 +86,13 @@ val cloneHederaProtobufs = sourceSets { main { pbj { + srcDir(cloneHederaProtobufs.flatMap { it.localCloneDirectory.dir("block") }) srcDir(cloneHederaProtobufs.flatMap { it.localCloneDirectory.dir("platform") }) srcDir(cloneHederaProtobufs.flatMap { it.localCloneDirectory.dir("services") }) srcDir(cloneHederaProtobufs.flatMap { it.localCloneDirectory.dir("streams") }) } proto { + srcDir(cloneHederaProtobufs.flatMap { it.localCloneDirectory.dir("block") }) srcDir(cloneHederaProtobufs.flatMap { it.localCloneDirectory.dir("platform") }) srcDir(cloneHederaProtobufs.flatMap { it.localCloneDirectory.dir("services") }) srcDir(cloneHederaProtobufs.flatMap { it.localCloneDirectory.dir("streams") }) diff --git a/pbj-integration-tests/src/jmh/java/com/hedera/pbj/intergration/jmh/ProtobufObjectBench.java b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/intergration/jmh/ProtobufObjectBench.java index 18268a90..a6d24703 100644 --- a/pbj-integration-tests/src/jmh/java/com/hedera/pbj/intergration/jmh/ProtobufObjectBench.java +++ b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/intergration/jmh/ProtobufObjectBench.java @@ -2,6 +2,7 @@ import com.google.protobuf.CodedOutputStream; import com.google.protobuf.GeneratedMessage; +import com.hedera.hapi.block.stream.protoc.Block; import com.hedera.hapi.node.base.Timestamp; import com.hedera.hapi.node.token.AccountDetails; import com.hedera.pbj.integration.AccountDetailsPbj; @@ -28,15 +29,17 @@ import org.openjdk.jmh.annotations.Warmup; import org.openjdk.jmh.infra.Blackhole; +import java.io.BufferedInputStream; import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; import java.util.concurrent.TimeUnit; +import java.util.zip.GZIPInputStream; @SuppressWarnings("unused") @Fork(1) -@Warmup(iterations = 2, time = 2) -@Measurement(iterations = 5, time = 2) +@Warmup(iterations = 3, time = 2) +@Measurement(iterations = 7, time = 2) @OutputTimeUnit(TimeUnit.NANOSECONDS) @BenchmarkMode(Mode.AverageTime) public abstract class ProtobufObjectBench

{ @@ -113,7 +116,7 @@ public void configure(P pbjModelObject, Codec

pbjCodec, @Benchmark @OperationsPerInvocation(OPERATION_COUNT) public void parsePbjByteArray(BenchmarkState benchmarkState, Blackhole blackhole) throws ParseException { - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < OPERATION_COUNT; i++) { benchmarkState.protobufDataBuffer.resetPosition(); blackhole.consume(benchmarkState.pbjCodec.parse(benchmarkState.protobufDataBuffer)); } @@ -122,7 +125,7 @@ public void parsePbjByteArray(BenchmarkState benchmarkState, Blackhole blac @Benchmark @OperationsPerInvocation(OPERATION_COUNT) public void parsePbjByteBuffer(BenchmarkState benchmarkState, Blackhole blackhole) throws ParseException { - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < OPERATION_COUNT; i++) { benchmarkState.protobufDataBuffer.resetPosition(); blackhole.consume(benchmarkState.pbjCodec.parse(benchmarkState.protobufDataBuffer)); } @@ -132,7 +135,7 @@ public void parsePbjByteBuffer(BenchmarkState benchmarkState, Blackhole bla @OperationsPerInvocation(OPERATION_COUNT) public void parsePbjByteBufferDirect(BenchmarkState benchmarkState, Blackhole blackhole) throws ParseException { - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < OPERATION_COUNT; i++) { benchmarkState.protobufDataBufferDirect.resetPosition(); blackhole.consume(benchmarkState.pbjCodec.parse(benchmarkState.protobufDataBufferDirect)); } @@ -140,7 +143,7 @@ public void parsePbjByteBufferDirect(BenchmarkState benchmarkState, Blackho @Benchmark @OperationsPerInvocation(OPERATION_COUNT) public void parsePbjInputStream(BenchmarkState benchmarkState, Blackhole blackhole) throws ParseException { - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < OPERATION_COUNT; i++) { benchmarkState.bin.resetPosition(); blackhole.consume(benchmarkState.pbjCodec.parse(new ReadableStreamingData(benchmarkState.bin))); } @@ -149,14 +152,14 @@ public void parsePbjInputStream(BenchmarkState benchmarkState, Blackhole bl @Benchmark @OperationsPerInvocation(OPERATION_COUNT) public void parseProtoCByteArray(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException { - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < OPERATION_COUNT; i++) { blackhole.consume(benchmarkState.googleByteArrayParseMethod.parse(benchmarkState.protobuf)); } } @Benchmark @OperationsPerInvocation(OPERATION_COUNT) public void parseProtoCByteBufferDirect(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException { - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < OPERATION_COUNT; i++) { benchmarkState.protobufByteBufferDirect.position(0); blackhole.consume(benchmarkState.googleByteBufferParseMethod.parse(benchmarkState.protobufByteBufferDirect)); } @@ -164,14 +167,14 @@ public void parseProtoCByteBufferDirect(BenchmarkState benchmarkState, Blac @Benchmark @OperationsPerInvocation(OPERATION_COUNT) public void parseProtoCByteBuffer(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException { - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < OPERATION_COUNT; i++) { blackhole.consume(benchmarkState.googleByteBufferParseMethod.parse(benchmarkState.protobufByteBuffer)); } } @Benchmark @OperationsPerInvocation(OPERATION_COUNT) public void parseProtoCInputStream(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException { - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < OPERATION_COUNT; i++) { benchmarkState.bin.resetPosition(); blackhole.consume(benchmarkState.googleInputStreamParseMethod.parse(benchmarkState.bin)); } @@ -181,7 +184,7 @@ public void parseProtoCInputStream(BenchmarkState benchmarkState, Blackhole @Benchmark @OperationsPerInvocation(OPERATION_COUNT) public void writePbjByteArray(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException { - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < OPERATION_COUNT; i++) { benchmarkState.outDataBuffer.reset(); benchmarkState.pbjCodec.write(benchmarkState.pbjModelObject, benchmarkState.outDataBuffer); blackhole.consume(benchmarkState.outDataBuffer); @@ -191,7 +194,7 @@ public void writePbjByteArray(BenchmarkState benchmarkState, Blackhole blac @Benchmark @OperationsPerInvocation(OPERATION_COUNT) public void writePbjByteBuffer(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException { - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < OPERATION_COUNT; i++) { benchmarkState.outDataBuffer.reset(); benchmarkState.pbjCodec.write(benchmarkState.pbjModelObject, benchmarkState.outDataBuffer); blackhole.consume(benchmarkState.outDataBuffer); @@ -200,7 +203,7 @@ public void writePbjByteBuffer(BenchmarkState benchmarkState, Blackhole bla @Benchmark @OperationsPerInvocation(OPERATION_COUNT) public void writePbjByteDirect(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException { - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < OPERATION_COUNT; i++) { benchmarkState.outDataBufferDirect.reset(); benchmarkState.pbjCodec.write(benchmarkState.pbjModelObject, benchmarkState.outDataBufferDirect); blackhole.consume(benchmarkState.outDataBufferDirect); @@ -209,7 +212,7 @@ public void writePbjByteDirect(BenchmarkState benchmarkState, Blackhole bla @Benchmark @OperationsPerInvocation(OPERATION_COUNT) public void writePbjOutputStream(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException { - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < OPERATION_COUNT; i++) { benchmarkState.bout.reset(); benchmarkState.pbjCodec.write(benchmarkState.pbjModelObject, new WritableStreamingData(benchmarkState.bout)); blackhole.consume(benchmarkState.bout.toByteArray()); @@ -219,7 +222,7 @@ public void writePbjOutputStream(BenchmarkState benchmarkState, Blackhole b @Benchmark @OperationsPerInvocation(OPERATION_COUNT) public void writeProtoCByteArray(BenchmarkState benchmarkState, Blackhole blackhole) { - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < OPERATION_COUNT; i++) { blackhole.consume(benchmarkState.googleModelObject.toByteArray()); } } @@ -227,7 +230,7 @@ public void writeProtoCByteArray(BenchmarkState benchmarkState, Blackhole b @Benchmark @OperationsPerInvocation(OPERATION_COUNT) public void writeProtoCByteBuffer(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException { - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < OPERATION_COUNT; i++) { CodedOutputStream cout = CodedOutputStream.newInstance(benchmarkState.bbout); benchmarkState.googleModelObject.writeTo(cout); blackhole.consume(benchmarkState.bbout); @@ -237,7 +240,7 @@ public void writeProtoCByteBuffer(BenchmarkState benchmarkState, Blackhole @Benchmark @OperationsPerInvocation(OPERATION_COUNT) public void writeProtoCByteBufferDirect(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException { - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < OPERATION_COUNT; i++) { CodedOutputStream cout = CodedOutputStream.newInstance(benchmarkState.bboutDirect); benchmarkState.googleModelObject.writeTo(cout); blackhole.consume(benchmarkState.bbout); @@ -247,7 +250,7 @@ public void writeProtoCByteBufferDirect(BenchmarkState benchmarkState, Blac @Benchmark @OperationsPerInvocation(OPERATION_COUNT) public void writeProtoCOutputStream(BenchmarkState benchmarkState, Blackhole blackhole) throws IOException { - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < OPERATION_COUNT; i++) { benchmarkState.bout.reset(); benchmarkState.googleModelObject.writeTo(benchmarkState.bout); blackhole.consume(benchmarkState.bout.toByteArray()); diff --git a/pbj-integration-tests/src/jmh/java/com/hedera/pbj/intergration/jmh/SampleBlockBench.java b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/intergration/jmh/SampleBlockBench.java new file mode 100644 index 00000000..87ddcc9e --- /dev/null +++ b/pbj-integration-tests/src/jmh/java/com/hedera/pbj/intergration/jmh/SampleBlockBench.java @@ -0,0 +1,207 @@ +package com.hedera.pbj.intergration.jmh; + +import com.google.protobuf.CodedOutputStream; +import com.hedera.hapi.block.stream.BlockItem; +import com.hedera.hapi.block.stream.protoc.Block; +import com.hedera.pbj.integration.NonSynchronizedByteArrayInputStream; +import com.hedera.pbj.integration.NonSynchronizedByteArrayOutputStream; +import com.hedera.pbj.runtime.ParseException; +import com.hedera.pbj.runtime.io.buffer.BufferedData; +import com.hedera.pbj.runtime.io.buffer.Bytes; +import com.hedera.pbj.runtime.io.stream.ReadableStreamingData; +import com.hedera.pbj.runtime.io.stream.WritableStreamingData; +import java.util.Comparator; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +import java.io.BufferedInputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Objects; +import java.util.concurrent.TimeUnit; +import java.util.zip.GZIPInputStream; + +/** + * Benchmarks for parsing and writing a sample block using PBJ and Google Protobuf. + */ +@SuppressWarnings("unused") +@Fork(1) +@Warmup(iterations = 3, time = 2) +@Measurement(iterations = 5, time = 2) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@BenchmarkMode(Mode.AverageTime) +@State(Scope.Benchmark) +public class SampleBlockBench { + // test block + private static final com.hedera.hapi.block.stream.Block TEST_BLOCK; + private static final Block TEST_BLOCK_GOOGLE; + // input bytes + private static final byte[] TEST_BLOCK_PROTOBUF_BYTES; + private static final ByteBuffer PROTOBUF_BYTE_BUFFER; + private static final BufferedData PROTOBUF_DATA_BUFFER; + private static final ByteBuffer PROTOBUF_BYTE_BUFFER_DIRECT; + private static final BufferedData PROTOBUF_DATA_BUFFER_DIRECT; + private static final NonSynchronizedByteArrayInputStream PROTOBUF_INPUT_STREAM; + // load test block from resources + static { + // load the protobuf bytes + try (var in = new BufferedInputStream(new GZIPInputStream( + Objects.requireNonNull(SampleBlockBench.class.getResourceAsStream("/000000000000000000000000000000497558.blk.gz"))))) { + TEST_BLOCK_PROTOBUF_BYTES = in.readAllBytes(); + } catch (IOException e) { + throw new RuntimeException(e); + } + // load using PBJ + try { + TEST_BLOCK = com.hedera.hapi.block.stream.Block.PROTOBUF.parse(Bytes.wrap(TEST_BLOCK_PROTOBUF_BYTES)); + } catch (ParseException e) { + throw new RuntimeException(e); + } + // load using google protoc as well + try { + TEST_BLOCK_GOOGLE = Block.parseFrom(TEST_BLOCK_PROTOBUF_BYTES); + } catch (IOException e) { + throw new RuntimeException(e); + } + // input buffers + PROTOBUF_BYTE_BUFFER = ByteBuffer.wrap(TEST_BLOCK_PROTOBUF_BYTES); + PROTOBUF_DATA_BUFFER = BufferedData.wrap(TEST_BLOCK_PROTOBUF_BYTES); + PROTOBUF_BYTE_BUFFER_DIRECT = ByteBuffer.allocateDirect(TEST_BLOCK_PROTOBUF_BYTES.length); + PROTOBUF_BYTE_BUFFER_DIRECT.put(TEST_BLOCK_PROTOBUF_BYTES); + PROTOBUF_DATA_BUFFER_DIRECT = BufferedData.wrap(PROTOBUF_BYTE_BUFFER_DIRECT); + PROTOBUF_INPUT_STREAM = new NonSynchronizedByteArrayInputStream(TEST_BLOCK_PROTOBUF_BYTES); + ReadableStreamingData din = new ReadableStreamingData(PROTOBUF_INPUT_STREAM); + } + + // output buffers + private final NonSynchronizedByteArrayOutputStream bout = new NonSynchronizedByteArrayOutputStream(); + private final BufferedData outDataBuffer = BufferedData.allocate(TEST_BLOCK_PROTOBUF_BYTES.length); + private final BufferedData outDataBufferDirect = BufferedData.allocateOffHeap(TEST_BLOCK_PROTOBUF_BYTES.length); + private final ByteBuffer bbout = ByteBuffer.allocate(TEST_BLOCK_PROTOBUF_BYTES.length); + private final ByteBuffer bboutDirect = ByteBuffer.allocateDirect(TEST_BLOCK_PROTOBUF_BYTES.length); + + /** Same as parsePbjByteBuffer because DataBuffer.wrap(byte[]) uses ByteBuffer today, added this because makes result plotting easier */ + @Benchmark + public void parsePbjByteArray(Blackhole blackhole) throws ParseException { + PROTOBUF_DATA_BUFFER.resetPosition(); + blackhole.consume(com.hedera.hapi.block.stream.Block.PROTOBUF.parse(PROTOBUF_DATA_BUFFER)); + } + + @Benchmark + public void parsePbjByteBuffer(Blackhole blackhole) throws ParseException { + PROTOBUF_DATA_BUFFER.resetPosition(); + blackhole.consume(com.hedera.hapi.block.stream.Block.PROTOBUF.parse(PROTOBUF_DATA_BUFFER)); + } + + @Benchmark + public void parsePbjByteBufferDirect(Blackhole blackhole) + throws ParseException { + PROTOBUF_DATA_BUFFER_DIRECT.resetPosition(); + blackhole.consume(com.hedera.hapi.block.stream.Block.PROTOBUF.parse(PROTOBUF_DATA_BUFFER_DIRECT)); + } + + @Benchmark + public void parsePbjInputStream(Blackhole blackhole) throws ParseException { + PROTOBUF_INPUT_STREAM.resetPosition(); + blackhole.consume(com.hedera.hapi.block.stream.Block.PROTOBUF.parse(new ReadableStreamingData(PROTOBUF_INPUT_STREAM))); + } + + @Benchmark + public void parseProtoCByteArray(Blackhole blackhole) throws IOException { + blackhole.consume(Block.parseFrom(TEST_BLOCK_PROTOBUF_BYTES)); + } + + @Benchmark + public void parseProtoCByteBufferDirect(Blackhole blackhole) throws IOException { + PROTOBUF_BYTE_BUFFER_DIRECT.position(0); + blackhole.consume(Block.parseFrom(PROTOBUF_BYTE_BUFFER_DIRECT)); + } + + @Benchmark + public void parseProtoCByteBuffer(Blackhole blackhole) throws IOException { + blackhole.consume(Block.parseFrom(PROTOBUF_BYTE_BUFFER)); + } + + @Benchmark + public void parseProtoCInputStream(Blackhole blackhole) throws IOException { + PROTOBUF_INPUT_STREAM.resetPosition(); + blackhole.consume(Block.parseFrom(PROTOBUF_INPUT_STREAM)); + } + + /** Same as writePbjByteBuffer because DataBuffer.wrap(byte[]) uses ByteBuffer today, added this because makes result plotting easier */ + @Benchmark + public void writePbjByteArray(Blackhole blackhole) throws IOException { + outDataBuffer.reset(); + com.hedera.hapi.block.stream.Block.PROTOBUF.write(TEST_BLOCK, outDataBuffer); + blackhole.consume(outDataBuffer); + } + + /** Added as should be same as above but creates new byte[] and does extra measure. But this is used a lot */ + @Benchmark + public void writePbjToBytes(Blackhole blackhole) { + final Bytes bytes = com.hedera.hapi.block.stream.Block.PROTOBUF.toBytes(TEST_BLOCK); + blackhole.consume(bytes); + } + + @Benchmark + public void writePbjByteBuffer(Blackhole blackhole) throws IOException { + outDataBuffer.reset(); + com.hedera.hapi.block.stream.Block.PROTOBUF.write(TEST_BLOCK, outDataBuffer); + blackhole.consume(outDataBuffer); + } + + @Benchmark + public void writePbjByteDirect(Blackhole blackhole) throws IOException { + outDataBufferDirect.reset(); + com.hedera.hapi.block.stream.Block.PROTOBUF.write(TEST_BLOCK, outDataBufferDirect); + blackhole.consume(outDataBufferDirect); + } + + @Benchmark + public void writePbjOutputStream(Blackhole blackhole) throws IOException { + bout.reset(); + com.hedera.hapi.block.stream.Block.PROTOBUF.write(TEST_BLOCK, new WritableStreamingData(bout)); + blackhole.consume(bout.toByteArray()); + } + + @Benchmark + public void writeProtoCByteArray(Blackhole blackhole) { + blackhole.consume(TEST_BLOCK_GOOGLE.toByteArray()); + } + + @Benchmark + public void writeProtoCByteBuffer(Blackhole blackhole) throws IOException { + CodedOutputStream cout = CodedOutputStream.newInstance(bbout); + TEST_BLOCK_GOOGLE.writeTo(cout); + blackhole.consume(bbout); + } + + @Benchmark + public void writeProtoCByteBufferDirect(Blackhole blackhole) throws IOException { + CodedOutputStream cout = CodedOutputStream.newInstance(bboutDirect); + TEST_BLOCK_GOOGLE.writeTo(cout); + blackhole.consume(bbout); + } + + @Benchmark + public void writeProtoCOutputStream(Blackhole blackhole) throws IOException { + bout.reset(); + TEST_BLOCK_GOOGLE.writeTo(bout); + blackhole.consume(bout.toByteArray()); + } + + /** + * Handy test main method for performance profiling + * + * @param args no args needed + */ + public static void main(String[] args) { + for (int i = 0; i < 1000; i++) { + final Bytes result = com.hedera.hapi.block.stream.Block.PROTOBUF.toBytes(TEST_BLOCK); +// TEST_BLOCK_GOOGLE.toByteArray(); + } +// var biggsetItem = TEST_BLOCK.items().stream().sorted(Comparator.comparingLong(BlockItem.PROTOBUF::measureRecord)).toList().getLast(); +// final Bytes result = com.hedera.hapi.block.stream.BlockItem.PROTOBUF.toBytes(biggsetItem); + + } +} diff --git a/pbj-integration-tests/src/jmh/resources/000000000000000000000000000000497558.blk.gz b/pbj-integration-tests/src/jmh/resources/000000000000000000000000000000497558.blk.gz new file mode 100644 index 00000000..6db51b1b Binary files /dev/null and b/pbj-integration-tests/src/jmh/resources/000000000000000000000000000000497558.blk.gz differ