calling thread drives client side loop

This commit is contained in:
Vaci Koblizek 2020-11-23 20:50:11 +00:00
parent cdb719eed0
commit beec84a1bc
8 changed files with 158 additions and 333 deletions

View file

@ -255,7 +255,6 @@ final class RpcState<VatId> {
this.bootstrapFactory = bootstrapFactory; this.bootstrapFactory = bootstrapFactory;
this.connection = connection; this.connection = connection;
this.disconnectFulfiller = disconnectFulfiller; this.disconnectFulfiller = disconnectFulfiller;
startMessageLoop();
} }
@Override @Override
@ -391,36 +390,41 @@ final class RpcState<VatId> {
return pipeline.getPipelinedCap(new PipelineOp[0]); return pipeline.getPipelinedCap(new PipelineOp[0]);
} }
private void startMessageLoop() { /**
* Returns a CompletableFuture that, when complete, has processed one message.
*/
public CompletableFuture<java.lang.Void> pollOnce() {
if (isDisconnected()) { if (isDisconnected()) {
this.messageLoop.completeExceptionally(this.disconnected); this.messageLoop.completeExceptionally(this.disconnected);
return; return CompletableFuture.failedFuture(this.disconnected);
} }
var messageReader = this.connection.receiveIncomingMessage() return this.connection.receiveIncomingMessage()
.thenAccept(message -> { .thenAccept(message -> {
if (message == null) { if (message == null) {
this.disconnect(RpcException.disconnected("Peer disconnected")); this.disconnect(RpcException.disconnected("Peer disconnected"));
this.messageLoop.complete(null); this.messageLoop.complete(null);
return; return;
}
try {
this.handleMessage(message);
while (!this.lastEvals.isEmpty()) {
this.lastEvals.remove().call();
} }
try {
this.handleMessage(message);
while (!this.lastEvals.isEmpty()) {
this.lastEvals.remove().call();
}
}
catch (Throwable rpcExc) {
// either we received an Abort message from peer
// or internal RpcState is bad.
this.disconnect(rpcExc);
}
});
}
} public void runMessageLoop() {
catch (Throwable rpcExc) { this.pollOnce().thenRun(this::runMessageLoop).exceptionally(exc -> {
// either we received an Abort message from peer LOGGER.warning(() -> "Event loop exited: " + exc.getMessage());
// or internal RpcState is bad. return null;
this.disconnect(rpcExc); });
}
});
messageReader.thenRunAsync(this::startMessageLoop).exceptionallyCompose(
CompletableFuture::failedFuture);
} }
private void handleMessage(IncomingRpcMessage message) throws RpcException { private void handleMessage(IncomingRpcMessage message) throws RpcException {
@ -766,7 +770,6 @@ final class RpcState<VatId> {
} }
// This import is an unfulfilled promise. // This import is an unfulfilled promise.
assert !imp.promise.isDone();
switch (resolve.which()) { switch (resolve.which()) {
case CAP -> { case CAP -> {
var cap = receiveCap(resolve.getCap(), message.getAttachedFds()); var cap = receiveCap(resolve.getCap(), message.getAttachedFds());
@ -981,10 +984,8 @@ final class RpcState<VatId> {
var resolve = message.getBody().initAs(RpcProtocol.Message.factory).initResolve(); var resolve = message.getBody().initAs(RpcProtocol.Message.factory).initResolve();
resolve.setPromiseId(exportId); resolve.setPromiseId(exportId);
FromException(exc, resolve.initException()); FromException(exc, resolve.initException());
LOGGER.log(Level.INFO, this.toString() + ": > RESOLVE", exc.getMessage()); LOGGER.info(() -> this.toString() + ": > RESOLVE FAILED export=" + exportId + " msg=" + exc.getMessage());
message.send(); message.send();
// TODO disconnect?
}); });
} }
@ -1900,6 +1901,7 @@ final class RpcState<VatId> {
var replacementBrand = replacement.getBrand(); var replacementBrand = replacement.getBrand();
boolean isSameConnection = replacementBrand == RpcState.this; boolean isSameConnection = replacementBrand == RpcState.this;
if (isSameConnection) { if (isSameConnection) {
// We resolved to some other RPC capability hosted by the same peer.
var promise = replacement.whenMoreResolved(); var promise = replacement.whenMoreResolved();
if (promise != null) { if (promise != null) {
var other = (PromiseClient)replacement; var other = (PromiseClient)replacement;
@ -1936,6 +1938,7 @@ final class RpcState<VatId> {
// TODO Flow control // TODO Flow control
if (resolutionType == ResolutionType.REFLECTED && receivedCall && !isDisconnected()) { if (resolutionType == ResolutionType.REFLECTED && receivedCall && !isDisconnected()) {
LOGGER.fine(() -> RpcState.this.toString() + ": embargoing reflected capability " + this.toString());
// The new capability is hosted locally, not on the remote machine. And, we had made calls // The new capability is hosted locally, not on the remote machine. And, we had made calls
// to the promise. We need to make sure those calls echo back to us before we allow new // to the promise. We need to make sure those calls echo back to us before we allow new
// calls to go directly to the local capability, so we need to set a local embargo and send // calls to go directly to the local capability, so we need to set a local embargo and send

View file

@ -1,14 +1,14 @@
package org.capnproto; package org.capnproto;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
public class RpcSystem<VatId extends StructReader> { public class RpcSystem<VatId extends StructReader> {
private final VatNetwork<VatId> network; private final VatNetwork<VatId> network;
private final BootstrapFactory<VatId> bootstrapFactory; private final BootstrapFactory<VatId> bootstrapFactory;
private final Map<VatNetwork.Connection<VatId>, RpcState<VatId>> connections = new ConcurrentHashMap<>(); private final Map<VatNetwork.Connection<VatId>, RpcState<VatId>> connections = new HashMap<>();
public RpcSystem(VatNetwork<VatId> network) { public RpcSystem(VatNetwork<VatId> network) {
this(network, clientId -> new Capability.Client( this(network, clientId -> new Capability.Client(
@ -29,7 +29,6 @@ public class RpcSystem<VatId extends StructReader> {
BootstrapFactory<VatId> bootstrapFactory) { BootstrapFactory<VatId> bootstrapFactory) {
this.network = network; this.network = network;
this.bootstrapFactory = bootstrapFactory; this.bootstrapFactory = bootstrapFactory;
this.startAcceptLoop();
} }
public Capability.Client bootstrap(VatId vatId) { public Capability.Client bootstrap(VatId vatId) {
@ -45,7 +44,8 @@ public class RpcSystem<VatId extends StructReader> {
} }
public void accept(VatNetwork.Connection<VatId> connection) { public void accept(VatNetwork.Connection<VatId> connection) {
getConnectionState(connection); var state = getConnectionState(connection);
state.runMessageLoop();
} }
private RpcState<VatId> getConnectionState(VatNetwork.Connection<VatId> connection) { private RpcState<VatId> getConnectionState(VatNetwork.Connection<VatId> connection) {
@ -59,10 +59,17 @@ public class RpcSystem<VatId extends StructReader> {
}); });
} }
private void startAcceptLoop() { public void runOnce() {
for (var state: this.connections.values()) {
state.pollOnce().join();
return;
}
}
public void start() {
this.network.accept() this.network.accept()
.thenAccept(this::accept) .thenAccept(this::accept)
.thenRunAsync(this::startAcceptLoop); .thenRunAsync(this::start);
} }
public static <VatId extends StructReader> public static <VatId extends StructReader>

View file

@ -1,5 +1,6 @@
package org.capnproto; package org.capnproto;
import java.io.IOException;
import java.nio.channels.AsynchronousByteChannel; import java.nio.channels.AsynchronousByteChannel;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
@ -35,4 +36,11 @@ public class TwoPartyClient {
CompletableFuture<java.lang.Void> onDisconnect() { CompletableFuture<java.lang.Void> onDisconnect() {
return this.network.onDisconnect(); return this.network.onDisconnect();
} }
public <T> CompletableFuture<T> runUntil(CompletableFuture<T> done) {
while (!done.isDone()) {
this.rpcSystem.runOnce();
}
return done;
}
} }

View file

@ -1,8 +1,6 @@
package org.capnproto; package org.capnproto;
import java.nio.channels.AsynchronousServerSocketChannel; import java.nio.channels.*;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.channels.CompletionHandler;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
@ -10,41 +8,18 @@ import java.util.concurrent.CompletableFuture;
public class TwoPartyServer { public class TwoPartyServer {
private class AcceptedConnection { private class AcceptedConnection {
final AsynchronousSocketChannel connection; private final AsynchronousByteChannel connection;
final TwoPartyVatNetwork network; private final TwoPartyVatNetwork network;
final RpcSystem<RpcTwoPartyProtocol.VatId.Reader> rpcSystem; private final RpcSystem<RpcTwoPartyProtocol.VatId.Reader> rpcSystem;
AcceptedConnection(Capability.Client bootstrapInterface, AsynchronousSocketChannel connection) { AcceptedConnection(Capability.Client bootstrapInterface, AsynchronousByteChannel connection) {
this.connection = connection; this.connection = connection;
this.network = new TwoPartyVatNetwork(this.connection, RpcTwoPartyProtocol.Side.SERVER); this.network = new TwoPartyVatNetwork(this.connection, RpcTwoPartyProtocol.Side.SERVER);
this.rpcSystem = new RpcSystem<>(network, bootstrapInterface); this.rpcSystem = new RpcSystem<>(network, bootstrapInterface);
this.rpcSystem.start();
} }
} }
class ConnectionReceiver {
final AsynchronousServerSocketChannel listener;
ConnectionReceiver(AsynchronousServerSocketChannel listener) {
this.listener = listener;
}
CompletableFuture<AsynchronousSocketChannel> accept() {
CompletableFuture<AsynchronousSocketChannel> result = new CompletableFuture<>();
this.listener.accept(null, new CompletionHandler<>() {
@Override
public void completed(AsynchronousSocketChannel channel, Object attachment) {
result.complete(channel);
}
@Override
public void failed(Throwable exc, Object attachment) {
result.completeExceptionally(exc);
}
});
return result.copy();
}
}
private final Capability.Client bootstrapInterface; private final Capability.Client bootstrapInterface;
private final List<AcceptedConnection> connections = new ArrayList<>(); private final List<AcceptedConnection> connections = new ArrayList<>();
@ -65,14 +40,20 @@ public class TwoPartyServer {
} }
public CompletableFuture<java.lang.Void> listen(AsynchronousServerSocketChannel listener) { public CompletableFuture<java.lang.Void> listen(AsynchronousServerSocketChannel listener) {
return this.listen(wrapListenSocket(listener)); var result = new CompletableFuture<AsynchronousSocketChannel>();
} listener.accept(null, new CompletionHandler<>() {
@Override
public void completed(AsynchronousSocketChannel channel, Object attachment) {
accept(channel);
result.complete(null);
}
CompletableFuture<java.lang.Void> listen(ConnectionReceiver listener) { @Override
return listener.accept().thenCompose(channel -> { public void failed(Throwable exc, Object attachment) {
this.accept(channel); result.completeExceptionally(exc);
return this.listen(listener); }
}); });
return result.thenCompose(void_ -> this.listen(listener));
} }
CompletableFuture<java.lang.Void> drain() { CompletableFuture<java.lang.Void> drain() {
@ -82,8 +63,4 @@ public class TwoPartyServer {
} }
return loop; return loop;
} }
ConnectionReceiver wrapListenSocket(AsynchronousServerSocketChannel channel) {
return new ConnectionReceiver(channel);
}
} }

View file

@ -9,17 +9,12 @@ public class TwoPartyVatNetwork
implements VatNetwork<RpcTwoPartyProtocol.VatId.Reader>, implements VatNetwork<RpcTwoPartyProtocol.VatId.Reader>,
VatNetwork.Connection<RpcTwoPartyProtocol.VatId.Reader> { VatNetwork.Connection<RpcTwoPartyProtocol.VatId.Reader> {
public interface MessageTap {
void incoming(IncomingRpcMessage message, RpcTwoPartyProtocol.Side side);
}
private CompletableFuture<java.lang.Void> previousWrite = CompletableFuture.completedFuture(null); private CompletableFuture<java.lang.Void> previousWrite = CompletableFuture.completedFuture(null);
private final CompletableFuture<java.lang.Void> disconnectPromise = new CompletableFuture<>(); private final CompletableFuture<java.lang.Void> disconnectPromise = new CompletableFuture<>();
private final AsynchronousByteChannel channel; private final AsynchronousByteChannel channel;
private final RpcTwoPartyProtocol.Side side; private final RpcTwoPartyProtocol.Side side;
private final MessageBuilder peerVatId = new MessageBuilder(4); private final MessageBuilder peerVatId = new MessageBuilder(4);
private boolean accepted; private boolean accepted;
private MessageTap tap;
public TwoPartyVatNetwork(AsynchronousByteChannel channel, RpcTwoPartyProtocol.Side side) { public TwoPartyVatNetwork(AsynchronousByteChannel channel, RpcTwoPartyProtocol.Side side) {
this.channel = channel; this.channel = channel;
@ -65,26 +60,9 @@ public class TwoPartyVatNetwork
@Override @Override
public CompletableFuture<IncomingRpcMessage> receiveIncomingMessage() { public CompletableFuture<IncomingRpcMessage> receiveIncomingMessage() {
var message = Serialize.readAsync(channel) return Serialize.readAsync(channel)
.thenApply(reader -> (IncomingRpcMessage) new IncomingMessage(reader)) .thenApply(reader -> (IncomingRpcMessage) new IncomingMessage(reader))
.exceptionally(exc -> null); .exceptionally(exc -> null);
// send to message tap
if (this.tap != null) {
message = message.whenComplete((msg, exc) -> {
if (this.tap == null || msg == null) {
return;
}
var side = this.side == RpcTwoPartyProtocol.Side.CLIENT
? RpcTwoPartyProtocol.Side.SERVER
: RpcTwoPartyProtocol.Side.CLIENT;
this.tap.incoming(msg, side);
});
}
return message;
} }
@Override @Override
@ -109,10 +87,6 @@ public class TwoPartyVatNetwork
return side; return side;
} }
public void setTap(MessageTap tap) {
this.tap = tap;
}
public Connection<RpcTwoPartyProtocol.VatId.Reader> asConnection() { public Connection<RpcTwoPartyProtocol.VatId.Reader> asConnection() {
return this; return this;
} }
@ -120,8 +94,7 @@ public class TwoPartyVatNetwork
public CompletableFuture<java.lang.Void> onDisconnect() { public CompletableFuture<java.lang.Void> onDisconnect() {
return this.disconnectPromise.copy(); return this.disconnectPromise.copy();
} }
public CompletableFuture<Connection<RpcTwoPartyProtocol.VatId.Reader>> accept() { public CompletableFuture<Connection<RpcTwoPartyProtocol.VatId.Reader>> accept() {
if (side == RpcTwoPartyProtocol.Side.SERVER & !accepted) { if (side == RpcTwoPartyProtocol.Side.SERVER & !accepted) {
accepted = true; accepted = true;

View file

@ -1,161 +0,0 @@
package org.capnproto;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.CompletableFuture;
public class RpcStateTest {
class TestConnection implements VatNetwork.Connection<RpcTwoPartyProtocol.VatId.Reader> {
private CompletableFuture<IncomingRpcMessage> nextIncomingMessage = new CompletableFuture<>();
private final CompletableFuture<RpcState.DisconnectInfo> disconnect = new CompletableFuture<>();
@Override
public OutgoingRpcMessage newOutgoingMessage(int firstSegmentWordSize) {
var message = new MessageBuilder();
return new OutgoingRpcMessage() {
@Override
public AnyPointer.Builder getBody() {
return message.getRoot(AnyPointer.factory);
}
@Override
public void send() {
sent.add(this);
var msg = new IncomingRpcMessage() {
@Override
public AnyPointer.Reader getBody() {
return message.getRoot(AnyPointer.factory).asReader();
}
};
if (nextIncomingMessage.isDone()) {
nextIncomingMessage = CompletableFuture.completedFuture(msg);
}
else {
nextIncomingMessage.complete(msg);
}
}
@Override
public int sizeInWords() {
return 0;
}
};
}
@Override
public CompletableFuture<IncomingRpcMessage> receiveIncomingMessage() {
return this.nextIncomingMessage;
}
@Override
public CompletableFuture<java.lang.Void> shutdown() {
this.disconnect.complete(null);
return this.disconnect.thenRun(() -> {});
}
@Override
public RpcTwoPartyProtocol.VatId.Reader getPeerVatId() {
return null;
}
@Override
public void close() {
}
}
TestConnection connection;
Capability.Client bootstrapInterface;
RpcState rpc;
final Queue<OutgoingRpcMessage> sent = new ArrayDeque<>();
@Before
public void setUp() {
this.connection = new TestConnection();
this.bootstrapInterface = new Capability.Client(Capability.newNullCap());
var bootstrapFactory = new BootstrapFactory<RpcTwoPartyProtocol.VatId.Reader>() {
@Override
public Capability.Client createFor(RpcTwoPartyProtocol.VatId.Reader clientId) {
return bootstrapInterface;
}
};
this.rpc = new RpcState<>(bootstrapFactory, connection, connection.disconnect);
}
@After
public void tearDown() {
this.connection = null;
this.rpc = null;
this.sent.clear();
}
/*
@Test
public void handleUnimplemented() {
var msg = this.connection.newOutgoingMessage(0);
var root = msg.getBody().initAs(RpcProtocol.Message.factory).initUnimplemented();
var resolve = root.initResolve();
RpcState.FromException(new Exception("foo"), resolve.initException());
msg.send();
Assert.assertFalse(sent.isEmpty());
}
*/
@Test
public void handleAbort() {
var msg = this.connection.newOutgoingMessage(0);
var builder = msg.getBody().initAs(RpcProtocol.Message.factory);
RpcState.FromException(RpcException.failed("Test abort"), builder.initAbort());
msg.send();
}
@Test
public void handleBootstrap() {
var msg = this.connection.newOutgoingMessage(0);
var bootstrap = msg.getBody().initAs(RpcProtocol.Message.factory).initBootstrap();
bootstrap.setQuestionId(0);
msg.send();
Assert.assertEquals(2, sent.size());
sent.remove(); // bootstrap
var reply = sent.remove(); // return
var rpcMsg = reply.getBody().getAs(RpcProtocol.Message.factory);
Assert.assertEquals(RpcProtocol.Message.Which.RETURN, rpcMsg.which());
var ret = rpcMsg.getReturn();
Assert.assertEquals(ret.getAnswerId(), 0);
Assert.assertEquals(RpcProtocol.Return.Which.RESULTS, ret.which());
var results = ret.getResults();
Assert.assertEquals(results.getCapTable().size(), 1); // got a capability!
Assert.assertTrue(results.hasContent());
}
@Test
public void handleCall() {
}
@Test
public void handleReturn() {
}
@Test
public void handleFinish() {
}
@Test
public void handleResolve() {
}
@Test
public void handleDisembargo() {
}
}

View file

@ -31,7 +31,6 @@ import java.util.Map;
import java.util.Queue; import java.util.Queue;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
public class RpcTest { public class RpcTest {
@ -268,6 +267,7 @@ public class RpcTest {
this.serverNetwork = this.network.add("server"); this.serverNetwork = this.network.add("server");
this.rpcClient = RpcSystem.makeRpcClient(this.clientNetwork); this.rpcClient = RpcSystem.makeRpcClient(this.clientNetwork);
this.rpcServer = RpcSystem.makeRpcServer(this.serverNetwork, bootstrapFactory); this.rpcServer = RpcSystem.makeRpcServer(this.serverNetwork, bootstrapFactory);
this.rpcServer.start();
} }
Capability.Client connect(Test.TestSturdyRefObjectId.Tag tag) { Capability.Client connect(Test.TestSturdyRefObjectId.Tag tag) {
@ -278,6 +278,13 @@ public class RpcTest {
ref.getObjectId().initAs(Test.TestSturdyRefObjectId.factory).setTag(tag); ref.getObjectId().initAs(Test.TestSturdyRefObjectId.factory).setTag(tag);
return rpcClient.bootstrap(ref.asReader()); return rpcClient.bootstrap(ref.asReader());
} }
public <T> CompletableFuture<T> runUntil(CompletableFuture<T> done) {
while (!done.isDone()) {
this.rpcClient.runOnce();
}
return done;
}
} }
static BootstrapFactory<Test.TestSturdyRef.Reader> bootstrapFactory = new BootstrapFactory<>() { static BootstrapFactory<Test.TestSturdyRef.Reader> bootstrapFactory = new BootstrapFactory<>() {
@ -321,7 +328,6 @@ public class RpcTest {
this.context = null; this.context = null;
} }
@org.junit.Test @org.junit.Test
public void testBasic() { public void testBasic() {
var client = new Test.TestInterface.Client(context.connect(Test.TestSturdyRefObjectId.Tag.TEST_INTERFACE)); var client = new Test.TestInterface.Client(context.connect(Test.TestSturdyRefObjectId.Tag.TEST_INTERFACE));
@ -343,12 +349,15 @@ public class RpcTest {
RpcTestUtil.initTestMessage(request2.getParams().initS()); RpcTestUtil.initTestMessage(request2.getParams().initS());
var promise2 = request2.send(); var promise2 = request2.send();
var response1 = promise1.join(); var response1 = this.context.runUntil(promise1).join();
Assert.assertEquals("foo", response1.getX().toString()); Assert.assertEquals("foo", response1.getX().toString());
var response2 = promise2.join(); while (!promise2.isDone()) {
promise3.join(); this.context.rpcClient.runOnce();
}
var response2 = this.context.runUntil(promise2).join();
this.context.runUntil(promise3).join();
Assert.assertTrue(ref.barFailed); Assert.assertTrue(ref.barFailed);
} }
@ -376,10 +385,10 @@ public class RpcTest {
//Assert.assertEquals(0, chainedCallCount.value()); //Assert.assertEquals(0, chainedCallCount.value());
var response = pipelinePromise.join(); var response = this.context.runUntil(pipelinePromise).join();
Assert.assertEquals("bar", response.getX().toString()); Assert.assertEquals("bar", response.getX().toString());
var response2 = pipelinePromise2.join(); var response2 = this.context.runUntil(pipelinePromise2).join();
RpcTestUtil.checkTestMessage(response2); RpcTestUtil.checkTestMessage(response2);
Assert.assertEquals(1, chainedCallCount.value()); Assert.assertEquals(1, chainedCallCount.value());
@ -389,15 +398,15 @@ public class RpcTest {
public void testRelease() { public void testRelease() {
var client = new Test.TestMoreStuff.Client(context.connect(Test.TestSturdyRefObjectId.Tag.TEST_MORE_STUFF)); var client = new Test.TestMoreStuff.Client(context.connect(Test.TestSturdyRefObjectId.Tag.TEST_MORE_STUFF));
var handle1 = client.getHandleRequest().send().join().getHandle(); var handle1 = this.context.runUntil(client.getHandleRequest().send()).join().getHandle();
var promise = client.getHandleRequest().send(); var promise = client.getHandleRequest().send();
var handle2 = promise.join().getHandle(); var handle2 = this.context.runUntil(promise).join().getHandle();
handle1 = null; handle1 = null;
handle2 = null; handle2 = null;
System.gc(); System.gc();
client.echoRequest().send().join(); this.context.runUntil(client.echoRequest().send()).join();
} }
@org.junit.Test @org.junit.Test
@ -421,15 +430,15 @@ public class RpcTest {
// Make sure getCap() has been called on the server side by sending another call and waiting // Make sure getCap() has been called on the server side by sending another call and waiting
// for it. // for it.
Assert.assertEquals(2, client.getCallSequenceRequest().send().join().getN()); Assert.assertEquals(2, this.context.runUntil(client.getCallSequenceRequest().send()).join().getN());
//Assert.assertEquals(3, context.restorer.callCount); //Assert.assertEquals(3, context.restorer.callCount);
// OK, now fulfill the local promise. // OK, now fulfill the local promise.
paf.complete(new Test.TestInterface.Client(new RpcTestUtil.TestInterfaceImpl(chainedCallCount))); paf.complete(new Test.TestInterface.Client(new RpcTestUtil.TestInterfaceImpl(chainedCallCount)));
// We should now be able to wait for getCap() to finish. // We should now be able to wait for getCap() to finish.
Assert.assertEquals("bar", promise.join().getS().toString()); Assert.assertEquals("bar", this.context.runUntil(promise).join().getS().toString());
Assert.assertEquals("bar", promise2.join().getS().toString()); Assert.assertEquals("bar", this.context.runUntil(promise2).join().getS().toString());
//Assert.assertEquals(3, context.restorer.callCount); //Assert.assertEquals(3, context.restorer.callCount);
Assert.assertEquals(2, chainedCallCount.value()); Assert.assertEquals(2, chainedCallCount.value());
@ -447,16 +456,16 @@ public class RpcTest {
var promise = request.send(); var promise = request.send();
var dependentCall0 = promise.getC().getCallSequenceRequest().send(); var dependentCall0 = promise.getC().getCallSequenceRequest().send();
var response = promise.join(); var response = this.context.runUntil(promise).join();
Assert.assertEquals(456, response.getI()); Assert.assertEquals(456, response.getI());
var dependentCall1 = promise.getC().getCallSequenceRequest().send(); var dependentCall1 = promise.getC().getCallSequenceRequest().send();
Assert.assertEquals(0, dependentCall0.join().getN()); Assert.assertEquals(0, this.context.runUntil(dependentCall0).join().getN());
Assert.assertEquals(1, dependentCall1.join().getN()); Assert.assertEquals(1, this.context.runUntil(dependentCall1).join().getN());
var dependentCall2 = response.getC().getCallSequenceRequest().send(); var dependentCall2 = response.getC().getCallSequenceRequest().send();
Assert.assertEquals(2, dependentCall2.join().getN()); Assert.assertEquals(2, this.context.runUntil(dependentCall2).join().getN());
Assert.assertEquals(1, calleeCallCount.value()); Assert.assertEquals(1, calleeCallCount.value());
} }
@ -482,26 +491,26 @@ public class RpcTest {
var call0 = getCallSequence(pipeline, 0); var call0 = getCallSequence(pipeline, 0);
var call1 = getCallSequence(pipeline, 1); var call1 = getCallSequence(pipeline, 1);
earlyCall.join(); this.context.runUntil(earlyCall).join();
var call2 = getCallSequence(pipeline, 2); var call2 = getCallSequence(pipeline, 2);
var resolved = echo.join().getCap(); var resolved = this.context.runUntil(echo).join().getCap();
var call3 = getCallSequence(pipeline, 3); var call3 = getCallSequence(pipeline, 3);
var call4 = getCallSequence(pipeline, 4); var call4 = getCallSequence(pipeline, 4);
var call5 = getCallSequence(pipeline, 5); var call5 = getCallSequence(pipeline, 5);
Assert.assertEquals(0, call0.join().getN()); Assert.assertEquals(0, this.context.runUntil(call0).join().getN());
Assert.assertEquals(1, call1.join().getN()); Assert.assertEquals(1, this.context.runUntil(call1).join().getN());
Assert.assertEquals(2, call2.join().getN()); Assert.assertEquals(2, this.context.runUntil(call2).join().getN());
Assert.assertEquals(3, call3.join().getN()); Assert.assertEquals(3, this.context.runUntil(call3).join().getN());
Assert.assertEquals(4, call4.join().getN()); Assert.assertEquals(4, this.context.runUntil(call4).join().getN());
Assert.assertEquals(5, call5.join().getN()); Assert.assertEquals(5, this.context.runUntil(call5).join().getN());
} }
@org.junit.Test @org.junit.Test
public void testCallBrokenPromise() throws ExecutionException, InterruptedException { public void testCallBrokenPromise() {
var client = new Test.TestMoreStuff.Client(context.connect(Test.TestSturdyRefObjectId.Tag.TEST_MORE_STUFF)); var client = new Test.TestMoreStuff.Client(context.connect(Test.TestSturdyRefObjectId.Tag.TEST_MORE_STUFF));
var paf = new CompletableFuture<Test.TestInterface.Client>(); var paf = new CompletableFuture<Test.TestInterface.Client>();
@ -509,7 +518,7 @@ public class RpcTest {
{ {
var req = client.holdRequest(); var req = client.holdRequest();
req.getParams().setCap(paf); req.getParams().setCap(paf);
req.send().join(); this.context.runUntil(req.send()).join();
} }
AtomicBoolean returned = new AtomicBoolean(false); AtomicBoolean returned = new AtomicBoolean(false);
@ -524,10 +533,11 @@ public class RpcTest {
Assert.assertFalse(returned.get()); Assert.assertFalse(returned.get());
paf.completeExceptionally(new Exception("foo")); paf.completeExceptionally(new Exception("foo"));
this.context.runUntil(req);
Assert.assertTrue(returned.get()); Assert.assertTrue(returned.get());
// Verify that we are still connected // Verify that we are still connected
getCallSequence(client, 1).get(); this.context.runUntil(getCallSequence(client, 1)).join();
} }
@org.junit.Test @org.junit.Test
@ -581,24 +591,24 @@ public class RpcTest {
var call0 = getCallSequence(pipeline, 0); var call0 = getCallSequence(pipeline, 0);
var call1 = getCallSequence(pipeline, 1); var call1 = getCallSequence(pipeline, 1);
earlyCall.join(); this.context.runUntil(earlyCall).join();
var call2 = getCallSequence(pipeline, 2); var call2 = getCallSequence(pipeline, 2);
var resolved = echo.join().getCap(); var resolved = this.context.runUntil(echo).join().getCap();
var call3 = getCallSequence(pipeline, 3); var call3 = getCallSequence(pipeline, 3);
var call4 = getCallSequence(pipeline, 4); var call4 = getCallSequence(pipeline, 4);
var call5 = getCallSequence(pipeline, 5); var call5 = getCallSequence(pipeline, 5);
Assert.assertEquals(0, call0.join().getN()); Assert.assertEquals(0, this.context.runUntil(call0).join().getN());
Assert.assertEquals(1, call1.join().getN()); Assert.assertEquals(1, this.context.runUntil(call1).join().getN());
Assert.assertEquals(2, call2.join().getN()); Assert.assertEquals(2, this.context.runUntil(call2).join().getN());
Assert.assertEquals(3, call3.join().getN()); Assert.assertEquals(3, this.context.runUntil(call3).join().getN());
Assert.assertEquals(4, call4.join().getN()); Assert.assertEquals(4, this.context.runUntil(call4).join().getN());
Assert.assertEquals(5, call5.join().getN()); Assert.assertEquals(5, this.context.runUntil(call5).join().getN());
int unwrappedAt = unwrap.join(); int unwrappedAt = this.context.runUntil(unwrap).join();
Assert.assertTrue(unwrappedAt >= 0); Assert.assertTrue(unwrappedAt >= 0);
} }
} }

View file

@ -7,39 +7,44 @@ import org.junit.Before;
import java.io.IOException; import java.io.IOException;
import java.nio.channels.AsynchronousByteChannel; import java.nio.channels.AsynchronousByteChannel;
import java.nio.channels.AsynchronousChannelGroup;
import java.nio.channels.AsynchronousServerSocketChannel; import java.nio.channels.AsynchronousServerSocketChannel;
import java.nio.channels.AsynchronousSocketChannel; import java.nio.channels.AsynchronousSocketChannel;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.Consumer; import java.util.function.Consumer;
@SuppressWarnings({"OverlyCoupledMethod", "OverlyLongMethod"})
public class TwoPartyTest { public class TwoPartyTest {
static final class PipeThread { static final class PipeThread {
Thread thread; Thread thread;
AsynchronousByteChannel channel; AsynchronousSocketChannel channel;
static PipeThread newPipeThread(Consumer<AsynchronousByteChannel> startFunc) throws Exception { }
var pipeThread = new PipeThread();
var serverAcceptSocket = AsynchronousServerSocketChannel.open();
serverAcceptSocket.bind(null);
var clientSocket = AsynchronousSocketChannel.open();
pipeThread.thread = new Thread(() -> { private AsynchronousChannelGroup group;
try {
var serverSocket = serverAcceptSocket.accept().get();
startFunc.accept(serverSocket);
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
}
});
pipeThread.thread.start();
pipeThread.thread.setName("TwoPartyTest server");
clientSocket.connect(serverAcceptSocket.getLocalAddress()).get(); PipeThread newPipeThread(Consumer<AsynchronousSocketChannel> startFunc) throws Exception {
pipeThread.channel = clientSocket; var pipeThread = new PipeThread();
return pipeThread; var serverAcceptSocket = AsynchronousServerSocketChannel.open(this.group);
} serverAcceptSocket.bind(null);
var clientSocket = AsynchronousSocketChannel.open();
pipeThread.thread = new Thread(() -> {
try {
var serverSocket = serverAcceptSocket.accept().get();
startFunc.accept(serverSocket);
} catch (InterruptedException | ExecutionException exc) {
exc.printStackTrace();
}
});
pipeThread.thread.start();
pipeThread.thread.setName("TwoPartyTest server");
clientSocket.connect(serverAcceptSocket.getLocalAddress()).get();
pipeThread.channel = clientSocket;
return pipeThread;
} }
PipeThread runServer(Capability.Server bootstrapInterface) throws Exception { PipeThread runServer(Capability.Server bootstrapInterface) throws Exception {
@ -47,19 +52,22 @@ public class TwoPartyTest {
} }
PipeThread runServer(Capability.Client bootstrapInterface) throws Exception { PipeThread runServer(Capability.Client bootstrapInterface) throws Exception {
return PipeThread.newPipeThread(channel -> { return newPipeThread(channel -> {
var network = new TwoPartyVatNetwork(channel, RpcTwoPartyProtocol.Side.SERVER); var network = new TwoPartyVatNetwork(channel, RpcTwoPartyProtocol.Side.SERVER);
var system = new RpcSystem<>(network, bootstrapInterface); var system = new RpcSystem<>(network, bootstrapInterface);
system.start();
network.onDisconnect().join(); network.onDisconnect().join();
}); });
} }
@Before @Before
public void setUp() { public void setUp() throws IOException {
this.group = AsynchronousChannelGroup.withThreadPool(Executors.newFixedThreadPool(5));
} }
@After @After
public void tearDown() { public void tearDown() {
this.group.shutdown();
} }
@org.junit.Test @org.junit.Test
@ -68,7 +76,7 @@ public class TwoPartyTest {
var rpcClient = new TwoPartyClient(pipe.channel); var rpcClient = new TwoPartyClient(pipe.channel);
var client = rpcClient.bootstrap(); var client = rpcClient.bootstrap();
var resolved = client.whenResolved(); var resolved = client.whenResolved();
resolved.get(); rpcClient.runUntil(resolved).join();
} }
@org.junit.Test @org.junit.Test
@ -93,11 +101,11 @@ public class TwoPartyTest {
.thenAccept(results -> Assert.fail("Expected bar() to fail")) .thenAccept(results -> Assert.fail("Expected bar() to fail"))
.exceptionally(exc -> null); .exceptionally(exc -> null);
var response1 = promise1.join(); var response1 = rpcClient.runUntil(promise1).join();
Assert.assertEquals("foo", response1.getX().toString()); Assert.assertEquals("foo", response1.getX().toString());
promise2.join(); rpcClient.runUntil(promise2).join();
promise3.join(); rpcClient.runUntil(promise3).join();
Assert.assertEquals(2, callCount.value()); Assert.assertEquals(2, callCount.value());
} }
@ -136,10 +144,10 @@ public class TwoPartyTest {
//Assert.assertEquals(0, chainedCallCount.value()); //Assert.assertEquals(0, chainedCallCount.value());
var response = pipelinePromise.join(); var response = rpcClient.runUntil(pipelinePromise).join();
Assert.assertEquals("bar", response.getX().toString()); Assert.assertEquals("bar", response.getX().toString());
var response2 = pipelinePromise2.join(); var response2 = rpcClient.runUntil(pipelinePromise2).join();
RpcTestUtil.checkTestMessage(response2); RpcTestUtil.checkTestMessage(response2);
Assert.assertEquals(1, chainedCallCount.value()); Assert.assertEquals(1, chainedCallCount.value());
@ -147,7 +155,7 @@ public class TwoPartyTest {
// disconnect the client // disconnect the client
((AsynchronousSocketChannel)pipe.channel).shutdownOutput(); ((AsynchronousSocketChannel)pipe.channel).shutdownOutput();
rpcClient.onDisconnect().join(); rpcClient.runUntil(rpcClient.onDisconnect()).join();
{ {
// Use the now-broken capability. // Use the now-broken capability.