diff --git a/compiler/src/test/scala/org/capnproto/EncodingTest.scala b/compiler/src/test/scala/org/capnproto/EncodingTest.scala index 2ac196b..6b2c3c8 100644 --- a/compiler/src/test/scala/org/capnproto/EncodingTest.scala +++ b/compiler/src/test/scala/org/capnproto/EncodingTest.scala @@ -98,6 +98,43 @@ class EncodingSuite extends FunSuite { root.getInt16Field() should equal (32767); } + test("UpgradeStructInBuilder") { + val builder = new MessageBuilder(); + val root = builder.initRoot(TestAnyPointer.factory); + + { + val oldVersion = root.getAnyPointerField().initAs(TestOldVersion.factory); + oldVersion.setOld1(123); + oldVersion.setOld2("foo"); + val sub = oldVersion.initOld3(); + sub.setOld1(456); + sub.setOld2("bar"); + } + + { + val newVersion = root.getAnyPointerField().getAs(TestNewVersion.factory); + newVersion.getOld1() should equal (123); + newVersion.getOld2().toString() should equal ("foo"); + newVersion.getNew1() should equal (987); + newVersion.getNew2().toString() should equal ("baz"); + val sub = newVersion.getOld3(); + sub.getOld1() should equal (456); + sub.getOld2().toString() should equal ("bar"); + + newVersion.setOld1(234); + newVersion.setOld2("qux"); + newVersion.setNew1(654); + newVersion.setNew2("quux"); + + } + + { + val oldVersion = root.getAnyPointerField().getAs(TestOldVersion.factory); + oldVersion.getOld1() should equal (234); + oldVersion.getOld2.toString() should equal ("qux"); + } + } + test("StructListUpgrade") { val message = new MessageBuilder(); val root = message.initRoot(TestAnyPointer.factory); @@ -254,27 +291,6 @@ class EncodingSuite extends FunSuite { } } - test("UpgradeStructInBuilder") { - val builder = new MessageBuilder(); - val root = builder.initRoot(TestAnyPointer.factory); - - val oldReader = { - val oldVersion = root.getAnyPointerField().initAs(TestOldVersion.factory); - oldVersion.setOld1(123); - oldVersion.setOld2("foo"); - val sub = oldVersion.initOld3(); - sub.setOld1(456); - sub.setOld2("bar"); - oldVersion - } - - { - //val newVersion = root.getAnyPointerField().getAsStruct(TestNewVersion.factory); - } - - //... - } - test("Constants") { assert(Void.VOID == TestConstants.VOID_CONST); assert(true == TestConstants.BOOL_CONST); diff --git a/runtime/src/main/java/org/capnproto/WireHelpers.java b/runtime/src/main/java/org/capnproto/WireHelpers.java index c546cca..eb58a14 100644 --- a/runtime/src/main/java/org/capnproto/WireHelpers.java +++ b/runtime/src/main/java/org/capnproto/WireHelpers.java @@ -378,13 +378,52 @@ final class WireHelpers { short oldDataSize = StructPointer.dataSize(resolved.ref); short oldPointerCount = StructPointer.ptrCount(resolved.ref); - int oldPointerSectionOffset = resolved.ptr + oldDataSize; + int oldPointerSection = resolved.ptr + oldDataSize; if (oldDataSize < size.data || oldPointerCount < size.pointers) { - throw new Error("unimplemented"); + //# The space allocated for this struct is too small. Unlike with readers, we can't just + //# run with it and do bounds checks at access time, because how would we handle writes? + //# Instead, we have to copy the struct to a new space now. + + short newDataSize = (short)Math.max(oldDataSize, size.data); + short newPointerCount = (short)Math.max(oldPointerCount, size.pointers); + int totalSize = newDataSize + newPointerCount * Constants.WORDS_PER_POINTER; + + //# Don't let allocate() zero out the object just yet. + zeroPointerAndFars(segment, refOffset); + + AllocateResult allocation = allocate(refOffset, segment, + totalSize, WirePointer.STRUCT); + + StructPointer.set(allocation.segment.buffer, allocation.refOffset, + newDataSize, newPointerCount); + + //# Copy data section. + memcpy(allocation.segment.buffer, allocation.ptr * Constants.BYTES_PER_WORD, + resolved.segment.buffer, resolved.ptr * Constants.BYTES_PER_WORD, + oldDataSize * Constants.BYTES_PER_WORD); + + //# Copy pointer section. + int newPointerSection = allocation.ptr + newDataSize; + for (int ii = 0; ii < oldPointerCount; ++ii) { + transferPointer(allocation.segment, newPointerSection + ii, + resolved.segment, oldPointerSection + ii); + } + + //# Zero out old location. This has two purposes: + //# 1) We don't want to leak the original contents of the struct when the message is written + //# out as it may contain secrets that the caller intends to remove from the new copy. + //# 2) Zeros will be deflated by packing, making this dead memory almost-free if it ever + //# hits the wire. + memset(resolved.segment.buffer, resolved.ptr * Constants.BYTES_PER_WORD, (byte)0, + (oldDataSize + oldPointerCount * Constants.WORDS_PER_POINTER) * Constants.BYTES_PER_WORD); + + return factory.constructBuilder(allocation.segment, allocation.ptr * Constants.BYTES_PER_WORD, + newPointerSection, newDataSize * Constants.BITS_PER_WORD, + newPointerCount); } else { return factory.constructBuilder(resolved.segment, resolved.ptr * Constants.BYTES_PER_WORD, - oldPointerSectionOffset, oldDataSize * Constants.BITS_PER_WORD, + oldPointerSection, oldDataSize * Constants.BITS_PER_WORD, oldPointerCount); }