implement rpc tail calls

This commit is contained in:
Vaci Koblizek 2020-11-05 14:05:12 +00:00
parent cee3aa79ae
commit f2df5c2191
5 changed files with 86 additions and 44 deletions

View file

@ -168,7 +168,7 @@ final class RpcState<VatId> {
final int answerId; final int answerId;
boolean active = false; boolean active = false;
PipelineHook pipeline; PipelineHook pipeline;
CompletionStage<RpcResponse> redirectedResults; CompletableFuture<RpcResponse> redirectedResults;
RpcCallContext callContext; RpcCallContext callContext;
int[] resultExports; int[] resultExports;
@ -599,24 +599,26 @@ final class RpcState<VatId> {
} }
var pap = startCall(call.getInterfaceId(), call.getMethodId(), cap, context); var pap = startCall(call.getInterfaceId(), call.getMethodId(), cap, context);
{ {
var answer = answers.find(answerId); var answer = answers.find(answerId);
assert answer != null; assert answer != null;
answer.pipeline = pap.pipeline; answer.pipeline = pap.pipeline;
if (redirectResults) { if (redirectResults) {
answer.redirectedResults = pap.promise.thenApply(x -> { answer.redirectedResults = pap.promise.thenApply(
return context.consumeRedirectedResponse(); void_ -> context.consumeRedirectedResponse());
});
// TODO cancellation deferral // TODO cancellation deferral
} }
else { else {
pap.promise.thenAccept(x -> { pap.promise.whenComplete((void_, exc) -> {
context.sendReturn(); if (exc == null) {
}).exceptionally(exc -> { context.sendReturn();
context.sendErrorReturn(exc); }
// TODO wait on the cancellation... else {
return null; context.sendErrorReturn(exc);
// TODO wait on the cancellation...
}
}); });
} }
} }
@ -628,7 +630,6 @@ final class RpcState<VatId> {
} }
void handleReturn(IncomingRpcMessage message, RpcProtocol.Return.Reader callReturn) { void handleReturn(IncomingRpcMessage message, RpcProtocol.Return.Reader callReturn) {
var question = questions.find(callReturn.getAnswerId()); var question = questions.find(callReturn.getAnswerId());
if (question == null) { if (question == null) {
assert false: "Invalid question ID in Return message."; assert false: "Invalid question ID in Return message.";
@ -703,7 +704,7 @@ final class RpcState<VatId> {
assert false: "`Return.takeFromOtherQuestion` referenced a call that did not use `sendResultsTo.yourself`."; assert false: "`Return.takeFromOtherQuestion` referenced a call that did not use `sendResultsTo.yourself`.";
break; break;
} }
question.response = answer.redirectedResults.toCompletableFuture(); question.response = answer.redirectedResults;
answer.redirectedResults = null; answer.redirectedResults = null;
break; break;
@ -1230,7 +1231,7 @@ final class RpcState<VatId> {
@Override @Override
public AnyPointer.Builder getResultsBuilder() { public AnyPointer.Builder getResultsBuilder() {
return payload.getContent().imbue(capTable); return this.payload.getContent().imbue(capTable);
} }
int[] send() { int[] send() {
@ -1284,7 +1285,7 @@ final class RpcState<VatId> {
private RpcProtocol.Return.Builder returnMessage; private RpcProtocol.Return.Builder returnMessage;
private boolean redirectResults = false; private boolean redirectResults = false;
private boolean responseSent = false; private boolean responseSent = false;
private CompletableFuture<PipelineHook> tailCallPipelineFuture; private CompletableFuture<AnyPointer.Pipeline> tailCallPipeline;
private boolean cancelRequested = false; private boolean cancelRequested = false;
private boolean cancelAllowed = false; private boolean cancelAllowed = false;
@ -1336,10 +1337,10 @@ final class RpcState<VatId> {
@Override @Override
public CompletableFuture<java.lang.Void> tailCall(RequestHook request) { public CompletableFuture<java.lang.Void> tailCall(RequestHook request) {
var result = this.directTailCall(request); var result = this.directTailCall(request);
if (this.tailCallPipelineFuture != null) { if (this.tailCallPipeline != null) {
this.tailCallPipelineFuture.complete(result.pipeline); this.tailCallPipeline.complete(new AnyPointer.Pipeline(result.pipeline));
} }
return result.promise.toCompletableFuture().copy(); return result.promise.copy();
} }
@Override @Override
@ -1347,8 +1348,10 @@ final class RpcState<VatId> {
} }
@Override @Override
public CompletableFuture<PipelineHook> onTailCall() { public CompletableFuture<AnyPointer.Pipeline> onTailCall() {
return null; assert this.tailCallPipeline == null: "Called onTailCall twice?";
this.tailCallPipeline = new CompletableFuture<>();
return this.tailCallPipeline.copy();
} }
@Override @Override
@ -1441,7 +1444,7 @@ final class RpcState<VatId> {
message.send(); message.send();
} }
cleanupAnswerTable(new int[0], false); cleanupAnswerTable(null, false);
} }
private boolean isFirstResponder() { private boolean isFirstResponder() {
@ -1453,6 +1456,10 @@ final class RpcState<VatId> {
} }
private void cleanupAnswerTable(int[] resultExports, boolean shouldFreePipeline) { private void cleanupAnswerTable(int[] resultExports, boolean shouldFreePipeline) {
if (resultExports == null) {
resultExports = new int[0];
}
if (this.cancelRequested) { if (this.cancelRequested) {
assert resultExports.length == 0; assert resultExports.length == 0;
answers.erase(this.answerId); answers.erase(this.answerId);

View file

@ -406,5 +406,30 @@ public class RpcTest {
//Assert.assertEquals(3, context.restorer.callCount); //Assert.assertEquals(3, context.restorer.callCount);
Assert.assertEquals(2, chainedCallCount.value()); Assert.assertEquals(2, chainedCallCount.value());
} }
@org.junit.Test
public void testTailCall() {
var context = new TestContext(bootstrapFactory);
var caller = new Test.TestTailCaller.Client(context.connect(Test.TestSturdyRefObjectId.Tag.TEST_TAIL_CALLER));
var calleeCallCount = new Counter();
var callee = new Test.TestTailCallee.Client(new RpcTestUtil.TestTailCalleeImpl(calleeCallCount));
var request = caller.fooRequest();
request.getParams().setI(456);
request.getParams().setCallee(callee);
var promise = request.send();
var dependentCall0 = promise.getC().getCallSequenceRequest().send();
var response = promise.join();
Assert.assertEquals(456, response.getI());
var dependentCall1 = promise.getC().getCallSequenceRequest().send();
Assert.assertEquals(0, dependentCall0.join().getN());
Assert.assertEquals(1, dependentCall1.join().getN());
var dependentCall2 = response.getC().getCallSequenceRequest().send();
Assert.assertEquals(2, dependentCall2.join().getN());
Assert.assertEquals(1, calleeCallCount.value());
}
} }

View file

@ -3,6 +3,7 @@ package org.capnproto;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
public interface CallContextHook { public interface CallContextHook {
AnyPointer.Reader getParams(); AnyPointer.Reader getParams();
void releaseParams(); void releaseParams();
@ -17,7 +18,7 @@ public interface CallContextHook {
void allowCancellation(); void allowCancellation();
CompletableFuture<PipelineHook> onTailCall(); CompletableFuture<AnyPointer.Pipeline> onTailCall();
ClientHook.VoidPromiseAndPipeline directTailCall(RequestHook request); ClientHook.VoidPromiseAndPipeline directTailCall(RequestHook request);
} }

View file

@ -185,17 +185,16 @@ public final class Capability {
} }
var promise = this.whenResolved().thenCompose( var promise = this.whenResolved().thenCompose(
x -> this.callInternal(interfaceId, methodId, ctx)); void_ -> this.callInternal(interfaceId, methodId, ctx));
CompletableFuture<PipelineHook> pipelinePromise = promise.thenApply(x -> {
var pipelinePromise = promise.thenApply(x -> {
ctx.releaseParams(); ctx.releaseParams();
return new LocalPipeline(ctx); return (PipelineHook)new LocalPipeline(ctx);
}); });
var tailCall = ctx.onTailCall(); var tailCall = ctx.onTailCall().thenApply(pipeline -> pipeline.hook);
if (tailCall != null) { pipelinePromise = tailCall.applyToEither(pipelinePromise, pipeline -> pipeline);
pipelinePromise = tailCall.applyToEither(pipelinePromise, pipeline -> pipeline);
}
return new VoidPromiseAndPipeline( return new VoidPromiseAndPipeline(
promise, promise,
@ -213,7 +212,7 @@ public final class Capability {
return CompletableFuture.completedFuture(this.resolved); return CompletableFuture.completedFuture(this.resolved);
} }
else if (this.resolveTask != null) { else if (this.resolveTask != null) {
return this.resolveTask.thenApply(x -> this.resolved); return this.resolveTask.thenApply(void_ -> this.resolved);
} }
else { else {
return null; return null;
@ -335,7 +334,7 @@ public final class Capability {
final MessageBuilder message = new MessageBuilder(); final MessageBuilder message = new MessageBuilder();
final long interfaceId; final long interfaceId;
final short methodId; final short methodId;
ClientHook client; final ClientHook client;
LocalRequest(long interfaceId, short methodId, ClientHook client) { LocalRequest(long interfaceId, short methodId, ClientHook client) {
this.interfaceId = interfaceId; this.interfaceId = interfaceId;
@ -371,6 +370,7 @@ public final class Capability {
} }
private static final class LocalPipeline implements PipelineHook { private static final class LocalPipeline implements PipelineHook {
private final CallContextHook ctx; private final CallContextHook ctx;
private final AnyPointer.Reader results; private final AnyPointer.Reader results;
@ -396,7 +396,8 @@ public final class Capability {
private static class LocalCallContext implements CallContextHook { private static class LocalCallContext implements CallContextHook {
final CompletableFuture<?> cancelAllowed; final CompletableFuture<java.lang.Void> cancelAllowed;
CompletableFuture<AnyPointer.Pipeline> tailCallPipeline;
MessageBuilder request; MessageBuilder request;
Response<AnyPointer.Reader> response; Response<AnyPointer.Reader> response;
AnyPointer.Builder responseBuilder; AnyPointer.Builder responseBuilder;
@ -404,7 +405,7 @@ public final class Capability {
LocalCallContext(MessageBuilder request, LocalCallContext(MessageBuilder request,
ClientHook clientRef, ClientHook clientRef,
CompletableFuture<?> cancelAllowed) { CompletableFuture<java.lang.Void> cancelAllowed) {
this.request = request; this.request = request;
this.clientRef = clientRef; this.clientRef = clientRef;
this.cancelAllowed = cancelAllowed; this.cancelAllowed = cancelAllowed;
@ -412,7 +413,7 @@ public final class Capability {
@Override @Override
public AnyPointer.Reader getParams() { public AnyPointer.Reader getParams() {
return request.getRoot(AnyPointer.factory).asReader(); return this.request.getRoot(AnyPointer.factory).asReader();
} }
@Override @Override
@ -437,20 +438,27 @@ public final class Capability {
@Override @Override
public CompletableFuture<java.lang.Void> tailCall(RequestHook request) { public CompletableFuture<java.lang.Void> tailCall(RequestHook request) {
// TODO implement tailCall var result = this.directTailCall(request);
return null; if (this.tailCallPipeline != null) {
this.tailCallPipeline.complete(new AnyPointer.Pipeline(result.pipeline));
}
return result.promise;
} }
@Override @Override
public CompletableFuture<PipelineHook> onTailCall() { public CompletableFuture<AnyPointer.Pipeline> onTailCall() {
// TODO implement onTailCall this.tailCallPipeline = new CompletableFuture<>();
return null; return this.tailCallPipeline.copy();
} }
@Override @Override
public ClientHook.VoidPromiseAndPipeline directTailCall(RequestHook request) { public ClientHook.VoidPromiseAndPipeline directTailCall(RequestHook request) {
// TODO implement directTailCall assert this.response == null: "Can't call tailCall() after initializing the results struct.";
return null; var promise = request.send();
var voidPromise = promise._getResponse().thenAccept(tailResponse -> {
this.response = tailResponse;
});
return new ClientHook.VoidPromiseAndPipeline(voidPromise, promise.pipeline().hook);
} }
} }

View file

@ -20,14 +20,15 @@ public class RemotePromise<Results>
public RemotePromise(CompletableFuture<Response<Results>> promise, public RemotePromise(CompletableFuture<Response<Results>> promise,
AnyPointer.Pipeline pipeline) { AnyPointer.Pipeline pipeline) {
super(promise.thenApply(response -> { super(promise.thenApply(response -> response.getResults()));
//System.out.println("Got a response for remote promise " + promise.toString());
return response.getResults();
}));
this.response = promise; this.response = promise;
this.pipeline = pipeline; this.pipeline = pipeline;
} }
CompletableFuture<Response<Results>> _getResponse() {
return this.response;
}
public AnyPointer.Pipeline pipeline() { public AnyPointer.Pipeline pipeline() {
return this.pipeline; return this.pipeline;
} }