implement local, queued and promised hooks
This commit is contained in:
parent
37fe39bcde
commit
15b83a9c05
4 changed files with 448 additions and 10 deletions
|
@ -1,6 +1,8 @@
|
|||
package org.capnproto;
|
||||
|
||||
public class Capability {
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
public final class Capability {
|
||||
|
||||
public static class Client {
|
||||
|
||||
|
@ -9,7 +11,105 @@ public class Capability {
|
|||
public Client(ClientHook hook) {
|
||||
this.hook = hook;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static ClientHook newLocalPromiseClient(CompletableFuture<ClientHook> promise) {
|
||||
return new QueuedClient(promise);
|
||||
}
|
||||
|
||||
static class LocalRequest implements RequestHook {
|
||||
|
||||
final MessageBuilder message = new MessageBuilder();
|
||||
final long interfaceId;
|
||||
final short methodId;
|
||||
ClientHook client;
|
||||
|
||||
LocalRequest(long interfaceId, short methodId, ClientHook client) {
|
||||
this.interfaceId = interfaceId;
|
||||
this.methodId = methodId;
|
||||
this.client = client;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RemotePromise<AnyPointer.Reader> send() {
|
||||
var cancelPaf = new CompletableFuture<java.lang.Void>();
|
||||
var context = new LocalCallContext(message, client, cancelPaf);
|
||||
var promiseAndPipeline = client.call(interfaceId, methodId, context);
|
||||
var promise = promiseAndPipeline.promise.thenApply(x -> {
|
||||
context.getResults(); // force allocation
|
||||
return context.response;
|
||||
});
|
||||
|
||||
return new RemotePromise<AnyPointer.Reader>(promise, promiseAndPipeline.pipeline);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getBrand() {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
static class LocalResponse implements ResponseHook {
|
||||
final MessageBuilder message = new MessageBuilder();
|
||||
}
|
||||
|
||||
static class LocalCallContext implements CallContextHook {
|
||||
|
||||
final CompletableFuture<?> cancelAllowed;
|
||||
MessageBuilder request;
|
||||
Response response;
|
||||
AnyPointer.Builder responseBuilder;
|
||||
ClientHook clientRef;
|
||||
|
||||
LocalCallContext(MessageBuilder request,
|
||||
ClientHook clientRef,
|
||||
CompletableFuture<?> cancelAllowed) {
|
||||
this.request = request;
|
||||
this.clientRef = clientRef;
|
||||
this.cancelAllowed = cancelAllowed;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AnyPointer.Reader getParams() {
|
||||
return request.getRoot(AnyPointer.factory).asReader();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void releaseParams() {
|
||||
this.request = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AnyPointer.Builder getResults() {
|
||||
if (this.response == null) {
|
||||
var localResponse = new LocalResponse();
|
||||
this.responseBuilder = localResponse.message.getRoot(AnyPointer.factory);
|
||||
this.response = new Response(this.responseBuilder.asReader(), localResponse);
|
||||
}
|
||||
return this.responseBuilder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void allowCancellation() {
|
||||
this.cancelAllowed.complete(null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletableFuture<java.lang.Void> tailCall(RequestHook request) {
|
||||
// TODO implement tailCall
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletableFuture<PipelineHook> onTailCall() {
|
||||
// TODO implement onTailCall
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ClientHook.VoidPromiseAndPipeline directTailCall(RequestHook request) {
|
||||
// TODO implement directTailCall
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
47
runtime/src/main/java/org/capnproto/QueuedClient.java
Normal file
47
runtime/src/main/java/org/capnproto/QueuedClient.java
Normal file
|
@ -0,0 +1,47 @@
|
|||
package org.capnproto;
|
||||
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
class QueuedClient implements ClientHook {
|
||||
|
||||
final CompletableFuture<ClientHook> promise;
|
||||
final CompletableFuture<ClientHook> promiseForCallForwarding;
|
||||
final CompletableFuture<ClientHook> promiseForClientResolution;
|
||||
final CompletableFuture<java.lang.Void> setResolutionOp;
|
||||
ClientHook redirect;
|
||||
|
||||
QueuedClient(CompletableFuture<ClientHook> promise) {
|
||||
// TODO revisit futures
|
||||
this.promise = promise.copy();
|
||||
this.promiseForCallForwarding = promise.copy();
|
||||
this.promiseForClientResolution = promise.copy();
|
||||
this.setResolutionOp = promise.thenAccept(inner -> {
|
||||
this.redirect = inner;
|
||||
}).exceptionally(exc -> {
|
||||
this.redirect = ClientHook.newBrokenCap(exc);
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public Request<AnyPointer.Builder, AnyPointer.Reader> newCall(long interfaceId, short methodId) {
|
||||
var hook = new Capability.LocalRequest(interfaceId, methodId, this);
|
||||
var root = hook.message.getRoot(AnyPointer.factory);
|
||||
return new Request<>(root, hook);
|
||||
}
|
||||
|
||||
@Override
|
||||
public VoidPromiseAndPipeline call(long interfaceId, short methodId, CallContextHook ctx) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ClientHook getResolved() {
|
||||
return redirect;
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletableFuture<ClientHook> whenMoreResolved() {
|
||||
return promiseForClientResolution;
|
||||
}
|
||||
}
|
|
@ -81,7 +81,7 @@ final class RpcState {
|
|||
return this.disconnected != null;
|
||||
}
|
||||
|
||||
void handleMessage(IncomingRpcMessage message) {
|
||||
void handleMessage(IncomingRpcMessage message) throws RpcException {
|
||||
var reader = message.getBody().getAs(RpcProtocol.Message.factory);
|
||||
|
||||
switch (reader.which()) {
|
||||
|
@ -110,7 +110,12 @@ final class RpcState {
|
|||
handleDisembargo(reader.getDisembargo());
|
||||
break;
|
||||
default:
|
||||
// TODO send unimplemented response
|
||||
if (!isDisconnected()) {
|
||||
// boomin' back atcha
|
||||
var msg = connection.newOutgoingMessage(1024);
|
||||
msg.getBody().initAs(RpcProtocol.Message.factory).setUnimplemented(reader);
|
||||
msg.send();
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -150,7 +155,8 @@ final class RpcState {
|
|||
}
|
||||
}
|
||||
|
||||
void handleAbort(RpcProtocol.Exception.Reader abort) {
|
||||
void handleAbort(RpcProtocol.Exception.Reader abort) throws RpcException {
|
||||
throw RpcException.toException(abort);
|
||||
}
|
||||
|
||||
void handleBootstrap(IncomingRpcMessage message, RpcProtocol.Bootstrap.Reader bootstrap) {
|
||||
|
@ -338,7 +344,6 @@ final class RpcState {
|
|||
});
|
||||
}
|
||||
|
||||
|
||||
void releaseExport(int exportId, int refcount) {
|
||||
var export = exports.find(exportId);
|
||||
assert export != null;
|
||||
|
@ -356,6 +361,112 @@ final class RpcState {
|
|||
}
|
||||
}
|
||||
|
||||
private List<ClientHook> receiveCaps(StructList.Reader<RpcProtocol.CapDescriptor.Reader> capTable, List<Integer> fds) {
|
||||
var result = new ArrayList<ClientHook>();
|
||||
for (var cap: capTable) {
|
||||
result.add(receiveCap(cap, fds));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private ClientHook receiveCap(RpcProtocol.CapDescriptor.Reader descriptor, List<Integer> fds) {
|
||||
// TODO AutoCloseFd
|
||||
Integer fd = null;
|
||||
|
||||
int fdIndex = descriptor.getAttachedFd();
|
||||
if (fdIndex >= 0 && fdIndex < fds.size()) {
|
||||
fd = fds.get(fdIndex);
|
||||
if (fd != null) {
|
||||
fds.set(fdIndex, null);
|
||||
}
|
||||
}
|
||||
|
||||
switch (descriptor.which()) {
|
||||
case NONE:
|
||||
return null;
|
||||
|
||||
case SENDER_HOSTED:
|
||||
return importCap(descriptor.getSenderHosted(), false, fd);
|
||||
|
||||
case SENDER_PROMISE:
|
||||
return importCap(descriptor.getSenderPromise(), true, fd);
|
||||
|
||||
case RECEIVER_HOSTED:
|
||||
var exp = exports.find(descriptor.getReceiverHosted());
|
||||
if (exp == null) {
|
||||
return ClientHook.newBrokenCap("invalid 'receiverHosted' export ID");
|
||||
}
|
||||
if (exp.clientHook.getBrand() == this) {
|
||||
// TODO Tribble 4-way race!
|
||||
return exp.clientHook;
|
||||
}
|
||||
|
||||
return exp.clientHook;
|
||||
|
||||
case RECEIVER_ANSWER:
|
||||
var promisedAnswer = descriptor.getReceiverAnswer();
|
||||
var answer = answers.find(promisedAnswer.getQuestionId());
|
||||
var ops = PipelineOp.ToPipelineOps(promisedAnswer);
|
||||
|
||||
if (answer == null || !answer.active || answer.pipeline == null || ops == null) {
|
||||
return ClientHook.newBrokenCap("invalid 'receiverAnswer'");
|
||||
}
|
||||
|
||||
var result = answer.pipeline.getPipelinedCap(ops);
|
||||
if (result == null) {
|
||||
return ClientHook.newBrokenCap("Unrecognised pipeline ops");
|
||||
}
|
||||
|
||||
if (result.getBrand() == this) {
|
||||
// TODO Tribble 4-way race!
|
||||
return result;
|
||||
}
|
||||
|
||||
return result;
|
||||
|
||||
case THIRD_PARTY_HOSTED:
|
||||
return ClientHook.newBrokenCap("Third party caps not supported");
|
||||
|
||||
default:
|
||||
return ClientHook.newBrokenCap("unknown CapDescriptor type");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private ClientHook importCap(int importId, boolean isPromise, Integer fd) {
|
||||
// Receive a new import.
|
||||
|
||||
var imp = imports.put(importId);
|
||||
|
||||
if (imp.importClient == null) {
|
||||
imp.importClient = new ImportClient(importId, fd);
|
||||
}
|
||||
else {
|
||||
imp.importClient.setFdIfMissing(fd);
|
||||
}
|
||||
imp.importClient.addRemoteRef();
|
||||
|
||||
if (!isPromise) {
|
||||
imp.appClient = imp.importClient;
|
||||
return imp.importClient;
|
||||
}
|
||||
|
||||
if (imp.appClient != null) {
|
||||
return imp.appClient;
|
||||
}
|
||||
|
||||
imp.promise = new CompletableFuture<ClientHook>();
|
||||
var result = new PromiseClient(imp.importClient, imp.promise, importId);
|
||||
imp.appClient = result;
|
||||
return result;
|
||||
}
|
||||
|
||||
ClientHook writeTarget(ClientHook cap, RpcProtocol.MessageTarget.Builder target) {
|
||||
return cap.getBrand() == this
|
||||
? ((RpcClient)cap).writeTarget(target)
|
||||
: cap;
|
||||
}
|
||||
|
||||
ClientHook getInnermostClient(ClientHook client) {
|
||||
for (;;) {
|
||||
var inner = client.getResolved();
|
||||
|
@ -465,16 +576,194 @@ final class RpcState {
|
|||
}
|
||||
|
||||
class ImportClient extends RpcClient {
|
||||
|
||||
final int importId;
|
||||
int remoteRefCount = 0;
|
||||
Integer fd;
|
||||
|
||||
ImportClient(int importId, Integer fd) {
|
||||
this.importId = importId;
|
||||
this.fd = fd;
|
||||
}
|
||||
|
||||
void addRemoteRef() {
|
||||
this.remoteRefCount++;
|
||||
}
|
||||
|
||||
void setFdIfMissing(Integer fd) {
|
||||
if (this.fd == null) {
|
||||
this.fd = fd;
|
||||
}
|
||||
}
|
||||
|
||||
public void dispose() {
|
||||
// TODO manage destruction...
|
||||
var imp = imports.find(importId);
|
||||
if (imp != null) {
|
||||
if (imp.importClient == this) {
|
||||
imports.erase(importId, imp);
|
||||
}
|
||||
}
|
||||
|
||||
if (remoteRefCount > 0 && !isDisconnected()) {
|
||||
var message = connection.newOutgoingMessage(1024);
|
||||
var builder = message.getBody().initAs(RpcProtocol.Message.factory).initRelease();
|
||||
builder.setId(importId);
|
||||
builder.setReferenceCount(remoteRefCount);
|
||||
message.send();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer writeDescriptor(RpcProtocol.CapDescriptor.Builder descriptor, List<Integer> fds) {
|
||||
descriptor.setReceiverHosted(importId);
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ClientHook writeTarget(RpcProtocol.MessageTarget.Builder target) {
|
||||
target.setImportedCap(importId);
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VoidPromiseAndPipeline call(long interfaceId, short methodId, CallContextHook context) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletableFuture<ClientHook> whenMoreResolved() {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
enum ResolutionType {
|
||||
UNRESOLVED,
|
||||
REMOTE,
|
||||
REFLECTED,
|
||||
MERGED,
|
||||
BROKEN
|
||||
}
|
||||
|
||||
class PromiseClient extends RpcClient {
|
||||
final ClientHook cap;
|
||||
final Integer importId;
|
||||
final CompletableFuture<ClientHook> promise;
|
||||
boolean receivedCall = false;
|
||||
ResolutionType resolutionType = ResolutionType.UNRESOLVED;
|
||||
|
||||
public PromiseClient(RpcClient initial,
|
||||
CompletableFuture<ClientHook> eventual,
|
||||
Integer importId) {
|
||||
this.cap = initial;
|
||||
this.importId = importId;
|
||||
this.promise = eventual.thenApply(resolution -> {
|
||||
return resolve(resolution);
|
||||
});
|
||||
}
|
||||
|
||||
public boolean isResolved() {
|
||||
return resolutionType != ResolutionType.UNRESOLVED;
|
||||
}
|
||||
|
||||
private ClientHook resolve(ClientHook replacement) {
|
||||
assert !isResolved();
|
||||
|
||||
var replacementBrand = replacement.getBrand();
|
||||
boolean isSameConnection = replacementBrand == RpcState.this;
|
||||
if (isSameConnection) {
|
||||
var promise = replacement.whenMoreResolved();
|
||||
if (promise != null) {
|
||||
var other = (PromiseClient)replacement;
|
||||
while (other.resolutionType == ResolutionType.MERGED) {
|
||||
replacement = other.cap;
|
||||
other = (PromiseClient)replacement;
|
||||
assert replacement.getBrand() == replacementBrand;
|
||||
}
|
||||
|
||||
if (other.isResolved()) {
|
||||
resolutionType = other.resolutionType;
|
||||
}
|
||||
else {
|
||||
other.receivedCall = other.receivedCall || receivedCall;
|
||||
resolutionType = ResolutionType.MERGED;
|
||||
}
|
||||
}
|
||||
else {
|
||||
resolutionType = ResolutionType.REMOTE;
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (replacementBrand == NULL_CAPABILITY_BRAND ||
|
||||
replacementBrand == BROKEN_CAPABILITY_BRAND) {
|
||||
resolutionType = ResolutionType.BROKEN;
|
||||
}
|
||||
else {
|
||||
resolutionType = ResolutionType.REFLECTED;
|
||||
}
|
||||
}
|
||||
|
||||
assert isResolved();
|
||||
|
||||
// TODO Flow control
|
||||
|
||||
if (resolutionType == ResolutionType.REFLECTED && receivedCall && !isDisconnected()) {
|
||||
var message = connection.newOutgoingMessage(1024);
|
||||
var disembargo = message.getBody().initAs(RpcProtocol.Message.factory).initDisembargo();
|
||||
{
|
||||
var redirect = RpcState.this.writeTarget(cap, disembargo.initTarget());
|
||||
assert redirect == null;
|
||||
}
|
||||
|
||||
var embargo = new Embargo();
|
||||
var embargoId = embargos.next(embargo);
|
||||
disembargo.getContext().setSenderLoopback(embargoId);
|
||||
|
||||
embargo.fulfiller = new CompletableFuture<>();
|
||||
|
||||
final ClientHook finalReplacement = replacement;
|
||||
var embargoPromise = embargo.fulfiller.thenApply(x -> {
|
||||
return finalReplacement;
|
||||
});
|
||||
|
||||
replacement = Capability.newLocalPromiseClient(embargoPromise);
|
||||
message.send();
|
||||
|
||||
}
|
||||
return replacement;
|
||||
}
|
||||
|
||||
ClientHook writeTarget(ClientHook cap, RpcProtocol.MessageTarget.Builder target) {
|
||||
if (cap.getBrand() == this) {
|
||||
return ((RpcClient)cap).writeTarget(target);
|
||||
}
|
||||
else {
|
||||
return cap;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer writeDescriptor(RpcProtocol.CapDescriptor.Builder target, List<Integer> fds) {
|
||||
receivedCall = true;
|
||||
return RpcState.this.writeDescriptor(cap, target, fds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ClientHook writeTarget(RpcProtocol.MessageTarget.Builder target) {
|
||||
receivedCall = true;
|
||||
return RpcState.this.writeTarget(cap, target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ClientHook getInnermostClient() {
|
||||
receivedCall = true;
|
||||
return RpcState.this.getInnermostClient(cap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public VoidPromiseAndPipeline call(long interfaceId, short methodId, CallContextHook context) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -9,8 +9,6 @@ import java.util.ArrayDeque;
|
|||
import java.util.Queue;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class RpcStateTest {
|
||||
|
||||
class TestMessage implements IncomingRpcMessage {
|
||||
|
@ -73,7 +71,7 @@ public class RpcStateTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void handleUnimplemented() {
|
||||
public void handleUnimplemented() throws RpcException {
|
||||
var msg = new TestMessage();
|
||||
msg.builder.getRoot(RpcProtocol.Message.factory).initUnimplemented();
|
||||
rpc.handleMessage(msg);
|
||||
|
@ -82,10 +80,14 @@ public class RpcStateTest {
|
|||
|
||||
@Test
|
||||
public void handleAbort() {
|
||||
var msg = new TestMessage();
|
||||
var builder = msg.builder.getRoot(RpcProtocol.Message.factory);
|
||||
RpcException.fromException(RpcException.failed("Test abort"), builder.initAbort());
|
||||
Assert.assertThrows(RpcException.class, () -> rpc.handleMessage(msg));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void handleBootstrap() {
|
||||
public void handleBootstrap() throws RpcException {
|
||||
var msg = new TestMessage();
|
||||
var bootstrap = msg.builder.getRoot(RpcProtocol.Message.factory).initBootstrap();
|
||||
bootstrap.setQuestionId(0);
|
||||
|
|
Loading…
Reference in a new issue