From 7134461e7d6997f4bf849dcdca86f851e1411904 Mon Sep 17 00:00:00 2001 From: Vaci Koblizek Date: Thu, 15 Oct 2020 17:41:12 +0100 Subject: [PATCH] use weak refs to cleanup import table --- .../main/java/org/capnproto/ImportTable.java | 4 +- .../src/main/java/org/capnproto/RpcState.java | 244 ++++++++++-------- 2 files changed, 137 insertions(+), 111 deletions(-) diff --git a/runtime/src/main/java/org/capnproto/ImportTable.java b/runtime/src/main/java/org/capnproto/ImportTable.java index b15d2ea..6905060 100644 --- a/runtime/src/main/java/org/capnproto/ImportTable.java +++ b/runtime/src/main/java/org/capnproto/ImportTable.java @@ -8,10 +8,10 @@ abstract class ImportTable implements Iterable { private final HashMap slots = new HashMap<>(); - protected abstract T newImportable(); + protected abstract T newImportable(int id); public T put(int id) { - return slots.computeIfAbsent(id, key -> newImportable()); + return this.slots.computeIfAbsent(id, key -> newImportable(id)); } public T find(int id) { diff --git a/runtime/src/main/java/org/capnproto/RpcState.java b/runtime/src/main/java/org/capnproto/RpcState.java index 5b098fd..92cad37 100644 --- a/runtime/src/main/java/org/capnproto/RpcState.java +++ b/runtime/src/main/java/org/capnproto/RpcState.java @@ -1,5 +1,7 @@ package org.capnproto; +import java.lang.ref.ReferenceQueue; +import java.lang.ref.WeakReference; import java.util.*; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; @@ -7,7 +9,6 @@ import java.util.concurrent.CompletionStage; final class RpcState { - final class Question { final int id; CompletableFuture response = new CompletableFuture<>(); @@ -49,29 +50,61 @@ final class RpcState { } static final class Answer { + final int answerId; boolean active = false; PipelineHook pipeline; CompletionStage redirectedResults; RpcCallContext callContext; List resultExports; + + Answer(int answerId) { + this.answerId = answerId; + } } static final class Export { - final int id; + final int exportId; int refcount; ClientHook clientHook; CompletionStage resolveOp; - Export(int id) { - this.id = id; + Export(int exportId) { + this.exportId = exportId; } } - static final class Import { - ImportClient importClient; - RpcClient appClient; + final class Import { + final int importId; + ImportRef importClient; + int remoteRefCount; + WeakReference appClient; CompletableFuture promise; // If non-null, the import is a promise. + + Import(int importId) { + this.importId = importId; + } + + void addRemoteRef() { + this.remoteRefCount++; + } + + public void dispose() { + // Remove self from the import table. + var imp = imports.find(importId); + if (imp == this) { + imports.erase(importId, imp); + } + + // Send a message releasing our remote references. + if (remoteRefCount > 0 && !isDisconnected()) { + var message = connection.newOutgoingMessage(1024); + var builder = message.getBody().initAs(RpcProtocol.Message.factory).initRelease(); + builder.setId(importId); + builder.setReferenceCount(remoteRefCount); + message.send(); + } + } } final static class Embargo { @@ -97,21 +130,21 @@ final class RpcState { } }; - private final ImportTable answers = new ImportTable() { + private final ImportTable answers = new ImportTable<>() { @Override - protected Answer newImportable() { - return new Answer(); + protected Answer newImportable(int answerId) { + return new Answer(answerId); } }; - private final ImportTable imports = new ImportTable() { + private final ImportTable imports = new ImportTable<>() { @Override - protected Import newImportable() { - return new Import(); + protected Import newImportable(int importId) { + return new Import(importId); } }; - private final ExportTable embargos = new ExportTable() { + private final ExportTable embargos = new ExportTable<>() { @Override Embargo newExportable(int id) { return new Embargo(id); @@ -163,6 +196,7 @@ final class RpcState { // run message loop once final CompletableFuture runOnce() { + this.cleanupImports(); if (isDisconnected()) { return CompletableFuture.failedFuture(disconnected); @@ -189,6 +223,8 @@ final class RpcState { // run message loop until promise is completed public final CompletableFuture messageLoop(CompletableFuture done) { + this.cleanupImports(); + if (done.isDone()) { return done; } @@ -237,6 +273,9 @@ final class RpcState { case DISEMBARGO: handleDisembargo(reader.getDisembargo()); break; + case RELEASE: + handleRelease(reader.getRelease()); + break; default: if (!isDisconnected()) { // boomin' back atcha @@ -432,8 +471,7 @@ final class RpcState { var payload = callReturn.getResults(); var capTable = receiveCaps(payload.getCapTable(), message.getAttachedFds()); // TODO question, message unused in RpcResponseImpl - // var response = new RpcResponseImpl(question, message, capTable, payload.getContent()); - var response = new RpcResponseImpl(capTable, payload.getContent()); + var response = new RpcResponseImpl(question, message, capTable, payload.getContent()); question.answer(response); break; @@ -508,40 +546,32 @@ final class RpcState { void handleResolve(IncomingRpcMessage message, RpcProtocol.Resolve.Reader resolve) { - ClientHook replacement = null; - Throwable exc = null; - - switch (resolve.which()) { - case CAP: - replacement = receiveCap(resolve.getCap(), message.getAttachedFds()); - break; - case EXCEPTION: - exc = new RuntimeException(resolve.getException().getReason().toString()); - break; - default: - assert false; - return; - } - var imp = imports.find(resolve.getPromiseId()); if (imp == null) { return; } - var fulfiller = imp.promise; - if (fulfiller != null) { - if (exc != null) { - fulfiller.completeExceptionally(exc); - } - else { - fulfiller.complete(replacement); - } - } - else if (imp.importClient != null) { + if (imp.importClient != null) { // It appears this is a valid entry on the import table, but was not expected to be a // promise. - assert false; + assert false: "Import already resolved."; } + + switch (resolve.which()) { + case CAP: + imp.promise.complete(receiveCap(resolve.getCap(), message.getAttachedFds())); + break; + case EXCEPTION: + imp.promise.completeExceptionally(RpcException.toException(resolve.getException())); + break; + default: + assert false; + return; + } + } + + private void handleRelease(RpcProtocol.Release.Reader release) { + releaseExport(release.getId(), release.getReferenceCount()); } void handleDisembargo(RpcProtocol.Disembargo.Reader disembargo) { @@ -669,13 +699,13 @@ final class RpcState { var wrapped = inner.whenMoreResolved(); if (wrapped != null) { // This is a promise. Arrange for the `Resolve` message to be sent later. - export.resolveOp = resolveExportedPromise(export.id, wrapped); - descriptor.setSenderPromise(export.id); + export.resolveOp = resolveExportedPromise(export.exportId, wrapped); + descriptor.setSenderPromise(export.exportId); } else { - descriptor.setSenderHosted(export.id); + descriptor.setSenderHosted(export.exportId); } - return export.id; + return export.exportId; } CompletionStage resolveExportedPromise(int exportId, CompletionStage promise) { @@ -741,8 +771,8 @@ final class RpcState { return; } - if (export.refcount <= refcount) { - assert false: "Over-reducing export refcount"; + if (export.refcount < refcount) { + assert false: "Over-reducing export refcount. exported=" + export.refcount + ", requested=" + refcount; return; } @@ -828,27 +858,35 @@ final class RpcState { // Receive a new import. var imp = imports.put(importId); - - if (imp.importClient == null) { - imp.importClient = new ImportClient(importId, fd); + ImportClient importClient = null; + if (imp.importClient != null) { + importClient = imp.importClient.get(); + } + if (importClient == null) { + importClient = new ImportClient(imp, fd); + imp.importClient = new ImportRef(importId, importClient); } else { - imp.importClient.setFdIfMissing(fd); + importClient.setFdIfMissing(fd); } - imp.importClient.addRemoteRef(); + + imp.addRemoteRef(); if (!isPromise) { - imp.appClient = imp.importClient; - return imp.importClient; + imp.appClient = new WeakReference<>(importClient); + return importClient; } if (imp.appClient != null) { - return imp.appClient; + var tmp = imp.appClient.get(); + if (tmp != null) { + return tmp; + } } - imp.promise = new CompletableFuture(); - var result = new PromiseClient(imp.importClient, imp.promise, importId); - imp.appClient = result; + imp.promise = new CompletableFuture<>(); + var result = new PromiseClient(importClient, imp.promise, importId); + imp.appClient = new WeakReference<>(result); return result; } @@ -923,17 +961,16 @@ final class RpcState { } static class RpcResponseImpl implements RpcResponse { - // TODO unused? - // private final Question question; - // private final IncomingRpcMessage message; + private final Question question; + private final IncomingRpcMessage message; private final AnyPointer.Reader results; - RpcResponseImpl(/*Question question, - IncomingRpcMessage message,*/ + RpcResponseImpl(Question question, + IncomingRpcMessage message, List capTable, AnyPointer.Reader results) { - // this.question = question; - // this.message = message; + this.question = question; + this.message = message; this.results = results.imbue(new ReaderCapabilityTable(capTable)); } @@ -1381,55 +1418,43 @@ final class RpcState { } } - private class ImportClient extends RpcClient { + private ReferenceQueue importRefs = new ReferenceQueue<>(); + + private class ImportRef extends WeakReference { final int importId; - int remoteRefCount = 0; + + ImportRef(int importId, ImportClient hook) { + super(hook, importRefs); + this.importId = importId; + } + } + + private class ImportClient extends RpcClient { + + final Import imp; Integer fd; - ImportClient(int importId, Integer fd) { - this.importId = importId; + ImportClient(Import imp, Integer fd) { + this.imp = imp; this.fd = fd; } - void addRemoteRef() { - this.remoteRefCount++; - } - void setFdIfMissing(Integer fd) { if (this.fd == null) { this.fd = fd; } } - public void remove() { - // Remove self from the import table. - var imp = imports.find(importId); - if (imp != null) { - if (imp.importClient == this) { - imports.erase(importId, imp); - } - } - - // Send a message releasing our remote references. - if (remoteRefCount > 0 && !isDisconnected()) { - var message = connection.newOutgoingMessage(1024); - var builder = message.getBody().initAs(RpcProtocol.Message.factory).initRelease(); - builder.setId(importId); - builder.setReferenceCount(remoteRefCount); - message.send(); - } - } - @Override public Integer writeDescriptor(RpcProtocol.CapDescriptor.Builder descriptor, List fds) { - descriptor.setReceiverHosted(importId); + descriptor.setReceiverHosted(this.imp.importId); return null; } @Override public ClientHook writeTarget(RpcProtocol.MessageTarget.Builder target) { - target.setImportedCap(importId); + target.setImportedCap(this.imp.importId); return null; } @@ -1444,6 +1469,20 @@ final class RpcState { } } + private void cleanupImports() { + while (true) { + var ref = (ImportRef) this.importRefs.poll(); + if (ref == null) { + return; + } + var imp = this.imports.find(ref.importId); + assert imp != null; + if (imp != null) { + imp.dispose(); + } + } + } + enum ResolutionType { UNRESOLVED, REMOTE, @@ -1464,9 +1503,7 @@ final class RpcState { Integer importId) { this.cap = initial; this.importId = importId; - this.promise = eventual.thenApply(resolution -> { - return resolve(resolution); - }); + this.promise = eventual.thenApply(resolution -> resolve(resolution)); } public boolean isResolved() { @@ -1570,17 +1607,6 @@ final class RpcState { public VoidPromiseAndPipeline call(long interfaceId, short methodId, CallContextHook context) { return null; } - - public void remove() { - if (this.importId != null) { - // This object represents an import promise. Clean that up. - var imp = imports.find(this.importId); - if (imp.appClient != null && imp.appClient == this) { - imp.appClient = null; - imp.importClient.remove(); - } - } - } } class PipelineClient extends RpcClient {