be more careful about integer overflow in Serialize.read()

Previously, we were attempting to read the entire message
in one call to fillBuffer(). This was doomed to fail
if the message had more than Integer.MAX_VALUE bytes.

After this diff, we will call fillBuffer() separately for each
segment. This approach turns out to be simpler, too.
It might imply a small performance hit for messages with
many small segments, but such messages are discouraged anyway.

This diff also adds more overflow checking in the surrounding
logic.
This commit is contained in:
David Renshaw 2021-05-12 20:13:52 -04:00
parent a6c5240790
commit 4ec14e39f9

View file

@ -37,6 +37,16 @@ public final class Serialize {
return result; return result;
} }
static int MAX_SEGMENT_WORDS = (1 << 28) - 1;
static ByteBuffer makeByteBufferForWords(int words) throws IOException {
if (words > MAX_SEGMENT_WORDS) {
// Trying to construct the segment would cause overflow.
throw new DecodeException("segment has too many words (" + words + ")");
}
return makeByteBuffer(words * Constants.BYTES_PER_WORD);
}
public static void fillBuffer(ByteBuffer buffer, ReadableByteChannel bc) throws IOException { public static void fillBuffer(ByteBuffer buffer, ReadableByteChannel bc) throws IOException {
while(buffer.hasRemaining()) { while(buffer.hasRemaining()) {
int r = bc.read(buffer); int r = bc.read(buffer);
@ -53,21 +63,22 @@ public final class Serialize {
} }
public static MessageReader read(ReadableByteChannel bc, ReaderOptions options) throws IOException { public static MessageReader read(ReadableByteChannel bc, ReaderOptions options) throws IOException {
ByteBuffer firstWord = makeByteBuffer(Constants.BYTES_PER_WORD); ByteBuffer firstWord = makeByteBufferForWords(1);
fillBuffer(firstWord, bc); fillBuffer(firstWord, bc);
int segmentCount = 1 + firstWord.getInt(0); int rawSegmentCount = firstWord.getInt(0);
if (rawSegmentCount < 0 || rawSegmentCount > 511) {
throw new DecodeException("segment count must be between 0 and 512");
}
int segmentCount = 1 + rawSegmentCount;
int segment0Size = 0; int segment0Size = 0;
if (segmentCount > 0) { if (segmentCount > 0) {
segment0Size = firstWord.getInt(4); segment0Size = firstWord.getInt(4);
} }
int totalWords = segment0Size; long totalWords = segment0Size;
if (segmentCount > 512) {
throw new IOException("too many segments");
}
// in words // in words
ArrayList<Integer> moreSizes = new ArrayList<Integer>(); ArrayList<Integer> moreSizes = new ArrayList<Integer>();
@ -86,23 +97,17 @@ public final class Serialize {
throw new DecodeException("Message size exceeds traversal limit."); throw new DecodeException("Message size exceeds traversal limit.");
} }
ByteBuffer allSegments = makeByteBuffer(totalWords * Constants.BYTES_PER_WORD);
fillBuffer(allSegments, bc);
ByteBuffer[] segmentSlices = new ByteBuffer[segmentCount]; ByteBuffer[] segmentSlices = new ByteBuffer[segmentCount];
allSegments.rewind(); segmentSlices[0] = makeByteBufferForWords(segment0Size);
segmentSlices[0] = allSegments.slice(); fillBuffer(segmentSlices[0], bc);
segmentSlices[0].limit(segment0Size * Constants.BYTES_PER_WORD); segmentSlices[0].rewind();
segmentSlices[0].order(ByteOrder.LITTLE_ENDIAN);
int offset = segment0Size; int offset = segment0Size;
for (int ii = 1; ii < segmentCount; ++ii) { for (int ii = 1; ii < segmentCount; ++ii) {
allSegments.position(offset * Constants.BYTES_PER_WORD); segmentSlices[ii] = makeByteBufferForWords(moreSizes.get(ii - 1));
segmentSlices[ii] = allSegments.slice(); fillBuffer(segmentSlices[ii], bc);
segmentSlices[ii].limit(moreSizes.get(ii - 1) * Constants.BYTES_PER_WORD); segmentSlices[ii].rewind();
segmentSlices[ii].order(ByteOrder.LITTLE_ENDIAN);
offset += moreSizes.get(ii - 1);
} }
return new MessageReader(segmentSlices, options); return new MessageReader(segmentSlices, options);
@ -112,15 +117,16 @@ public final class Serialize {
return read(bb, ReaderOptions.DEFAULT_READER_OPTIONS); return read(bb, ReaderOptions.DEFAULT_READER_OPTIONS);
} }
/* /**
* Upon return, `bb.position()` will be at the end of the message. * Upon return, `bb.position()` will be at the end of the message.
*/ */
public static MessageReader read(ByteBuffer bb, ReaderOptions options) throws IOException { public static MessageReader read(ByteBuffer bb, ReaderOptions options) throws IOException {
bb.order(ByteOrder.LITTLE_ENDIAN); bb.order(ByteOrder.LITTLE_ENDIAN);
int segmentCount = 1 + bb.getInt(); int rawSegmentCount = bb.getInt();
if (segmentCount > 512) { int segmentCount = 1 + rawSegmentCount;
throw new IOException("too many segments"); if (rawSegmentCount < 0 || rawSegmentCount > 511) {
throw new DecodeException("segment count must be between 0 and 512");
} }
ByteBuffer[] segmentSlices = new ByteBuffer[segmentCount]; ByteBuffer[] segmentSlices = new ByteBuffer[segmentCount];
@ -135,6 +141,10 @@ public final class Serialize {
for (int ii = 0; ii < segmentCount; ++ii) { for (int ii = 0; ii < segmentCount; ++ii) {
int segmentSize = bb.getInt(segmentSizesBase + ii * 4); int segmentSize = bb.getInt(segmentSizesBase + ii * 4);
if (segmentSize > MAX_SEGMENT_WORDS -
(totalWords + segmentBase / Constants.BYTES_PER_WORD)) {
throw new DecodeException("segment size is too large");
}
bb.position(segmentBase + totalWords * Constants.BYTES_PER_WORD); bb.position(segmentBase + totalWords * Constants.BYTES_PER_WORD);
segmentSlices[ii] = bb.slice(); segmentSlices[ii] = bb.slice();