diff --git a/runtime/src/main/java/org/capnproto/Capability.java b/runtime/src/main/java/org/capnproto/Capability.java index 05338c3..0718748 100644 --- a/runtime/src/main/java/org/capnproto/Capability.java +++ b/runtime/src/main/java/org/capnproto/Capability.java @@ -1,6 +1,8 @@ package org.capnproto; -public class Capability { +import java.util.concurrent.CompletableFuture; + +public final class Capability { public static class Client { @@ -9,7 +11,105 @@ public class Capability { public Client(ClientHook hook) { this.hook = hook; } - } + static ClientHook newLocalPromiseClient(CompletableFuture promise) { + return new QueuedClient(promise); + } + + static class LocalRequest implements RequestHook { + + final MessageBuilder message = new MessageBuilder(); + final long interfaceId; + final short methodId; + ClientHook client; + + LocalRequest(long interfaceId, short methodId, ClientHook client) { + this.interfaceId = interfaceId; + this.methodId = methodId; + this.client = client; + } + + @Override + public RemotePromise send() { + var cancelPaf = new CompletableFuture(); + var context = new LocalCallContext(message, client, cancelPaf); + var promiseAndPipeline = client.call(interfaceId, methodId, context); + var promise = promiseAndPipeline.promise.thenApply(x -> { + context.getResults(); // force allocation + return context.response; + }); + + return new RemotePromise(promise, promiseAndPipeline.pipeline); + } + + @Override + public Object getBrand() { + return null; + } + } + + static class LocalResponse implements ResponseHook { + final MessageBuilder message = new MessageBuilder(); + } + + static class LocalCallContext implements CallContextHook { + + final CompletableFuture cancelAllowed; + MessageBuilder request; + Response response; + AnyPointer.Builder responseBuilder; + ClientHook clientRef; + + LocalCallContext(MessageBuilder request, + ClientHook clientRef, + CompletableFuture cancelAllowed) { + this.request = request; + this.clientRef = clientRef; + this.cancelAllowed = cancelAllowed; + } + + @Override + public AnyPointer.Reader getParams() { + return request.getRoot(AnyPointer.factory).asReader(); + } + + @Override + public void releaseParams() { + this.request = null; + } + + @Override + public AnyPointer.Builder getResults() { + if (this.response == null) { + var localResponse = new LocalResponse(); + this.responseBuilder = localResponse.message.getRoot(AnyPointer.factory); + this.response = new Response(this.responseBuilder.asReader(), localResponse); + } + return this.responseBuilder; + } + + @Override + public void allowCancellation() { + this.cancelAllowed.complete(null); + } + + @Override + public CompletableFuture tailCall(RequestHook request) { + // TODO implement tailCall + return null; + } + + @Override + public CompletableFuture onTailCall() { + // TODO implement onTailCall + return null; + } + + @Override + public ClientHook.VoidPromiseAndPipeline directTailCall(RequestHook request) { + // TODO implement directTailCall + return null; + } + } } diff --git a/runtime/src/main/java/org/capnproto/QueuedClient.java b/runtime/src/main/java/org/capnproto/QueuedClient.java new file mode 100644 index 0000000..886de67 --- /dev/null +++ b/runtime/src/main/java/org/capnproto/QueuedClient.java @@ -0,0 +1,47 @@ +package org.capnproto; + +import java.util.concurrent.CompletableFuture; + +class QueuedClient implements ClientHook { + + final CompletableFuture promise; + final CompletableFuture promiseForCallForwarding; + final CompletableFuture promiseForClientResolution; + final CompletableFuture setResolutionOp; + ClientHook redirect; + + QueuedClient(CompletableFuture promise) { + // TODO revisit futures + this.promise = promise.copy(); + this.promiseForCallForwarding = promise.copy(); + this.promiseForClientResolution = promise.copy(); + this.setResolutionOp = promise.thenAccept(inner -> { + this.redirect = inner; + }).exceptionally(exc -> { + this.redirect = ClientHook.newBrokenCap(exc); + return null; + }); + } + + @Override + public Request newCall(long interfaceId, short methodId) { + var hook = new Capability.LocalRequest(interfaceId, methodId, this); + var root = hook.message.getRoot(AnyPointer.factory); + return new Request<>(root, hook); + } + + @Override + public VoidPromiseAndPipeline call(long interfaceId, short methodId, CallContextHook ctx) { + return null; + } + + @Override + public ClientHook getResolved() { + return redirect; + } + + @Override + public CompletableFuture whenMoreResolved() { + return promiseForClientResolution; + } +} diff --git a/runtime/src/main/java/org/capnproto/RpcState.java b/runtime/src/main/java/org/capnproto/RpcState.java index 6b2cfbe..1e55de8 100644 --- a/runtime/src/main/java/org/capnproto/RpcState.java +++ b/runtime/src/main/java/org/capnproto/RpcState.java @@ -81,7 +81,7 @@ final class RpcState { return this.disconnected != null; } - void handleMessage(IncomingRpcMessage message) { + void handleMessage(IncomingRpcMessage message) throws RpcException { var reader = message.getBody().getAs(RpcProtocol.Message.factory); switch (reader.which()) { @@ -110,7 +110,12 @@ final class RpcState { handleDisembargo(reader.getDisembargo()); break; default: - // TODO send unimplemented response + if (!isDisconnected()) { + // boomin' back atcha + var msg = connection.newOutgoingMessage(1024); + msg.getBody().initAs(RpcProtocol.Message.factory).setUnimplemented(reader); + msg.send(); + } break; } } @@ -150,7 +155,8 @@ final class RpcState { } } - void handleAbort(RpcProtocol.Exception.Reader abort) { + void handleAbort(RpcProtocol.Exception.Reader abort) throws RpcException { + throw RpcException.toException(abort); } void handleBootstrap(IncomingRpcMessage message, RpcProtocol.Bootstrap.Reader bootstrap) { @@ -338,7 +344,6 @@ final class RpcState { }); } - void releaseExport(int exportId, int refcount) { var export = exports.find(exportId); assert export != null; @@ -356,6 +361,112 @@ final class RpcState { } } + private List receiveCaps(StructList.Reader capTable, List fds) { + var result = new ArrayList(); + for (var cap: capTable) { + result.add(receiveCap(cap, fds)); + } + return result; + } + + private ClientHook receiveCap(RpcProtocol.CapDescriptor.Reader descriptor, List fds) { + // TODO AutoCloseFd + Integer fd = null; + + int fdIndex = descriptor.getAttachedFd(); + if (fdIndex >= 0 && fdIndex < fds.size()) { + fd = fds.get(fdIndex); + if (fd != null) { + fds.set(fdIndex, null); + } + } + + switch (descriptor.which()) { + case NONE: + return null; + + case SENDER_HOSTED: + return importCap(descriptor.getSenderHosted(), false, fd); + + case SENDER_PROMISE: + return importCap(descriptor.getSenderPromise(), true, fd); + + case RECEIVER_HOSTED: + var exp = exports.find(descriptor.getReceiverHosted()); + if (exp == null) { + return ClientHook.newBrokenCap("invalid 'receiverHosted' export ID"); + } + if (exp.clientHook.getBrand() == this) { + // TODO Tribble 4-way race! + return exp.clientHook; + } + + return exp.clientHook; + + case RECEIVER_ANSWER: + var promisedAnswer = descriptor.getReceiverAnswer(); + var answer = answers.find(promisedAnswer.getQuestionId()); + var ops = PipelineOp.ToPipelineOps(promisedAnswer); + + if (answer == null || !answer.active || answer.pipeline == null || ops == null) { + return ClientHook.newBrokenCap("invalid 'receiverAnswer'"); + } + + var result = answer.pipeline.getPipelinedCap(ops); + if (result == null) { + return ClientHook.newBrokenCap("Unrecognised pipeline ops"); + } + + if (result.getBrand() == this) { + // TODO Tribble 4-way race! + return result; + } + + return result; + + case THIRD_PARTY_HOSTED: + return ClientHook.newBrokenCap("Third party caps not supported"); + + default: + return ClientHook.newBrokenCap("unknown CapDescriptor type"); + } + } + + + private ClientHook importCap(int importId, boolean isPromise, Integer fd) { + // Receive a new import. + + var imp = imports.put(importId); + + if (imp.importClient == null) { + imp.importClient = new ImportClient(importId, fd); + } + else { + imp.importClient.setFdIfMissing(fd); + } + imp.importClient.addRemoteRef(); + + if (!isPromise) { + imp.appClient = imp.importClient; + return imp.importClient; + } + + if (imp.appClient != null) { + return imp.appClient; + } + + imp.promise = new CompletableFuture(); + var result = new PromiseClient(imp.importClient, imp.promise, importId); + imp.appClient = result; + return result; + } + + ClientHook writeTarget(ClientHook cap, RpcProtocol.MessageTarget.Builder target) { + return cap.getBrand() == this + ? ((RpcClient)cap).writeTarget(target) + : cap; + } + ClientHook getInnermostClient(ClientHook client) { for (;;) { var inner = client.getResolved(); @@ -465,16 +576,194 @@ final class RpcState { } class ImportClient extends RpcClient { + + final int importId; + int remoteRefCount = 0; + Integer fd; + + ImportClient(int importId, Integer fd) { + this.importId = importId; + this.fd = fd; + } + + void addRemoteRef() { + this.remoteRefCount++; + } + + void setFdIfMissing(Integer fd) { + if (this.fd == null) { + this.fd = fd; + } + } + + public void dispose() { + // TODO manage destruction... + var imp = imports.find(importId); + if (imp != null) { + if (imp.importClient == this) { + imports.erase(importId, imp); + } + } + + if (remoteRefCount > 0 && !isDisconnected()) { + var message = connection.newOutgoingMessage(1024); + var builder = message.getBody().initAs(RpcProtocol.Message.factory).initRelease(); + builder.setId(importId); + builder.setReferenceCount(remoteRefCount); + message.send(); + } + } + @Override public Integer writeDescriptor(RpcProtocol.CapDescriptor.Builder descriptor, List fds) { + descriptor.setReceiverHosted(importId); return null; } @Override public ClientHook writeTarget(RpcProtocol.MessageTarget.Builder target) { + target.setImportedCap(importId); + return null; + } + + @Override + public VoidPromiseAndPipeline call(long interfaceId, short methodId, CallContextHook context) { + return null; + } + + @Override + public CompletableFuture whenMoreResolved() { return null; } } + enum ResolutionType { + UNRESOLVED, + REMOTE, + REFLECTED, + MERGED, + BROKEN + } + + class PromiseClient extends RpcClient { + final ClientHook cap; + final Integer importId; + final CompletableFuture promise; + boolean receivedCall = false; + ResolutionType resolutionType = ResolutionType.UNRESOLVED; + + public PromiseClient(RpcClient initial, + CompletableFuture eventual, + Integer importId) { + this.cap = initial; + this.importId = importId; + this.promise = eventual.thenApply(resolution -> { + return resolve(resolution); + }); + } + + public boolean isResolved() { + return resolutionType != ResolutionType.UNRESOLVED; + } + + private ClientHook resolve(ClientHook replacement) { + assert !isResolved(); + + var replacementBrand = replacement.getBrand(); + boolean isSameConnection = replacementBrand == RpcState.this; + if (isSameConnection) { + var promise = replacement.whenMoreResolved(); + if (promise != null) { + var other = (PromiseClient)replacement; + while (other.resolutionType == ResolutionType.MERGED) { + replacement = other.cap; + other = (PromiseClient)replacement; + assert replacement.getBrand() == replacementBrand; + } + + if (other.isResolved()) { + resolutionType = other.resolutionType; + } + else { + other.receivedCall = other.receivedCall || receivedCall; + resolutionType = ResolutionType.MERGED; + } + } + else { + resolutionType = ResolutionType.REMOTE; + } + } + else { + if (replacementBrand == NULL_CAPABILITY_BRAND || + replacementBrand == BROKEN_CAPABILITY_BRAND) { + resolutionType = ResolutionType.BROKEN; + } + else { + resolutionType = ResolutionType.REFLECTED; + } + } + + assert isResolved(); + + // TODO Flow control + + if (resolutionType == ResolutionType.REFLECTED && receivedCall && !isDisconnected()) { + var message = connection.newOutgoingMessage(1024); + var disembargo = message.getBody().initAs(RpcProtocol.Message.factory).initDisembargo(); + { + var redirect = RpcState.this.writeTarget(cap, disembargo.initTarget()); + assert redirect == null; + } + + var embargo = new Embargo(); + var embargoId = embargos.next(embargo); + disembargo.getContext().setSenderLoopback(embargoId); + + embargo.fulfiller = new CompletableFuture<>(); + + final ClientHook finalReplacement = replacement; + var embargoPromise = embargo.fulfiller.thenApply(x -> { + return finalReplacement; + }); + + replacement = Capability.newLocalPromiseClient(embargoPromise); + message.send(); + + } + return replacement; + } + + ClientHook writeTarget(ClientHook cap, RpcProtocol.MessageTarget.Builder target) { + if (cap.getBrand() == this) { + return ((RpcClient)cap).writeTarget(target); + } + else { + return cap; + } + } + + @Override + public Integer writeDescriptor(RpcProtocol.CapDescriptor.Builder target, List fds) { + receivedCall = true; + return RpcState.this.writeDescriptor(cap, target, fds); + } + + @Override + public ClientHook writeTarget(RpcProtocol.MessageTarget.Builder target) { + receivedCall = true; + return RpcState.this.writeTarget(cap, target); + } + + @Override + public ClientHook getInnermostClient() { + receivedCall = true; + return RpcState.this.getInnermostClient(cap); + } + + @Override + public VoidPromiseAndPipeline call(long interfaceId, short methodId, CallContextHook context) { + return null; + } + } } diff --git a/runtime/src/test/java/org/capnproto/RpcStateTest.java b/runtime/src/test/java/org/capnproto/RpcStateTest.java index 1d0e64e..43faab3 100644 --- a/runtime/src/test/java/org/capnproto/RpcStateTest.java +++ b/runtime/src/test/java/org/capnproto/RpcStateTest.java @@ -9,8 +9,6 @@ import java.util.ArrayDeque; import java.util.Queue; import java.util.concurrent.CompletableFuture; -import static org.junit.Assert.*; - public class RpcStateTest { class TestMessage implements IncomingRpcMessage { @@ -73,7 +71,7 @@ public class RpcStateTest { } @Test - public void handleUnimplemented() { + public void handleUnimplemented() throws RpcException { var msg = new TestMessage(); msg.builder.getRoot(RpcProtocol.Message.factory).initUnimplemented(); rpc.handleMessage(msg); @@ -82,10 +80,14 @@ public class RpcStateTest { @Test public void handleAbort() { + var msg = new TestMessage(); + var builder = msg.builder.getRoot(RpcProtocol.Message.factory); + RpcException.fromException(RpcException.failed("Test abort"), builder.initAbort()); + Assert.assertThrows(RpcException.class, () -> rpc.handleMessage(msg)); } @Test - public void handleBootstrap() { + public void handleBootstrap() throws RpcException { var msg = new TestMessage(); var bootstrap = msg.builder.getRoot(RpcProtocol.Message.factory).initBootstrap(); bootstrap.setQuestionId(0);