Skip to content

Commit 6b6384a

Browse files
committed
Improve WebFlux Protobuf support
- Update javadoc for decoding default instances - Refactor and simplify tests - Add missing tests - Refactor decoding with flatMapIterable instead of concatMap and avoid recursive call Issue: SPR-15776
1 parent 8e571de commit 6b6384a

File tree

4 files changed

+135
-95
lines changed

4 files changed

+135
-95
lines changed

spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.io.IOException;
2020
import java.lang.reflect.Method;
21+
import java.util.ArrayList;
2122
import java.util.List;
2223
import java.util.Map;
2324
import java.util.concurrent.ConcurrentMap;
@@ -44,13 +45,17 @@
4445
* A {@code Decoder} that reads {@link com.google.protobuf.Message}s
4546
* using <a href="https://developers.google.com/protocol-buffers/">Google Protocol Buffers</a>.
4647
*
47-
* Flux deserialized via
48+
* <p>Flux deserialized via
4849
* {@link #decode(Publisher, ResolvableType, MimeType, Map)} are expected to use
4950
* <a href="https://developers.google.com/protocol-buffers/docs/techniques?hl=en#streaming">delimited Protobuf messages</a>
5051
* with the size of each message specified before the message itself. Single values deserialized
5152
* via {@link #decodeToMono(Publisher, ResolvableType, MimeType, Map)} are expected to use
5253
* regular Protobuf message format (without the size prepended before the message).
5354
*
55+
* <p>Notice that default instance of Protobuf message produces empty byte array, so
56+
* {@code Mono.just(Msg.getDefaultInstance())} sent over the network will be deserialized
57+
* as an empty {@link Mono}.
58+
*
5459
* <p>To generate {@code Message} Java classes, you need to install the {@code protoc} binary.
5560
*
5661
* <p>This decoder requires Protobuf 3 or higher, and supports
@@ -108,7 +113,7 @@ public Flux<Message> decode(Publisher<DataBuffer> inputStream, ResolvableType el
108113
@Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
109114

110115
return Flux.from(inputStream)
111-
.concatMap(new MessageDecoderFunction(elementType, this.maxMessageSize));
116+
.flatMapIterable(new MessageDecoderFunction(elementType, this.maxMessageSize));
112117
}
113118

114119
@Override
@@ -152,7 +157,7 @@ public List<MimeType> getDecodableMimeTypes() {
152157
}
153158

154159

155-
private class MessageDecoderFunction implements Function<DataBuffer, Publisher<? extends Message>> {
160+
private class MessageDecoderFunction implements Function<DataBuffer, Iterable<? extends Message>> {
156161

157162
private final ResolvableType elementType;
158163

@@ -163,55 +168,59 @@ private class MessageDecoderFunction implements Function<DataBuffer, Publisher<?
163168

164169
private int messageBytesToRead;
165170

171+
166172
public MessageDecoderFunction(ResolvableType elementType, int maxMessageSize) {
167173
this.elementType = elementType;
168174
this.maxMessageSize = maxMessageSize;
169175
}
170176

171-
// TODO Instead of the recursive call, loop over the current DataBuffer,
172-
// produce a list of as many messages as are contained, and save any remaining bytes with flatMapIterable
177+
173178
@Override
174-
public Publisher<? extends Message> apply(DataBuffer input) {
179+
public Iterable<? extends Message> apply(DataBuffer input) {
175180
try {
176-
if (this.output == null) {
177-
int firstByte = input.read();
178-
if (firstByte == -1) {
179-
return Flux.error(new DecodingException("Can't parse message size"));
181+
List<Message> messages = new ArrayList<>();
182+
int remainingBytesToRead;
183+
int chunkBytesToRead;
184+
185+
do {
186+
if (this.output == null) {
187+
int firstByte = input.read();
188+
if (firstByte == -1) {
189+
throw new DecodingException("Can't parse message size");
190+
}
191+
this.messageBytesToRead = CodedInputStream.readRawVarint32(firstByte, input.asInputStream());
192+
if (this.messageBytesToRead > this.maxMessageSize) {
193+
throw new DecodingException(
194+
"The number of bytes to read parsed in the incoming stream (" +
195+
this.messageBytesToRead + ") exceeds the configured limit (" + this.maxMessageSize + ")");
196+
}
197+
this.output = input.factory().allocateBuffer(this.messageBytesToRead);
180198
}
181-
this.messageBytesToRead = CodedInputStream.readRawVarint32(firstByte, input.asInputStream());
182-
if (this.messageBytesToRead > this.maxMessageSize) {
183-
return Flux.error(new DecodingException(
184-
"The number of bytes to read parsed in the incoming stream (" +
185-
this.messageBytesToRead + ") exceeds the configured limit (" + this.maxMessageSize + ")"));
199+
200+
chunkBytesToRead = this.messageBytesToRead >= input.readableByteCount() ?
201+
input.readableByteCount() : this.messageBytesToRead;
202+
remainingBytesToRead = input.readableByteCount() - chunkBytesToRead;
203+
204+
byte[] bytesToWrite = new byte[chunkBytesToRead];
205+
input.read(bytesToWrite, 0, chunkBytesToRead);
206+
this.output.write(bytesToWrite);
207+
this.messageBytesToRead -= chunkBytesToRead;
208+
209+
if (this.messageBytesToRead == 0) {
210+
Message.Builder builder = getMessageBuilder(this.elementType.toClass());
211+
builder.mergeFrom(CodedInputStream.newInstance(this.output.asByteBuffer()), extensionRegistry);
212+
messages.add(builder.build());
213+
DataBufferUtils.release(this.output);
214+
this.output = null;
186215
}
187-
this.output = input.factory().allocateBuffer(this.messageBytesToRead);
188-
}
189-
int chunkBytesToRead = this.messageBytesToRead >= input.readableByteCount() ?
190-
input.readableByteCount() : this.messageBytesToRead;
191-
int remainingBytesToRead = input.readableByteCount() - chunkBytesToRead;
192-
this.output.write(input.slice(input.readPosition(), chunkBytesToRead));
193-
this.messageBytesToRead -= chunkBytesToRead;
194-
Message message = null;
195-
if (this.messageBytesToRead == 0) {
196-
Message.Builder builder = getMessageBuilder(this.elementType.toClass());
197-
builder.mergeFrom(CodedInputStream.newInstance(this.output.asByteBuffer()), extensionRegistry);
198-
message = builder.build();
199-
DataBufferUtils.release(this.output);
200-
this.output = null;
201-
}
202-
if (remainingBytesToRead > 0) {
203-
return Mono.justOrEmpty(message).concatWith(
204-
apply(input.slice(input.readPosition() + chunkBytesToRead, remainingBytesToRead)));
205-
}
206-
else {
207-
return Mono.justOrEmpty(message);
208-
}
216+
} while (remainingBytesToRead > 0);
217+
return messages;
209218
}
210219
catch (IOException ex) {
211-
return Flux.error(new DecodingException("I/O error while parsing input stream", ex));
220+
throw new DecodingException("I/O error while parsing input stream", ex);
212221
}
213222
catch (Exception ex) {
214-
return Flux.error(new DecodingException("Could not read Protobuf message: " + ex.getMessage(), ex));
223+
throw new DecodingException("Could not read Protobuf message: " + ex.getMessage(), ex);
215224
}
216225
}
217226
}

spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufEncoder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
* An {@code Encoder} that writes {@link com.google.protobuf.Message}s
4141
* using <a href="https://developers.google.com/protocol-buffers/">Google Protocol Buffers</a>.
4242
*
43-
* Flux are serialized using
43+
* <p>Flux are serialized using
4444
* <a href="https://developers.google.com/protocol-buffers/docs/techniques?hl=en#streaming">delimited Protobuf messages</a>
4545
* with the size of each message specified before the message itself. Single values are
4646
* serialized using regular Protobuf message format (without the size prepended before the message).

spring-web/src/test/java/org/springframework/http/codec/protobuf/ProtobufDecoderTests.java

Lines changed: 70 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,7 @@
1616

1717
package org.springframework.http.codec.protobuf;
1818

19-
import java.io.ByteArrayOutputStream;
2019
import java.io.IOException;
21-
import java.io.OutputStream;
22-
import java.util.ArrayList;
23-
import java.util.Arrays;
24-
import java.util.List;
2520

2621
import com.google.protobuf.Message;
2722
import org.junit.Before;
@@ -47,16 +42,20 @@
4742

4843
/**
4944
* Unit tests for {@link ProtobufDecoder}.
50-
* TODO Make tests more readable
51-
* TODO Add a test where an input DataBuffer is larger than a message
5245
*
5346
* @author Sebastien Deleuze
5447
*/
5548
public class ProtobufDecoderTests extends AbstractDataBufferAllocatingTestCase {
5649

5750
private final static MimeType PROTOBUF_MIME_TYPE = new MimeType("application", "x-protobuf");
5851

59-
private final Msg testMsg = Msg.newBuilder().setFoo("Foo").setBlah(SecondMsg.newBuilder().setBlah(123).build()).build();
52+
private final SecondMsg secondMsg = SecondMsg.newBuilder().setBlah(123).build();
53+
54+
private final Msg testMsg = Msg.newBuilder().setFoo("Foo").setBlah(secondMsg).build();
55+
56+
private final SecondMsg secondMsg2 = SecondMsg.newBuilder().setBlah(456).build();
57+
58+
private final Msg testMsg2 = Msg.newBuilder().setFoo("Bar").setBlah(secondMsg2).build();
6059

6160
private ProtobufDecoder decoder;
6261

@@ -82,51 +81,59 @@ public void canDecode() {
8281

8382
@Test
8483
public void decodeToMono() {
85-
byte[] body = this.testMsg.toByteArray();
86-
Flux<DataBuffer> source = Flux.just(this.bufferFactory.wrap(body));
84+
DataBuffer data = this.bufferFactory.wrap(testMsg.toByteArray());
8785
ResolvableType elementType = forClass(Msg.class);
88-
Mono<Message> mono = this.decoder.decodeToMono(source, elementType, null,
89-
emptyMap());
86+
87+
Mono<Message> mono = this.decoder.decodeToMono(Flux.just(data), elementType, null, emptyMap());
9088

9189
StepVerifier.create(mono)
92-
.expectNext(this.testMsg)
90+
.expectNext(testMsg)
91+
.verifyComplete();
92+
}
93+
94+
@Test
95+
public void decodeToMonoWithLargerDataBuffer() {
96+
DataBuffer buffer = this.bufferFactory.allocateBuffer(1024);
97+
buffer.write(testMsg.toByteArray());
98+
ResolvableType elementType = forClass(Msg.class);
99+
100+
Mono<Message> mono = this.decoder.decodeToMono(Flux.just(buffer), elementType, null, emptyMap());
101+
102+
StepVerifier.create(mono)
103+
.expectNext(testMsg)
93104
.verifyComplete();
94105
}
95106

96107
@Test
97108
public void decodeChunksToMono() {
98-
byte[] body = this.testMsg.toByteArray();
99-
List<DataBuffer> chunks = new ArrayList<>();
100-
chunks.add(this.bufferFactory.wrap(Arrays.copyOfRange(body, 0, 4)));
101-
chunks.add(this.bufferFactory.wrap(Arrays.copyOfRange(body, 4, body.length)));
102-
Flux<DataBuffer> source = Flux.fromIterable(chunks);
109+
DataBuffer buffer = this.bufferFactory.wrap(testMsg.toByteArray());
110+
Flux<DataBuffer> chunks = Flux.just(
111+
buffer.slice(0, 4),
112+
buffer.slice(4, buffer.readableByteCount() - 4));
113+
DataBufferUtils.retain(buffer);
103114
ResolvableType elementType = forClass(Msg.class);
104-
Mono<Message> mono = this.decoder.decodeToMono(source, elementType, null,
115+
116+
Mono<Message> mono = this.decoder.decodeToMono(chunks, elementType, null,
105117
emptyMap());
106118

107119
StepVerifier.create(mono)
108-
.expectNext(this.testMsg)
120+
.expectNext(testMsg)
109121
.verifyComplete();
110122
}
111123

112124
@Test
113125
public void decode() throws IOException {
114-
Msg testMsg2 = Msg.newBuilder().setFoo("Bar").setBlah(SecondMsg.newBuilder().setBlah(456).build()).build();
115-
116126
DataBuffer buffer = bufferFactory.allocateBuffer();
117-
OutputStream outputStream = buffer.asOutputStream();
118-
this.testMsg.writeDelimitedTo(outputStream);
119-
127+
testMsg.writeDelimitedTo(buffer.asOutputStream());
120128
DataBuffer buffer2 = bufferFactory.allocateBuffer();
121-
OutputStream outputStream2 = buffer2.asOutputStream();
122-
testMsg2.writeDelimitedTo(outputStream2);
123-
129+
testMsg2.writeDelimitedTo(buffer2.asOutputStream());
124130
Flux<DataBuffer> source = Flux.just(buffer, buffer2);
125131
ResolvableType elementType = forClass(Msg.class);
132+
126133
Flux<Message> messages = this.decoder.decode(source, elementType, null, emptyMap());
127134

128135
StepVerifier.create(messages)
129-
.expectNext(this.testMsg)
136+
.expectNext(testMsg)
130137
.expectNext(testMsg2)
131138
.verifyComplete();
132139

@@ -135,42 +142,50 @@ public void decode() throws IOException {
135142
}
136143

137144
@Test
138-
public void decodeChunks() throws IOException {
139-
Msg testMsg2 = Msg.newBuilder().setFoo("Bar").setBlah(SecondMsg.newBuilder().setBlah(456).build()).build();
140-
List<DataBuffer> chunks = new ArrayList<>();
141-
142-
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
143-
this.testMsg.writeDelimitedTo(outputStream);
144-
byte[] byteArray = outputStream.toByteArray();
145-
ByteArrayOutputStream outputStream2 = new ByteArrayOutputStream();
146-
testMsg2.writeDelimitedTo(outputStream2);
147-
byte[] byteArray2 = outputStream2.toByteArray();
148-
149-
chunks.add(this.bufferFactory.wrap(Arrays.copyOfRange(byteArray, 0, 4)));
150-
byte[] chunk2 = Arrays.copyOfRange(byteArray, 4, byteArray.length);
151-
byte[] chunk3 = Arrays.copyOfRange(byteArray2, 0, 4);
152-
byte[] combined = new byte[chunk2.length + chunk3.length];
153-
for (int i = 0; i < combined.length; ++i)
154-
{
155-
combined[i] = i < chunk2.length ? chunk2[i] : chunk3[i - chunk2.length];
156-
}
157-
chunks.add(this.bufferFactory.wrap(combined));
158-
chunks.add(this.bufferFactory.wrap(Arrays.copyOfRange(byteArray2, 4, byteArray2.length)));
159-
160-
Flux<DataBuffer> source = Flux.fromIterable(chunks);
145+
public void decodeSplitChunks() throws IOException {
146+
DataBuffer buffer = bufferFactory.allocateBuffer();
147+
testMsg.writeDelimitedTo(buffer.asOutputStream());
148+
DataBuffer buffer2 = bufferFactory.allocateBuffer();
149+
testMsg2.writeDelimitedTo(buffer2.asOutputStream());
150+
Flux<DataBuffer> chunks = Flux.just(
151+
buffer.slice(0, 4),
152+
buffer.slice(4, buffer.readableByteCount() - 4),
153+
buffer2.slice(0, 2),
154+
buffer2.slice(2, buffer2.readableByteCount() - 2));
155+
161156
ResolvableType elementType = forClass(Msg.class);
162-
Flux<Message> messages = this.decoder.decode(source, elementType, null, emptyMap());
157+
Flux<Message> messages = this.decoder.decode(chunks, elementType, null, emptyMap());
163158

164159
StepVerifier.create(messages)
165-
.expectNext(this.testMsg)
160+
.expectNext(testMsg)
166161
.expectNext(testMsg2)
167162
.verifyComplete();
163+
164+
DataBufferUtils.release(buffer);
165+
DataBufferUtils.release(buffer2);
166+
}
167+
168+
@Test
169+
public void decodeMergedChunks() throws IOException {
170+
DataBuffer buffer = bufferFactory.allocateBuffer();
171+
testMsg.writeDelimitedTo(buffer.asOutputStream());
172+
testMsg.writeDelimitedTo(buffer.asOutputStream());
173+
174+
ResolvableType elementType = forClass(Msg.class);
175+
Flux<Message> messages = this.decoder.decode(Mono.just(buffer), elementType, null, emptyMap());
176+
177+
StepVerifier.create(messages)
178+
.expectNext(testMsg)
179+
.expectNext(testMsg)
180+
.verifyComplete();
181+
182+
DataBufferUtils.release(buffer);
168183
}
169184

170185
@Test
171186
public void exceedMaxSize() {
172187
this.decoder.setMaxMessageSize(1);
173-
byte[] body = this.testMsg.toByteArray();
188+
byte[] body = testMsg.toByteArray();
174189
Flux<DataBuffer> source = Flux.just(this.bufferFactory.wrap(body));
175190
ResolvableType elementType = forClass(Msg.class);
176191
Flux<Message> messages = this.decoder.decode(source, elementType, null,

spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ProtobufIntegrationTests.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,17 @@ public void empty() {
129129
.verifyComplete();
130130
}
131131

132+
@Test
133+
public void defaultInstance() {
134+
Mono<Msg> result = this.webClient.get()
135+
.uri("/default-instance")
136+
.retrieve()
137+
.bodyToMono(Msg.class);
138+
139+
StepVerifier.create(result)
140+
.verifyComplete();
141+
}
142+
132143
@RestController
133144
@SuppressWarnings("unused")
134145
static class ProtobufController {
@@ -153,6 +164,11 @@ Mono<Msg> empty() {
153164
return Mono.empty();
154165
}
155166

167+
@GetMapping("default-instance")
168+
Mono<Msg> defaultInstance() {
169+
return Mono.just(Msg.getDefaultInstance());
170+
}
171+
156172
}
157173

158174
@Configuration

0 commit comments

Comments
 (0)