add more robustness to integer wrapping problems

This commit is contained in:
David Renshaw 2021-10-08 14:24:22 -04:00
parent 059252cba5
commit ab303cbc28
4 changed files with 84 additions and 20 deletions

View file

@ -875,4 +875,40 @@ public class EncodingTest {
TestUtil.checkTestMessage(root2.getAs(Test.TestAllTypes.factory)); TestUtil.checkTestMessage(root2.getAs(Test.TestAllTypes.factory));
} }
@org.junit.Test
public void testZeroPointerUnderflow() throws DecodeException {
byte[] bytes = new byte[8 + 8 * 65535];
bytes[4] = -1;
bytes[5] = -1; // struct pointer with 65535 words of data section.
for (int ii = 0; ii < 8 * 65535; ++ii) {
bytes[8 + ii] = 101; // populate the data section with sentinel data
}
ByteBuffer segment = ByteBuffer.wrap(bytes);
segment.order(ByteOrder.LITTLE_ENDIAN);
MessageReader message1 = new MessageReader(new ByteBuffer[]{segment},
ReaderOptions.DEFAULT_READER_OPTIONS);
Test.TestAnyPointer.Reader message1RootReader = message1.getRoot(Test.TestAnyPointer.factory);
MessageBuilder message2Builder =
new MessageBuilder(3 * 65535); // ample space to avoid far pointers
Test.TestAnyPointer.Builder message2RootBuilder =
message2Builder.getRoot(Test.TestAnyPointer.factory);
// Copy the struct that has the sentinel data.
message2RootBuilder.getAnyPointerField().setAs(Test.TestAnyPointer.factory, message1RootReader);
// Now clear the struct pointer.
message2RootBuilder.getAnyPointerField().clear();
java.nio.ByteBuffer[] outputSegments = message2Builder.getSegmentsForOutput();
Assert.assertEquals(1, outputSegments.length);
Assert.assertEquals(0L, outputSegments[0].getLong(8)); // null because cleared
Assert.assertEquals(16 + 8 * 65535, outputSegments[0].limit());
for (int ii = 0; ii < 65535; ++ii) {
// All of the data should have been cleared.
Assert.assertEquals(0L, outputSegments[0].getLong((2 + ii) * 8));
}
}
} }

View file

