Skip to content

Commit 572a7af

Browse files
authored
protobuf,protobuf-lite: configurable protobuf recursion limit (#10094)
1 parent 3c01bfe commit 572a7af

File tree

4 files changed

+146
-14
lines changed

4 files changed

+146
-14
lines changed

protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java

+24-5
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,20 @@ public static void setExtensionRegistry(ExtensionRegistryLite newRegistry) {
8181
*/
8282
public static <T extends MessageLite> Marshaller<T> marshaller(T defaultInstance) {
8383
// TODO(ejona): consider changing return type to PrototypeMarshaller (assuming ABI safe)
84-
return new MessageMarshaller<>(defaultInstance);
84+
return new MessageMarshaller<>(defaultInstance, -1);
85+
}
86+
87+
/**
88+
* Creates a {@link Marshaller} for protos of the same type as {@code defaultInstance} and a
89+
* custom limit for the recursion depth. Any negative number will leave the limit to its default
90+
* value as defined by the protobuf library.
91+
*
92+
* @since 1.56.0
93+
*/
94+
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/10108")
95+
public static <T extends MessageLite> Marshaller<T> marshallerWithRecursionLimit(
96+
T defaultInstance, int recursionLimit) {
97+
return new MessageMarshaller<>(defaultInstance, recursionLimit);
8598
}
8699

87100
/**
@@ -117,18 +130,20 @@ private ProtoLiteUtils() {
117130

118131
private static final class MessageMarshaller<T extends MessageLite>
119132
implements PrototypeMarshaller<T> {
133+
120134
private static final ThreadLocal<Reference<byte[]>> bufs = new ThreadLocal<>();
121135

122136
private final Parser<T> parser;
123137
private final T defaultInstance;
138+
private final int recursionLimit;
124139

125140
@SuppressWarnings("unchecked")
126-
MessageMarshaller(T defaultInstance) {
127-
this.defaultInstance = defaultInstance;
128-
parser = (Parser<T>) defaultInstance.getParserForType();
141+
MessageMarshaller(T defaultInstance, int recursionLimit) {
142+
this.defaultInstance = checkNotNull(defaultInstance, "defaultInstance cannot be null");
143+
this.parser = (Parser<T>) defaultInstance.getParserForType();
144+
this.recursionLimit = recursionLimit;
129145
}
130146

131-
132147
@SuppressWarnings("unchecked")
133148
@Override
134149
public Class<T> getMessageClass() {
@@ -211,6 +226,10 @@ public T parse(InputStream stream) {
211226
// when parsing.
212227
cis.setSizeLimit(Integer.MAX_VALUE);
213228

229+
if (recursionLimit >= 0) {
230+
cis.setRecursionLimit(recursionLimit);
231+
}
232+
214233
try {
215234
return parseFrom(cis);
216235
} catch (InvalidProtocolBufferException ipbe) {

protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java

+95-8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import static org.junit.Assert.assertEquals;
2121
import static org.junit.Assert.assertNotNull;
2222
import static org.junit.Assert.assertSame;
23+
import static org.junit.Assert.assertThrows;
2324
import static org.junit.Assert.fail;
2425

2526
import com.google.common.io.ByteStreams;
@@ -36,6 +37,7 @@
3637
import io.grpc.Status;
3738
import io.grpc.StatusRuntimeException;
3839
import io.grpc.internal.GrpcUtil;
40+
import io.grpc.testing.protobuf.SimpleRecursiveMessage;
3941
import java.io.ByteArrayInputStream;
4042
import java.io.ByteArrayOutputStream;
4143
import java.io.IOException;
@@ -54,7 +56,7 @@ public class ProtoLiteUtilsTest {
5456
@SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467
5557
@Rule public final ExpectedException thrown = ExpectedException.none();
5658

57-
private Marshaller<Type> marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance());
59+
private final Marshaller<Type> marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance());
5860
private Type proto = Type.newBuilder().setName("name").build();
5961

6062
@Test
@@ -85,7 +87,7 @@ public void testInvalidatedMessage() throws Exception {
8587
}
8688

8789
@Test
88-
public void parseInvalid() throws Exception {
90+
public void parseInvalid() {
8991
InputStream is = new ByteArrayInputStream(new byte[] {-127});
9092
try {
9193
marshaller.parse(is);
@@ -97,15 +99,15 @@ public void parseInvalid() throws Exception {
9799
}
98100

99101
@Test
100-
public void testMismatch() throws Exception {
102+
public void testMismatch() {
101103
Marshaller<Enum> enumMarshaller = ProtoLiteUtils.marshaller(Enum.getDefaultInstance());
102104
// Enum's name and Type's name are both strings with tag 1.
103105
Enum altProto = Enum.newBuilder().setName(proto.getName()).build();
104106
assertEquals(proto, marshaller.parse(enumMarshaller.stream(altProto)));
105107
}
106108

107109
@Test
108-
public void introspection() throws Exception {
110+
public void introspection() {
109111
Marshaller<Enum> enumMarshaller = ProtoLiteUtils.marshaller(Enum.getDefaultInstance());
110112
PrototypeMarshaller<Enum> prototypeMarshaller = (PrototypeMarshaller<Enum>) enumMarshaller;
111113
assertSame(Enum.getDefaultInstance(), prototypeMarshaller.getMessagePrototype());
@@ -219,7 +221,7 @@ public void extensionRegistry_notNull() {
219221
}
220222

221223
@Test
222-
public void parseFromKnowLengthInputStream() throws Exception {
224+
public void parseFromKnowLengthInputStream() {
223225
Marshaller<Type> marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance());
224226
Type expect = Type.newBuilder().setName("expected name").build();
225227

@@ -232,21 +234,106 @@ public void defaultMaxMessageSize() {
232234
assertEquals(GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE, ProtoLiteUtils.DEFAULT_MAX_MESSAGE_SIZE);
233235
}
234236

237+
@Test
238+
public void testNullDefaultInstance() {
239+
String expectedMessage = "defaultInstance cannot be null";
240+
assertThrows(expectedMessage, NullPointerException.class,
241+
() -> ProtoLiteUtils.marshaller(null));
242+
243+
assertThrows(expectedMessage, NullPointerException.class,
244+
() -> ProtoLiteUtils.marshallerWithRecursionLimit(null, 10)
245+
);
246+
}
247+
248+
@Test
249+
public void givenPositiveLimit_testRecursionLimitExceeded() throws IOException {
250+
Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit(
251+
SimpleRecursiveMessage.getDefaultInstance(), 10);
252+
SimpleRecursiveMessage message = buildRecursiveMessage(12);
253+
254+
assertRecursionLimitExceeded(marshaller, message);
255+
}
256+
257+
@Test
258+
public void givenZeroLimit_testRecursionLimitExceeded() throws IOException {
259+
Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit(
260+
SimpleRecursiveMessage.getDefaultInstance(), 0);
261+
SimpleRecursiveMessage message = buildRecursiveMessage(1);
262+
263+
assertRecursionLimitExceeded(marshaller, message);
264+
}
265+
266+
@Test
267+
public void givenPositiveLimit_testRecursionLimitNotExceeded() throws IOException {
268+
Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit(
269+
SimpleRecursiveMessage.getDefaultInstance(), 15);
270+
SimpleRecursiveMessage message = buildRecursiveMessage(12);
271+
272+
assertRecursionLimitNotExceeded(marshaller, message);
273+
}
274+
275+
@Test
276+
public void givenZeroLimit_testRecursionLimitNotExceeded() throws IOException {
277+
Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit(
278+
SimpleRecursiveMessage.getDefaultInstance(), 0);
279+
SimpleRecursiveMessage message = buildRecursiveMessage(0);
280+
281+
assertRecursionLimitNotExceeded(marshaller, message);
282+
}
283+
284+
@Test
285+
public void testDefaultRecursionLimit() throws IOException {
286+
Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshaller(
287+
SimpleRecursiveMessage.getDefaultInstance());
288+
SimpleRecursiveMessage message = buildRecursiveMessage(100);
289+
290+
assertRecursionLimitNotExceeded(marshaller, message);
291+
}
292+
293+
private static void assertRecursionLimitExceeded(Marshaller<SimpleRecursiveMessage> marshaller,
294+
SimpleRecursiveMessage message) throws IOException {
295+
InputStream is = marshaller.stream(message);
296+
ByteArrayInputStream bais = new ByteArrayInputStream(ByteStreams.toByteArray(is));
297+
298+
assertThrows(StatusRuntimeException.class, () -> marshaller.parse(bais));
299+
}
300+
301+
private static void assertRecursionLimitNotExceeded(Marshaller<SimpleRecursiveMessage> marshaller,
302+
SimpleRecursiveMessage message) throws IOException {
303+
InputStream is = marshaller.stream(message);
304+
ByteArrayInputStream bais = new ByteArrayInputStream(ByteStreams.toByteArray(is));
305+
306+
assertEquals(message, marshaller.parse(bais));
307+
}
308+
309+
private static SimpleRecursiveMessage buildRecursiveMessage(int depth) {
310+
SimpleRecursiveMessage.Builder builder = SimpleRecursiveMessage.newBuilder()
311+
.setValue("depth-" + depth);
312+
for (int i = depth; i > 0; i--) {
313+
builder = SimpleRecursiveMessage.newBuilder()
314+
.setValue("depth-" + i)
315+
.setMessage(builder.build());
316+
}
317+
318+
return builder.build();
319+
}
320+
235321
private static class CustomKnownLengthInputStream extends InputStream implements KnownLength {
322+
236323
private int position = 0;
237-
private byte[] source;
324+
private final byte[] source;
238325

239326
private CustomKnownLengthInputStream(byte[] source) {
240327
this.source = source;
241328
}
242329

243330
@Override
244-
public int available() throws IOException {
331+
public int available() {
245332
return source.length - position;
246333
}
247334

248335
@Override
249-
public int read() throws IOException {
336+
public int read() {
250337
if (position == source.length) {
251338
return -1;
252339
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
syntax = "proto3";
2+
3+
package grpc.testing;
4+
5+
option java_package = "io.grpc.testing.protobuf";
6+
option java_outer_classname = "SimpleRecursiveProto";
7+
option java_multiple_files = true;
8+
9+
// A simple recursive message for testing purposes
10+
message SimpleRecursiveMessage {
11+
string value = 1;
12+
SimpleRecursiveMessage message = 2;
13+
}

protobuf/src/main/java/io/grpc/protobuf/ProtoUtils.java

+14-1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,19 @@ public static <T extends Message> Marshaller<T> marshaller(final T defaultInstan
5757
return ProtoLiteUtils.marshaller(defaultInstance);
5858
}
5959

60+
/**
61+
* Creates a {@link Marshaller} for protos of the same type as {@code defaultInstance} and a
62+
* custom limit for the recursion depth. Any negative number will leave the limit to its default
63+
* value as defined by the protobuf library.
64+
*
65+
* @since 1.56.0
66+
*/
67+
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/10108")
68+
public static <T extends Message> Marshaller<T> marshallerWithRecursionLimit(T defaultInstance,
69+
int recursionLimit) {
70+
return ProtoLiteUtils.marshallerWithRecursionLimit(defaultInstance, recursionLimit);
71+
}
72+
6073
/**
6174
* Produce a metadata key for a generated protobuf type.
6275
*
@@ -70,7 +83,7 @@ public static <T extends Message> Metadata.Key<T> keyForProto(T instance) {
7083

7184
/**
7285
* Produce a metadata marshaller for a protobuf type.
73-
*
86+
*
7487
* @since 1.13.0
7588
*/
7689
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/4477")

0 commit comments

Comments
 (0)