diff --git a/runtime/src/main/java/org/capnproto/Serialize.java b/runtime/src/main/java/org/capnproto/Serialize.java index 4d9aeaa..34212d0 100644 --- a/runtime/src/main/java/org/capnproto/Serialize.java +++ b/runtime/src/main/java/org/capnproto/Serialize.java @@ -37,6 +37,16 @@ public final class Serialize { return result; } + static int MAX_SEGMENT_WORDS = (1 << 28) - 1; + + static ByteBuffer makeByteBufferForWords(int words) throws IOException { + if (words > MAX_SEGMENT_WORDS) { + // Trying to construct the segment would cause overflow. + throw new DecodeException("segment has too many words (" + words + ")"); + } + return makeByteBuffer(words * Constants.BYTES_PER_WORD); + } + public static void fillBuffer(ByteBuffer buffer, ReadableByteChannel bc) throws IOException { while(buffer.hasRemaining()) { int r = bc.read(buffer); @@ -53,21 +63,22 @@ public final class Serialize { } public static MessageReader read(ReadableByteChannel bc, ReaderOptions options) throws IOException { - ByteBuffer firstWord = makeByteBuffer(Constants.BYTES_PER_WORD); + ByteBuffer firstWord = makeByteBufferForWords(1); fillBuffer(firstWord, bc); - int segmentCount = 1 + firstWord.getInt(0); + int rawSegmentCount = firstWord.getInt(0); + if (rawSegmentCount < 0 || rawSegmentCount > 511) { + throw new DecodeException("segment count must be between 0 and 512"); + } + + int segmentCount = 1 + rawSegmentCount; int segment0Size = 0; if (segmentCount > 0) { segment0Size = firstWord.getInt(4); } - int totalWords = segment0Size; - - if (segmentCount > 512) { - throw new IOException("too many segments"); - } + long totalWords = segment0Size; // in words ArrayList moreSizes = new ArrayList(); @@ -86,23 +97,17 @@ public final class Serialize { throw new DecodeException("Message size exceeds traversal limit."); } - ByteBuffer allSegments = makeByteBuffer(totalWords * Constants.BYTES_PER_WORD); - fillBuffer(allSegments, bc); - ByteBuffer[] segmentSlices = new ByteBuffer[segmentCount]; - allSegments.rewind(); - segmentSlices[0] = allSegments.slice(); - segmentSlices[0].limit(segment0Size * Constants.BYTES_PER_WORD); - segmentSlices[0].order(ByteOrder.LITTLE_ENDIAN); + segmentSlices[0] = makeByteBufferForWords(segment0Size); + fillBuffer(segmentSlices[0], bc); + segmentSlices[0].rewind(); int offset = segment0Size; for (int ii = 1; ii < segmentCount; ++ii) { - allSegments.position(offset * Constants.BYTES_PER_WORD); - segmentSlices[ii] = allSegments.slice(); - segmentSlices[ii].limit(moreSizes.get(ii - 1) * Constants.BYTES_PER_WORD); - segmentSlices[ii].order(ByteOrder.LITTLE_ENDIAN); - offset += moreSizes.get(ii - 1); + segmentSlices[ii] = makeByteBufferForWords(moreSizes.get(ii - 1)); + fillBuffer(segmentSlices[ii], bc); + segmentSlices[ii].rewind(); } return new MessageReader(segmentSlices, options); @@ -112,15 +117,16 @@ public final class Serialize { return read(bb, ReaderOptions.DEFAULT_READER_OPTIONS); } - /* + /** * Upon return, `bb.position()` will be at the end of the message. */ public static MessageReader read(ByteBuffer bb, ReaderOptions options) throws IOException { bb.order(ByteOrder.LITTLE_ENDIAN); - int segmentCount = 1 + bb.getInt(); - if (segmentCount > 512) { - throw new IOException("too many segments"); + int rawSegmentCount = bb.getInt(); + int segmentCount = 1 + rawSegmentCount; + if (rawSegmentCount < 0 || rawSegmentCount > 511) { + throw new DecodeException("segment count must be between 0 and 512"); } ByteBuffer[] segmentSlices = new ByteBuffer[segmentCount]; @@ -135,6 +141,10 @@ public final class Serialize { for (int ii = 0; ii < segmentCount; ++ii) { int segmentSize = bb.getInt(segmentSizesBase + ii * 4); + if (segmentSize > MAX_SEGMENT_WORDS - + (totalWords + segmentBase / Constants.BYTES_PER_WORD)) { + throw new DecodeException("segment size is too large"); + } bb.position(segmentBase + totalWords * Constants.BYTES_PER_WORD); segmentSlices[ii] = bb.slice();