refactor connection and disconnection

This commit is contained in:
Vaci Koblizek 2020-11-13 17:57:49 +00:00
parent 37aa04b262
commit ad17a4c148
12 changed files with 161 additions and 223 deletions

View file

@ -5,9 +5,11 @@ import java.io.PrintWriter;
import java.io.StringWriter;
import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
import java.nio.channels.ClosedChannelException;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.function.Consumer;
@ -288,7 +290,7 @@ final class RpcState<VatId> {
startMessageLoop();
}
public CompletableFuture<java.lang.Void> getMessageLoop() {
CompletableFuture<java.lang.Void> onDisconnection() {
return this.messageLoop;
}
@ -363,6 +365,12 @@ final class RpcState<VatId> {
return CompletableFuture.completedFuture(null);
}
}
else if (ioExc instanceof CompletionException) {
var compExc = (CompletionException)ioExc;
if (compExc.getCause() instanceof ClosedChannelException) {
return CompletableFuture.completedFuture(null);
}
}
return CompletableFuture.failedFuture(ioExc);
});
@ -371,9 +379,7 @@ final class RpcState<VatId> {
this.disconnectFulfiller.complete(new DisconnectInfo(shutdownPromise));
for (var pipeline: pipelinesToRelease) {
if (pipeline instanceof RpcState<?>.RpcPipeline) {
((RpcPipeline) pipeline).redirectLater.completeExceptionally(networkExc);
}
pipeline.cancel(networkExc);
}
}
@ -1556,8 +1562,8 @@ final class RpcState<VatId> {
}
@Override
public void close() {
this.question.finish();
public void cancel(Throwable exc) {
this.question.reject(exc);
}
}

View file

