implement rpc tail calls
This commit is contained in:
parent
cee3aa79ae
commit
f2df5c2191
5 changed files with 86 additions and 44 deletions
|
@ -168,7 +168,7 @@ final class RpcState<VatId> {
|
|||
final int answerId;
|
||||
boolean active = false;
|
||||
PipelineHook pipeline;
|
||||
CompletionStage<RpcResponse> redirectedResults;
|
||||
CompletableFuture<RpcResponse> redirectedResults;
|
||||
RpcCallContext callContext;
|
||||
int[] resultExports;
|
||||
|
||||
|
@ -599,24 +599,26 @@ final class RpcState<VatId> {
|
|||
}
|
||||
|
||||
var pap = startCall(call.getInterfaceId(), call.getMethodId(), cap, context);
|
||||
|
||||
{
|
||||
var answer = answers.find(answerId);
|
||||
assert answer != null;
|
||||
answer.pipeline = pap.pipeline;
|
||||
|
||||
if (redirectResults) {
|
||||
answer.redirectedResults = pap.promise.thenApply(x -> {
|
||||
return context.consumeRedirectedResponse();
|
||||
});
|
||||
answer.redirectedResults = pap.promise.thenApply(
|
||||
void_ -> context.consumeRedirectedResponse());
|
||||
// TODO cancellation deferral
|
||||
}
|
||||
else {
|
||||
pap.promise.thenAccept(x -> {
|
||||
pap.promise.whenComplete((void_, exc) -> {
|
||||
if (exc == null) {
|
||||
context.sendReturn();
|
||||
}).exceptionally(exc -> {
|
||||
}
|
||||
else {
|
||||
context.sendErrorReturn(exc);
|
||||
// TODO wait on the cancellation...
|
||||
return null;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -628,7 +630,6 @@ final class RpcState<VatId> {
|
|||
}
|
||||
|
||||
void handleReturn(IncomingRpcMessage message, RpcProtocol.Return.Reader callReturn) {
|
||||
|
||||
var question = questions.find(callReturn.getAnswerId());
|
||||
if (question == null) {
|
||||
assert false: "Invalid question ID in Return message.";
|
||||
|
@ -703,7 +704,7 @@ final class RpcState<VatId> {
|
|||
assert false: "`Return.takeFromOtherQuestion` referenced a call that did not use `sendResultsTo.yourself`.";
|
||||
break;
|
||||
}
|
||||
question.response = answer.redirectedResults.toCompletableFuture();
|
||||
question.response = answer.redirectedResults;
|
||||
answer.redirectedResults = null;
|
||||
break;
|
||||
|
||||
|
@ -1230,7 +1231,7 @@ final class RpcState<VatId> {
|
|||
|
||||
@Override
|
||||
public AnyPointer.Builder getResultsBuilder() {
|
||||
return payload.getContent().imbue(capTable);
|
||||
return this.payload.getContent().imbue(capTable);
|
||||
}
|
||||
|
||||
int[] send() {
|
||||
|
@ -1284,7 +1285,7 @@ final class RpcState<VatId> {
|
|||
private RpcProtocol.Return.Builder returnMessage;
|
||||
private boolean redirectResults = false;
|
||||
private boolean responseSent = false;
|
||||
private CompletableFuture<PipelineHook> tailCallPipelineFuture;
|
||||
private CompletableFuture<AnyPointer.Pipeline> tailCallPipeline;
|
||||
|
||||
private boolean cancelRequested = false;
|
||||
private boolean cancelAllowed = false;
|
||||
|
@ -1336,10 +1337,10 @@ final class RpcState<VatId> {
|
|||
@Override
|
||||
public CompletableFuture<java.lang.Void> tailCall(RequestHook request) {
|
||||
var result = this.directTailCall(request);
|
||||
if (this.tailCallPipelineFuture != null) {
|
||||
this.tailCallPipelineFuture.complete(result.pipeline);
|
||||
if (this.tailCallPipeline != null) {
|
||||
this.tailCallPipeline.complete(new AnyPointer.Pipeline(result.pipeline));
|
||||
}
|
||||
return result.promise.toCompletableFuture().copy();
|
||||
return result.promise.copy();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1347,8 +1348,10 @@ final class RpcState<VatId> {
|
|||
}
|
||||
|
||||
@Override
|
||||
public CompletableFuture<PipelineHook> onTailCall() {
|
||||
return null;
|
||||
public CompletableFuture<AnyPointer.Pipeline> onTailCall() {
|
||||
assert this.tailCallPipeline == null: "Called onTailCall twice?";
|
||||
this.tailCallPipeline = new CompletableFuture<>();
|
||||
return this.tailCallPipeline.copy();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1441,7 +1444,7 @@ final class RpcState<VatId> {
|
|||
message.send();
|
||||
}
|
||||
|
||||
cleanupAnswerTable(new int[0], false);
|
||||
cleanupAnswerTable(null, false);
|
||||
}
|
||||
|
||||
private boolean isFirstResponder() {
|
||||
|
@ -1453,6 +1456,10 @@ final class RpcState<VatId> {
|
|||
}
|
||||
|
||||
private void cleanupAnswerTable(int[] resultExports, boolean shouldFreePipeline) {
|
||||
if (resultExports == null) {
|
||||
resultExports = new int[0];
|
||||
}
|
||||
|
||||
if (this.cancelRequested) {
|
||||
assert resultExports.length == 0;
|
||||
answers.erase(this.answerId);
|
||||
|
|
|
@ -406,5 +406,30 @@ public class RpcTest {
|
|||
//Assert.assertEquals(3, context.restorer.callCount);
|
||||
Assert.assertEquals(2, chainedCallCount.value());
|
||||
}
|
||||
|
||||
@org.junit.Test
|
||||
public void testTailCall() {
|
||||
var context = new TestContext(bootstrapFactory);
|
||||
var caller = new Test.TestTailCaller.Client(context.connect(Test.TestSturdyRefObjectId.Tag.TEST_TAIL_CALLER));
|
||||
|
||||
var calleeCallCount = new Counter();
|
||||
var callee = new Test.TestTailCallee.Client(new RpcTestUtil.TestTailCalleeImpl(calleeCallCount));
|
||||
var request = caller.fooRequest();
|
||||
request.getParams().setI(456);
|
||||
request.getParams().setCallee(callee);
|
||||
|
||||
var promise = request.send();
|
||||
var dependentCall0 = promise.getC().getCallSequenceRequest().send();
|
||||
var response = promise.join();
|
||||
Assert.assertEquals(456, response.getI());
|
||||
|
||||
var dependentCall1 = promise.getC().getCallSequenceRequest().send();
|
||||
Assert.assertEquals(0, dependentCall0.join().getN());
|
||||
Assert.assertEquals(1, dependentCall1.join().getN());
|
||||
|
||||
var dependentCall2 = response.getC().getCallSequenceRequest().send();
|
||||
Assert.assertEquals(2, dependentCall2.join().getN());
|
||||
Assert.assertEquals(1, calleeCallCount.value());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ package org.capnproto;
|
|||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
public interface CallContextHook {
|
||||
|
||||
AnyPointer.Reader getParams();
|
||||
|
||||
void releaseParams();
|
||||
|
@ -17,7 +18,7 @@ public interface CallContextHook {
|
|||
|
||||
void allowCancellation();
|
||||
|
||||
CompletableFuture<PipelineHook> onTailCall();
|
||||
CompletableFuture<AnyPointer.Pipeline> onTailCall();
|
||||
|
||||
ClientHook.VoidPromiseAndPipeline directTailCall(RequestHook request);
|
||||
}
|
||||
|
|
|
@ -185,17 +185,16 @@ public final class Capability {
|
|||
}
|
||||
|
||||
var promise = this.whenResolved().thenCompose(
|
||||
x -> this.callInternal(interfaceId, methodId, ctx));
|
||||
void_ -> this.callInternal(interfaceId, methodId, ctx));
|
||||
|
||||
CompletableFuture<PipelineHook> pipelinePromise = promise.thenApply(x -> {
|
||||
|
||||
var pipelinePromise = promise.thenApply(x -> {
|
||||
ctx.releaseParams();
|
||||
return new LocalPipeline(ctx);
|
||||
return (PipelineHook)new LocalPipeline(ctx);
|
||||
});
|
||||
|
||||
var tailCall = ctx.onTailCall();
|
||||
if (tailCall != null) {
|
||||
var tailCall = ctx.onTailCall().thenApply(pipeline -> pipeline.hook);
|
||||
pipelinePromise = tailCall.applyToEither(pipelinePromise, pipeline -> pipeline);
|
||||
}
|
||||
|
||||
return new VoidPromiseAndPipeline(
|
||||
promise,
|
||||
|
@ -213,7 +212,7 @@ public final class Capability {
|
|||
return CompletableFuture.completedFuture(this.resolved);
|
||||
}
|
||||
else if (this.resolveTask != null) {
|
||||
return this.resolveTask.thenApply(x -> this.resolved);
|
||||
return this.resolveTask.thenApply(void_ -> this.resolved);
|
||||
}
|
||||
else {
|
||||
return null;
|
||||
|
@ -335,7 +334,7 @@ public final class Capability {
|
|||
final MessageBuilder message = new MessageBuilder();
|
||||
final long interfaceId;
|
||||
final short methodId;
|
||||
ClientHook client;
|
||||
final ClientHook client;
|
||||
|
||||
LocalRequest(long interfaceId, short methodId, ClientHook client) {
|
||||
this.interfaceId = interfaceId;
|
||||
|
@ -371,6 +370,7 @@ public final class Capability {
|
|||
}
|
||||
|
||||
private static final class LocalPipeline implements PipelineHook {
|
||||
|
||||
private final CallContextHook ctx;
|
||||
private final AnyPointer.Reader results;
|
||||
|
||||
|
@ -396,7 +396,8 @@ public final class Capability {
|
|||
|
||||
private static class LocalCallContext implements CallContextHook {
|
||||
|
||||
final CompletableFuture<?> cancelAllowed;
|
||||
final CompletableFuture<java.lang.Void> cancelAllowed;
|
||||
CompletableFuture<AnyPointer.Pipeline> tailCallPipeline;
|
||||
MessageBuilder request;
|
||||
Response<AnyPointer.Reader> response;
|
||||
AnyPointer.Builder responseBuilder;
|
||||
|
@ -404,7 +405,7 @@ public final class Capability {
|
|||
|
||||
LocalCallContext(MessageBuilder request,
|
||||
ClientHook clientRef,
|
||||
CompletableFuture<?> cancelAllowed) {
|
||||
CompletableFuture<java.lang.Void> cancelAllowed) {
|
||||
this.request = request;
|
||||
this.clientRef = clientRef;
|
||||
this.cancelAllowed = cancelAllowed;
|
||||
|
@ -412,7 +413,7 @@ public final class Capability {
|
|||
|
||||
@Override
|
||||
public AnyPointer.Reader getParams() {
|
||||
return request.getRoot(AnyPointer.factory).asReader();
|
||||
return this.request.getRoot(AnyPointer.factory).asReader();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -437,20 +438,27 @@ public final class Capability {
|
|||
|
||||
@Override
|
||||
public CompletableFuture<java.lang.Void> tailCall(RequestHook request) {
|
||||
// TODO implement tailCall
|
||||
return null;
|
||||
var result = this.directTailCall(request);
|
||||
if (this.tailCallPipeline != null) {
|
||||
this.tailCallPipeline.complete(new AnyPointer.Pipeline(result.pipeline));
|
||||
}
|
||||
return result.promise;
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletableFuture<PipelineHook> onTailCall() {
|
||||
// TODO implement onTailCall
|
||||
return null;
|
||||
public CompletableFuture<AnyPointer.Pipeline> onTailCall() {
|
||||
this.tailCallPipeline = new CompletableFuture<>();
|
||||
return this.tailCallPipeline.copy();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ClientHook.VoidPromiseAndPipeline directTailCall(RequestHook request) {
|
||||
// TODO implement directTailCall
|
||||
return null;
|
||||
assert this.response == null: "Can't call tailCall() after initializing the results struct.";
|
||||
var promise = request.send();
|
||||
var voidPromise = promise._getResponse().thenAccept(tailResponse -> {
|
||||
this.response = tailResponse;
|
||||
});
|
||||
return new ClientHook.VoidPromiseAndPipeline(voidPromise, promise.pipeline().hook);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,14 +20,15 @@ public class RemotePromise<Results>
|
|||
|
||||
public RemotePromise(CompletableFuture<Response<Results>> promise,
|
||||
AnyPointer.Pipeline pipeline) {
|
||||
super(promise.thenApply(response -> {
|
||||
//System.out.println("Got a response for remote promise " + promise.toString());
|
||||
return response.getResults();
|
||||
}));
|
||||
super(promise.thenApply(response -> response.getResults()));
|
||||
this.response = promise;
|
||||
this.pipeline = pipeline;
|
||||
}
|
||||
|
||||
CompletableFuture<Response<Results>> _getResponse() {
|
||||
return this.response;
|
||||
}
|
||||
|
||||
public AnyPointer.Pipeline pipeline() {
|
||||
return this.pipeline;
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue