From e04adc90b637c7440c728fb7b4a219b4cf085036 Mon Sep 17 00:00:00 2001 From: Vaci Koblizek Date: Sun, 8 Nov 2020 19:45:47 +0000 Subject: [PATCH] embargo test and tribble --- .../src/main/java/org/capnproto/RpcState.java | 79 +++++++++++++------ .../src/test/java/org/capnproto/RpcTest.java | 41 ++++++++++ .../test/java/org/capnproto/RpcTestUtil.java | 15 +++- 3 files changed, 108 insertions(+), 27 deletions(-) diff --git a/runtime-rpc/src/main/java/org/capnproto/RpcState.java b/runtime-rpc/src/main/java/org/capnproto/RpcState.java index 369cbb6..7a1f52f 100644 --- a/runtime-rpc/src/main/java/org/capnproto/RpcState.java +++ b/runtime-rpc/src/main/java/org/capnproto/RpcState.java @@ -1041,19 +1041,18 @@ final class RpcState { case SENDER_PROMISE: return importCap(descriptor.getSenderPromise(), true, fd); - case RECEIVER_HOSTED: + case RECEIVER_HOSTED: { var exp = exports.find(descriptor.getReceiverHosted()); if (exp == null) { return Capability.newBrokenCap("invalid 'receiverHosted' export ID"); - } - if (exp.clientHook.getBrand() == this) { - // TODO Tribble 4-way race! + } else if (exp.clientHook.getBrand() == this) { + return new TribbleRaceBlocker(exp.clientHook); + } else { return exp.clientHook; } + } - return exp.clientHook; - - case RECEIVER_ANSWER: + case RECEIVER_ANSWER: { var promisedAnswer = descriptor.getReceiverAnswer(); var answer = answers.find(promisedAnswer.getQuestionId()); var ops = ToPipelineOps(promisedAnswer); @@ -1065,14 +1064,12 @@ final class RpcState { var result = answer.pipeline.getPipelinedCap(ops); if (result == null) { return Capability.newBrokenCap("Unrecognised pipeline ops"); - } - - if (result.getBrand() == this) { - // TODO Tribble 4-way race! + } else if (result.getBrand() == this) { + return new TribbleRaceBlocker(result); + } else { return result; } - - return result; + } case THIRD_PARTY_HOSTED: return Capability.newBrokenCap("Third party caps not supported"); @@ -1579,15 +1576,15 @@ final class RpcState { } @Override - public VoidPromiseAndPipeline call(long interfaceId, short methodId, CallContextHook context) { - return null; + public VoidPromiseAndPipeline call(long interfaceId, short methodId, CallContextHook ctx) { + return this.callNoIntercept(interfaceId, methodId, ctx); } - public VoidPromiseAndPipeline callNoIntercept(long interfaceId, short methodId, CallContextHook context) { - var params = context.getParams(); + public VoidPromiseAndPipeline callNoIntercept(long interfaceId, short methodId, CallContextHook ctx) { + var params = ctx.getParams(); var request = newCallNoIntercept(interfaceId, methodId); - context.allowCancellation(); - return context.directTailCall(request.getHook()); + ctx.allowCancellation(); + return ctx.directTailCall(request.getHook()); } @Override @@ -1760,11 +1757,6 @@ final class RpcState { return null; } - @Override - public VoidPromiseAndPipeline call(long interfaceId, short methodId, CallContextHook context) { - return null; - } - @Override public CompletableFuture whenMoreResolved() { return null; @@ -2001,4 +1993,43 @@ final class RpcState { } return new RpcException(type, reader.getReason().toString()); } + + class TribbleRaceBlocker implements ClientHook { + + final ClientHook inner; + + TribbleRaceBlocker(ClientHook inner) { + this.inner = inner; + } + + @Override + public Request newCall(long interfaceId, short methodId) { + return this.inner.newCall(interfaceId, methodId); + } + + @Override + public VoidPromiseAndPipeline call(long interfaceId, short methodId, CallContextHook ctx) { + return this.inner.call(interfaceId, methodId, ctx); + } + + @Override + public ClientHook getResolved() { + return null; + } + + @Override + public CompletableFuture whenMoreResolved() { + return null; + } + + @Override + public Object getBrand() { + return null; + } + + @Override + public Integer getFd() { + return this.inner.getFd(); + } + } } diff --git a/runtime-rpc/src/test/java/org/capnproto/RpcTest.java b/runtime-rpc/src/test/java/org/capnproto/RpcTest.java index 9fe93cf..df8fa0f 100644 --- a/runtime-rpc/src/test/java/org/capnproto/RpcTest.java +++ b/runtime-rpc/src/test/java/org/capnproto/RpcTest.java @@ -431,5 +431,46 @@ public class RpcTest { Assert.assertEquals(2, dependentCall2.join().getN()); Assert.assertEquals(1, calleeCallCount.value()); } + + static CompletableFuture getCallSequence( + Test.TestCallOrder.Client client, int expected) { + var req = client.getCallSequenceRequest(); + req.getParams().setExpected(expected); + return req.send(); + } + + @org.junit.Test + public void testEmbargo() { + var context = new TestContext(bootstrapFactory); + var client = new Test.TestMoreStuff.Client(context.connect(Test.TestSturdyRefObjectId.Tag.TEST_MORE_STUFF)); + + var cap = new Test.TestCallOrder.Client(new TestCallOrderImpl()); + var earlyCall = client.getCallSequenceRequest().send(); + + var echoRequest = client.echoRequest(); + echoRequest.getParams().setCap(cap); + var echo = echoRequest.send(); + + var pipeline = echo.getCap(); + var call0 = getCallSequence(pipeline, 0); + var call1 = getCallSequence(pipeline, 1); + + earlyCall.join(); + + var call2 = getCallSequence(pipeline, 2); + + var resolved = echo.join().getCap(); + + var call3 = getCallSequence(pipeline, 3); + var call4 = getCallSequence(pipeline, 4); + var call5 = getCallSequence(pipeline, 5); + + Assert.assertEquals(0, call0.join().getN()); + Assert.assertEquals(1, call1.join().getN()); + Assert.assertEquals(2, call2.join().getN()); + Assert.assertEquals(3, call3.join().getN()); + Assert.assertEquals(4, call4.join().getN()); + Assert.assertEquals(5, call5.join().getN()); + } } diff --git a/runtime-rpc/src/test/java/org/capnproto/RpcTestUtil.java b/runtime-rpc/src/test/java/org/capnproto/RpcTestUtil.java index 5bc80aa..5fc0072 100644 --- a/runtime-rpc/src/test/java/org/capnproto/RpcTestUtil.java +++ b/runtime-rpc/src/test/java/org/capnproto/RpcTestUtil.java @@ -114,17 +114,26 @@ class RpcTestUtil { this.handleCount = handleCount; } + @Override + protected CompletableFuture echo(CallContext context) { + this.callCount.inc(); + var params = context.getParams(); + var result = context.getResults(); + result.setCap(params.getCap()); + return READY_NOW; + } + @Override protected CompletableFuture getHandle(CallContext context) { context.getResults().setHandle(new HandleImpl(this.handleCount)); - return Capability.Server.READY_NOW; + return READY_NOW; } @Override protected CompletableFuture getCallSequence(CallContext context) { var result = context.getResults(); result.setN(this.callCount.inc()); - return Capability.Server.READY_NOW; + return READY_NOW; } @Override @@ -179,7 +188,7 @@ class RpcTestUtil { results.setI(params.getI()); results.setT(params.getT()); results.setC(new TestCallOrderImpl()); - return Capability.Server.READY_NOW; + return READY_NOW; } }