implement local, queued and promised hooks

This commit is contained in:
Vaci Koblizek 2020-09-28 21:59:55 +01:00
parent 37fe39bcde
commit 15b83a9c05
4 changed files with 448 additions and 10 deletions

View file

@ -1,6 +1,8 @@
package org.capnproto;
public class Capability {
import java.util.concurrent.CompletableFuture;
public final class Capability {
public static class Client {
@ -9,7 +11,105 @@ public class Capability {
public Client(ClientHook hook) {
this.hook = hook;
}
}
static ClientHook newLocalPromiseClient(CompletableFuture<ClientHook> promise) {
return new QueuedClient(promise);
}
static class LocalRequest implements RequestHook {
final MessageBuilder message = new MessageBuilder();
final long interfaceId;
final short methodId;
ClientHook client;
LocalRequest(long interfaceId, short methodId, ClientHook client) {
this.interfaceId = interfaceId;
this.methodId = methodId;
this.client = client;
}
@Override
public RemotePromise<AnyPointer.Reader> send() {
var cancelPaf = new CompletableFuture<java.lang.Void>();
var context = new LocalCallContext(message, client, cancelPaf);
var promiseAndPipeline = client.call(interfaceId, methodId, context);
var promise = promiseAndPipeline.promise.thenApply(x -> {
context.getResults(); // force allocation
return context.response;
});
return new RemotePromise<AnyPointer.Reader>(promise, promiseAndPipeline.pipeline);
}
@Override
public Object getBrand() {
return null;
}
}
static class LocalResponse implements ResponseHook {
final MessageBuilder message = new MessageBuilder();
}
static class LocalCallContext implements CallContextHook {
final CompletableFuture<?> cancelAllowed;
MessageBuilder request;
Response response;
AnyPointer.Builder responseBuilder;
ClientHook clientRef;
LocalCallContext(MessageBuilder request,
ClientHook clientRef,
CompletableFuture<?> cancelAllowed) {
this.request = request;
this.clientRef = clientRef;
this.cancelAllowed = cancelAllowed;
}
@Override
public AnyPointer.Reader getParams() {
return request.getRoot(AnyPointer.factory).asReader();
}
@Override
public void releaseParams() {
this.request = null;
}
@Override
public AnyPointer.Builder getResults() {
if (this.response == null) {
var localResponse = new LocalResponse();
this.responseBuilder = localResponse.message.getRoot(AnyPointer.factory);
this.response = new Response(this.responseBuilder.asReader(), localResponse);
}
return this.responseBuilder;
}
@Override
public void allowCancellation() {
this.cancelAllowed.complete(null);
}
@Override
public CompletableFuture<java.lang.Void> tailCall(RequestHook request) {
// TODO implement tailCall
return null;
}
@Override
public CompletableFuture<PipelineHook> onTailCall() {
// TODO implement onTailCall
return null;
}
@Override
public ClientHook.VoidPromiseAndPipeline directTailCall(RequestHook request) {
// TODO implement directTailCall
return null;
}
}
}

View file

@ -0,0 +1,47 @@
package org.capnproto;
import java.util.concurrent.CompletableFuture;
class QueuedClient implements ClientHook {
final CompletableFuture<ClientHook> promise;
final CompletableFuture<ClientHook> promiseForCallForwarding;
final CompletableFuture<ClientHook> promiseForClientResolution;
final CompletableFuture<java.lang.Void> setResolutionOp;
ClientHook redirect;
QueuedClient(CompletableFuture<ClientHook> promise) {
// TODO revisit futures
this.promise = promise.copy();
this.promiseForCallForwarding = promise.copy();
this.promiseForClientResolution = promise.copy();
this.setResolutionOp = promise.thenAccept(inner -> {
this.redirect = inner;
}).exceptionally(exc -> {
this.redirect = ClientHook.newBrokenCap(exc);
return null;
});
}
@Override
public Request<AnyPointer.Builder, AnyPointer.Reader> newCall(long interfaceId, short methodId) {
var hook = new Capability.LocalRequest(interfaceId, methodId, this);
var root = hook.message.getRoot(AnyPointer.factory);
return new Request<>(root, hook);
}
@Override
public VoidPromiseAndPipeline call(long interfaceId, short methodId, CallContextHook ctx) {
return null;
}
@Override
public ClientHook getResolved() {
return redirect;
}
@Override
public CompletableFuture<ClientHook> whenMoreResolved() {
return promiseForClientResolution;
}
}

View file

@ -81,7 +81,7 @@ final class RpcState {
return this.disconnected != null;
}
void handleMessage(IncomingRpcMessage message) {
void handleMessage(IncomingRpcMessage message) throws RpcException {
var reader = message.getBody().getAs(RpcProtocol.Message.factory);
switch (reader.which()) {
@ -110,7 +110,12 @@ final class RpcState {
handleDisembargo(reader.getDisembargo());
break;
default:
// TODO send unimplemented response
if (!isDisconnected()) {
// boomin' back atcha
var msg = connection.newOutgoingMessage(1024);
msg.getBody().initAs(RpcProtocol.Message.factory).setUnimplemented(reader);
msg.send();
}
break;
}
}
@ -150,7 +155,8 @@ final class RpcState {
}
}
void handleAbort(RpcProtocol.Exception.Reader abort) {
void handleAbort(RpcProtocol.Exception.Reader abort) throws RpcException {
throw RpcException.toException(abort);
}
void handleBootstrap(IncomingRpcMessage message, RpcProtocol.Bootstrap.Reader bootstrap) {
@ -338,7 +344,6 @@ final class RpcState {
});
}
void releaseExport(int exportId, int refcount) {
var export = exports.find(exportId);
assert export != null;
@ -356,6 +361,112 @@ final class RpcState {
}
}
private List<ClientHook> receiveCaps(StructList.Reader<RpcProtocol.CapDescriptor.Reader> capTable, List<Integer> fds) {
var result = new ArrayList<ClientHook>();
for (var cap: capTable) {
result.add(receiveCap(cap, fds));
}
return result;
}
private ClientHook receiveCap(RpcProtocol.CapDescriptor.Reader descriptor, List<Integer> fds) {
// TODO AutoCloseFd
Integer fd = null;
int fdIndex = descriptor.getAttachedFd();
if (fdIndex >= 0 && fdIndex < fds.size()) {
fd = fds.get(fdIndex);
if (fd != null) {
fds.set(fdIndex, null);
}
}
switch (descriptor.which()) {
case NONE:
return null;
case SENDER_HOSTED:
return importCap(descriptor.getSenderHosted(), false, fd);
case SENDER_PROMISE:
return importCap(descriptor.getSenderPromise(), true, fd);
case RECEIVER_HOSTED:
var exp = exports.find(descriptor.getReceiverHosted());
if (exp == null) {
return ClientHook.newBrokenCap("invalid 'receiverHosted' export ID");
}
if (exp.clientHook.getBrand() == this) {
// TODO Tribble 4-way race!
return exp.clientHook;
}
return exp.clientHook;
case RECEIVER_ANSWER:
var promisedAnswer = descriptor.getReceiverAnswer();
var answer = answers.find(promisedAnswer.getQuestionId());
var ops = PipelineOp.ToPipelineOps(promisedAnswer);
if (answer == null || !answer.active || answer.pipeline == null || ops == null) {
return ClientHook.newBrokenCap("invalid 'receiverAnswer'");
}
var result = answer.pipeline.getPipelinedCap(ops);
if (result == null) {
return ClientHook.newBrokenCap("Unrecognised pipeline ops");
}
if (result.getBrand() == this) {
// TODO Tribble 4-way race!
return result;
}
return result;
case THIRD_PARTY_HOSTED:
return ClientHook.newBrokenCap("Third party caps not supported");
default:
return ClientHook.newBrokenCap("unknown CapDescriptor type");
}
}
private ClientHook importCap(int importId, boolean isPromise, Integer fd) {
// Receive a new import.
var imp = imports.put(importId);
if (imp.importClient == null) {
imp.importClient = new ImportClient(importId, fd);
}
else {
imp.importClient.setFdIfMissing(fd);
}
imp.importClient.addRemoteRef();
if (!isPromise) {
imp.appClient = imp.importClient;
return imp.importClient;
}
if (imp.appClient != null) {
return imp.appClient;
}
imp.promise = new CompletableFuture<ClientHook>();
var result = new PromiseClient(imp.importClient, imp.promise, importId);
imp.appClient = result;
return result;
}
ClientHook writeTarget(ClientHook cap, RpcProtocol.MessageTarget.Builder target) {
return cap.getBrand() == this
? ((RpcClient)cap).writeTarget(target)
: cap;
}
ClientHook getInnermostClient(ClientHook client) {
for (;;) {
var inner = client.getResolved();
@ -465,16 +576,194 @@ final class RpcState {
}
class ImportClient extends RpcClient {
final int importId;
int remoteRefCount = 0;
Integer fd;
ImportClient(int importId, Integer fd) {
this.importId = importId;
this.fd = fd;
}
void addRemoteRef() {
this.remoteRefCount++;
}
void setFdIfMissing(Integer fd) {
if (this.fd == null) {
this.fd = fd;
}
}
public void dispose() {
// TODO manage destruction...
var imp = imports.find(importId);
if (imp != null) {
if (imp.importClient == this) {
imports.erase(importId, imp);
}
}
if (remoteRefCount > 0 && !isDisconnected()) {
var message = connection.newOutgoingMessage(1024);
var builder = message.getBody().initAs(RpcProtocol.Message.factory).initRelease();
builder.setId(importId);
builder.setReferenceCount(remoteRefCount);
message.send();
}
}
@Override
public Integer writeDescriptor(RpcProtocol.CapDescriptor.Builder descriptor, List<Integer> fds) {
descriptor.setReceiverHosted(importId);
return null;
}
@Override
public ClientHook writeTarget(RpcProtocol.MessageTarget.Builder target) {
target.setImportedCap(importId);
return null;
}
@Override
public VoidPromiseAndPipeline call(long interfaceId, short methodId, CallContextHook context) {
return null;
}
@Override
public CompletableFuture<ClientHook> whenMoreResolved() {
return null;
}
}
enum ResolutionType {
UNRESOLVED,
REMOTE,
REFLECTED,
MERGED,
BROKEN
}
class PromiseClient extends RpcClient {
final ClientHook cap;
final Integer importId;
final CompletableFuture<ClientHook> promise;
boolean receivedCall = false;
ResolutionType resolutionType = ResolutionType.UNRESOLVED;
public PromiseClient(RpcClient initial,
CompletableFuture<ClientHook> eventual,
Integer importId) {
this.cap = initial;
this.importId = importId;
this.promise = eventual.thenApply(resolution -> {
return resolve(resolution);
});
}
public boolean isResolved() {
return resolutionType != ResolutionType.UNRESOLVED;
}
private ClientHook resolve(ClientHook replacement) {
assert !isResolved();
var replacementBrand = replacement.getBrand();
boolean isSameConnection = replacementBrand == RpcState.this;
if (isSameConnection) {
var promise = replacement.whenMoreResolved();
if (promise != null) {
var other = (PromiseClient)replacement;
while (other.resolutionType == ResolutionType.MERGED) {
replacement = other.cap;
other = (PromiseClient)replacement;
assert replacement.getBrand() == replacementBrand;
}
if (other.isResolved()) {
resolutionType = other.resolutionType;
}
else {
other.receivedCall = other.receivedCall || receivedCall;
resolutionType = ResolutionType.MERGED;
}
}
else {
resolutionType = ResolutionType.REMOTE;
}
}
else {
if (replacementBrand == NULL_CAPABILITY_BRAND ||
replacementBrand == BROKEN_CAPABILITY_BRAND) {
resolutionType = ResolutionType.BROKEN;
}
else {
resolutionType = ResolutionType.REFLECTED;
}
}
assert isResolved();
// TODO Flow control
if (resolutionType == ResolutionType.REFLECTED && receivedCall && !isDisconnected()) {
var message = connection.newOutgoingMessage(1024);
var disembargo = message.getBody().initAs(RpcProtocol.Message.factory).initDisembargo();
{
var redirect = RpcState.this.writeTarget(cap, disembargo.initTarget());
assert redirect == null;
}
var embargo = new Embargo();
var embargoId = embargos.next(embargo);
disembargo.getContext().setSenderLoopback(embargoId);
embargo.fulfiller = new CompletableFuture<>();
final ClientHook finalReplacement = replacement;
var embargoPromise = embargo.fulfiller.thenApply(x -> {
return finalReplacement;
});
replacement = Capability.newLocalPromiseClient(embargoPromise);
message.send();
}
return replacement;
}
ClientHook writeTarget(ClientHook cap, RpcProtocol.MessageTarget.Builder target) {
if (cap.getBrand() == this) {
return ((RpcClient)cap).writeTarget(target);
}
else {
return cap;
}
}
@Override
public Integer writeDescriptor(RpcProtocol.CapDescriptor.Builder target, List<Integer> fds) {
receivedCall = true;
return RpcState.this.writeDescriptor(cap, target, fds);
}
@Override
public ClientHook writeTarget(RpcProtocol.MessageTarget.Builder target) {
receivedCall = true;
return RpcState.this.writeTarget(cap, target);
}
@Override
public ClientHook getInnermostClient() {
receivedCall = true;
return RpcState.this.getInnermostClient(cap);
}
@Override
public VoidPromiseAndPipeline call(long interfaceId, short methodId, CallContextHook context) {
return null;
}
}
}

View file

@ -9,8 +9,6 @@ import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.CompletableFuture;
import static org.junit.Assert.*;
public class RpcStateTest {
class TestMessage implements IncomingRpcMessage {
@ -73,7 +71,7 @@ public class RpcStateTest {
}
@Test
public void handleUnimplemented() {
public void handleUnimplemented() throws RpcException {
var msg = new TestMessage();
msg.builder.getRoot(RpcProtocol.Message.factory).initUnimplemented();
rpc.handleMessage(msg);
@ -82,10 +80,14 @@ public class RpcStateTest {
@Test
public void handleAbort() {
var msg = new TestMessage();
var builder = msg.builder.getRoot(RpcProtocol.Message.factory);
RpcException.fromException(RpcException.failed("Test abort"), builder.initAbort());
Assert.assertThrows(RpcException.class, () -> rpc.handleMessage(msg));
}
@Test
public void handleBootstrap() {
public void handleBootstrap() throws RpcException {
var msg = new TestMessage();
var bootstrap = msg.builder.getRoot(RpcProtocol.Message.factory).initBootstrap();
bootstrap.setQuestionId(0);