add resolve test and fix handleResolve bugs

This commit is contained in:
Vaci Koblizek 2020-11-06 15:32:20 +00:00
parent d2d851d630
commit 0ce52fe135
5 changed files with 110 additions and 18 deletions

View file

@ -57,7 +57,7 @@ public final class Capability {
ClientHook getHook(); ClientHook getHook();
default CompletionStage<java.lang.Void> whenResolved() { default CompletableFuture<java.lang.Void> whenResolved() {
return this.getHook().whenResolved(); return this.getHook().whenResolved();
} }

View file

@ -434,7 +434,7 @@ final class RpcState<VatId> {
private void handleMessage(IncomingRpcMessage message) throws RpcException { private void handleMessage(IncomingRpcMessage message) throws RpcException {
var reader = message.getBody().getAs(RpcProtocol.Message.factory); var reader = message.getBody().getAs(RpcProtocol.Message.factory);
//System.out.println(reader.which()); //System.out.println(this + ": Received message: " + reader.which());
switch (reader.which()) { switch (reader.which()) {
case UNIMPLEMENTED: case UNIMPLEMENTED:
handleUnimplemented(reader.getUnimplemented()); handleUnimplemented(reader.getUnimplemented());
@ -744,26 +744,44 @@ final class RpcState<VatId> {
} }
private void handleResolve(IncomingRpcMessage message, RpcProtocol.Resolve.Reader resolve) { private void handleResolve(IncomingRpcMessage message, RpcProtocol.Resolve.Reader resolve) {
ClientHook cap = null;
Throwable exc = null;
switch (resolve.which()) {
case CAP:
cap = receiveCap(resolve.getCap(), message.getAttachedFds());
break;
case EXCEPTION:
exc = RpcException.toException(resolve.getException());
break;
default:
assert false: "Unknown 'Resolve' type.";
return;
}
var importId = resolve.getPromiseId();
var imp = this.imports.find(resolve.getPromiseId()); var imp = this.imports.find(resolve.getPromiseId());
if (imp == null) { if (imp == null) {
return; return;
} }
assert imp.importClient == null : "Import already resolved."; if (imp.promise != null) {
assert !imp.promise.isDone();
switch (resolve.which()) { // This import is an unfulfilled promise.
case CAP: if (exc != null) {
var cap = receiveCap(resolve.getCap(), message.getAttachedFds());
imp.promise.complete(cap);
break;
case EXCEPTION:
var exc = RpcException.toException(resolve.getException());
imp.promise.completeExceptionally(exc); imp.promise.completeExceptionally(exc);
break; }
default: else {
assert false; imp.promise.complete(cap);
}
return; return;
} }
// It appears this is a valid entry on the import table, but was not expected to be a
// promise.
assert imp.importClient == null : "Import already resolved.";
} }
private void handleRelease(RpcProtocol.Release.Reader release) { private void handleRelease(RpcProtocol.Release.Reader release) {
@ -898,7 +916,7 @@ final class RpcState<VatId> {
var wrapped = inner.whenMoreResolved(); var wrapped = inner.whenMoreResolved();
if (wrapped != null) { if (wrapped != null) {
// This is a promise. Arrange for the `Resolve` message to be sent later. // This is a promise. Arrange for the `Resolve` message to be sent later.
export.resolveOp = resolveExportedPromise(export.exportId, wrapped); export.resolveOp = this.resolveExportedPromise(export.exportId, wrapped);
descriptor.setSenderPromise(export.exportId); descriptor.setSenderPromise(export.exportId);
} }
else { else {
@ -916,6 +934,7 @@ final class RpcState<VatId> {
resolution = this.getInnermostClient(resolution); resolution = this.getInnermostClient(resolution);
var exp = exports.find(exportId); var exp = exports.find(exportId);
assert exp != null;
exportsByCap.remove(exp.clientHook); exportsByCap.remove(exp.clientHook);
exp.clientHook = resolution; exp.clientHook = resolution;

View file

@ -24,14 +24,13 @@ package org.capnproto;
import org.capnproto.test.Test; import org.capnproto.test.Test;
import org.junit.Assert; import org.junit.Assert;
import org.junit.*;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
class Counter { class Counter {
private int count = 0; private int count = 0;
void inc() { count++; } int inc() { return count++; }
int value() { return count; } int value() { return count; }
} }

View file

@ -369,8 +369,42 @@ public class RpcTest {
handle1 = null; handle1 = null;
handle2 = null; handle2 = null;
}
@org.junit.Test
public void testPromiseResolve() {
var context = new TestContext(bootstrapFactory);
var client = new Test.TestMoreStuff.Client(context.connect(Test.TestSturdyRefObjectId.Tag.TEST_MORE_STUFF));
var chainedCallCount = new Counter();
var request = client.callFooRequest();
var request2 = client.callFooWhenResolvedRequest();
var paf = new CompletableFuture<Test.TestInterface.Client>();
{
request.getParams().setCap(new Test.TestInterface.Client(paf.copy()));
request2.getParams().setCap(new Test.TestInterface.Client(paf.copy()));
}
var promise = request.send();
var promise2 = request2.send();
// Make sure getCap() has been called on the server side by sending another call and waiting
// for it.
Assert.assertEquals(2, client.getCallSequenceRequest().send().join().getN());
//Assert.assertEquals(3, context.restorer.callCount);
// OK, now fulfill the local promise.
paf.complete(new Test.TestInterface.Client(new TestUtil.TestInterfaceImpl(chainedCallCount)));
// We should now be able to wait for getCap() to finish.
Assert.assertEquals("bar", promise.join().getS().toString());
Assert.assertEquals("bar", promise2.join().getS().toString());
//Assert.assertEquals(3, context.restorer.callCount);
Assert.assertEquals(2, chainedCallCount.value());
} }
} }

View file

@ -114,6 +114,46 @@ class TestUtil {
context.getResults().setHandle(new HandleImpl(this.handleCount)); context.getResults().setHandle(new HandleImpl(this.handleCount));
return READY_NOW; return READY_NOW;
} }
@Override
protected CompletableFuture<java.lang.Void> getCallSequence(CallContext<Test.TestCallOrder.GetCallSequenceParams.Reader, Test.TestCallOrder.GetCallSequenceResults.Builder> context) {
var result = context.getResults();
result.setN(this.callCount.inc());
return READY_NOW;
}
@Override
protected CompletableFuture<java.lang.Void> callFoo(CallContext<Test.TestMoreStuff.CallFooParams.Reader, Test.TestMoreStuff.CallFooResults.Builder> context) {
this.callCount.inc();
var params = context.getParams();
var cap = params.getCap();
var request = cap.fooRequest();
request.getParams().setI(123);
request.getParams().setJ(true);
return request.send().thenAccept(response -> {
Assert.assertEquals("foo", response.getX().toString());
context.getResults().setS("bar");
});
}
@Override
protected CompletableFuture<java.lang.Void> callFooWhenResolved(CallContext<Test.TestMoreStuff.CallFooWhenResolvedParams.Reader, Test.TestMoreStuff.CallFooWhenResolvedResults.Builder> context) {
this.callCount.inc();
var params = context.getParams();
var cap = params.getCap();
return cap.whenResolved().thenCompose(void_ -> {
var request = cap.fooRequest();
request.getParams().setI(123);
request.getParams().setJ(true);
return request.send().thenAccept(response -> {
Assert.assertEquals("foo", response.getX().toString());
context.getResults().setS("bar");
});
});
}
} }
static class TestTailCalleeImpl extends Test.TestTailCallee.Server { static class TestTailCalleeImpl extends Test.TestTailCallee.Server {