Skip to content

Commit

Permalink
Merge branch 'main' into 14255-reusable-gradle-config
Browse files Browse the repository at this point in the history
  • Loading branch information
jjohannes committed Jan 10, 2025
2 parents 0242fd6 + 4088c86 commit 823bd7f
Show file tree
Hide file tree
Showing 12 changed files with 366 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ public void generate(Protobuf3Parser.MessageDefContext msgDef, final File destin
* Protobuf Codec for $modelClass model object. Generated based on protobuf schema.
*/
public final class $codecClass implements Codec<$modelClass> {
$unsetOneOfConstants
$parseMethod
$writeMethod
$measureDataMethod
$measureRecordMethod
$fastEqualsMethod
$unsetOneOfConstants
$parseMethod
$writeMethod
$measureDataMethod
$measureRecordMethod
$fastEqualsMethod
}
"""
.replace("$package", codecPackage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ private static String generateFieldSizeOfLines(final Field field, final String m
return prefix + switch (field.type()) {
case ENUM -> "size += sizeOfEnumList(%s, %s);"
.formatted(fieldDef, getValueCode);
case MESSAGE -> "size += sizeOfMessageList($fieldDef, $valueCode, $codec::measureRecord);"
case MESSAGE -> "size += sizeOfMessageList($fieldDef, $valueCode, $codec);"
.replace("$fieldDef", fieldDef)
.replace("$valueCode", getValueCode)
.replace("$codec", ((SingleField) field).messageTypeModelPackage() + "." +
Expand Down Expand Up @@ -147,7 +147,7 @@ private static String generateFieldSizeOfLines(final Field field, final String m
.formatted(fieldDef, getValueCode);
case STRING -> "size += sizeOfString(%s, %s, %s);"
.formatted(fieldDef, getValueCode, skipDefault);
case MESSAGE -> "size += sizeOfMessage($fieldDef, $valueCode, $codec::measureRecord);"
case MESSAGE -> "size += sizeOfMessage($fieldDef, $valueCode, $codec);"
.replace("$fieldDef", fieldDef)
.replace("$valueCode", getValueCode)
.replace("$codec", ((SingleField)field).messageTypeModelPackage() + "." +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ private static String generateFieldWriteLines(final Field field, final String mo
return prefix + switch(field.type()) {
case ENUM -> "writeEnumList(out, %s, %s);"
.formatted(fieldDef, getValueCode);
case MESSAGE -> "writeMessageList(out, $fieldDef, $valueCode, $codec::write, $codec::measureRecord);"
case MESSAGE -> "writeMessageList(out, $fieldDef, $valueCode, $codec);"
.replace("$fieldDef", fieldDef)
.replace("$valueCode", getValueCode)
.replace("$codec", ((SingleField)field).messageTypeModelPackage() + "." +
Expand Down Expand Up @@ -165,7 +165,7 @@ private static String generateFieldWriteLines(final Field field, final String mo
.formatted(fieldDef, getValueCode);
case STRING -> "writeString(out, %s, %s, %s);"
.formatted(fieldDef, getValueCode, skipDefault);
case MESSAGE -> "writeMessage(out, $fieldDef, $valueCode, $codec::write, $codec::measureRecord);"
case MESSAGE -> "writeMessage(out, $fieldDef, $valueCode, $codec);"
.replace("$fieldDef", fieldDef)
.replace("$valueCode", getValueCode)
.replace("$codec", ((SingleField)field).messageTypeModelPackage() + "." +
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
package com.hedera.pbj.runtime;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import com.hedera.pbj.runtime.io.ReadableSequentialData;
import com.hedera.pbj.runtime.io.WritableSequentialData;
import com.hedera.pbj.runtime.io.buffer.BufferedData;
import com.hedera.pbj.runtime.io.buffer.Bytes;
import com.hedera.pbj.runtime.io.stream.WritableStreamingData;
import edu.umd.cs.findbugs.annotations.NonNull;
import java.io.IOException;
import java.io.UncheckedIOException;

/**
* Encapsulates Serialization, Deserialization and other IO operations.
Expand Down Expand Up @@ -166,11 +165,11 @@ public interface Codec<T /*extends Record*/> {
* to write to the {@link WritableStreamingData}
*/
default Bytes toBytes(@NonNull T item) {
byte[] bytes;
try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
WritableStreamingData writableStreamingData = new WritableStreamingData(byteArrayOutputStream)) {
write(item, writableStreamingData);
bytes = byteArrayOutputStream.toByteArray();
// it is cheaper performance wise to measure the size of the object first than grow a buffer as needed
final byte[] bytes = new byte[measureRecord(item)];
final BufferedData bufferedData = BufferedData.wrap(bytes);
try {
write(item, bufferedData);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
package com.hedera.pbj.runtime;

import static com.hedera.pbj.runtime.ProtoConstants.WIRE_TYPE_DELIMITED;
import static com.hedera.pbj.runtime.ProtoConstants.WIRE_TYPE_FIXED_32_BIT;
import static com.hedera.pbj.runtime.ProtoConstants.WIRE_TYPE_FIXED_64_BIT;
import static com.hedera.pbj.runtime.ProtoConstants.WIRE_TYPE_VARINT_OR_ZIGZAG;

import com.hedera.pbj.runtime.io.WritableSequentialData;
import com.hedera.pbj.runtime.io.buffer.Bytes;
import com.hedera.pbj.runtime.io.buffer.RandomAccessData;
Expand All @@ -18,6 +13,8 @@
import java.util.function.Consumer;
import java.util.function.ToIntFunction;

import static com.hedera.pbj.runtime.ProtoConstants.*;

/**
* Static helper methods for Writers
*/
Expand Down Expand Up @@ -423,37 +420,35 @@ private static void writeBytesNoChecks(final WritableSequentialData out, final F
* @param out The data output to write to
* @param field the descriptor for the field we are writing, the field must not be repeated
* @param message the message to write
* @param writer method reference to writer for the given message type
* @param sizeOf method reference to sizeOf measure method for the given message type
* @param codec the codec for the given message type
* @throws IOException If a I/O error occurs
* @param <T> type of message
*/
public static <T> void writeMessage(final WritableSequentialData out, final FieldDefinition field,
final T message, final ProtoWriter<T> writer, final ToIntFunction<T> sizeOf) throws IOException {
final T message, final Codec<T> codec) throws IOException {
assert field.type() == FieldType.MESSAGE : "Not a message type " + field;
assert !field.repeated() : "Use writeMessageList with repeated types";
writeMessageNoChecks(out, field, message, writer, sizeOf);
writeMessageNoChecks(out, field, message, codec);
}

/**
* Write a message to data output, assuming the corresponding field is repeated. Usually this method is
* called multiple times, one for every repeated value. If all values are available immediately, {@link
* #writeMessageList(WritableSequentialData, FieldDefinition, List, ProtoWriter, ToIntFunction)} should
* #writeMessageList(WritableSequentialData, FieldDefinition, List, Codec)} should
* be used instead.
*
* @param out The data output to write to
* @param field the descriptor for the field we are writing, the field must be repeated
* @param message the message to write
* @param writer method reference to writer for the given message type
* @param sizeOf method reference to sizeOf measure method for the given message type
* @param codec the codec for the given message type
* @throws IOException If a I/O error occurs
* @param <T> type of message
*/
public static <T> void writeOneRepeatedMessage(final WritableSequentialData out, final FieldDefinition field,
final T message, final ProtoWriter<T> writer, final ToIntFunction<T> sizeOf) throws IOException {
final T message, final Codec<T> codec) throws IOException {
assert field.type() == FieldType.MESSAGE : "Not a message type " + field;
assert field.repeated() : "writeOneRepeatedMessage can only be used with repeated fields";
writeMessageNoChecks(out, field, message, writer, sizeOf);
writeMessageNoChecks(out, field, message, codec);
}

/**
Expand All @@ -462,23 +457,22 @@ public static <T> void writeOneRepeatedMessage(final WritableSequentialData out,
* @param out The data output to write to
* @param field the descriptor for the field we are writing
* @param message the message to write
* @param writer method reference to writer for the given message type
* @param sizeOf method reference to sizeOf measure method for the given message type
* @param codec the codec for the given message type
* @throws IOException If a I/O error occurs
* @param <T> type of message
*/
private static <T> void writeMessageNoChecks(final WritableSequentialData out, final FieldDefinition field,
final T message, final ProtoWriter<T> writer, final ToIntFunction<T> sizeOf) throws IOException {
final T message, final Codec<T> codec) throws IOException {
// When not a oneOf don't write default value
if (field.oneOf() && message == null) {
writeTag(out, field, WIRE_TYPE_DELIMITED);
out.writeVarInt(0, false);
} else if (message != null) {
writeTag(out, field, WIRE_TYPE_DELIMITED);
final int size = sizeOf.applyAsInt(message);
final int size = codec.measureRecord(message);
out.writeVarInt(size, false);
if (size > 0) {
writer.write(message, out);
codec.write(message, out);
}
}
}
Expand Down Expand Up @@ -900,12 +894,11 @@ public static void writeStringList(WritableSequentialData out, FieldDefinition f
* @param out The data output to write to
* @param field the descriptor for the field we are writing
* @param list the list of messages value to write
* @param writer method reference to writer method for message type
* @param sizeOf method reference to size of method for message type
* @param codec the codec for the message type
* @throws IOException If a I/O error occurs
* @param <T> type of message
*/
public static <T> void writeMessageList(WritableSequentialData out, FieldDefinition field, List<T> list, ProtoWriter<T> writer, ToIntFunction<T> sizeOf) throws IOException {
public static <T> void writeMessageList(WritableSequentialData out, FieldDefinition field, List<T> list, Codec<T> codec) throws IOException {
assert field.type() == FieldType.MESSAGE : "Not a message type " + field;
assert field.repeated() : "Use writeMessage with non-repeated types";
// When not a oneOf don't write default value
Expand All @@ -914,7 +907,7 @@ public static <T> void writeMessageList(WritableSequentialData out, FieldDefinit
}
final int listSize = list.size();
for (int i = 0; i < listSize; i++) {
writeMessageNoChecks(out, field, list.get(i), writer, sizeOf);
writeMessageNoChecks(out, field, list.get(i), codec);
}
}

Expand Down Expand Up @@ -1367,16 +1360,16 @@ public static int sizeOfBytes(FieldDefinition field, RandomAccessData value, boo
*
* @param field descriptor of field
* @param message message value to get encoded size for
* @param sizeOf method reference to sizeOf measure function for message type
* @param codec the protobuf codec for message type
* @return the number of bytes for encoded value
* @param <T> The type of the message
*/
public static <T> int sizeOfMessage(FieldDefinition field, T message, ToIntFunction<T> sizeOf) {
public static <T> int sizeOfMessage(FieldDefinition field, T message, Codec<T> codec) {
// When not a oneOf don't write default value
if (field.oneOf() && message == null) {
return sizeOfTag(field, WIRE_TYPE_DELIMITED) + 1;
} else if (message != null) {
final int size = sizeOf.applyAsInt(message);
final int size = codec.measureRecord(message);
return sizeOfTag(field, WIRE_TYPE_DELIMITED) + sizeOfVarInt32(size) + size;
} else {
return 0;
Expand All @@ -1396,20 +1389,22 @@ public static int sizeOfIntegerList(FieldDefinition field, List<Integer> list) {
return 0;
}
int size = 0;
final int listSize = list.size();
switch (field.type()) {
case INT32 -> {
for (final int i : list) {
size += sizeOfVarInt32(i);
for (int i = 0; i < listSize; i++) {
size += sizeOfVarInt32(list.get(i));
}
}
case UINT32 -> {
for (final int i : list) {
size += sizeOfUnsignedVarInt32(i);
for (int i = 0; i < listSize; i++) {
size += sizeOfUnsignedVarInt32(list.get(i));
}
}
case SINT32 -> {
for (final int i : list) {
size += sizeOfUnsignedVarInt64(((long)i << 1) ^ ((long)i >> 63));
for (int i = 0; i < listSize; i++) {
final long val = list.get(i);
size += sizeOfUnsignedVarInt64((val << 1) ^ (val >> 63));
}
}
case SFIXED32, FIXED32 -> size += FIXED32_SIZE * list.size();
Expand All @@ -1431,15 +1426,17 @@ public static int sizeOfLongList(FieldDefinition field, List<Long> list) {
return 0;
}
int size = 0;
final int listSize = list.size();
switch (field.type()) {
case INT64, UINT64 -> {
for (final long i : list) {
size += sizeOfUnsignedVarInt64(i);
for (int i = 0; i < listSize; i++) {
size += sizeOfUnsignedVarInt64(list.get(i));
}
}
case SINT64 -> {
for (final long i : list) {
size += sizeOfUnsignedVarInt64((i << 1) ^ (i >> 63));
for (int i = 0; i < listSize; i++) {
final long val = list.get(i);
size += sizeOfUnsignedVarInt64((val << 1) ^ (val >> 63));
}
}
case SFIXED64, FIXED64 -> size += FIXED64_SIZE * list.size();
Expand Down Expand Up @@ -1509,8 +1506,9 @@ public static int sizeOfEnumList(FieldDefinition field, List<? extends EnumWithP
return 0;
}
int size = 0;
for (final EnumWithProtoMetadata enumValue : list) {
size += sizeOfUnsignedVarInt64(enumValue.protoOrdinal());
final int listSize = list.size();
for (int i = 0; i < listSize; i++) {
size += sizeOfUnsignedVarInt64(list.get(i).protoOrdinal());
}
return sizeOfTag(field, WIRE_TYPE_DELIMITED) + sizeOfVarInt32(size) + size;
}
Expand All @@ -1524,8 +1522,9 @@ public static int sizeOfEnumList(FieldDefinition field, List<? extends EnumWithP
*/
public static int sizeOfStringList(FieldDefinition field, List<String> list) {
int size = 0;
for (final String value : list) {
size += sizeOfDelimited(field, sizeOfStringNoTag(value));
final int listSize = list.size();
for (int i = 0; i < listSize; i++) {
size += sizeOfDelimited(field, sizeOfStringNoTag(list.get(i)));
}
return size;
}
Expand All @@ -1535,14 +1534,15 @@ public static int sizeOfStringList(FieldDefinition field, List<String> list) {
*
* @param field descriptor of field
* @param list message list value to get encoded size for
* @param sizeOf method reference to sizeOf measure function for message type
* @param codec the protobuf codec for message type
* @return the number of bytes for encoded value
* @param <T> type for message
*/
public static <T> int sizeOfMessageList(FieldDefinition field, List<T> list, ToIntFunction<T> sizeOf) {
public static <T> int sizeOfMessageList(FieldDefinition field, List<T> list, Codec<T> codec) {
int size = 0;
for (final T value : list) {
size += sizeOfMessage(field, value, sizeOf);
final int listSize = list.size();
for (int i = 0; i < listSize; i++) {
size += sizeOfMessage(field, list.get(i), codec);
}
return size;
}
Expand All @@ -1556,12 +1556,21 @@ public static <T> int sizeOfMessageList(FieldDefinition field, List<T> list, ToI
*/
public static int sizeOfBytesList(FieldDefinition field, List<? extends RandomAccessData> list) {
int size = 0;
for (final RandomAccessData value : list) {
size += Math.toIntExact(sizeOfTag(field, WIRE_TYPE_DELIMITED) + sizeOfVarInt32(Math.toIntExact(value.length())) + value.length());
final int listSize = list.size();
for (int i = 0; i < listSize; i++) {
final long valueLength = list.get(i).length();
size += sizeOfDelimited(field, Math.toIntExact(valueLength));
}
return size;
}

/**
* Get number of bytes that would be needed to encode a field of wire type delimited
*
* @param field The field definition of the field to be measured
* @param length The length of the delimited field data in bytes
* @return the number of bytes for encoded value
*/
public static int sizeOfDelimited(final FieldDefinition field, final int length) {
return Math.toIntExact(sizeOfTag(field, WIRE_TYPE_DELIMITED) + sizeOfVarInt32(length) + length);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package com.hedera.pbj.runtime;

import com.hedera.pbj.runtime.io.ReadableSequentialData;
import com.hedera.pbj.runtime.io.WritableSequentialData;
import edu.umd.cs.findbugs.annotations.NonNull;
import java.io.IOException;
import java.util.function.ToIntFunction;

/**
* Now that we use Codecs in ProtoWriterTools and ProtoReaderTools, we need to wrap the old test lambdas ProtoWriter and
* ToIntFunction into a codec class. Only the two methods are implemented and the rest are left as unsupported.
*
* @param <T> The type of the object to be encoded/decoded
*/
class CodecWrapper<T> implements Codec<T> {
private final ProtoWriter<T> writer;
private final ToIntFunction<T> sizeOf;

CodecWrapper(ProtoWriter<T> writer, ToIntFunction<T> sizeOf) {
this.writer = writer;
this.sizeOf = sizeOf;
}

@NonNull
@Override
public T parse(@NonNull ReadableSequentialData input, boolean strictMode, int maxDepth)
throws ParseException {
throw new UnsupportedOperationException();
}

@Override
public void write(@NonNull T item, @NonNull WritableSequentialData output) throws IOException {
writer.write(item, output);
}

@Override
public int measure(@NonNull ReadableSequentialData input) throws ParseException {
throw new UnsupportedOperationException();
}

@Override
public int measureRecord(T item) {
return sizeOf.applyAsInt(item);
}

@Override
public boolean fastEquals(@NonNull T item, @NonNull ReadableSequentialData input) throws ParseException {
throw new UnsupportedOperationException();
}
}
Loading

0 comments on commit 823bd7f

Please sign in to comment.