diff --git a/runtime-rpc/src/main/java/org/capnproto/RpcState.java b/runtime-rpc/src/main/java/org/capnproto/RpcState.java index 28a1315..369cbb6 100644 --- a/runtime-rpc/src/main/java/org/capnproto/RpcState.java +++ b/runtime-rpc/src/main/java/org/capnproto/RpcState.java @@ -168,7 +168,7 @@ final class RpcState { final int answerId; boolean active = false; PipelineHook pipeline; - CompletionStage redirectedResults; + CompletableFuture redirectedResults; RpcCallContext callContext; int[] resultExports; @@ -599,24 +599,26 @@ final class RpcState { } var pap = startCall(call.getInterfaceId(), call.getMethodId(), cap, context); + { var answer = answers.find(answerId); assert answer != null; answer.pipeline = pap.pipeline; if (redirectResults) { - answer.redirectedResults = pap.promise.thenApply(x -> { - return context.consumeRedirectedResponse(); - }); + answer.redirectedResults = pap.promise.thenApply( + void_ -> context.consumeRedirectedResponse()); // TODO cancellation deferral } else { - pap.promise.thenAccept(x -> { - context.sendReturn(); - }).exceptionally(exc -> { - context.sendErrorReturn(exc); - // TODO wait on the cancellation... - return null; + pap.promise.whenComplete((void_, exc) -> { + if (exc == null) { + context.sendReturn(); + } + else { + context.sendErrorReturn(exc); + // TODO wait on the cancellation... + } }); } } @@ -628,7 +630,6 @@ final class RpcState { } void handleReturn(IncomingRpcMessage message, RpcProtocol.Return.Reader callReturn) { - var question = questions.find(callReturn.getAnswerId()); if (question == null) { assert false: "Invalid question ID in Return message."; @@ -703,7 +704,7 @@ final class RpcState { assert false: "`Return.takeFromOtherQuestion` referenced a call that did not use `sendResultsTo.yourself`."; break; } - question.response = answer.redirectedResults.toCompletableFuture(); + question.response = answer.redirectedResults; answer.redirectedResults = null; break; @@ -1230,7 +1231,7 @@ final class RpcState { @Override public AnyPointer.Builder getResultsBuilder() { - return payload.getContent().imbue(capTable); + return this.payload.getContent().imbue(capTable); } int[] send() { @@ -1284,7 +1285,7 @@ final class RpcState { private RpcProtocol.Return.Builder returnMessage; private boolean redirectResults = false; private boolean responseSent = false; - private CompletableFuture tailCallPipelineFuture; + private CompletableFuture tailCallPipeline; private boolean cancelRequested = false; private boolean cancelAllowed = false; @@ -1336,10 +1337,10 @@ final class RpcState { @Override public CompletableFuture tailCall(RequestHook request) { var result = this.directTailCall(request); - if (this.tailCallPipelineFuture != null) { - this.tailCallPipelineFuture.complete(result.pipeline); + if (this.tailCallPipeline != null) { + this.tailCallPipeline.complete(new AnyPointer.Pipeline(result.pipeline)); } - return result.promise.toCompletableFuture().copy(); + return result.promise.copy(); } @Override @@ -1347,8 +1348,10 @@ final class RpcState { } @Override - public CompletableFuture onTailCall() { - return null; + public CompletableFuture onTailCall() { + assert this.tailCallPipeline == null: "Called onTailCall twice?"; + this.tailCallPipeline = new CompletableFuture<>(); + return this.tailCallPipeline.copy(); } @Override @@ -1441,7 +1444,7 @@ final class RpcState { message.send(); } - cleanupAnswerTable(new int[0], false); + cleanupAnswerTable(null, false); } private boolean isFirstResponder() { @@ -1453,6 +1456,10 @@ final class RpcState { } private void cleanupAnswerTable(int[] resultExports, boolean shouldFreePipeline) { + if (resultExports == null) { + resultExports = new int[0]; + } + if (this.cancelRequested) { assert resultExports.length == 0; answers.erase(this.answerId); diff --git a/runtime-rpc/src/test/java/org/capnproto/RpcTest.java b/runtime-rpc/src/test/java/org/capnproto/RpcTest.java index 0e73711..9fe93cf 100644 --- a/runtime-rpc/src/test/java/org/capnproto/RpcTest.java +++ b/runtime-rpc/src/test/java/org/capnproto/RpcTest.java @@ -406,5 +406,30 @@ public class RpcTest { //Assert.assertEquals(3, context.restorer.callCount); Assert.assertEquals(2, chainedCallCount.value()); } + + @org.junit.Test + public void testTailCall() { + var context = new TestContext(bootstrapFactory); + var caller = new Test.TestTailCaller.Client(context.connect(Test.TestSturdyRefObjectId.Tag.TEST_TAIL_CALLER)); + + var calleeCallCount = new Counter(); + var callee = new Test.TestTailCallee.Client(new RpcTestUtil.TestTailCalleeImpl(calleeCallCount)); + var request = caller.fooRequest(); + request.getParams().setI(456); + request.getParams().setCallee(callee); + + var promise = request.send(); + var dependentCall0 = promise.getC().getCallSequenceRequest().send(); + var response = promise.join(); + Assert.assertEquals(456, response.getI()); + + var dependentCall1 = promise.getC().getCallSequenceRequest().send(); + Assert.assertEquals(0, dependentCall0.join().getN()); + Assert.assertEquals(1, dependentCall1.join().getN()); + + var dependentCall2 = response.getC().getCallSequenceRequest().send(); + Assert.assertEquals(2, dependentCall2.join().getN()); + Assert.assertEquals(1, calleeCallCount.value()); + } } diff --git a/runtime/src/main/java/org/capnproto/CallContextHook.java b/runtime/src/main/java/org/capnproto/CallContextHook.java index 461db7c..9b00255 100644 --- a/runtime/src/main/java/org/capnproto/CallContextHook.java +++ b/runtime/src/main/java/org/capnproto/CallContextHook.java @@ -3,6 +3,7 @@ package org.capnproto; import java.util.concurrent.CompletableFuture; public interface CallContextHook { + AnyPointer.Reader getParams(); void releaseParams(); @@ -17,7 +18,7 @@ public interface CallContextHook { void allowCancellation(); - CompletableFuture onTailCall(); + CompletableFuture onTailCall(); ClientHook.VoidPromiseAndPipeline directTailCall(RequestHook request); } diff --git a/runtime/src/main/java/org/capnproto/Capability.java b/runtime/src/main/java/org/capnproto/Capability.java index ce1b4b1..cfb9f68 100644 --- a/runtime/src/main/java/org/capnproto/Capability.java +++ b/runtime/src/main/java/org/capnproto/Capability.java @@ -185,17 +185,16 @@ public final class Capability { } var promise = this.whenResolved().thenCompose( - x -> this.callInternal(interfaceId, methodId, ctx)); + void_ -> this.callInternal(interfaceId, methodId, ctx)); - CompletableFuture pipelinePromise = promise.thenApply(x -> { + + var pipelinePromise = promise.thenApply(x -> { ctx.releaseParams(); - return new LocalPipeline(ctx); + return (PipelineHook)new LocalPipeline(ctx); }); - var tailCall = ctx.onTailCall(); - if (tailCall != null) { - pipelinePromise = tailCall.applyToEither(pipelinePromise, pipeline -> pipeline); - } + var tailCall = ctx.onTailCall().thenApply(pipeline -> pipeline.hook); + pipelinePromise = tailCall.applyToEither(pipelinePromise, pipeline -> pipeline); return new VoidPromiseAndPipeline( promise, @@ -213,7 +212,7 @@ public final class Capability { return CompletableFuture.completedFuture(this.resolved); } else if (this.resolveTask != null) { - return this.resolveTask.thenApply(x -> this.resolved); + return this.resolveTask.thenApply(void_ -> this.resolved); } else { return null; @@ -335,7 +334,7 @@ public final class Capability { final MessageBuilder message = new MessageBuilder(); final long interfaceId; final short methodId; - ClientHook client; + final ClientHook client; LocalRequest(long interfaceId, short methodId, ClientHook client) { this.interfaceId = interfaceId; @@ -371,6 +370,7 @@ public final class Capability { } private static final class LocalPipeline implements PipelineHook { + private final CallContextHook ctx; private final AnyPointer.Reader results; @@ -396,7 +396,8 @@ public final class Capability { private static class LocalCallContext implements CallContextHook { - final CompletableFuture cancelAllowed; + final CompletableFuture cancelAllowed; + CompletableFuture tailCallPipeline; MessageBuilder request; Response response; AnyPointer.Builder responseBuilder; @@ -404,7 +405,7 @@ public final class Capability { LocalCallContext(MessageBuilder request, ClientHook clientRef, - CompletableFuture cancelAllowed) { + CompletableFuture cancelAllowed) { this.request = request; this.clientRef = clientRef; this.cancelAllowed = cancelAllowed; @@ -412,7 +413,7 @@ public final class Capability { @Override public AnyPointer.Reader getParams() { - return request.getRoot(AnyPointer.factory).asReader(); + return this.request.getRoot(AnyPointer.factory).asReader(); } @Override @@ -437,20 +438,27 @@ public final class Capability { @Override public CompletableFuture tailCall(RequestHook request) { - // TODO implement tailCall - return null; + var result = this.directTailCall(request); + if (this.tailCallPipeline != null) { + this.tailCallPipeline.complete(new AnyPointer.Pipeline(result.pipeline)); + } + return result.promise; } @Override - public CompletableFuture onTailCall() { - // TODO implement onTailCall - return null; + public CompletableFuture onTailCall() { + this.tailCallPipeline = new CompletableFuture<>(); + return this.tailCallPipeline.copy(); } @Override public ClientHook.VoidPromiseAndPipeline directTailCall(RequestHook request) { - // TODO implement directTailCall - return null; + assert this.response == null: "Can't call tailCall() after initializing the results struct."; + var promise = request.send(); + var voidPromise = promise._getResponse().thenAccept(tailResponse -> { + this.response = tailResponse; + }); + return new ClientHook.VoidPromiseAndPipeline(voidPromise, promise.pipeline().hook); } } diff --git a/runtime/src/main/java/org/capnproto/RemotePromise.java b/runtime/src/main/java/org/capnproto/RemotePromise.java index 9b60a34..0bbdc06 100644 --- a/runtime/src/main/java/org/capnproto/RemotePromise.java +++ b/runtime/src/main/java/org/capnproto/RemotePromise.java @@ -20,14 +20,15 @@ public class RemotePromise public RemotePromise(CompletableFuture> promise, AnyPointer.Pipeline pipeline) { - super(promise.thenApply(response -> { - //System.out.println("Got a response for remote promise " + promise.toString()); - return response.getResults(); - })); + super(promise.thenApply(response -> response.getResults())); this.response = promise; this.pipeline = pipeline; } + CompletableFuture> _getResponse() { + return this.response; + } + public AnyPointer.Pipeline pipeline() { return this.pipeline; }