diff --git a/runtime-rpc/src/main/java/org/capnproto/TwoPartyVatNetwork.java b/runtime-rpc/src/main/java/org/capnproto/TwoPartyVatNetwork.java index 7929286..b8a2c12 100644 --- a/runtime-rpc/src/main/java/org/capnproto/TwoPartyVatNetwork.java +++ b/runtime-rpc/src/main/java/org/capnproto/TwoPartyVatNetwork.java @@ -1,6 +1,7 @@ package org.capnproto; import java.io.IOException; +import java.nio.channels.AsynchronousByteChannel; import java.nio.channels.AsynchronousSocketChannel; import java.util.List; import java.util.concurrent.CompletableFuture; @@ -20,13 +21,13 @@ public class TwoPartyVatNetwork private CompletableFuture previousWrite = CompletableFuture.completedFuture(null); private final CompletableFuture disconnectPromise = new CompletableFuture<>(); - private final AsynchronousSocketChannel channel; + private final AsynchronousByteChannel channel; private final RpcTwoPartyProtocol.Side side; private final MessageBuilder peerVatId = new MessageBuilder(4); private boolean accepted; private MessageTap tap; - public TwoPartyVatNetwork(AsynchronousSocketChannel channel, RpcTwoPartyProtocol.Side side) { + public TwoPartyVatNetwork(AsynchronousByteChannel channel, RpcTwoPartyProtocol.Side side) { this.channel = channel; this.side = side; this.peerVatId.initRoot(RpcTwoPartyProtocol.VatId.factory).setSide( @@ -114,7 +115,9 @@ public class TwoPartyVatNetwork var result = this.previousWrite.thenRun(() -> { try { - this.channel.shutdownOutput(); + if (this.channel instanceof AsynchronousSocketChannel) { + ((AsynchronousSocketChannel)this.channel).shutdownOutput(); + } } catch (Exception ioExc) { }