diff --git a/runtime/src/main/java/org/capnproto/ClientHook.java b/runtime/src/main/java/org/capnproto/ClientHook.java index e8a5e08..bda1eaf 100644 --- a/runtime/src/main/java/org/capnproto/ClientHook.java +++ b/runtime/src/main/java/org/capnproto/ClientHook.java @@ -1,4 +1,52 @@ package org.capnproto; public interface ClientHook { + + Object NULL_CAPABILITY_BRAND = new Object(); + Object BROKEN_CAPABILITY_BRAND = new Object(); + + default ClientHook getResolved() { + return null; + } + + default Object getBrand() { + return NULL_CAPABILITY_BRAND; + } + + default boolean isNull() { + return getBrand() == NULL_CAPABILITY_BRAND; + } + + default boolean isError() { + return getBrand() == BROKEN_CAPABILITY_BRAND; + } + + default Integer getFd() { + return null; + } + + static ClientHook newBrokenCap(String reason) { + return newBrokenClient(reason, false, BROKEN_CAPABILITY_BRAND); + } + + static ClientHook newBrokenCap(Throwable exc) { + return newBrokenClient(exc, false, BROKEN_CAPABILITY_BRAND); + } + + static ClientHook newNullCap() { + return newBrokenClient(new RuntimeException("Called null capability"), true, NULL_CAPABILITY_BRAND); + } + + static private ClientHook newBrokenClient(String reason, boolean resolved, Object brand) { + return newBrokenClient(new RuntimeException(reason), resolved, brand); + } + + static private ClientHook newBrokenClient(Throwable exc, boolean resolved, Object brand) { + return new ClientHook() { + @Override + public Object getBrand() { + return brand; + } + }; + } } diff --git a/runtime/src/main/java/org/capnproto/WireHelpers.java b/runtime/src/main/java/org/capnproto/WireHelpers.java index 2c9da87..8600a6a 100644 --- a/runtime/src/main/java/org/capnproto/WireHelpers.java +++ b/runtime/src/main/java/org/capnproto/WireHelpers.java @@ -1312,4 +1312,38 @@ final class WireHelpers { return new Data.Reader(resolved.segment.buffer, resolved.ptr, size); } + static void setCapabilityPointer(SegmentBuilder segment, CapTableBuilder capTable, int refOffset, ClientHook cap) { + long ref = segment.get(refOffset); + + if (!WirePointer.isNull(ref)) { + zeroObject(segment, refOffset); + } + + if (cap == null) { + // TODO check zeroMemory behaviour + zeroPointerAndFars(segment, refOffset); + } + else { + WirePointer.setCap(segment.buffer, refOffset, capTable.injectCap(cap)); + } + } + + static ClientHook readCapabilityPointer(SegmentReader segment, CapTableReader capTable, int refOffset, int maxValue) { + long ref = segment.get(refOffset); + + if (WirePointer.isNull(ref)) { + return ClientHook.newNullCap(); + } + + if (WirePointer.kind(ref) != WirePointer.OTHER) { + return ClientHook.newBrokenCap("Calling capability extracted from a non-capability pointer."); + } + + var cap = capTable.extractCap(WirePointer.upper32Bits(ref)); + if (cap == null) { + return ClientHook.newBrokenCap("Calling invalid capability pointer."); + } + return cap; + } + } diff --git a/runtime/src/main/java/org/capnproto/WirePointer.java b/runtime/src/main/java/org/capnproto/WirePointer.java index 9b6f653..ba33292 100644 --- a/runtime/src/main/java/org/capnproto/WirePointer.java +++ b/runtime/src/main/java/org/capnproto/WirePointer.java @@ -85,4 +85,8 @@ final class WirePointer { public static int upper32Bits(long wirePointer) { return (int)(wirePointer >>> 32); } + + public static void setCap(ByteBuffer buffer, int offset, int cap) { + WirePointer.setOffsetAndKind(buffer, offset, (cap << 2) | OTHER); + } }