From 0ce52fe1357ac7308ef57694fd055e05678506e8 Mon Sep 17 00:00:00 2001 From: Vaci Koblizek Date: Fri, 6 Nov 2020 15:32:20 +0000 Subject: [PATCH] add resolve test and fix handleResolve bugs --- .../main/java/org/capnproto/Capability.java | 2 +- .../src/main/java/org/capnproto/RpcState.java | 49 +++++++++++++------ .../java/org/capnproto/CapabilityTest.java | 3 +- .../src/test/java/org/capnproto/RpcTest.java | 34 +++++++++++++ .../src/test/java/org/capnproto/TestUtil.java | 40 +++++++++++++++ 5 files changed, 110 insertions(+), 18 deletions(-) diff --git a/runtime/src/main/java/org/capnproto/Capability.java b/runtime/src/main/java/org/capnproto/Capability.java index 9a45e7d..ce1b4b1 100644 --- a/runtime/src/main/java/org/capnproto/Capability.java +++ b/runtime/src/main/java/org/capnproto/Capability.java @@ -57,7 +57,7 @@ public final class Capability { ClientHook getHook(); - default CompletionStage whenResolved() { + default CompletableFuture whenResolved() { return this.getHook().whenResolved(); } diff --git a/runtime/src/main/java/org/capnproto/RpcState.java b/runtime/src/main/java/org/capnproto/RpcState.java index b9657ee..8f1c3a8 100644 --- a/runtime/src/main/java/org/capnproto/RpcState.java +++ b/runtime/src/main/java/org/capnproto/RpcState.java @@ -434,7 +434,7 @@ final class RpcState { private void handleMessage(IncomingRpcMessage message) throws RpcException { var reader = message.getBody().getAs(RpcProtocol.Message.factory); - //System.out.println(reader.which()); + //System.out.println(this + ": Received message: " + reader.which()); switch (reader.which()) { case UNIMPLEMENTED: handleUnimplemented(reader.getUnimplemented()); @@ -744,26 +744,44 @@ final class RpcState { } private void handleResolve(IncomingRpcMessage message, RpcProtocol.Resolve.Reader resolve) { + ClientHook cap = null; + Throwable exc = null; + + switch (resolve.which()) { + case CAP: + cap = receiveCap(resolve.getCap(), message.getAttachedFds()); + break; + case EXCEPTION: + exc = RpcException.toException(resolve.getException()); + break; + default: + assert false: "Unknown 'Resolve' type."; + return; + } + + var importId = resolve.getPromiseId(); var imp = this.imports.find(resolve.getPromiseId()); if (imp == null) { return; } + if (imp.promise != null) { + assert !imp.promise.isDone(); + + // This import is an unfulfilled promise. + if (exc != null) { + imp.promise.completeExceptionally(exc); + } + else { + imp.promise.complete(cap); + } + return; + } + + // It appears this is a valid entry on the import table, but was not expected to be a + // promise. assert imp.importClient == null : "Import already resolved."; - switch (resolve.which()) { - case CAP: - var cap = receiveCap(resolve.getCap(), message.getAttachedFds()); - imp.promise.complete(cap); - break; - case EXCEPTION: - var exc = RpcException.toException(resolve.getException()); - imp.promise.completeExceptionally(exc); - break; - default: - assert false; - return; - } } private void handleRelease(RpcProtocol.Release.Reader release) { @@ -898,7 +916,7 @@ final class RpcState { var wrapped = inner.whenMoreResolved(); if (wrapped != null) { // This is a promise. Arrange for the `Resolve` message to be sent later. - export.resolveOp = resolveExportedPromise(export.exportId, wrapped); + export.resolveOp = this.resolveExportedPromise(export.exportId, wrapped); descriptor.setSenderPromise(export.exportId); } else { @@ -916,6 +934,7 @@ final class RpcState { resolution = this.getInnermostClient(resolution); var exp = exports.find(exportId); + assert exp != null; exportsByCap.remove(exp.clientHook); exp.clientHook = resolution; diff --git a/runtime/src/test/java/org/capnproto/CapabilityTest.java b/runtime/src/test/java/org/capnproto/CapabilityTest.java index 3397be4..cdfb2bb 100644 --- a/runtime/src/test/java/org/capnproto/CapabilityTest.java +++ b/runtime/src/test/java/org/capnproto/CapabilityTest.java @@ -24,14 +24,13 @@ package org.capnproto; import org.capnproto.test.Test; import org.junit.Assert; -import org.junit.*; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; class Counter { private int count = 0; - void inc() { count++; } + int inc() { return count++; } int value() { return count; } } diff --git a/runtime/src/test/java/org/capnproto/RpcTest.java b/runtime/src/test/java/org/capnproto/RpcTest.java index 0ea195a..9345cb9 100644 --- a/runtime/src/test/java/org/capnproto/RpcTest.java +++ b/runtime/src/test/java/org/capnproto/RpcTest.java @@ -369,8 +369,42 @@ public class RpcTest { handle1 = null; handle2 = null; + } + @org.junit.Test + public void testPromiseResolve() { + var context = new TestContext(bootstrapFactory); + var client = new Test.TestMoreStuff.Client(context.connect(Test.TestSturdyRefObjectId.Tag.TEST_MORE_STUFF)); + var chainedCallCount = new Counter(); + + var request = client.callFooRequest(); + var request2 = client.callFooWhenResolvedRequest(); + + var paf = new CompletableFuture(); + + { + request.getParams().setCap(new Test.TestInterface.Client(paf.copy())); + request2.getParams().setCap(new Test.TestInterface.Client(paf.copy())); + } + + var promise = request.send(); + var promise2 = request2.send(); + + // Make sure getCap() has been called on the server side by sending another call and waiting + // for it. + Assert.assertEquals(2, client.getCallSequenceRequest().send().join().getN()); + //Assert.assertEquals(3, context.restorer.callCount); + + // OK, now fulfill the local promise. + paf.complete(new Test.TestInterface.Client(new TestUtil.TestInterfaceImpl(chainedCallCount))); + + // We should now be able to wait for getCap() to finish. + Assert.assertEquals("bar", promise.join().getS().toString()); + Assert.assertEquals("bar", promise2.join().getS().toString()); + + //Assert.assertEquals(3, context.restorer.callCount); + Assert.assertEquals(2, chainedCallCount.value()); } } diff --git a/runtime/src/test/java/org/capnproto/TestUtil.java b/runtime/src/test/java/org/capnproto/TestUtil.java index 0936a21..4f2c8ac 100644 --- a/runtime/src/test/java/org/capnproto/TestUtil.java +++ b/runtime/src/test/java/org/capnproto/TestUtil.java @@ -114,6 +114,46 @@ class TestUtil { context.getResults().setHandle(new HandleImpl(this.handleCount)); return READY_NOW; } + + @Override + protected CompletableFuture getCallSequence(CallContext context) { + var result = context.getResults(); + result.setN(this.callCount.inc()); + return READY_NOW; + } + + @Override + protected CompletableFuture callFoo(CallContext context) { + this.callCount.inc(); + var params = context.getParams(); + var cap = params.getCap(); + var request = cap.fooRequest(); + request.getParams().setI(123); + request.getParams().setJ(true); + + return request.send().thenAccept(response -> { + Assert.assertEquals("foo", response.getX().toString()); + context.getResults().setS("bar"); + }); + } + + @Override + protected CompletableFuture callFooWhenResolved(CallContext context) { + this.callCount.inc(); + var params = context.getParams(); + var cap = params.getCap(); + + return cap.whenResolved().thenCompose(void_ -> { + var request = cap.fooRequest(); + request.getParams().setI(123); + request.getParams().setJ(true); + + return request.send().thenAccept(response -> { + Assert.assertEquals("foo", response.getX().toString()); + context.getResults().setS("bar"); + }); + }); + } } static class TestTailCalleeImpl extends Test.TestTailCallee.Server {