20
20
import static org .junit .Assert .assertEquals ;
21
21
import static org .junit .Assert .assertNotNull ;
22
22
import static org .junit .Assert .assertSame ;
23
+ import static org .junit .Assert .assertThrows ;
23
24
import static org .junit .Assert .fail ;
24
25
25
26
import com .google .common .io .ByteStreams ;
36
37
import io .grpc .Status ;
37
38
import io .grpc .StatusRuntimeException ;
38
39
import io .grpc .internal .GrpcUtil ;
40
+ import io .grpc .testing .protobuf .SimpleRecursiveMessage ;
39
41
import java .io .ByteArrayInputStream ;
40
42
import java .io .ByteArrayOutputStream ;
41
43
import java .io .IOException ;
@@ -54,7 +56,7 @@ public class ProtoLiteUtilsTest {
54
56
@ SuppressWarnings ("deprecation" ) // https://github.com/grpc/grpc-java/issues/7467
55
57
@ Rule public final ExpectedException thrown = ExpectedException .none ();
56
58
57
- private Marshaller <Type > marshaller = ProtoLiteUtils .marshaller (Type .getDefaultInstance ());
59
+ private final Marshaller <Type > marshaller = ProtoLiteUtils .marshaller (Type .getDefaultInstance ());
58
60
private Type proto = Type .newBuilder ().setName ("name" ).build ();
59
61
60
62
@ Test
@@ -85,7 +87,7 @@ public void testInvalidatedMessage() throws Exception {
85
87
}
86
88
87
89
@ Test
88
- public void parseInvalid () throws Exception {
90
+ public void parseInvalid () {
89
91
InputStream is = new ByteArrayInputStream (new byte [] {-127 });
90
92
try {
91
93
marshaller .parse (is );
@@ -97,15 +99,15 @@ public void parseInvalid() throws Exception {
97
99
}
98
100
99
101
@ Test
100
- public void testMismatch () throws Exception {
102
+ public void testMismatch () {
101
103
Marshaller <Enum > enumMarshaller = ProtoLiteUtils .marshaller (Enum .getDefaultInstance ());
102
104
// Enum's name and Type's name are both strings with tag 1.
103
105
Enum altProto = Enum .newBuilder ().setName (proto .getName ()).build ();
104
106
assertEquals (proto , marshaller .parse (enumMarshaller .stream (altProto )));
105
107
}
106
108
107
109
@ Test
108
- public void introspection () throws Exception {
110
+ public void introspection () {
109
111
Marshaller <Enum > enumMarshaller = ProtoLiteUtils .marshaller (Enum .getDefaultInstance ());
110
112
PrototypeMarshaller <Enum > prototypeMarshaller = (PrototypeMarshaller <Enum >) enumMarshaller ;
111
113
assertSame (Enum .getDefaultInstance (), prototypeMarshaller .getMessagePrototype ());
@@ -219,7 +221,7 @@ public void extensionRegistry_notNull() {
219
221
}
220
222
221
223
@ Test
222
- public void parseFromKnowLengthInputStream () throws Exception {
224
+ public void parseFromKnowLengthInputStream () {
223
225
Marshaller <Type > marshaller = ProtoLiteUtils .marshaller (Type .getDefaultInstance ());
224
226
Type expect = Type .newBuilder ().setName ("expected name" ).build ();
225
227
@@ -232,21 +234,106 @@ public void defaultMaxMessageSize() {
232
234
assertEquals (GrpcUtil .DEFAULT_MAX_MESSAGE_SIZE , ProtoLiteUtils .DEFAULT_MAX_MESSAGE_SIZE );
233
235
}
234
236
237
+ @ Test
238
+ public void testNullDefaultInstance () {
239
+ String expectedMessage = "defaultInstance cannot be null" ;
240
+ assertThrows (expectedMessage , NullPointerException .class ,
241
+ () -> ProtoLiteUtils .marshaller (null ));
242
+
243
+ assertThrows (expectedMessage , NullPointerException .class ,
244
+ () -> ProtoLiteUtils .marshallerWithRecursionLimit (null , 10 )
245
+ );
246
+ }
247
+
248
+ @ Test
249
+ public void givenPositiveLimit_testRecursionLimitExceeded () throws IOException {
250
+ Marshaller <SimpleRecursiveMessage > marshaller = ProtoLiteUtils .marshallerWithRecursionLimit (
251
+ SimpleRecursiveMessage .getDefaultInstance (), 10 );
252
+ SimpleRecursiveMessage message = buildRecursiveMessage (12 );
253
+
254
+ assertRecursionLimitExceeded (marshaller , message );
255
+ }
256
+
257
+ @ Test
258
+ public void givenZeroLimit_testRecursionLimitExceeded () throws IOException {
259
+ Marshaller <SimpleRecursiveMessage > marshaller = ProtoLiteUtils .marshallerWithRecursionLimit (
260
+ SimpleRecursiveMessage .getDefaultInstance (), 0 );
261
+ SimpleRecursiveMessage message = buildRecursiveMessage (1 );
262
+
263
+ assertRecursionLimitExceeded (marshaller , message );
264
+ }
265
+
266
+ @ Test
267
+ public void givenPositiveLimit_testRecursionLimitNotExceeded () throws IOException {
268
+ Marshaller <SimpleRecursiveMessage > marshaller = ProtoLiteUtils .marshallerWithRecursionLimit (
269
+ SimpleRecursiveMessage .getDefaultInstance (), 15 );
270
+ SimpleRecursiveMessage message = buildRecursiveMessage (12 );
271
+
272
+ assertRecursionLimitNotExceeded (marshaller , message );
273
+ }
274
+
275
+ @ Test
276
+ public void givenZeroLimit_testRecursionLimitNotExceeded () throws IOException {
277
+ Marshaller <SimpleRecursiveMessage > marshaller = ProtoLiteUtils .marshallerWithRecursionLimit (
278
+ SimpleRecursiveMessage .getDefaultInstance (), 0 );
279
+ SimpleRecursiveMessage message = buildRecursiveMessage (0 );
280
+
281
+ assertRecursionLimitNotExceeded (marshaller , message );
282
+ }
283
+
284
+ @ Test
285
+ public void testDefaultRecursionLimit () throws IOException {
286
+ Marshaller <SimpleRecursiveMessage > marshaller = ProtoLiteUtils .marshaller (
287
+ SimpleRecursiveMessage .getDefaultInstance ());
288
+ SimpleRecursiveMessage message = buildRecursiveMessage (100 );
289
+
290
+ assertRecursionLimitNotExceeded (marshaller , message );
291
+ }
292
+
293
+ private static void assertRecursionLimitExceeded (Marshaller <SimpleRecursiveMessage > marshaller ,
294
+ SimpleRecursiveMessage message ) throws IOException {
295
+ InputStream is = marshaller .stream (message );
296
+ ByteArrayInputStream bais = new ByteArrayInputStream (ByteStreams .toByteArray (is ));
297
+
298
+ assertThrows (StatusRuntimeException .class , () -> marshaller .parse (bais ));
299
+ }
300
+
301
+ private static void assertRecursionLimitNotExceeded (Marshaller <SimpleRecursiveMessage > marshaller ,
302
+ SimpleRecursiveMessage message ) throws IOException {
303
+ InputStream is = marshaller .stream (message );
304
+ ByteArrayInputStream bais = new ByteArrayInputStream (ByteStreams .toByteArray (is ));
305
+
306
+ assertEquals (message , marshaller .parse (bais ));
307
+ }
308
+
309
+ private static SimpleRecursiveMessage buildRecursiveMessage (int depth ) {
310
+ SimpleRecursiveMessage .Builder builder = SimpleRecursiveMessage .newBuilder ()
311
+ .setValue ("depth-" + depth );
312
+ for (int i = depth ; i > 0 ; i --) {
313
+ builder = SimpleRecursiveMessage .newBuilder ()
314
+ .setValue ("depth-" + i )
315
+ .setMessage (builder .build ());
316
+ }
317
+
318
+ return builder .build ();
319
+ }
320
+
235
321
private static class CustomKnownLengthInputStream extends InputStream implements KnownLength {
322
+
236
323
private int position = 0 ;
237
- private byte [] source ;
324
+ private final byte [] source ;
238
325
239
326
private CustomKnownLengthInputStream (byte [] source ) {
240
327
this .source = source ;
241
328
}
242
329
243
330
@ Override
244
- public int available () throws IOException {
331
+ public int available () {
245
332
return source .length - position ;
246
333
}
247
334
248
335
@ Override
249
- public int read () throws IOException {
336
+ public int read () {
250
337
if (position == source .length ) {
251
338
return -1 ;
252
339
}
0 commit comments