From 3e2034f45d7645c08377b19224594e6bdd267dd4 Mon Sep 17 00:00:00 2001 From: David Renshaw Date: Wed, 23 Nov 2022 10:46:22 -0500 Subject: [PATCH] fix reading of upgraded pointer lists --- .../java/org/capnproto/test/EncodingTest.java | 32 +++++++++++++ compiler/src/test/schema/test.capnp | 14 ++++-- .../main/java/org/capnproto/ListReader.java | 12 +++-- .../main/java/org/capnproto/WireHelpers.java | 48 +++++++++++++------ 4 files changed, 85 insertions(+), 21 deletions(-) diff --git a/compiler/src/test/java/org/capnproto/test/EncodingTest.java b/compiler/src/test/java/org/capnproto/test/EncodingTest.java index e927f7f..5bd55f1 100644 --- a/compiler/src/test/java/org/capnproto/test/EncodingTest.java +++ b/compiler/src/test/java/org/capnproto/test/EncodingTest.java @@ -113,6 +113,38 @@ public class EncodingTest { } } + @org.junit.Test + public void testUpgradeStructReadAsOld() { + MessageBuilder builder = new MessageBuilder(); + Test.TestAnyPointer.Builder root = builder.initRoot(Test.TestAnyPointer.factory); + + { + Test.TestNewVersion.Builder newVersion = root.getAnyPointerField().initAs(Test.TestNewVersion.factory); + newVersion.setOld1(123); + newVersion.setOld2("foo"); + Test.TestNewVersion.Builder sub = newVersion.initOld3(); + sub.setOld1(456); + sub.setOld2("bar"); + + StructList.Builder names = + newVersion.initOld4(2); + + names.get(0).setTextField("alice"); + names.get(1).setTextField("bob"); + } + + { + Test.TestOldVersion.Reader oldVersion = root.getAnyPointerField().asReader().getAs(Test.TestOldVersion.factory); + Assert.assertEquals(oldVersion.getOld1(), 123); + Assert.assertEquals(oldVersion.getOld2().toString(), "foo"); + + TextList.Reader names = oldVersion.getOld4(); + Assert.assertEquals(names.size(), 2); + Assert.assertEquals("alice", names.get(0).toString()); + Assert.assertEquals("bob", names.get(1).toString()); + } + } + @org.junit.Test public void testUpgradeStructInBuilder() { MessageBuilder builder = new MessageBuilder(); diff --git a/compiler/src/test/schema/test.capnp b/compiler/src/test/schema/test.capnp index c0332cd..5522653 100644 --- a/compiler/src/test/schema/test.capnp +++ b/compiler/src/test/schema/test.capnp @@ -319,6 +319,7 @@ struct TestOldVersion { old1 @0 :Int64; old2 @1 :Text; old3 @2 :TestOldVersion; + old4 @3 :List(Text); } struct TestNewVersion { @@ -326,9 +327,16 @@ struct TestNewVersion { old1 @0 :Int64; old2 @1 :Text; old3 @2 :TestNewVersion; - new1 @3 :Int64 = 987; - new2 @4 :Text = "baz"; - new3 @5 :Data; + + struct UpgradedFromText { + textField @0 :Text; + int32Field @1 :Int32; + dataField @2 :Data; + } + old4 @3 :List(UpgradedFromText); + new1 @4 :Int64 = 987; + new2 @5 :Text = "baz"; + new3 @6 :TestDefaults; } struct TestGenerics(Foo, Bar) { diff --git a/runtime/src/main/java/org/capnproto/ListReader.java b/runtime/src/main/java/org/capnproto/ListReader.java index f63ecfe..6fd6b4e 100644 --- a/runtime/src/main/java/org/capnproto/ListReader.java +++ b/runtime/src/main/java/org/capnproto/ListReader.java @@ -131,10 +131,14 @@ public class ListReader extends CapTableReader.ReaderContext { } protected T _getPointerElement(FromPointerReader factory, int index) { - return factory.fromPointerReader(this.segment, - this.capTable, - (this.ptr + (int)((long)index * this.step / Constants.BITS_PER_BYTE)) / Constants.BYTES_PER_WORD, - this.nestingLimit); + return factory.fromPointerReader( + this.segment, + this.capTable, + (this.ptr + + (this.structDataSize / Constants.BITS_PER_BYTE) + + (int)((long)index * this.step / Constants.BITS_PER_BYTE)) + / Constants.BYTES_PER_WORD, + this.nestingLimit); } protected T _getPointerElement(FromPointerReaderBlobDefault factory, int index, diff --git a/runtime/src/main/java/org/capnproto/WireHelpers.java b/runtime/src/main/java/org/capnproto/WireHelpers.java index 8192f77..4cff56f 100644 --- a/runtime/src/main/java/org/capnproto/WireHelpers.java +++ b/runtime/src/main/java/org/capnproto/WireHelpers.java @@ -1255,8 +1255,8 @@ final class WireHelpers { throw new DecodeException("Message contains non-list pointer where list was expected."); } - byte elementSize = ListPointer.elementSize(resolved.ref); - switch (elementSize) { + byte oldSize = ListPointer.elementSize(resolved.ref); + switch (oldSize) { case ElementSize.INLINE_COMPOSITE : { int wordCount = ListPointer.inlineCompositeWordCount(resolved.ref); @@ -1269,7 +1269,8 @@ final class WireHelpers { } int size = WirePointer.inlineCompositeListElementCount(tag); - + int dataSize = StructPointer.dataSize(tag); + short ptrCount = (short)StructPointer.ptrCount(tag); int wordsPerElement = StructPointer.wordSize(tag); if ((long)size * wordsPerElement > wordCount) { @@ -1282,23 +1283,42 @@ final class WireHelpers { resolved.segment.arena.checkReadLimit(size); } - // TODO check whether the size is compatible + switch (expectedElementSize) { + case ElementSize.VOID: break; + case ElementSize.BIT: { + throw new DecodeException("Found struct list where bit list was expected"); + } + case ElementSize.BYTE: + case ElementSize.TWO_BYTES: + case ElementSize.FOUR_BYTES: + case ElementSize.EIGHT_BYTES: + if (dataSize == 0) { + throw new DecodeException( + "Expected a primitive list, but got a list of pointer-only structs"); + } + case ElementSize.POINTER: + if (ptrCount == 0) { + throw new DecodeException( + "Expected a pointer list, but got a list of data-only structs"); + } + default: break; + } - return factory.constructReader(resolved.segment, capTable, - ptr * Constants.BYTES_PER_WORD, - size, - wordsPerElement * Constants.BITS_PER_WORD, - StructPointer.dataSize(tag) * Constants.BITS_PER_WORD, - (short)StructPointer.ptrCount(tag), - nestingLimit - 1); + return factory.constructReader(resolved.segment, + ptr * Constants.BYTES_PER_WORD, + size, + wordsPerElement * Constants.BITS_PER_WORD, + dataSize * Constants.BITS_PER_WORD, + ptrCount, + nestingLimit - 1); } default : { //# This is a primitive or pointer list, but all such //# lists can also be interpreted as struct lists. We //# need to compute the data size and pointer count for //# such structs. - int dataSize = ElementSize.dataBitsPerElement(elementSize); - int pointerCount = ElementSize.pointersPerElement(elementSize); + int dataSize = ElementSize.dataBitsPerElement(oldSize); + int pointerCount = ElementSize.pointersPerElement(oldSize); int elementCount = ListPointer.elementCount(resolved.ref); int step = dataSize + pointerCount * Constants.BITS_PER_POINTER; @@ -1309,7 +1329,7 @@ final class WireHelpers { throw new DecodeException("Message contains out-of-bounds list pointer"); } - if (elementSize == ElementSize.VOID) { + if (oldSize == ElementSize.VOID) { // Watch out for lists of void, which can claim to be arbitrarily large without // having sent actual data. resolved.segment.arena.checkReadLimit(elementCount);