implement rpc tail calls
This commit is contained in:
parent
cee3aa79ae
commit
f2df5c2191
5 changed files with 86 additions and 44 deletions
|
@ -168,7 +168,7 @@ final class RpcState<VatId> {
|
||||||
final int answerId;
|
final int answerId;
|
||||||
boolean active = false;
|
boolean active = false;
|
||||||
PipelineHook pipeline;
|
PipelineHook pipeline;
|
||||||
CompletionStage<RpcResponse> redirectedResults;
|
CompletableFuture<RpcResponse> redirectedResults;
|
||||||
RpcCallContext callContext;
|
RpcCallContext callContext;
|
||||||
int[] resultExports;
|
int[] resultExports;
|
||||||
|
|
||||||
|
@ -599,24 +599,26 @@ final class RpcState<VatId> {
|
||||||
}
|
}
|
||||||
|
|
||||||
var pap = startCall(call.getInterfaceId(), call.getMethodId(), cap, context);
|
var pap = startCall(call.getInterfaceId(), call.getMethodId(), cap, context);
|
||||||
|
|
||||||
{
|
{
|
||||||
var answer = answers.find(answerId);
|
var answer = answers.find(answerId);
|
||||||
assert answer != null;
|
assert answer != null;
|
||||||
answer.pipeline = pap.pipeline;
|
answer.pipeline = pap.pipeline;
|
||||||
|
|
||||||
if (redirectResults) {
|
if (redirectResults) {
|
||||||
answer.redirectedResults = pap.promise.thenApply(x -> {
|
answer.redirectedResults = pap.promise.thenApply(
|
||||||
return context.consumeRedirectedResponse();
|
void_ -> context.consumeRedirectedResponse());
|
||||||
});
|
|
||||||
// TODO cancellation deferral
|
// TODO cancellation deferral
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
pap.promise.thenAccept(x -> {
|
pap.promise.whenComplete((void_, exc) -> {
|
||||||
|
if (exc == null) {
|
||||||
context.sendReturn();
|
context.sendReturn();
|
||||||
}).exceptionally(exc -> {
|
}
|
||||||
|
else {
|
||||||
context.sendErrorReturn(exc);
|
context.sendErrorReturn(exc);
|
||||||
// TODO wait on the cancellation...
|
// TODO wait on the cancellation...
|
||||||
return null;
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -628,7 +630,6 @@ final class RpcState<VatId> {
|
||||||
}
|
}
|
||||||
|
|
||||||
void handleReturn(IncomingRpcMessage message, RpcProtocol.Return.Reader callReturn) {
|
void handleReturn(IncomingRpcMessage message, RpcProtocol.Return.Reader callReturn) {
|
||||||
|
|
||||||
var question = questions.find(callReturn.getAnswerId());
|
var question = questions.find(callReturn.getAnswerId());
|
||||||
if (question == null) {
|
if (question == null) {
|
||||||
assert false: "Invalid question ID in Return message.";
|
assert false: "Invalid question ID in Return message.";
|
||||||
|
@ -703,7 +704,7 @@ final class RpcState<VatId> {
|
||||||
assert false: "`Return.takeFromOtherQuestion` referenced a call that did not use `sendResultsTo.yourself`.";
|
assert false: "`Return.takeFromOtherQuestion` referenced a call that did not use `sendResultsTo.yourself`.";
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
question.response = answer.redirectedResults.toCompletableFuture();
|
question.response = answer.redirectedResults;
|
||||||
answer.redirectedResults = null;
|
answer.redirectedResults = null;
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -1230,7 +1231,7 @@ final class RpcState<VatId> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public AnyPointer.Builder getResultsBuilder() {
|
public AnyPointer.Builder getResultsBuilder() {
|
||||||
return payload.getContent().imbue(capTable);
|
return this.payload.getContent().imbue(capTable);
|
||||||
}
|
}
|
||||||
|
|
||||||
int[] send() {
|
int[] send() {
|
||||||
|
@ -1284,7 +1285,7 @@ final class RpcState<VatId> {
|
||||||
private RpcProtocol.Return.Builder returnMessage;
|
private RpcProtocol.Return.Builder returnMessage;
|
||||||
private boolean redirectResults = false;
|
private boolean redirectResults = false;
|
||||||
private boolean responseSent = false;
|
private boolean responseSent = false;
|
||||||
private CompletableFuture<PipelineHook> tailCallPipelineFuture;
|
private CompletableFuture<AnyPointer.Pipeline> tailCallPipeline;
|
||||||
|
|
||||||
private boolean cancelRequested = false;
|
private boolean cancelRequested = false;
|
||||||
private boolean cancelAllowed = false;
|
private boolean cancelAllowed = false;
|
||||||
|
@ -1336,10 +1337,10 @@ final class RpcState<VatId> {
|
||||||
@Override
|
@Override
|
||||||
public CompletableFuture<java.lang.Void> tailCall(RequestHook request) {
|
public CompletableFuture<java.lang.Void> tailCall(RequestHook request) {
|
||||||
var result = this.directTailCall(request);
|
var result = this.directTailCall(request);
|
||||||
if (this.tailCallPipelineFuture != null) {
|
if (this.tailCallPipeline != null) {
|
||||||
this.tailCallPipelineFuture.complete(result.pipeline);
|
this.tailCallPipeline.complete(new AnyPointer.Pipeline(result.pipeline));
|
||||||
}
|
}
|
||||||
return result.promise.toCompletableFuture().copy();
|
return result.promise.copy();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -1347,8 +1348,10 @@ final class RpcState<VatId> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public CompletableFuture<PipelineHook> onTailCall() {
|
public CompletableFuture<AnyPointer.Pipeline> onTailCall() {
|
||||||
return null;
|
assert this.tailCallPipeline == null: "Called onTailCall twice?";
|
||||||
|
this.tailCallPipeline = new CompletableFuture<>();
|
||||||
|
return this.tailCallPipeline.copy();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -1441,7 +1444,7 @@ final class RpcState<VatId> {
|
||||||
message.send();
|
message.send();
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanupAnswerTable(new int[0], false);
|
cleanupAnswerTable(null, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean isFirstResponder() {
|
private boolean isFirstResponder() {
|
||||||
|
@ -1453,6 +1456,10 @@ final class RpcState<VatId> {
|
||||||
}
|
}
|
||||||
|
|
||||||
private void cleanupAnswerTable(int[] resultExports, boolean shouldFreePipeline) {
|
private void cleanupAnswerTable(int[] resultExports, boolean shouldFreePipeline) {
|
||||||
|
if (resultExports == null) {
|
||||||
|
resultExports = new int[0];
|
||||||
|
}
|
||||||
|
|
||||||
if (this.cancelRequested) {
|
if (this.cancelRequested) {
|
||||||
assert resultExports.length == 0;
|
assert resultExports.length == 0;
|
||||||
answers.erase(this.answerId);
|
answers.erase(this.answerId);
|
||||||
|
|
|
@ -406,5 +406,30 @@ public class RpcTest {
|
||||||
//Assert.assertEquals(3, context.restorer.callCount);
|
//Assert.assertEquals(3, context.restorer.callCount);
|
||||||
Assert.assertEquals(2, chainedCallCount.value());
|
Assert.assertEquals(2, chainedCallCount.value());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@org.junit.Test
|
||||||
|
public void testTailCall() {
|
||||||
|
var context = new TestContext(bootstrapFactory);
|
||||||
|
var caller = new Test.TestTailCaller.Client(context.connect(Test.TestSturdyRefObjectId.Tag.TEST_TAIL_CALLER));
|
||||||
|
|
||||||
|
var calleeCallCount = new Counter();
|
||||||
|
var callee = new Test.TestTailCallee.Client(new RpcTestUtil.TestTailCalleeImpl(calleeCallCount));
|
||||||
|
var request = caller.fooRequest();
|
||||||
|
request.getParams().setI(456);
|
||||||
|
request.getParams().setCallee(callee);
|
||||||
|
|
||||||
|
var promise = request.send();
|
||||||
|
var dependentCall0 = promise.getC().getCallSequenceRequest().send();
|
||||||
|
var response = promise.join();
|
||||||
|
Assert.assertEquals(456, response.getI());
|
||||||
|
|
||||||
|
var dependentCall1 = promise.getC().getCallSequenceRequest().send();
|
||||||
|
Assert.assertEquals(0, dependentCall0.join().getN());
|
||||||
|
Assert.assertEquals(1, dependentCall1.join().getN());
|
||||||
|
|
||||||
|
var dependentCall2 = response.getC().getCallSequenceRequest().send();
|
||||||
|
Assert.assertEquals(2, dependentCall2.join().getN());
|
||||||
|
Assert.assertEquals(1, calleeCallCount.value());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ package org.capnproto;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
|
||||||
public interface CallContextHook {
|
public interface CallContextHook {
|
||||||
|
|
||||||
AnyPointer.Reader getParams();
|
AnyPointer.Reader getParams();
|
||||||
|
|
||||||
void releaseParams();
|
void releaseParams();
|
||||||
|
@ -17,7 +18,7 @@ public interface CallContextHook {
|
||||||
|
|
||||||
void allowCancellation();
|
void allowCancellation();
|
||||||
|
|
||||||
CompletableFuture<PipelineHook> onTailCall();
|
CompletableFuture<AnyPointer.Pipeline> onTailCall();
|
||||||
|
|
||||||
ClientHook.VoidPromiseAndPipeline directTailCall(RequestHook request);
|
ClientHook.VoidPromiseAndPipeline directTailCall(RequestHook request);
|
||||||
}
|
}
|
||||||
|
|
|
@ -185,17 +185,16 @@ public final class Capability {
|
||||||
}
|
}
|
||||||
|
|
||||||
var promise = this.whenResolved().thenCompose(
|
var promise = this.whenResolved().thenCompose(
|
||||||
x -> this.callInternal(interfaceId, methodId, ctx));
|
void_ -> this.callInternal(interfaceId, methodId, ctx));
|
||||||
|
|
||||||
CompletableFuture<PipelineHook> pipelinePromise = promise.thenApply(x -> {
|
|
||||||
|
var pipelinePromise = promise.thenApply(x -> {
|
||||||
ctx.releaseParams();
|
ctx.releaseParams();
|
||||||
return new LocalPipeline(ctx);
|
return (PipelineHook)new LocalPipeline(ctx);
|
||||||
});
|
});
|
||||||
|
|
||||||
var tailCall = ctx.onTailCall();
|
var tailCall = ctx.onTailCall().thenApply(pipeline -> pipeline.hook);
|
||||||
if (tailCall != null) {
|
|
||||||
pipelinePromise = tailCall.applyToEither(pipelinePromise, pipeline -> pipeline);
|
pipelinePromise = tailCall.applyToEither(pipelinePromise, pipeline -> pipeline);
|
||||||
}
|
|
||||||
|
|
||||||
return new VoidPromiseAndPipeline(
|
return new VoidPromiseAndPipeline(
|
||||||
promise,
|
promise,
|
||||||
|
@ -213,7 +212,7 @@ public final class Capability {
|
||||||
return CompletableFuture.completedFuture(this.resolved);
|
return CompletableFuture.completedFuture(this.resolved);
|
||||||
}
|
}
|
||||||
else if (this.resolveTask != null) {
|
else if (this.resolveTask != null) {
|
||||||
return this.resolveTask.thenApply(x -> this.resolved);
|
return this.resolveTask.thenApply(void_ -> this.resolved);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
return null;
|
return null;
|
||||||
|
@ -335,7 +334,7 @@ public final class Capability {
|
||||||
final MessageBuilder message = new MessageBuilder();
|
final MessageBuilder message = new MessageBuilder();
|
||||||
final long interfaceId;
|
final long interfaceId;
|
||||||
final short methodId;
|
final short methodId;
|
||||||
ClientHook client;
|
final ClientHook client;
|
||||||
|
|
||||||
LocalRequest(long interfaceId, short methodId, ClientHook client) {
|
LocalRequest(long interfaceId, short methodId, ClientHook client) {
|
||||||
this.interfaceId = interfaceId;
|
this.interfaceId = interfaceId;
|
||||||
|
@ -371,6 +370,7 @@ public final class Capability {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final class LocalPipeline implements PipelineHook {
|
private static final class LocalPipeline implements PipelineHook {
|
||||||
|
|
||||||
private final CallContextHook ctx;
|
private final CallContextHook ctx;
|
||||||
private final AnyPointer.Reader results;
|
private final AnyPointer.Reader results;
|
||||||
|
|
||||||
|
@ -396,7 +396,8 @@ public final class Capability {
|
||||||
|
|
||||||
private static class LocalCallContext implements CallContextHook {
|
private static class LocalCallContext implements CallContextHook {
|
||||||
|
|
||||||
final CompletableFuture<?> cancelAllowed;
|
final CompletableFuture<java.lang.Void> cancelAllowed;
|
||||||
|
CompletableFuture<AnyPointer.Pipeline> tailCallPipeline;
|
||||||
MessageBuilder request;
|
MessageBuilder request;
|
||||||
Response<AnyPointer.Reader> response;
|
Response<AnyPointer.Reader> response;
|
||||||
AnyPointer.Builder responseBuilder;
|
AnyPointer.Builder responseBuilder;
|
||||||
|
@ -404,7 +405,7 @@ public final class Capability {
|
||||||
|
|
||||||
LocalCallContext(MessageBuilder request,
|
LocalCallContext(MessageBuilder request,
|
||||||
ClientHook clientRef,
|
ClientHook clientRef,
|
||||||
CompletableFuture<?> cancelAllowed) {
|
CompletableFuture<java.lang.Void> cancelAllowed) {
|
||||||
this.request = request;
|
this.request = request;
|
||||||
this.clientRef = clientRef;
|
this.clientRef = clientRef;
|
||||||
this.cancelAllowed = cancelAllowed;
|
this.cancelAllowed = cancelAllowed;
|
||||||
|
@ -412,7 +413,7 @@ public final class Capability {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public AnyPointer.Reader getParams() {
|
public AnyPointer.Reader getParams() {
|
||||||
return request.getRoot(AnyPointer.factory).asReader();
|
return this.request.getRoot(AnyPointer.factory).asReader();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -437,20 +438,27 @@ public final class Capability {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public CompletableFuture<java.lang.Void> tailCall(RequestHook request) {
|
public CompletableFuture<java.lang.Void> tailCall(RequestHook request) {
|
||||||
// TODO implement tailCall
|
var result = this.directTailCall(request);
|
||||||
return null;
|
if (this.tailCallPipeline != null) {
|
||||||
|
this.tailCallPipeline.complete(new AnyPointer.Pipeline(result.pipeline));
|
||||||
|
}
|
||||||
|
return result.promise;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public CompletableFuture<PipelineHook> onTailCall() {
|
public CompletableFuture<AnyPointer.Pipeline> onTailCall() {
|
||||||
// TODO implement onTailCall
|
this.tailCallPipeline = new CompletableFuture<>();
|
||||||
return null;
|
return this.tailCallPipeline.copy();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ClientHook.VoidPromiseAndPipeline directTailCall(RequestHook request) {
|
public ClientHook.VoidPromiseAndPipeline directTailCall(RequestHook request) {
|
||||||
// TODO implement directTailCall
|
assert this.response == null: "Can't call tailCall() after initializing the results struct.";
|
||||||
return null;
|
var promise = request.send();
|
||||||
|
var voidPromise = promise._getResponse().thenAccept(tailResponse -> {
|
||||||
|
this.response = tailResponse;
|
||||||
|
});
|
||||||
|
return new ClientHook.VoidPromiseAndPipeline(voidPromise, promise.pipeline().hook);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,14 +20,15 @@ public class RemotePromise<Results>
|
||||||
|
|
||||||
public RemotePromise(CompletableFuture<Response<Results>> promise,
|
public RemotePromise(CompletableFuture<Response<Results>> promise,
|
||||||
AnyPointer.Pipeline pipeline) {
|
AnyPointer.Pipeline pipeline) {
|
||||||
super(promise.thenApply(response -> {
|
super(promise.thenApply(response -> response.getResults()));
|
||||||
//System.out.println("Got a response for remote promise " + promise.toString());
|
|
||||||
return response.getResults();
|
|
||||||
}));
|
|
||||||
this.response = promise;
|
this.response = promise;
|
||||||
this.pipeline = pipeline;
|
this.pipeline = pipeline;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CompletableFuture<Response<Results>> _getResponse() {
|
||||||
|
return this.response;
|
||||||
|
}
|
||||||
|
|
||||||
public AnyPointer.Pipeline pipeline() {
|
public AnyPointer.Pipeline pipeline() {
|
||||||
return this.pipeline;
|
return this.pipeline;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue