use weak refs to cleanup import table

This commit is contained in:
Vaci Koblizek 2020-10-15 17:41:12 +01:00
parent caec63d68c
commit 7134461e7d
2 changed files with 137 additions and 111 deletions

View file

@ -8,10 +8,10 @@ abstract class ImportTable<T> implements Iterable<T> {
private final HashMap<Integer, T> slots = new HashMap<>();
protected abstract T newImportable();
protected abstract T newImportable(int id);
public T put(int id) {
return slots.computeIfAbsent(id, key -> newImportable());
return this.slots.computeIfAbsent(id, key -> newImportable(id));
}
public T find(int id) {

View file

@ -1,5 +1,7 @@
package org.capnproto;
import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
@ -7,7 +9,6 @@ import java.util.concurrent.CompletionStage;
final class RpcState {
final class Question {
final int id;
CompletableFuture<RpcResponse> response = new CompletableFuture<>();
@ -49,29 +50,61 @@ final class RpcState {
}
static final class Answer {
final int answerId;
boolean active = false;
PipelineHook pipeline;
CompletionStage<RpcResponse> redirectedResults;
RpcCallContext callContext;
List<Integer> resultExports;
Answer(int answerId) {
this.answerId = answerId;
}
}
static final class Export {
final int id;
final int exportId;
int refcount;
ClientHook clientHook;
CompletionStage<?> resolveOp;
Export(int id) {
this.id = id;
Export(int exportId) {
this.exportId = exportId;
}
}
static final class Import {
ImportClient importClient;
RpcClient appClient;
final class Import {
final int importId;
ImportRef importClient;
int remoteRefCount;
WeakReference<RpcClient> appClient;
CompletableFuture<ClientHook> promise;
// If non-null, the import is a promise.
Import(int importId) {
this.importId = importId;
}
void addRemoteRef() {
this.remoteRefCount++;
}
public void dispose() {
// Remove self from the import table.
var imp = imports.find(importId);
if (imp == this) {
imports.erase(importId, imp);
}
// Send a message releasing our remote references.
if (remoteRefCount > 0 && !isDisconnected()) {
var message = connection.newOutgoingMessage(1024);
var builder = message.getBody().initAs(RpcProtocol.Message.factory).initRelease();
builder.setId(importId);
builder.setReferenceCount(remoteRefCount);
message.send();
}
}
}
final static class Embargo {
@ -97,21 +130,21 @@ final class RpcState {
}
};
private final ImportTable<Answer> answers = new ImportTable<Answer>() {
private final ImportTable<Answer> answers = new ImportTable<>() {
@Override
protected Answer newImportable() {
return new Answer();
protected Answer newImportable(int answerId) {
return new Answer(answerId);
}
};
private final ImportTable<Import> imports = new ImportTable<Import>() {
private final ImportTable<Import> imports = new ImportTable<>() {
@Override
protected Import newImportable() {
return new Import();
protected Import newImportable(int importId) {
return new Import(importId);
}
};
private final ExportTable<Embargo> embargos = new ExportTable<Embargo>() {
private final ExportTable<Embargo> embargos = new ExportTable<>() {
@Override
Embargo newExportable(int id) {
return new Embargo(id);
@ -163,6 +196,7 @@ final class RpcState {
// run message loop once
final CompletableFuture<?> runOnce() {
this.cleanupImports();
if (isDisconnected()) {
return CompletableFuture.failedFuture(disconnected);
@ -189,6 +223,8 @@ final class RpcState {
// run message loop until promise is completed
public final <T> CompletableFuture<T> messageLoop(CompletableFuture<T> done) {
this.cleanupImports();
if (done.isDone()) {
return done;
}
@ -237,6 +273,9 @@ final class RpcState {
case DISEMBARGO:
handleDisembargo(reader.getDisembargo());
break;
case RELEASE:
handleRelease(reader.getRelease());
break;
default:
if (!isDisconnected()) {
// boomin' back atcha
@ -432,8 +471,7 @@ final class RpcState {
var payload = callReturn.getResults();
var capTable = receiveCaps(payload.getCapTable(), message.getAttachedFds());
// TODO question, message unused in RpcResponseImpl
// var response = new RpcResponseImpl(question, message, capTable, payload.getContent());
var response = new RpcResponseImpl(capTable, payload.getContent());
var response = new RpcResponseImpl(question, message, capTable, payload.getContent());
question.answer(response);
break;
@ -508,40 +546,32 @@ final class RpcState {
void handleResolve(IncomingRpcMessage message, RpcProtocol.Resolve.Reader resolve) {
ClientHook replacement = null;
Throwable exc = null;
switch (resolve.which()) {
case CAP:
replacement = receiveCap(resolve.getCap(), message.getAttachedFds());
break;
case EXCEPTION:
exc = new RuntimeException(resolve.getException().getReason().toString());
break;
default:
assert false;
return;
}
var imp = imports.find(resolve.getPromiseId());
if (imp == null) {
return;
}
var fulfiller = imp.promise;
if (fulfiller != null) {
if (exc != null) {
fulfiller.completeExceptionally(exc);
}
else {
fulfiller.complete(replacement);
}
}
else if (imp.importClient != null) {
if (imp.importClient != null) {
// It appears this is a valid entry on the import table, but was not expected to be a
// promise.
assert false;
assert false: "Import already resolved.";
}
switch (resolve.which()) {
case CAP:
imp.promise.complete(receiveCap(resolve.getCap(), message.getAttachedFds()));
break;
case EXCEPTION:
imp.promise.completeExceptionally(RpcException.toException(resolve.getException()));
break;
default:
assert false;
return;
}
}
private void handleRelease(RpcProtocol.Release.Reader release) {
releaseExport(release.getId(), release.getReferenceCount());
}
void handleDisembargo(RpcProtocol.Disembargo.Reader disembargo) {
@ -669,13 +699,13 @@ final class RpcState {
var wrapped = inner.whenMoreResolved();
if (wrapped != null) {
// This is a promise. Arrange for the `Resolve` message to be sent later.
export.resolveOp = resolveExportedPromise(export.id, wrapped);
descriptor.setSenderPromise(export.id);
export.resolveOp = resolveExportedPromise(export.exportId, wrapped);
descriptor.setSenderPromise(export.exportId);
}
else {
descriptor.setSenderHosted(export.id);
descriptor.setSenderHosted(export.exportId);
}
return export.id;
return export.exportId;
}
CompletionStage<?> resolveExportedPromise(int exportId, CompletionStage<ClientHook> promise) {
@ -741,8 +771,8 @@ final class RpcState {
return;
}
if (export.refcount <= refcount) {
assert false: "Over-reducing export refcount";
if (export.refcount < refcount) {
assert false: "Over-reducing export refcount. exported=" + export.refcount + ", requested=" + refcount;
return;
}
@ -828,27 +858,35 @@ final class RpcState {
// Receive a new import.
var imp = imports.put(importId);
if (imp.importClient == null) {
imp.importClient = new ImportClient(importId, fd);
ImportClient importClient = null;
if (imp.importClient != null) {
importClient = imp.importClient.get();
}
if (importClient == null) {
importClient = new ImportClient(imp, fd);
imp.importClient = new ImportRef(importId, importClient);
}
else {
imp.importClient.setFdIfMissing(fd);
importClient.setFdIfMissing(fd);
}
imp.importClient.addRemoteRef();
imp.addRemoteRef();
if (!isPromise) {
imp.appClient = imp.importClient;
return imp.importClient;
imp.appClient = new WeakReference<>(importClient);
return importClient;
}
if (imp.appClient != null) {
return imp.appClient;
var tmp = imp.appClient.get();
if (tmp != null) {
return tmp;
}
}
imp.promise = new CompletableFuture<ClientHook>();
var result = new PromiseClient(imp.importClient, imp.promise, importId);
imp.appClient = result;
imp.promise = new CompletableFuture<>();
var result = new PromiseClient(importClient, imp.promise, importId);
imp.appClient = new WeakReference<>(result);
return result;
}
@ -923,17 +961,16 @@ final class RpcState {
}
static class RpcResponseImpl implements RpcResponse {
// TODO unused?
// private final Question question;
// private final IncomingRpcMessage message;
private final Question question;
private final IncomingRpcMessage message;
private final AnyPointer.Reader results;
RpcResponseImpl(/*Question question,
IncomingRpcMessage message,*/
RpcResponseImpl(Question question,
IncomingRpcMessage message,
List<ClientHook> capTable,
AnyPointer.Reader results) {
// this.question = question;
// this.message = message;
this.question = question;
this.message = message;
this.results = results.imbue(new ReaderCapabilityTable(capTable));
}
@ -1381,19 +1418,26 @@ final class RpcState {
}
}
private class ImportClient extends RpcClient {
private ReferenceQueue<ImportClient> importRefs = new ReferenceQueue<>();
private class ImportRef extends WeakReference<ImportClient> {
final int importId;
int remoteRefCount = 0;
Integer fd;
ImportClient(int importId, Integer fd) {
ImportRef(int importId, ImportClient hook) {
super(hook, importRefs);
this.importId = importId;
this.fd = fd;
}
}
void addRemoteRef() {
this.remoteRefCount++;
private class ImportClient extends RpcClient {
final Import imp;
Integer fd;
ImportClient(Import imp, Integer fd) {
this.imp = imp;
this.fd = fd;
}
void setFdIfMissing(Integer fd) {
@ -1402,34 +1446,15 @@ final class RpcState {
}
}
public void remove() {
// Remove self from the import table.
var imp = imports.find(importId);
if (imp != null) {
if (imp.importClient == this) {
imports.erase(importId, imp);
}
}
// Send a message releasing our remote references.
if (remoteRefCount > 0 && !isDisconnected()) {
var message = connection.newOutgoingMessage(1024);
var builder = message.getBody().initAs(RpcProtocol.Message.factory).initRelease();
builder.setId(importId);
builder.setReferenceCount(remoteRefCount);
message.send();
}
}
@Override
public Integer writeDescriptor(RpcProtocol.CapDescriptor.Builder descriptor, List<Integer> fds) {
descriptor.setReceiverHosted(importId);
descriptor.setReceiverHosted(this.imp.importId);
return null;
}
@Override
public ClientHook writeTarget(RpcProtocol.MessageTarget.Builder target) {
target.setImportedCap(importId);
target.setImportedCap(this.imp.importId);
return null;
}
@ -1444,6 +1469,20 @@ final class RpcState {
}
}
private void cleanupImports() {
while (true) {
var ref = (ImportRef) this.importRefs.poll();
if (ref == null) {
return;
}
var imp = this.imports.find(ref.importId);
assert imp != null;
if (imp != null) {
imp.dispose();
}
}
}
enum ResolutionType {
UNRESOLVED,
REMOTE,
@ -1464,9 +1503,7 @@ final class RpcState {
Integer importId) {
this.cap = initial;
this.importId = importId;
this.promise = eventual.thenApply(resolution -> {
return resolve(resolution);
});
this.promise = eventual.thenApply(resolution -> resolve(resolution));
}
public boolean isResolved() {
@ -1570,17 +1607,6 @@ final class RpcState {
public VoidPromiseAndPipeline call(long interfaceId, short methodId, CallContextHook context) {
return null;
}
public void remove() {
if (this.importId != null) {
// This object represents an import promise. Clean that up.
var imp = imports.find(this.importId);
if (imp.appClient != null && imp.appClient == this) {
imp.appClient = null;
imp.importClient.remove();
}
}
}
}
class PipelineClient extends RpcClient {