use weak refs to cleanup import table

This commit is contained in:
Vaci Koblizek 2020-10-15 17:41:12 +01:00
parent caec63d68c
commit 7134461e7d
2 changed files with 137 additions and 111 deletions

View file

@ -8,10 +8,10 @@ abstract class ImportTable<T> implements Iterable<T> {
private final HashMap<Integer, T> slots = new HashMap<>(); private final HashMap<Integer, T> slots = new HashMap<>();
protected abstract T newImportable(); protected abstract T newImportable(int id);
public T put(int id) { public T put(int id) {
return slots.computeIfAbsent(id, key -> newImportable()); return this.slots.computeIfAbsent(id, key -> newImportable(id));
} }
public T find(int id) { public T find(int id) {

View file

@ -1,5 +1,7 @@
package org.capnproto; package org.capnproto;
import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
import java.util.*; import java.util.*;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
@ -7,7 +9,6 @@ import java.util.concurrent.CompletionStage;
final class RpcState { final class RpcState {
final class Question { final class Question {
final int id; final int id;
CompletableFuture<RpcResponse> response = new CompletableFuture<>(); CompletableFuture<RpcResponse> response = new CompletableFuture<>();
@ -49,29 +50,61 @@ final class RpcState {
} }
static final class Answer { static final class Answer {
final int answerId;
boolean active = false; boolean active = false;
PipelineHook pipeline; PipelineHook pipeline;
CompletionStage<RpcResponse> redirectedResults; CompletionStage<RpcResponse> redirectedResults;
RpcCallContext callContext; RpcCallContext callContext;
List<Integer> resultExports; List<Integer> resultExports;
Answer(int answerId) {
this.answerId = answerId;
}
} }
static final class Export { static final class Export {
final int id; final int exportId;
int refcount; int refcount;
ClientHook clientHook; ClientHook clientHook;
CompletionStage<?> resolveOp; CompletionStage<?> resolveOp;
Export(int id) { Export(int exportId) {
this.id = id; this.exportId = exportId;
} }
} }
static final class Import { final class Import {
ImportClient importClient; final int importId;
RpcClient appClient; ImportRef importClient;
int remoteRefCount;
WeakReference<RpcClient> appClient;
CompletableFuture<ClientHook> promise; CompletableFuture<ClientHook> promise;
// If non-null, the import is a promise. // If non-null, the import is a promise.
Import(int importId) {
this.importId = importId;
}
void addRemoteRef() {
this.remoteRefCount++;
}
public void dispose() {
// Remove self from the import table.
var imp = imports.find(importId);
if (imp == this) {
imports.erase(importId, imp);
}
// Send a message releasing our remote references.
if (remoteRefCount > 0 && !isDisconnected()) {
var message = connection.newOutgoingMessage(1024);
var builder = message.getBody().initAs(RpcProtocol.Message.factory).initRelease();
builder.setId(importId);
builder.setReferenceCount(remoteRefCount);
message.send();
}
}
} }
final static class Embargo { final static class Embargo {
@ -97,21 +130,21 @@ final class RpcState {
} }
}; };
private final ImportTable<Answer> answers = new ImportTable<Answer>() { private final ImportTable<Answer> answers = new ImportTable<>() {
@Override @Override
protected Answer newImportable() { protected Answer newImportable(int answerId) {
return new Answer(); return new Answer(answerId);
} }
}; };
private final ImportTable<Import> imports = new ImportTable<Import>() { private final ImportTable<Import> imports = new ImportTable<>() {
@Override @Override
protected Import newImportable() { protected Import newImportable(int importId) {
return new Import(); return new Import(importId);
} }
}; };
private final ExportTable<Embargo> embargos = new ExportTable<Embargo>() { private final ExportTable<Embargo> embargos = new ExportTable<>() {
@Override @Override
Embargo newExportable(int id) { Embargo newExportable(int id) {
return new Embargo(id); return new Embargo(id);
@ -163,6 +196,7 @@ final class RpcState {
// run message loop once // run message loop once
final CompletableFuture<?> runOnce() { final CompletableFuture<?> runOnce() {
this.cleanupImports();
if (isDisconnected()) { if (isDisconnected()) {
return CompletableFuture.failedFuture(disconnected); return CompletableFuture.failedFuture(disconnected);
@ -189,6 +223,8 @@ final class RpcState {
// run message loop until promise is completed // run message loop until promise is completed
public final <T> CompletableFuture<T> messageLoop(CompletableFuture<T> done) { public final <T> CompletableFuture<T> messageLoop(CompletableFuture<T> done) {
this.cleanupImports();
if (done.isDone()) { if (done.isDone()) {
return done; return done;
} }
@ -237,6 +273,9 @@ final class RpcState {
case DISEMBARGO: case DISEMBARGO:
handleDisembargo(reader.getDisembargo()); handleDisembargo(reader.getDisembargo());
break; break;
case RELEASE:
handleRelease(reader.getRelease());
break;
default: default:
if (!isDisconnected()) { if (!isDisconnected()) {
// boomin' back atcha // boomin' back atcha
@ -432,8 +471,7 @@ final class RpcState {
var payload = callReturn.getResults(); var payload = callReturn.getResults();
var capTable = receiveCaps(payload.getCapTable(), message.getAttachedFds()); var capTable = receiveCaps(payload.getCapTable(), message.getAttachedFds());
// TODO question, message unused in RpcResponseImpl // TODO question, message unused in RpcResponseImpl
// var response = new RpcResponseImpl(question, message, capTable, payload.getContent()); var response = new RpcResponseImpl(question, message, capTable, payload.getContent());
var response = new RpcResponseImpl(capTable, payload.getContent());
question.answer(response); question.answer(response);
break; break;
@ -508,40 +546,32 @@ final class RpcState {
void handleResolve(IncomingRpcMessage message, RpcProtocol.Resolve.Reader resolve) { void handleResolve(IncomingRpcMessage message, RpcProtocol.Resolve.Reader resolve) {
ClientHook replacement = null;
Throwable exc = null;
switch (resolve.which()) {
case CAP:
replacement = receiveCap(resolve.getCap(), message.getAttachedFds());
break;
case EXCEPTION:
exc = new RuntimeException(resolve.getException().getReason().toString());
break;
default:
assert false;
return;
}
var imp = imports.find(resolve.getPromiseId()); var imp = imports.find(resolve.getPromiseId());
if (imp == null) { if (imp == null) {
return; return;
} }
var fulfiller = imp.promise; if (imp.importClient != null) {
if (fulfiller != null) {
if (exc != null) {
fulfiller.completeExceptionally(exc);
}
else {
fulfiller.complete(replacement);
}
}
else if (imp.importClient != null) {
// It appears this is a valid entry on the import table, but was not expected to be a // It appears this is a valid entry on the import table, but was not expected to be a
// promise. // promise.
assert false; assert false: "Import already resolved.";
} }
switch (resolve.which()) {
case CAP:
imp.promise.complete(receiveCap(resolve.getCap(), message.getAttachedFds()));
break;
case EXCEPTION:
imp.promise.completeExceptionally(RpcException.toException(resolve.getException()));
break;
default:
assert false;
return;
}
}
private void handleRelease(RpcProtocol.Release.Reader release) {
releaseExport(release.getId(), release.getReferenceCount());
} }
void handleDisembargo(RpcProtocol.Disembargo.Reader disembargo) { void handleDisembargo(RpcProtocol.Disembargo.Reader disembargo) {
@ -669,13 +699,13 @@ final class RpcState {
var wrapped = inner.whenMoreResolved(); var wrapped = inner.whenMoreResolved();
if (wrapped != null) { if (wrapped != null) {
// This is a promise. Arrange for the `Resolve` message to be sent later. // This is a promise. Arrange for the `Resolve` message to be sent later.
export.resolveOp = resolveExportedPromise(export.id, wrapped); export.resolveOp = resolveExportedPromise(export.exportId, wrapped);
descriptor.setSenderPromise(export.id); descriptor.setSenderPromise(export.exportId);
} }
else { else {
descriptor.setSenderHosted(export.id); descriptor.setSenderHosted(export.exportId);
} }
return export.id; return export.exportId;
} }
CompletionStage<?> resolveExportedPromise(int exportId, CompletionStage<ClientHook> promise) { CompletionStage<?> resolveExportedPromise(int exportId, CompletionStage<ClientHook> promise) {
@ -741,8 +771,8 @@ final class RpcState {
return; return;
} }
if (export.refcount <= refcount) { if (export.refcount < refcount) {
assert false: "Over-reducing export refcount"; assert false: "Over-reducing export refcount. exported=" + export.refcount + ", requested=" + refcount;
return; return;
} }
@ -828,27 +858,35 @@ final class RpcState {
// Receive a new import. // Receive a new import.
var imp = imports.put(importId); var imp = imports.put(importId);
ImportClient importClient = null;
if (imp.importClient == null) { if (imp.importClient != null) {
imp.importClient = new ImportClient(importId, fd); importClient = imp.importClient.get();
}
if (importClient == null) {
importClient = new ImportClient(imp, fd);
imp.importClient = new ImportRef(importId, importClient);
} }
else { else {
imp.importClient.setFdIfMissing(fd); importClient.setFdIfMissing(fd);
} }
imp.importClient.addRemoteRef();
imp.addRemoteRef();
if (!isPromise) { if (!isPromise) {
imp.appClient = imp.importClient; imp.appClient = new WeakReference<>(importClient);
return imp.importClient; return importClient;
} }
if (imp.appClient != null) { if (imp.appClient != null) {
return imp.appClient; var tmp = imp.appClient.get();
if (tmp != null) {
return tmp;
}
} }
imp.promise = new CompletableFuture<ClientHook>(); imp.promise = new CompletableFuture<>();
var result = new PromiseClient(imp.importClient, imp.promise, importId); var result = new PromiseClient(importClient, imp.promise, importId);
imp.appClient = result; imp.appClient = new WeakReference<>(result);
return result; return result;
} }
@ -923,17 +961,16 @@ final class RpcState {
} }
static class RpcResponseImpl implements RpcResponse { static class RpcResponseImpl implements RpcResponse {
// TODO unused? private final Question question;
// private final Question question; private final IncomingRpcMessage message;
// private final IncomingRpcMessage message;
private final AnyPointer.Reader results; private final AnyPointer.Reader results;
RpcResponseImpl(/*Question question, RpcResponseImpl(Question question,
IncomingRpcMessage message,*/ IncomingRpcMessage message,
List<ClientHook> capTable, List<ClientHook> capTable,
AnyPointer.Reader results) { AnyPointer.Reader results) {
// this.question = question; this.question = question;
// this.message = message; this.message = message;
this.results = results.imbue(new ReaderCapabilityTable(capTable)); this.results = results.imbue(new ReaderCapabilityTable(capTable));
} }
@ -1381,19 +1418,26 @@ final class RpcState {
} }
} }
private class ImportClient extends RpcClient { private ReferenceQueue<ImportClient> importRefs = new ReferenceQueue<>();
private class ImportRef extends WeakReference<ImportClient> {
final int importId; final int importId;
int remoteRefCount = 0;
Integer fd;
ImportClient(int importId, Integer fd) { ImportRef(int importId, ImportClient hook) {
super(hook, importRefs);
this.importId = importId; this.importId = importId;
this.fd = fd; }
} }
void addRemoteRef() { private class ImportClient extends RpcClient {
this.remoteRefCount++;
final Import imp;
Integer fd;
ImportClient(Import imp, Integer fd) {
this.imp = imp;
this.fd = fd;
} }
void setFdIfMissing(Integer fd) { void setFdIfMissing(Integer fd) {
@ -1402,34 +1446,15 @@ final class RpcState {
} }
} }
public void remove() {
// Remove self from the import table.
var imp = imports.find(importId);
if (imp != null) {
if (imp.importClient == this) {
imports.erase(importId, imp);
}
}
// Send a message releasing our remote references.
if (remoteRefCount > 0 && !isDisconnected()) {
var message = connection.newOutgoingMessage(1024);
var builder = message.getBody().initAs(RpcProtocol.Message.factory).initRelease();
builder.setId(importId);
builder.setReferenceCount(remoteRefCount);
message.send();
}
}
@Override @Override
public Integer writeDescriptor(RpcProtocol.CapDescriptor.Builder descriptor, List<Integer> fds) { public Integer writeDescriptor(RpcProtocol.CapDescriptor.Builder descriptor, List<Integer> fds) {
descriptor.setReceiverHosted(importId); descriptor.setReceiverHosted(this.imp.importId);
return null; return null;
} }
@Override @Override
public ClientHook writeTarget(RpcProtocol.MessageTarget.Builder target) { public ClientHook writeTarget(RpcProtocol.MessageTarget.Builder target) {
target.setImportedCap(importId); target.setImportedCap(this.imp.importId);
return null; return null;
} }
@ -1444,6 +1469,20 @@ final class RpcState {
} }
} }
private void cleanupImports() {
while (true) {
var ref = (ImportRef) this.importRefs.poll();
if (ref == null) {
return;
}
var imp = this.imports.find(ref.importId);
assert imp != null;
if (imp != null) {
imp.dispose();
}
}
}
enum ResolutionType { enum ResolutionType {
UNRESOLVED, UNRESOLVED,
REMOTE, REMOTE,
@ -1464,9 +1503,7 @@ final class RpcState {
Integer importId) { Integer importId) {
this.cap = initial; this.cap = initial;
this.importId = importId; this.importId = importId;
this.promise = eventual.thenApply(resolution -> { this.promise = eventual.thenApply(resolution -> resolve(resolution));
return resolve(resolution);
});
} }
public boolean isResolved() { public boolean isResolved() {
@ -1570,17 +1607,6 @@ final class RpcState {
public VoidPromiseAndPipeline call(long interfaceId, short methodId, CallContextHook context) { public VoidPromiseAndPipeline call(long interfaceId, short methodId, CallContextHook context) {
return null; return null;
} }
public void remove() {
if (this.importId != null) {
// This object represents an import promise. Clean that up.
var imp = imports.find(this.importId);
if (imp.appClient != null && imp.appClient == this) {
imp.appClient = null;
imp.importClient.remove();
}
}
}
} }
class PipelineClient extends RpcClient { class PipelineClient extends RpcClient {