add bounds checking in readListPointer

This commit is contained in:
David Renshaw 2021-10-01 21:11:33 -04:00
parent ddd43a491b
commit 104fb11104
3 changed files with 43 additions and 2 deletions

View file

@ -614,6 +614,24 @@ public class EncodingTest {
root.getAnyPointerField().getAs(StructList.newFactory(Test.TestAllTypes.factory)); root.getAnyPointerField().getAs(StructList.newFactory(Test.TestAllTypes.factory));
} }
// Test that we throw an exception on out-of-bounds list pointers.
// Before v0.1.11, we were vulnerable to a cpu amplification attack:
// reading an out-of-bounds pointer to list a huge number of elements of size BIT,
// when read as a struct list, would return without error.
@org.junit.Test(expected=DecodeException.class)
public void testListPointerOutOfBounds() throws DecodeException {
byte[] bytes = new byte[]
{0,0,0,0, 0,0,1,0, // struct, one pointer
1, 0x2f, 0, 0, 1, 0, -127, -128}; // list, points out of bounds.
ByteBuffer segment = ByteBuffer.wrap(bytes);
segment.order(ByteOrder.LITTLE_ENDIAN);
MessageReader message = new MessageReader(new ByteBuffer[]{segment},
ReaderOptions.DEFAULT_READER_OPTIONS);
Test.TestAnyPointer.Reader root = message.getRoot(Test.TestAnyPointer.factory);
root.getAnyPointerField().getAs(StructList.newFactory(Test.TestAllTypes.factory));
}
@org.junit.Test @org.junit.Test
public void testLongUint8List() { public void testLongUint8List() {
MessageBuilder message = new MessageBuilder(); MessageBuilder message = new MessageBuilder();

View file

@ -38,4 +38,12 @@ public class SegmentReader {
public final long get(int index) { public final long get(int index) {
return buffer.getLong(index * Constants.BYTES_PER_WORD); return buffer.getLong(index * Constants.BYTES_PER_WORD);
} }
// Verify that the `size`-long (in words) range starting at word index
// `start` is within bounds.
public final boolean in_bounds(int start, int size) {
if (start < 0 || size < 0) return false;
long sizeInWords = size * Constants.BYTES_PER_WORD;
return (long) start + sizeInWords < (long) this.buffer.capacity();
}
} }

View file

@ -38,6 +38,14 @@ final class WireHelpers {
return (int)((bits + 63) / ((long) Constants.BITS_PER_WORD)); return (int)((bits + 63) / ((long) Constants.BITS_PER_WORD));
} }
// ByteBuffer already does bounds checking, but we still want
// to check bounds in some places to avoid cpu amplification attacks.
static boolean bounds_check(SegmentReader segment,
int start,
int size) {
return segment == null || segment.in_bounds(start, size);
}
static class AllocateResult { static class AllocateResult {
public final int ptr; public final int ptr;
public final int refOffset; public final int refOffset;
@ -1166,6 +1174,9 @@ final class WireHelpers {
int ptr = resolved.ptr + 1; int ptr = resolved.ptr + 1;
resolved.segment.arena.checkReadLimit(wordCount + 1); resolved.segment.arena.checkReadLimit(wordCount + 1);
if (!bounds_check(resolved.segment, resolved.ptr, wordCount + 1)) {
throw new DecodeException("Message contains out-of-bounds list pointer");
}
int size = WirePointer.inlineCompositeListElementCount(tag); int size = WirePointer.inlineCompositeListElementCount(tag);
@ -1201,8 +1212,12 @@ final class WireHelpers {
int elementCount = ListPointer.elementCount(resolved.ref); int elementCount = ListPointer.elementCount(resolved.ref);
int step = dataSize + pointerCount * Constants.BITS_PER_POINTER; int step = dataSize + pointerCount * Constants.BITS_PER_POINTER;
resolved.segment.arena.checkReadLimit( int wordCount = roundBitsUpToWords((long)elementCount * step);
roundBitsUpToWords(elementCount * step)); resolved.segment.arena.checkReadLimit(wordCount);
if (!bounds_check(resolved.segment, resolved.ptr, wordCount)) {
throw new DecodeException("Message contains out-of-bounds list pointer");
}
if (elementSize == ElementSize.VOID) { if (elementSize == ElementSize.VOID) {
// Watch out for lists of void, which can claim to be arbitrarily large without // Watch out for lists of void, which can claim to be arbitrarily large without