diff --git a/compiler/src/test/java/org/capnproto/test/EncodingTest.java b/compiler/src/test/java/org/capnproto/test/EncodingTest.java index 1f10853..f52afd4 100644 --- a/compiler/src/test/java/org/capnproto/test/EncodingTest.java +++ b/compiler/src/test/java/org/capnproto/test/EncodingTest.java @@ -614,6 +614,24 @@ public class EncodingTest { root.getAnyPointerField().getAs(StructList.newFactory(Test.TestAllTypes.factory)); } + // Test that we throw an exception on out-of-bounds list pointers. + // Before v0.1.11, we were vulnerable to a cpu amplification attack: + // reading an out-of-bounds pointer to list a huge number of elements of size BIT, + // when read as a struct list, would return without error. + @org.junit.Test(expected=DecodeException.class) + public void testListPointerOutOfBounds() throws DecodeException { + byte[] bytes = new byte[] + {0,0,0,0, 0,0,1,0, // struct, one pointer + 1, 0x2f, 0, 0, 1, 0, -127, -128}; // list, points out of bounds. + ByteBuffer segment = ByteBuffer.wrap(bytes); + segment.order(ByteOrder.LITTLE_ENDIAN); + MessageReader message = new MessageReader(new ByteBuffer[]{segment}, + ReaderOptions.DEFAULT_READER_OPTIONS); + + Test.TestAnyPointer.Reader root = message.getRoot(Test.TestAnyPointer.factory); + root.getAnyPointerField().getAs(StructList.newFactory(Test.TestAllTypes.factory)); + } + @org.junit.Test public void testLongUint8List() { MessageBuilder message = new MessageBuilder(); diff --git a/runtime/src/main/java/org/capnproto/SegmentReader.java b/runtime/src/main/java/org/capnproto/SegmentReader.java index 82f8330..b407d21 100644 --- a/runtime/src/main/java/org/capnproto/SegmentReader.java +++ b/runtime/src/main/java/org/capnproto/SegmentReader.java @@ -38,4 +38,12 @@ public class SegmentReader { public final long get(int index) { return buffer.getLong(index * Constants.BYTES_PER_WORD); } + + // Verify that the `size`-long (in words) range starting at word index + // `start` is within bounds. + public final boolean in_bounds(int start, int size) { + if (start < 0 || size < 0) return false; + long sizeInWords = size * Constants.BYTES_PER_WORD; + return (long) start + sizeInWords < (long) this.buffer.capacity(); + } } diff --git a/runtime/src/main/java/org/capnproto/WireHelpers.java b/runtime/src/main/java/org/capnproto/WireHelpers.java index 2c9da87..64e6fff 100644 --- a/runtime/src/main/java/org/capnproto/WireHelpers.java +++ b/runtime/src/main/java/org/capnproto/WireHelpers.java @@ -38,6 +38,14 @@ final class WireHelpers { return (int)((bits + 63) / ((long) Constants.BITS_PER_WORD)); } + // ByteBuffer already does bounds checking, but we still want + // to check bounds in some places to avoid cpu amplification attacks. + static boolean bounds_check(SegmentReader segment, + int start, + int size) { + return segment == null || segment.in_bounds(start, size); + } + static class AllocateResult { public final int ptr; public final int refOffset; @@ -1166,6 +1174,9 @@ final class WireHelpers { int ptr = resolved.ptr + 1; resolved.segment.arena.checkReadLimit(wordCount + 1); + if (!bounds_check(resolved.segment, resolved.ptr, wordCount + 1)) { + throw new DecodeException("Message contains out-of-bounds list pointer"); + } int size = WirePointer.inlineCompositeListElementCount(tag); @@ -1201,8 +1212,12 @@ final class WireHelpers { int elementCount = ListPointer.elementCount(resolved.ref); int step = dataSize + pointerCount * Constants.BITS_PER_POINTER; - resolved.segment.arena.checkReadLimit( - roundBitsUpToWords(elementCount * step)); + int wordCount = roundBitsUpToWords((long)elementCount * step); + resolved.segment.arena.checkReadLimit(wordCount); + + if (!bounds_check(resolved.segment, resolved.ptr, wordCount)) { + throw new DecodeException("Message contains out-of-bounds list pointer"); + } if (elementSize == ElementSize.VOID) { // Watch out for lists of void, which can claim to be arbitrarily large without