@ -1,27 +1,21 @@
package org.capnproto;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
public class RpcSystem<VatId extends StructReader> {
private final VatNetwork<VatId> network;
private final BootstrapFactory<VatId> bootstrapFactory;
private final Map<VatNetwork.Connection<VatId>, RpcState<VatId>> connections = new HashMap<>();
private final CompletableFuture<java.lang.Void> messageLoop;
private final CompletableFuture<java.lang.Void> acceptLoop;
private final Map<VatNetwork.Connection<VatId>, RpcState<VatId>> connections = new ConcurrentHashMap<>();
public RpcSystem(VatNetwork<VatId> network) {
this.network = network;
this.bootstrapFactory = null;
this.acceptLoop = new CompletableFuture<>();
this.messageLoop = doMessageLoop();
}
public VatNetwork<VatId> getNetwork() {
return this.network;
this(network, (BootstrapFactory)null);
}
public RpcSystem(VatNetwork<VatId> network,
@ -49,8 +43,7 @@ public class RpcSystem<VatId extends StructReader> {
BootstrapFactory<VatId> bootstrapFactory) {
this.network = network;
this.bootstrapFactory = bootstrapFactory;
this.acceptLoop = doAcceptLoop();
this.messageLoop = doMessageLoop();
this.startAcceptLoop();
}
public Capability.Client bootstrap(VatId vatId) {
@ -68,21 +61,19 @@ public class RpcSystem<VatId extends StructReader> {
}
}
RpcState<VatId> getConnectionState(VatNetwork.Connection<VatId> connection) {
var state = this.connections.get(connection);
if (state == null) {
var onDisconnect = new CompletableFuture<RpcState.DisconnectInfo>()
.whenComplete((info, exc) -> {
this.connections.remove(connection);
try {
connection.close();
} catch (IOException ignored) {
public VatNetwork<VatId> getNetwork() {
return this.network;
}
});
state = new RpcState<>(this.bootstrapFactory, connection, onDisconnect);
this.connections.put(connection, state);
}
RpcState<VatId> getConnectionState(VatNetwork.Connection<VatId> connection) {
var state = this.connections.computeIfAbsent(connection, conn -> {
var onDisconnect = new CompletableFuture<RpcState.DisconnectInfo>();
onDisconnect.thenCompose(info -> {
this.connections.remove(connection);
return info.shutdownPromise.thenRun(() -> connection.close());
});
return new RpcState<>(this.bootstrapFactory, conn, onDisconnect);
});
return state;
}
@ -90,27 +81,10 @@ public class RpcSystem<VatId extends StructReader> {
getConnectionState(connection);
}
private CompletableFuture<java.lang.Void> doAcceptLoop() {
return this.network.baseAccept().thenCompose(connection -> {
this.accept(connection);
return this.doAcceptLoop();
});
}
private CompletableFuture<java.lang.Void> doMessageLoop() {
var accept = this.getAcceptLoop();
for (var conn: this.connections.values()) {
accept = accept.acceptEither(conn.getMessageLoop(), x -> {});
}
return accept.thenCompose(x -> this.doMessageLoop());
}
public CompletableFuture<java.lang.Void> getMessageLoop() {
return this.messageLoop;
}
private CompletableFuture<java.lang.Void> getAcceptLoop() {
return this.acceptLoop;
private void startAcceptLoop() {
this.network.baseAccept()
.thenAccept(this::accept)
.thenRunAsync(this::startAcceptLoop);
}
public static <VatId extends StructReader>

View file

@ -1,6 +1,6 @@
package org.capnproto;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.channels.AsynchronousByteChannel;
import java.util.concurrent.CompletableFuture;
public class TwoPartyClient {
@ -8,15 +8,15 @@ public class TwoPartyClient {
private final TwoPartyVatNetwork network;
private final RpcSystem<RpcTwoPartyProtocol.VatId.Reader> rpcSystem;
public TwoPartyClient(AsynchronousSocketChannel channel) {
public TwoPartyClient(AsynchronousByteChannel channel) {
this(channel, null);
}
public TwoPartyClient(AsynchronousSocketChannel channel, Capability.Client bootstrapInterface) {
public TwoPartyClient(AsynchronousByteChannel channel, Capability.Client bootstrapInterface) {
this(channel, bootstrapInterface, RpcTwoPartyProtocol.Side.CLIENT);
}
public TwoPartyClient(AsynchronousSocketChannel channel,
public TwoPartyClient(AsynchronousByteChannel channel,
Capability.Client bootstrapInterface,
RpcTwoPartyProtocol.Side side) {
this.network = new TwoPartyVatNetwork(channel, side);
@ -31,4 +31,8 @@ public class TwoPartyClient {
: RpcTwoPartyProtocol.Side.CLIENT);
return rpcSystem.bootstrap(vatId.asReader());
}
CompletableFuture<java.lang.Void> onDisconnect() {
return this.network.onDisconnect();
}
}

View file

@ -10,135 +10,80 @@ import java.util.concurrent.CompletableFuture;
public class TwoPartyServer {
private class AcceptedConnection {
final AsynchronousSocketChannel channel;
final AsynchronousSocketChannel connection;
final TwoPartyVatNetwork network;
final RpcSystem<RpcTwoPartyProtocol.VatId.Reader> rpcSystem;
private final CompletableFuture<?> messageLoop;
AcceptedConnection(Capability.Client bootstrapInterface, AsynchronousSocketChannel channel) {
this.channel = channel;
this.network = new TwoPartyVatNetwork(channel, RpcTwoPartyProtocol.Side.SERVER);
AcceptedConnection(Capability.Client bootstrapInterface, AsynchronousSocketChannel connection) {
this.connection = connection;
this.network = new TwoPartyVatNetwork(this.connection, RpcTwoPartyProtocol.Side.SERVER);
this.rpcSystem = new RpcSystem<>(network, bootstrapInterface);
this.messageLoop = this.rpcSystem.getMessageLoop().exceptionally(exc -> {
connections.remove(this);
return null;
});
}
public CompletableFuture<?> getMessageLoop() {
return this.messageLoop;
}
}
class ConnectionReceiver {
AsynchronousServerSocketChannel listener;
final CompletableFuture<?> messageLoop;
public ConnectionReceiver(AsynchronousServerSocketChannel listener) {
final AsynchronousServerSocketChannel listener;
ConnectionReceiver(AsynchronousServerSocketChannel listener) {
this.listener = listener;
this.messageLoop = doMessageLoop();
}
public CompletableFuture<?> getMessageLoop() {
return this.messageLoop;
}
private CompletableFuture<?> doMessageLoop() {
final var accepted = new CompletableFuture<AsynchronousSocketChannel>();
listener.accept(null, new CompletionHandler<>() {
CompletableFuture<AsynchronousSocketChannel> accept() {
CompletableFuture<AsynchronousSocketChannel> result = new CompletableFuture<>();
this.listener.accept(null, new CompletionHandler<>() {
@Override
public void completed(AsynchronousSocketChannel channel, Object attachment) {
accepted.complete(channel);
result.complete(channel);
}
@Override
public void failed(Throwable exc, Object attachment) {
accepted.completeExceptionally(exc);
result.completeExceptionally(exc);
}
});
return accepted.thenCompose(channel -> CompletableFuture.allOf(
accept(channel),
doMessageLoop()));
return result.copy();
}
}
private final Capability.Client bootstrapInterface;
private final List<AcceptedConnection> connections = new ArrayList<>();
private final List<ConnectionReceiver> listeners = new ArrayList<>();
private final CompletableFuture<?> messageLoop;
public TwoPartyServer(Capability.Client bootstrapInterface) {
this.bootstrapInterface = bootstrapInterface;
this.messageLoop = doMessageLoop();
}
public TwoPartyServer(Capability.Server bootstrapServer) {
this(new Capability.Client(bootstrapServer));
}
private CompletableFuture<?> getMessageLoop() {
return this.messageLoop;
}
public CompletableFuture<?> drain() {
CompletableFuture<java.lang.Void> done = new CompletableFuture<>();
for (var conn: this.connections) {
done = CompletableFuture.allOf(done, conn.getMessageLoop());
}
return done;
}
private CompletableFuture<java.lang.Void> accept(AsynchronousSocketChannel channel) {
public void accept(AsynchronousSocketChannel channel) {
var connection = new AcceptedConnection(this.bootstrapInterface, channel);
this.connections.add(connection);
return connection.network.onDisconnect().whenComplete((x, exc) -> {
connection.network.onDisconnect().whenComplete((x, exc) -> {
this.connections.remove(connection);
});
}
/*
private final CompletableFuture<?> acceptLoop(AsynchronousServerSocketChannel listener) {
final var accepted = new CompletableFuture<AsynchronousSocketChannel>();
listener.accept(null, new CompletionHandler<>() {
@Override
public void completed(AsynchronousSocketChannel channel, Object attachment) {
accepted.complete(channel);
public CompletableFuture<java.lang.Void> listen(AsynchronousServerSocketChannel listener) {
return this.listen(wrapListenSocket(listener));
}
@Override
public void failed(Throwable exc, Object attachment) {
accepted.completeExceptionally(exc);
}
CompletableFuture<java.lang.Void> listen(ConnectionReceiver listener) {
return listener.accept().thenCompose(channel -> {
this.accept(channel);
return this.listen(listener);
});
return accepted.thenCompose(channel -> CompletableFuture.anyOf(
accept(channel),
acceptLoop(listener)));
}
*/
public CompletableFuture<?> listen(AsynchronousServerSocketChannel listener) {
var receiver = new ConnectionReceiver(listener);
this.listeners.add(receiver);
return receiver.getMessageLoop();
}
private CompletableFuture<?> doMessageLoop() {
var done = new CompletableFuture<>();
CompletableFuture<java.lang.Void> drain() {
CompletableFuture<java.lang.Void> loop = CompletableFuture.completedFuture(null);
for (var conn: this.connections) {
done = CompletableFuture.anyOf(done, conn.getMessageLoop());
loop = CompletableFuture.allOf(loop, conn.network.onDisconnect());
}
for (var listener: this.listeners) {
done = CompletableFuture.anyOf(done, listener.getMessageLoop());
}
return done.thenCompose(x -> doMessageLoop());
return loop;
}
/*
public CompletableFuture<?> runOnce() {
var done = new CompletableFuture<>();
for (var conn: connections) {
done = CompletableFuture.anyOf(done, conn.runOnce());
ConnectionReceiver wrapListenSocket(AsynchronousServerSocketChannel channel) {
return new ConnectionReceiver(channel);
}
return done;
}
*/
}

View file

@ -37,10 +37,15 @@ public class TwoPartyVatNetwork
}
@Override
public void close() throws IOException {
public void close() {
try {
this.channel.close();
this.disconnectPromise.complete(null);
}
catch (Exception exc) {
this.disconnectPromise.completeExceptionally(exc);
}
}
public RpcTwoPartyProtocol.Side getSide() {
return side;
@ -113,13 +118,13 @@ public class TwoPartyVatNetwork
public CompletableFuture<java.lang.Void> shutdown() {
assert this.previousWrite != null: "Already shut down";
var result = this.previousWrite.thenRun(() -> {
var result = this.previousWrite.whenComplete((void_, exc) -> {
try {
if (this.channel instanceof AsynchronousSocketChannel) {
((AsynchronousSocketChannel)this.channel).shutdownOutput();
}
}
catch (Exception ioExc) {
catch (Exception ignored) {
}
});

View file

@ -13,7 +13,7 @@ public interface VatNetwork<VatId>
CompletableFuture<IncomingRpcMessage> receiveIncomingMessage();
CompletableFuture<java.lang.Void> shutdown();
VatId getPeerVatId();
void close() throws IOException;
void close();
}
CompletableFuture<Connection<VatId>> baseAccept();

View file

@ -4,78 +4,78 @@ import org.capnproto.rpctest.*;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.function.ThrowingRunnable;
import java.io.IOException;
import java.nio.channels.AsynchronousByteChannel;
import java.nio.channels.AsynchronousServerSocketChannel;
import java.nio.channels.AsynchronousSocketChannel;
import java.util.concurrent.ExecutionException;
import java.util.function.Consumer;
public class TwoPartyTest {
private Thread runServer(org.capnproto.TwoPartyVatNetwork network) {
var thread = new Thread(() -> {
static final class PipeThread {
Thread thread;
AsynchronousByteChannel channel;
static PipeThread newPipeThread(Consumer<AsynchronousByteChannel> startFunc) throws Exception {
var pipeThread = new PipeThread();
var serverAcceptSocket = AsynchronousServerSocketChannel.open();
serverAcceptSocket.bind(null);
var clientSocket = AsynchronousSocketChannel.open();
pipeThread.thread = new Thread(() -> {
try {
network.onDisconnect().get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
var serverSocket = serverAcceptSocket.accept().get();
startFunc.accept(serverSocket);
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
}
}, "Server");
});
pipeThread.thread.start();
pipeThread.thread.setName("TwoPartyTest server");
thread.start();
return thread;
clientSocket.connect(serverAcceptSocket.getLocalAddress()).get();
pipeThread.channel = clientSocket;
return pipeThread;
}
}
AsynchronousServerSocketChannel serverAcceptSocket;
AsynchronousSocketChannel serverSocket;
AsynchronousSocketChannel clientSocket;
TwoPartyClient client;
org.capnproto.TwoPartyVatNetwork serverNetwork;
Thread serverThread;
PipeThread runServer(Capability.Server bootstrapInterface) throws Exception {
return runServer(new Capability.Client(bootstrapInterface));
}
PipeThread runServer(Capability.Client bootstrapInterface) throws Exception {
return PipeThread.newPipeThread(channel -> {
var network = new TwoPartyVatNetwork(channel, RpcTwoPartyProtocol.Side.SERVER);
var system = new RpcSystem<>(network, bootstrapInterface);
network.onDisconnect().join();
});
}
@Before
public void setUp() throws Exception {
this.serverAcceptSocket = AsynchronousServerSocketChannel.open();
this.serverAcceptSocket.bind(null);
this.clientSocket = AsynchronousSocketChannel.open();
this.clientSocket.connect(this.serverAcceptSocket.getLocalAddress()).get();
this.client = new TwoPartyClient(clientSocket);
//this.client.getNetwork().setTap(new Tap());
this.serverSocket = serverAcceptSocket.accept().get();
this.serverNetwork = new org.capnproto.TwoPartyVatNetwork(this.serverSocket, RpcTwoPartyProtocol.Side.SERVER);
//this.serverNetwork.setTap(new Tap());
//this.serverNetwork.dumper.addSchema(Demo.TestCap1);
this.serverThread = runServer(this.serverNetwork);
public void setUp() {
}
@After
public void tearDown() throws Exception {
this.clientSocket.close();
this.serverSocket.close();
this.serverAcceptSocket.close();
this.serverThread.join();
this.client = null;
public void tearDown() {
}
@org.junit.Test
public void testNullCap() throws ExecutionException, InterruptedException {
var server = new RpcSystem<>(this.serverNetwork, new Capability.Client());
var cap = this.client.bootstrap();
var resolved = cap.whenResolved();
public void testNullCap() throws Exception {
var pipe = runServer(new Capability.Client());
var rpcClient = new TwoPartyClient(pipe.channel);
var client = rpcClient.bootstrap();
var resolved = client.whenResolved();
resolved.get();
}
@org.junit.Test
public void testBasic() throws InterruptedException, IOException {
public void testBasic() throws Exception {
var callCount = new Counter();
var server = new RpcSystem<>(this.serverNetwork, new RpcTestUtil.TestInterfaceImpl(callCount));
var client = new Test.TestInterface.Client(this.client.bootstrap());
var pipe = runServer(new RpcTestUtil.TestInterfaceImpl(callCount));
var rpcClient = new TwoPartyClient(pipe.channel);
var client = new Test.TestInterface.Client(rpcClient.bootstrap());
var request1 = client.fooRequest();
request1.getParams().setI(123);
request1.getParams().setJ(true);
@ -99,24 +99,22 @@ public class TwoPartyTest {
promise3.join();
Assert.assertEquals(2, callCount.value());
this.clientSocket.shutdownOutput();
serverThread.join();
}
@org.junit.Test
public void testDisconnect() throws IOException {
this.serverSocket.shutdownOutput();
this.serverNetwork.close();
this.serverNetwork.onDisconnect().join();
//this.serverSocket.shutdownOutput();
//this.serverNetwork.close();
//this.serverNetwork.onDisconnect().join();
}
@org.junit.Test
public void testPipelining() throws IOException {
public void testPipelining() throws Exception {
var callCount = new Counter();
var chainedCallCount = new Counter();
var server = new RpcSystem<>(this.serverNetwork, new RpcTestUtil.TestPipelineImpl(callCount));
var client = new Test.TestPipeline.Client(this.client.bootstrap());
var pipe = runServer(new RpcTestUtil.TestPipelineImpl(callCount));
var rpcClient = new TwoPartyClient(pipe.channel);
var client = new Test.TestPipeline.Client(rpcClient.bootstrap());
{
var request = client.getCapRequest();
@ -146,11 +144,9 @@ public class TwoPartyTest {
Assert.assertEquals(1, chainedCallCount.value());
}
/*
// disconnect the server
//this.serverSocket.shutdownOutput();
this.serverNetwork.close();
this.serverNetwork.onDisconnect().join();
// disconnect the client
((AsynchronousSocketChannel)pipe.channel).shutdownOutput();
rpcClient.onDisconnect().join();
{
// Use the now-broken capability.
@ -173,8 +169,11 @@ public class TwoPartyTest {
Assert.assertEquals(3, callCount.value());
Assert.assertEquals(1, chainedCallCount.value());
}
}
@org.junit.Test
public void testAbort() {
*/
}
/*
@org.junit.Test

View file

@ -167,6 +167,11 @@ public final class AnyPointer {
return this;
}
@Override
public void cancel(Throwable exc) {
this.hook.cancel(exc);
}
public Pipeline noop() {
return new Pipeline(this.hook, this.ops.clone());
}

View file

@ -388,11 +388,6 @@ public final class Capability {
public final ClientHook getPipelinedCap(PipelineOp[] ops) {
return this.results.getPipelinedCap(ops);
}
@Override
public void close() {
this.ctx.allowCancellation();
}
}
private static final class LocalResponse implements ResponseHook {
@ -542,7 +537,7 @@ public final class Capability {
: new QueuedClient(this.promise.thenApply(
pipeline -> pipeline.getPipelinedCap(ops)));
}
/*
@Override
public void close() {
if (this.redirect != null) {
@ -552,6 +547,7 @@ public final class Capability {
this.promise.cancel(false);
}
}
*/
}
// A ClientHook which simply queues calls while waiting for a ClientHook to which to forward them.

View file

@ -1,5 +1,10 @@
package org.capnproto;
public interface Pipeline {
AnyPointer.Pipeline typelessPipeline();
default void cancel(Throwable exc) {
this.typelessPipeline().cancel(exc);
}
}

View file

@ -1,14 +1,13 @@
package org.capnproto;
public interface PipelineHook extends AutoCloseable {
public interface PipelineHook {
ClientHook getPipelinedCap(PipelineOp[] ops);
default void cancel(Throwable exc) {
}
static PipelineHook newBrokenPipeline(Throwable exc) {
return ops -> Capability.newBrokenCap(exc);
}
@Override
default void close() {
}
}

View file

@ -27,8 +27,8 @@ public class RemotePromise<Results>
}
@Override
public void close() throws Exception {
this.pipeline.hook.close();
public void close() {
this.pipeline.cancel(RpcException.failed("Cancelled"));
this.join();
}