diff --git a/runtime-rpc/src/main/java/org/capnproto/RpcState.java b/runtime-rpc/src/main/java/org/capnproto/RpcState.java index d95e5cf..49ab7db 100644 --- a/runtime-rpc/src/main/java/org/capnproto/RpcState.java +++ b/runtime-rpc/src/main/java/org/capnproto/RpcState.java @@ -1,6 +1,8 @@ package org.capnproto; import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; import java.lang.ref.Reference; import java.lang.ref.ReferenceQueue; import java.lang.ref.WeakReference; @@ -43,13 +45,7 @@ final class RpcState { this.id = id; } - void dispose() { - var ref = questions.find(this.id); - if (ref == null) { - assert false: "Question ID no longer on table?"; - return; - } - + void finish() { if (isConnected() && !this.skipFinish) { var sizeHint = messageSizeHint(RpcProtocol.Finish.factory); var message = connection.newOutgoingMessage(sizeHint); @@ -58,12 +54,18 @@ final class RpcState { builder.setReleaseResultCaps(this.isAwaitingReturn); message.send(); } + this.skipFinish = true; + } + + void dispose() { + this.finish(); // Check if the question has returned and, if so, remove it from the table. // Remove question ID from the table. Must do this *after* sending `Finish` to ensure that // the ID is not re-allocated before the `Finish` message can be sent. - assert !this.isAwaitingReturn; - questions.erase(id); + if (!this.isAwaitingReturn) { + questions.erase(this.id); + } } } @@ -115,6 +117,10 @@ final class RpcState { void setSkipFinish(boolean value) { this.disposer.skipFinish = value; } + + public void finish() { + this.disposer.finish(); + } } class QuestionExportTable { @@ -514,7 +520,7 @@ final class RpcState { final var answerId = bootstrap.getQuestionId(); var answer = answers.put(answerId); if (answer.active) { - assert false: "questionId is already in use: " + answerId; + assert false: "bootstrap questionId is already in use: " + answerId; return; } answer.active = true; @@ -574,10 +580,9 @@ final class RpcState { var payload = call.getParams(); var capTableArray = receiveCaps(payload.getCapTable(), message.getAttachedFds()); var answerId = call.getQuestionId(); - var cancel = new CompletableFuture(); var context = new RpcCallContext( answerId, message, capTableArray, - payload.getContent(), redirectResults, cancel, + payload.getContent(), redirectResults, call.getInterfaceId(), call.getMethodId()); { @@ -593,28 +598,35 @@ final class RpcState { var pap = startCall(call.getInterfaceId(), call.getMethodId(), cap, context); + // Things may have changed -- in particular if startCall() immediately called + // context->directTailCall(). + { var answer = answers.find(answerId); assert answer != null; assert answer.pipeline == null; answer.pipeline = pap.pipeline; + var callReady = pap.promise; + if (redirectResults) { - answer.redirectedResults = pap.promise.thenApply( - void_ -> context.consumeRedirectedResponse()); - // TODO cancellation deferral + answer.redirectedResults = callReady.thenApply(void_ -> + context.consumeRedirectedResponse()); } else { - pap.promise.whenComplete((void_, exc) -> { + callReady.whenComplete((void_, exc) -> { if (exc == null) { context.sendReturn(); } else { context.sendErrorReturn(exc); - // TODO wait on the cancellation... } }); } + + context.whenCancelled().thenRun(() -> { + callReady.cancel(false); + }); } } @@ -636,10 +648,10 @@ final class RpcState { } question.setAwaitingReturn(false); - var exportsToRelease = new int[0]; + int[] exportsToRelease = null; if (callReturn.getReleaseParamCaps()) { exportsToRelease = question.paramExports; - question.paramExports = new int[0]; + question.paramExports = null; } if (callReturn.isTakeFromOtherQuestion()) { @@ -647,8 +659,9 @@ final class RpcState { if (answer != null) { answer.redirectedResults = null; } - //this.questions.erase(callReturn.getAnswerId()); - this.releaseExports(exportsToRelease); + if (exportsToRelease != null) { + this.releaseExports(exportsToRelease); + } return; } @@ -661,8 +674,7 @@ final class RpcState { var payload = callReturn.getResults(); var capTable = receiveCaps(payload.getCapTable(), message.getAttachedFds()); - // TODO question, message unused in RpcResponseImpl - var response = new RpcResponseImpl(question, message, capTable, payload.getContent()); + var response = new RpcResponseImpl(capTable, payload.getContent()); question.answer(response); break; @@ -707,7 +719,9 @@ final class RpcState { break; } - this.releaseExports(exportsToRelease); + if (exportsToRelease != null) { + this.releaseExports(exportsToRelease); + } } void handleFinish(RpcProtocol.Finish.Reader finish) { @@ -734,7 +748,9 @@ final class RpcState { answers.erase(questionId); } - this.releaseExports(exportsToRelease); + if (exportsToRelease != null) { + this.releaseExports(exportsToRelease); + } } private void handleResolve(IncomingRpcMessage message, RpcProtocol.Resolve.Reader resolve) { @@ -760,14 +776,14 @@ final class RpcState { } if (imp.promise != null) { - assert !imp.promise.isDone(); - // This import is an unfulfilled promise. - if (exc != null) { - imp.promise.completeExceptionally(exc); + + assert !imp.promise.isDone(); + if (exc == null) { + imp.promise.complete(cap); } else { - imp.promise.complete(cap); + imp.promise.completeExceptionally(exc); } return; } @@ -980,7 +996,7 @@ final class RpcState { } void releaseExports(int[] exports) { - for (var exportId : exports) { + for (var exportId: exports) { this.releaseExport(exportId, 1); } } @@ -1190,16 +1206,11 @@ final class RpcState { } class RpcResponseImpl implements RpcResponse { - private final Question question; - private final IncomingRpcMessage message; + private final AnyPointer.Reader results; - RpcResponseImpl(Question question, - IncomingRpcMessage message, - List capTable, + RpcResponseImpl(List capTable, AnyPointer.Reader results) { - this.question = question; - this.message = message; this.results = results.imbue(new ReaderCapabilityTable(capTable)); } @@ -1280,11 +1291,10 @@ final class RpcState { private boolean cancelRequested = false; private boolean cancelAllowed = false; - private final CompletableFuture whenCancelled; + private final CompletableFuture canceller = new CompletableFuture<>(); RpcCallContext(int answerId, IncomingRpcMessage request, List capTable, AnyPointer.Reader params, boolean redirectResults, - CompletableFuture whenCancelled, long interfaceId, short methodId) { this.answerId = answerId; this.interfaceId = interfaceId; @@ -1292,7 +1302,6 @@ final class RpcState { this.request = request; this.params = params.imbue(new ReaderCapabilityTable(capTable)); this.redirectResults = redirectResults; - this.whenCancelled = whenCancelled; } @Override @@ -1335,6 +1344,12 @@ final class RpcState { @Override public void allowCancellation() { + boolean previouslyRequestedButNotAllowed = (this.cancelAllowed == false && this.cancelRequested == true); + this.cancelAllowed = true; + + if (previouslyRequestedButNotAllowed) { + this.canceller.complete(null); + } } @Override @@ -1380,7 +1395,7 @@ final class RpcState { return new ClientHook.VoidPromiseAndPipeline(promise, response.pipeline().hook); } - private RpcResponse consumeRedirectedResponse() { + RpcResponse consumeRedirectedResponse() { assert this.redirectResults; if (this.response == null) { @@ -1446,16 +1461,19 @@ final class RpcState { private void cleanupAnswerTable(int[] resultExports) { if (this.cancelRequested) { assert resultExports == null || resultExports.length == 0; + // Already received `Finish` so it's our job to erase the table entry. We shouldn't have + // sent results if canceled, so we shouldn't have an export list to deal with. answers.erase(this.answerId); } else { + // We just have to null out callContext and set the exports. var answer = answers.find(answerId); answer.callContext = null; answer.resultExports = resultExports; } } - public void requestCancel() { + void requestCancel() { // Hints that the caller wishes to cancel this call. At the next time when cancellation is // deemed safe, the RpcCallContext shall send a canceled Return -- or if it never becomes // safe, the RpcCallContext will send a normal return when the call completes. Either way @@ -1468,9 +1486,15 @@ final class RpcState { if (previouslyAllowedButNotRequested) { // We just set CANCEL_REQUESTED, and CANCEL_ALLOWED was already set previously. Initiate // the cancellation. - this.whenCancelled.complete(null); + this.canceller.complete(null); } - // TODO do we care about cancelRequested if further completions are effectively ignored? + } + + /** Completed by the call context when a cancellation has been + * requested and cancellation is allowed + */ + CompletableFuture whenCancelled() { + return this.canceller; } } @@ -1544,6 +1568,7 @@ final class RpcState { var params = ctx.getParams(); var request = newCallNoIntercept(interfaceId, methodId); ctx.allowCancellation(); + ctx.releaseParams(); return ctx.directTailCall(request.getHook()); } @@ -1606,7 +1631,6 @@ final class RpcState { if (redirect != null) { var redirected = redirect.newCall( this.callBuilder.getInterfaceId(), this.callBuilder.getMethodId()); - //replacement.params = paramsBuilder; var replacement = new AnyPointer.Request(paramsBuilder, redirected.getHook()); return replacement.send(); } @@ -1619,12 +1643,7 @@ final class RpcState { var appPromise = question.response.thenApply( hook -> new Response<>(hook.getResults(), hook)); - // complete when either the message loop completes (exceptionally) or - // the appPromise is fulfilled - var loop = CompletableFuture.anyOf( - getMessageLoop(), appPromise).thenCompose(x -> appPromise); - - return new RemotePromise<>(loop, new AnyPointer.Pipeline(pipeline)); + return new RemotePromise<>(appPromise, new AnyPointer.Pipeline(pipeline)); } @Override @@ -1759,7 +1778,6 @@ final class RpcState { private final ClientHook cap; private final Integer importId; - private final CompletableFuture promise; private boolean receivedCall = false; private ResolutionType resolutionType = ResolutionType.UNRESOLVED; @@ -1768,7 +1786,14 @@ final class RpcState { Integer importId) { this.cap = initial; this.importId = importId; - this.promise = eventual.thenApply(resolution -> resolve(resolution)); + eventual.whenComplete((resolution, exc) -> { + if (exc != null) { + resolve(Capability.newBrokenCap(exc)); + } + else { + resolve(resolution); + } + }); } public boolean isResolved() { @@ -1932,25 +1957,32 @@ final class RpcState { } static void FromException(Throwable exc, RpcProtocol.Exception.Builder builder) { - builder.setReason(exc.getMessage()); - builder.setType(RpcProtocol.Exception.Type.FAILED); + var type = RpcProtocol.Exception.Type.FAILED; + if (exc instanceof RpcException) { + var rpcExc = (RpcException) exc; + type = switch (rpcExc.getType()) { + case FAILED -> RpcProtocol.Exception.Type.FAILED; + case OVERLOADED -> RpcProtocol.Exception.Type.OVERLOADED; + case DISCONNECTED -> RpcProtocol.Exception.Type.DISCONNECTED; + case UNIMPLEMENTED -> RpcProtocol.Exception.Type.UNIMPLEMENTED; + default -> RpcProtocol.Exception.Type.FAILED; + }; + } + builder.setType(type); + + var writer = new StringWriter(); + exc.printStackTrace(new PrintWriter(writer)); + builder.setReason(writer.toString()); } static RpcException ToException(RpcProtocol.Exception.Reader reader) { - var type = RpcException.Type.UNKNOWN; - - switch (reader.getType()) { - case UNIMPLEMENTED: - type = RpcException.Type.UNIMPLEMENTED; - break; - case FAILED: - type = RpcException.Type.FAILED; - break; - case DISCONNECTED: - case OVERLOADED: - default: - break; - } + var type = switch (reader.getType()) { + case FAILED -> RpcException.Type.FAILED; + case OVERLOADED -> RpcException.Type.OVERLOADED; + case DISCONNECTED -> RpcException.Type.DISCONNECTED; + case UNIMPLEMENTED -> RpcException.Type.UNIMPLEMENTED; + default -> RpcException.Type.FAILED; + }; return new RpcException(type, reader.getReason().toString()); } diff --git a/runtime-rpc/src/test/java/org/capnproto/RpcTest.java b/runtime-rpc/src/test/java/org/capnproto/RpcTest.java index 4c8faa9..f0a3ae4 100644 --- a/runtime-rpc/src/test/java/org/capnproto/RpcTest.java +++ b/runtime-rpc/src/test/java/org/capnproto/RpcTest.java @@ -25,11 +25,15 @@ import org.capnproto.rpctest.Test; import org.junit.Assert; +import java.lang.ref.ReferenceQueue; +import java.lang.ref.WeakReference; import java.util.ArrayDeque; import java.util.HashMap; import java.util.Map; import java.util.Queue; +import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; @@ -505,5 +509,32 @@ public class RpcTest { // Verify that we are still connected getCallSequence(client, 1).get(); } + + @org.junit.Test + public void testCallCancel() { + var context = new TestContext(bootstrapFactory); + var client = new Test.TestMoreStuff.Client(context.connect(Test.TestSturdyRefObjectId.Tag.TEST_MORE_STUFF)); + + var request = client.expectCancelRequest(); + var cap = new RpcTestUtil.TestCapDestructor(); + request.getParams().setCap(cap); + + // auto-close the request without waiting for a response, triggering a cancellation request. + try (var response = request.send()) { + response.thenRun(() -> Assert.fail("Never completing call returned?")); + } + catch (CompletionException exc) { + Assert.assertTrue(exc instanceof CompletionException); + Assert.assertNotNull(exc.getCause()); + Assert.assertTrue(exc.getCause() instanceof RpcException); + Assert.assertTrue(((RpcException)exc.getCause()).getType() == RpcException.Type.FAILED); + } + catch (Exception exc) { + Assert.fail(exc.toString()); + } + + // check that the connection is still open + getCallSequence(client, 1); + } } diff --git a/runtime-rpc/src/test/java/org/capnproto/RpcTestUtil.java b/runtime-rpc/src/test/java/org/capnproto/RpcTestUtil.java index 73e1d45..14f5c28 100644 --- a/runtime-rpc/src/test/java/org/capnproto/RpcTestUtil.java +++ b/runtime-rpc/src/test/java/org/capnproto/RpcTestUtil.java @@ -3,7 +3,11 @@ package org.capnproto; import org.capnproto.rpctest.Test; import org.junit.Assert; +import java.awt.desktop.SystemEventListener; +import java.lang.ref.ReferenceQueue; +import java.lang.ref.WeakReference; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; class RpcTestUtil { @@ -110,7 +114,7 @@ class RpcTestUtil { this.callCount = callCount; this.handleCount = handleCount; } - + @Override protected CompletableFuture echo(CallContext context) { this.callCount.inc(); @@ -120,6 +124,17 @@ class RpcTestUtil { return READY_NOW; } + @Override + protected CompletableFuture expectCancel(CallContext context) { + var cap = context.getParams().getCap(); + context.allowCancellation(); + return new CompletableFuture().whenComplete((void_, exc) -> { + if (exc != null) { + System.out.println("expectCancel completed exceptionally: " + exc.getMessage()); + } + }); // never completes, just await doom... + } + @Override protected CompletableFuture getHandle(CallContext context) { context.getResults().setHandle(new HandleImpl(this.handleCount)); @@ -193,6 +208,18 @@ class RpcTestUtil { result.setCap(this.clientToHold); return READY_NOW; } + + @Override + protected CompletableFuture neverReturn(CallContext context) { + this.callCount.inc(); + var cap = context.getParams().getCap(); + context.getResults().setCapCopy(cap); + context.allowCancellation(); + return new CompletableFuture<>().thenAccept(void_ -> { + // Ensure that the cap is used inside the lambda. + System.out.println(cap); + }); + } } static class TestTailCalleeImpl extends Test.TestTailCallee.Server { @@ -262,12 +289,23 @@ class RpcTestUtil { request.getParams().setJ(true); return request.send().thenAccept(response -> { - Assert.assertEquals("foo", response.getX().toString()); + Assert.assertEquals("foo", response.getX().toString()); - var result = context.getResults(); - result.setS("bar"); - result.initOutBox().setCap(new TestExtendsImpl(callCount)); + var result = context.getResults(); + result.setS("bar"); + result.initOutBox().setCap(new TestExtendsImpl(callCount)); }); } } + + static class TestCapDestructor extends Test.TestInterface.Server { + private final Counter dummy = new Counter(); + private final TestInterfaceImpl impl = new TestInterfaceImpl(dummy); + + @Override + protected CompletableFuture foo(CallContext context) { + return this.impl.foo(context); + } + } } + diff --git a/runtime/src/main/java/org/capnproto/CallContext.java b/runtime/src/main/java/org/capnproto/CallContext.java index 553d4a7..5f955d3 100644 --- a/runtime/src/main/java/org/capnproto/CallContext.java +++ b/runtime/src/main/java/org/capnproto/CallContext.java @@ -2,7 +2,7 @@ package org.capnproto; import java.util.concurrent.CompletableFuture; -public class CallContext { +public final class CallContext { private final FromPointerReader params; private final FromPointerBuilder results; diff --git a/runtime/src/main/java/org/capnproto/Capability.java b/runtime/src/main/java/org/capnproto/Capability.java index 75bb986..821960d 100644 --- a/runtime/src/main/java/org/capnproto/Capability.java +++ b/runtime/src/main/java/org/capnproto/Capability.java @@ -344,14 +344,18 @@ public final class Capability { @Override public RemotePromise send() { - var cancelPaf = new CompletableFuture(); - var context = new LocalCallContext(message, client, cancelPaf); + var cancel = new CompletableFuture(); + var context = new LocalCallContext(message, client, cancel); var promiseAndPipeline = client.call(interfaceId, methodId, context); var promise = promiseAndPipeline.promise.thenApply(x -> { context.getResults(); // force allocation return context.response; }); + cancel.whenComplete((void_, exc) -> { + promiseAndPipeline.promise.cancel(false); + }); + assert promiseAndPipeline.pipeline != null; return new RemotePromise<>(promise, new AnyPointer.Pipeline(promiseAndPipeline.pipeline)); } @@ -383,6 +387,11 @@ public final class Capability { public final ClientHook getPipelinedCap(PipelineOp[] ops) { return this.results.getPipelinedCap(ops); } + + @Override + public void close() { + this.ctx.allowCancellation(); + } } private static final class LocalResponse implements ResponseHook { @@ -532,6 +541,16 @@ public final class Capability { : new QueuedClient(this.promise.thenApply( pipeline -> pipeline.getPipelinedCap(ops))); } + + @Override + public void close() { + if (this.redirect != null) { + this.redirect.close(); + } + else { + this.promise.cancel(false); + } + } } // A ClientHook which simply queues calls while waiting for a ClientHook to which to forward them. diff --git a/runtime/src/main/java/org/capnproto/CompletableFutureWrapper.java b/runtime/src/main/java/org/capnproto/CompletableFutureWrapper.java index 09401ae..bfe3ed4 100644 --- a/runtime/src/main/java/org/capnproto/CompletableFutureWrapper.java +++ b/runtime/src/main/java/org/capnproto/CompletableFutureWrapper.java @@ -5,8 +5,10 @@ import java.util.concurrent.CompletionStage; public class CompletableFutureWrapper extends CompletableFuture { + private final CompletableFuture other; + public CompletableFutureWrapper(CompletionStage other) { - other.toCompletableFuture().whenComplete((value, exc) -> { + this.other = other.toCompletableFuture().whenComplete((value, exc) -> { if (exc == null) { this.complete(value); } @@ -15,4 +17,9 @@ public class CompletableFutureWrapper extends CompletableFuture { } }); } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return this.other.cancel(mayInterruptIfRunning); + } } \ No newline at end of file diff --git a/runtime/src/main/java/org/capnproto/PipelineHook.java b/runtime/src/main/java/org/capnproto/PipelineHook.java index e42a5db..274117e 100644 --- a/runtime/src/main/java/org/capnproto/PipelineHook.java +++ b/runtime/src/main/java/org/capnproto/PipelineHook.java @@ -1,6 +1,6 @@ package org.capnproto; -public interface PipelineHook { +public interface PipelineHook extends AutoCloseable { ClientHook getPipelinedCap(PipelineOp[] ops); @@ -12,4 +12,8 @@ public interface PipelineHook { } }; } + + @Override + default void close() { + } } diff --git a/runtime/src/main/java/org/capnproto/RemotePromise.java b/runtime/src/main/java/org/capnproto/RemotePromise.java index 0bbdc06..86bd8ef 100644 --- a/runtime/src/main/java/org/capnproto/RemotePromise.java +++ b/runtime/src/main/java/org/capnproto/RemotePromise.java @@ -3,7 +3,8 @@ package org.capnproto; import java.util.concurrent.CompletableFuture; public class RemotePromise - extends CompletableFutureWrapper { + extends CompletableFutureWrapper + implements AutoCloseable { private final CompletableFuture> response; private final AnyPointer.Pipeline pipeline; @@ -20,11 +21,17 @@ public class RemotePromise public RemotePromise(CompletableFuture> promise, AnyPointer.Pipeline pipeline) { - super(promise.thenApply(response -> response.getResults())); + super(promise.thenApply(Response::getResults)); this.response = promise; this.pipeline = pipeline; } + @Override + public void close() throws Exception { + this.pipeline.hook.close(); + this.join(); + } + CompletableFuture> _getResponse() { return this.response; } diff --git a/runtime/src/main/java/org/capnproto/RpcException.java b/runtime/src/main/java/org/capnproto/RpcException.java index f748228..5f0184a 100644 --- a/runtime/src/main/java/org/capnproto/RpcException.java +++ b/runtime/src/main/java/org/capnproto/RpcException.java @@ -3,13 +3,13 @@ package org.capnproto; public final class RpcException extends java.lang.Exception { public enum Type { - UNKNOWN, - UNIMPLEMENTED, FAILED, - DISCONNECTED + OVERLOADED, + DISCONNECTED, + UNIMPLEMENTED } - private Type type; + private final Type type; public RpcException(Type type, String message) { super(message);