diff --git a/runtime-rpc/src/test/java/org/capnproto/RpcTest.java b/runtime-rpc/src/test/java/org/capnproto/RpcTest.java index 26f96a0..4d3987e 100644 --- a/runtime-rpc/src/test/java/org/capnproto/RpcTest.java +++ b/runtime-rpc/src/test/java/org/capnproto/RpcTest.java @@ -53,12 +53,8 @@ public class RpcTest { } static final class TestNetworkAdapter - implements VatNetwork { - - @Override - public CompletableFuture> baseAccept() { - return this.accept().thenApply(conn -> conn); - } + implements VatNetwork, + AutoCloseable { class Connection implements VatNetwork.Connection { @@ -82,6 +78,14 @@ public class RpcTest { other.partner = this; } + void disconnect(Exception exc) { + while (!fulfillers.isEmpty()) { + fulfillers.remove().completeExceptionally(exc); + } + + this.networkException = exc; + } + TestNetwork getNetwork() { return network; } @@ -170,10 +174,6 @@ public class RpcTest { @Override public void close() { - var msg = newOutgoingMessage(0); - var abort = msg.getBody().initAs(RpcProtocol.Message.factory).initAbort(); - FromException(RpcException.disconnected(""), abort); - msg.send(); } } @@ -194,6 +194,18 @@ public class RpcTest { return new Connection(isClient, peerId); } + public CompletableFuture> baseAccept() { + return this.accept().thenApply(conn -> conn); + } + + @Override + public void close() { + var exc = RpcException.failed("Network was destroyed"); + for (var conn: this.connections.values()) { + conn.disconnect(exc); + } + } + @Override public VatNetwork.Connection connect(Test.TestSturdyRef.Reader refId) { var hostId = refId.getHostId().getHost().toString(); diff --git a/runtime-rpc/src/test/java/org/capnproto/TwoPartyTest.java b/runtime-rpc/src/test/java/org/capnproto/TwoPartyTest.java index fa92cd5..03fad39 100644 --- a/runtime-rpc/src/test/java/org/capnproto/TwoPartyTest.java +++ b/runtime-rpc/src/test/java/org/capnproto/TwoPartyTest.java @@ -12,6 +12,7 @@ import java.nio.channels.AsynchronousSocketChannel; import java.util.concurrent.ExecutionException; import java.util.function.Consumer; +@SuppressWarnings({"OverlyCoupledMethod", "OverlyLongMethod"}) public class TwoPartyTest { static final class PipeThread { @@ -163,8 +164,8 @@ public class TwoPartyTest { var pipelineRequest2 = new Test.TestExtends.Client(promise.getOutBox().getCap()).graultRequest(); var pipelinePromise2 = pipelineRequest2.send(); - Assert.assertThrows(Exception.class, () -> pipelinePromise.join()); - Assert.assertThrows(Exception.class, () -> pipelinePromise2.join()); + Assert.assertThrows(Exception.class, pipelinePromise::join); + Assert.assertThrows(Exception.class, pipelinePromise2::join); Assert.assertEquals(3, callCount.value()); Assert.assertEquals(1, chainedCallCount.value());