@ -24,17 +24,17 @@ package org.capnproto;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
final class StructPointer{ final class StructPointer{
public static short dataSize(long ref) { public static int dataSize(long ref) {
// in words. // in words.
return (short)(WirePointer.upper32Bits(ref) & 0xffff); return WirePointer.upper32Bits(ref) & 0xffff;
} }
public static short ptrCount(long ref) { public static int ptrCount(long ref) {
return (short)(WirePointer.upper32Bits(ref) >>> 16); return WirePointer.upper32Bits(ref) >>> 16;
} }
public static int wordSize(long ref) { public static int wordSize(long ref) {
return Short.toUnsignedInt(dataSize(ref)) + Short.toUnsignedInt(ptrCount(ref)); return dataSize(ref) + ptrCount(ref);
} }
public static void setFromStructSize(ByteBuffer buffer, int offset, StructSize size) { public static void setFromStructSize(ByteBuffer buffer, int offset, StructSize size) {

View file

@ -424,8 +424,8 @@ final class WireHelpers {
} }
FollowBuilderFarsResult resolved = followBuilderFars(ref, target, segment); FollowBuilderFarsResult resolved = followBuilderFars(ref, target, segment);
short oldDataSize = StructPointer.dataSize(resolved.ref); int oldDataSize = StructPointer.dataSize(resolved.ref);
short oldPointerCount = StructPointer.ptrCount(resolved.ref); int oldPointerCount = StructPointer.ptrCount(resolved.ref);
int oldPointerSection = resolved.ptr + oldDataSize; int oldPointerSection = resolved.ptr + oldDataSize;
if (oldDataSize < size.data || oldPointerCount < size.pointers) { if (oldDataSize < size.data || oldPointerCount < size.pointers) {
@ -472,7 +472,7 @@ final class WireHelpers {
} else { } else {
return factory.constructBuilder(resolved.segment, resolved.ptr * Constants.BYTES_PER_WORD, return factory.constructBuilder(resolved.segment, resolved.ptr * Constants.BYTES_PER_WORD,
oldPointerSection, oldDataSize * Constants.BITS_PER_WORD, oldPointerSection, oldDataSize * Constants.BITS_PER_WORD,
oldPointerCount); (short)oldPointerCount);
} }
} }
@ -609,7 +609,7 @@ final class WireHelpers {
throw new DecodeException("INLINE_COMPOSITE list with non-STRUCT elements not supported."); throw new DecodeException("INLINE_COMPOSITE list with non-STRUCT elements not supported.");
} }
int oldDataSize = StructPointer.dataSize(oldTag); int oldDataSize = StructPointer.dataSize(oldTag);
short oldPointerCount = StructPointer.ptrCount(oldTag); int oldPointerCount = StructPointer.ptrCount(oldTag);
int oldStep = (oldDataSize + oldPointerCount * Constants.POINTER_SIZE_IN_WORDS); int oldStep = (oldDataSize + oldPointerCount * Constants.POINTER_SIZE_IN_WORDS);
int elementCount = WirePointer.inlineCompositeListElementCount(oldTag); int elementCount = WirePointer.inlineCompositeListElementCount(oldTag);
@ -618,7 +618,8 @@ final class WireHelpers {
return factory.constructBuilder(resolved.segment, oldPtr * Constants.BYTES_PER_WORD, return factory.constructBuilder(resolved.segment, oldPtr * Constants.BYTES_PER_WORD,
elementCount, elementCount,
oldStep * Constants.BITS_PER_WORD, oldStep * Constants.BITS_PER_WORD,
oldDataSize * Constants.BITS_PER_WORD, oldPointerCount); oldDataSize * Constants.BITS_PER_WORD,
(short)oldPointerCount);
} }
//# The structs in this list are smaller than expected, probably written using an older //# The structs in this list are smaller than expected, probably written using an older
@ -926,21 +927,21 @@ final class WireHelpers {
resolved.segment.arena.checkReadLimit(StructPointer.wordSize(resolved.ref)); resolved.segment.arena.checkReadLimit(StructPointer.wordSize(resolved.ref));
return factory.constructReader(resolved.segment, return factory.constructReader(resolved.segment,
resolved.ptr * Constants.BYTES_PER_WORD, resolved.ptr * Constants.BYTES_PER_WORD,
(resolved.ptr + dataSizeWords), (resolved.ptr + dataSizeWords),
dataSizeWords * Constants.BITS_PER_WORD, dataSizeWords * Constants.BITS_PER_WORD,
StructPointer.ptrCount(resolved.ref), (short)StructPointer.ptrCount(resolved.ref),
nestingLimit - 1); nestingLimit - 1);
} }
static SegmentBuilder setStructPointer(SegmentBuilder segment, int refOffset, StructReader value) { static SegmentBuilder setStructPointer(SegmentBuilder segment, int refOffset, StructReader value) {
short dataSize = (short)roundBitsUpToWords(value.dataSize); int dataSize = roundBitsUpToWords(value.dataSize);
int totalSize = dataSize + value.pointerCount * Constants.POINTER_SIZE_IN_WORDS; int totalSize = dataSize + value.pointerCount * Constants.POINTER_SIZE_IN_WORDS;
AllocateResult allocation = allocate(refOffset, segment, totalSize, WirePointer.STRUCT); AllocateResult allocation = allocate(refOffset, segment, totalSize, WirePointer.STRUCT);
StructPointer.set(allocation.segment.buffer, allocation.refOffset, StructPointer.set(allocation.segment.buffer, allocation.refOffset,
dataSize, value.pointerCount); (short)dataSize, value.pointerCount);
if (value.dataSize == 1) { if (value.dataSize == 1) {
throw new RuntimeException("single bit case not handled"); throw new RuntimeException("single bit case not handled");
@ -1066,7 +1067,7 @@ final class WireHelpers {
resolved.ptr * Constants.BYTES_PER_WORD, resolved.ptr * Constants.BYTES_PER_WORD,
resolved.ptr + StructPointer.dataSize(resolved.ref), resolved.ptr + StructPointer.dataSize(resolved.ref),
StructPointer.dataSize(resolved.ref) * Constants.BITS_PER_WORD, StructPointer.dataSize(resolved.ref) * Constants.BITS_PER_WORD,
StructPointer.ptrCount(resolved.ref), (short)StructPointer.ptrCount(resolved.ref),
nestingLimit - 1)); nestingLimit - 1));
case WirePointer.LIST : case WirePointer.LIST :
byte elementSize = ListPointer.elementSize(resolved.ref); byte elementSize = ListPointer.elementSize(resolved.ref);
@ -1102,7 +1103,7 @@ final class WireHelpers {
elementCount, elementCount,
wordsPerElement * Constants.BITS_PER_WORD, wordsPerElement * Constants.BITS_PER_WORD,
StructPointer.dataSize(tag) * Constants.BITS_PER_WORD, StructPointer.dataSize(tag) * Constants.BITS_PER_WORD,
StructPointer.ptrCount(tag), (short)StructPointer.ptrCount(tag),
nestingLimit - 1)); nestingLimit - 1));
} else { } else {
int dataSize = ElementSize.dataBitsPerElement(elementSize); int dataSize = ElementSize.dataBitsPerElement(elementSize);
@ -1203,7 +1204,7 @@ final class WireHelpers {
size, size,
wordsPerElement * Constants.BITS_PER_WORD, wordsPerElement * Constants.BITS_PER_WORD,
StructPointer.dataSize(tag) * Constants.BITS_PER_WORD, StructPointer.dataSize(tag) * Constants.BITS_PER_WORD,
StructPointer.ptrCount(tag), (short)StructPointer.ptrCount(tag),
nestingLimit - 1); nestingLimit - 1);
} }
default : { default : {

View file

@ -4,6 +4,33 @@ import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
public class StructPointerTest { public class StructPointerTest {
@Test
public void testDataSize() {
Assert.assertEquals(
2,
StructPointer.dataSize(0x0001000200000000L));
}
@Test
public void testDataSizeUnderflow() {
Assert.assertEquals(
0xffff,
StructPointer.dataSize(0x0001ffff00000000L));
}
@Test
public void testPtrCount() {
Assert.assertEquals(
1,
StructPointer.ptrCount(0x0001000200000000L));
}
@Test
public void testPtrCountUnderflow() {
Assert.assertEquals(
0xffff,
StructPointer.ptrCount(0xffff000200000000L));
}
@Test @Test
public void testWordSize() { public void testWordSize() {