From 998615fc11cfcca1d1a51464550d599addb8e49f Mon Sep 17 00:00:00 2001 From: Meier Lukas Date: Sun, 7 Jul 2024 09:58:20 +0200 Subject: [PATCH] fix: memory leak caused by many unclosed redis subscriptions (#750) * fix: memory leak caused by many unclosed redis subscriptions * chore: address pull request feedback --- packages/api/src/router/cron-jobs.ts | 11 ++- packages/api/src/router/log.ts | 7 +- packages/api/src/router/widgets/app.ts | 8 +- .../api/src/router/widgets/media-server.ts | 11 ++- packages/api/src/router/widgets/smart-home.ts | 7 +- .../src/lib/channel-subscription-tracker.ts | 92 +++++++++++++++++++ packages/redis/src/lib/channel.ts | 22 +---- 7 files changed, 114 insertions(+), 44 deletions(-) create mode 100644 packages/redis/src/lib/channel-subscription-tracker.ts diff --git a/packages/api/src/router/cron-jobs.ts b/packages/api/src/router/cron-jobs.ts index feb88682e..a63c72ff9 100644 --- a/packages/api/src/router/cron-jobs.ts +++ b/packages/api/src/router/cron-jobs.ts @@ -21,21 +21,22 @@ export const cronJobsRouter = createTRPCRouter({ }), subscribeToStatusUpdates: publicProcedure.subscription(() => { return observable((emit) => { - let isConnectionClosed = false; + const unsubscribes: (() => void)[] = []; for (const job of jobGroup.getJobRegistry().values()) { const channel = createCronJobStatusChannel(job.name); - channel.subscribe((data) => { - if (isConnectionClosed) return; - + const unsubscribe = channel.subscribe((data) => { emit.next(data); }); + unsubscribes.push(unsubscribe); } logger.info("A tRPC client has connected to the cron job status updates procedure"); return () => { - isConnectionClosed = true; + unsubscribes.forEach((unsubscribe) => { + unsubscribe(); + }); }; }); }), diff --git a/packages/api/src/router/log.ts b/packages/api/src/router/log.ts index 1586e4239..ba5690657 100644 --- a/packages/api/src/router/log.ts +++ b/packages/api/src/router/log.ts @@ -9,16 +9,13 @@ import { createTRPCRouter, publicProcedure } from "../trpc"; export const logRouter = createTRPCRouter({ subscribe: publicProcedure.subscription(() => { return observable((emit) => { - let isConnectionClosed = false; - - loggingChannel.subscribe((data) => { - if (isConnectionClosed) return; + const unsubscribe = loggingChannel.subscribe((data) => { emit.next(data); }); logger.info("A tRPC client has connected to the logging procedure"); return () => { - isConnectionClosed = true; + unsubscribe(); }; }); }), diff --git a/packages/api/src/router/widgets/app.ts b/packages/api/src/router/widgets/app.ts index 142fe1972..bfb6063f0 100644 --- a/packages/api/src/router/widgets/app.ts +++ b/packages/api/src/router/widgets/app.ts @@ -27,19 +27,15 @@ export const appRouter = createTRPCRouter({ const pingResult = await sendPingRequestAsync(input.url); return observable<{ url: string; statusCode: number } | { url: string; error: string }>((emit) => { - let isConnectionClosed = false; - emit.next({ url: input.url, ...pingResult }); - pingChannel.subscribe((message) => { - if (isConnectionClosed) return; - + const unsubscribe = pingChannel.subscribe((message) => { // Only emit if same url if (message.url !== input.url) return; emit.next(message); }); return () => { - isConnectionClosed = true; + unsubscribe(); void pingUrlChannel.removeAsync(input.url); }; }); diff --git a/packages/api/src/router/widgets/media-server.ts b/packages/api/src/router/widgets/media-server.ts index 4eba38250..18fcff10d 100644 --- a/packages/api/src/router/widgets/media-server.ts +++ b/packages/api/src/router/widgets/media-server.ts @@ -25,20 +25,21 @@ export const mediaServerRouter = createTRPCRouter({ .unstable_concat(createManyIntegrationMiddleware("jellyfin", "plex")) .subscription(({ ctx }) => { return observable<{ integrationId: string; data: StreamSession[] }>((emit) => { - let isConnectionClosed = false; - + const unsubscribes: (() => void)[] = []; for (const integration of ctx.integrations) { const channel = createItemAndIntegrationChannel("mediaServer", integration.id); - void channel.subscribeAsync((sessions) => { - if (isConnectionClosed) return; + const unsubscribe = channel.subscribe((sessions) => { emit.next({ integrationId: integration.id, data: sessions, }); }); + unsubscribes.push(unsubscribe); } return () => { - isConnectionClosed = true; + unsubscribes.forEach((unsubscribe) => { + unsubscribe(); + }); }; }); }), diff --git a/packages/api/src/router/widgets/smart-home.ts b/packages/api/src/router/widgets/smart-home.ts index 8ba8b9e56..7b0b177fd 100644 --- a/packages/api/src/router/widgets/smart-home.ts +++ b/packages/api/src/router/widgets/smart-home.ts @@ -13,10 +13,7 @@ export const smartHomeRouter = createTRPCRouter({ entityId: string; state: string; }>((emit) => { - let isConnectionClosed = false; - - homeAssistantEntityState.subscribe((message) => { - if (isConnectionClosed) return; + const unsubscribe = homeAssistantEntityState.subscribe((message) => { if (message.entityId !== input.entityId) { return; } @@ -24,7 +21,7 @@ export const smartHomeRouter = createTRPCRouter({ }); return () => { - isConnectionClosed = true; + unsubscribe(); }; }); }), diff --git a/packages/redis/src/lib/channel-subscription-tracker.ts b/packages/redis/src/lib/channel-subscription-tracker.ts new file mode 100644 index 000000000..8b98b4732 --- /dev/null +++ b/packages/redis/src/lib/channel-subscription-tracker.ts @@ -0,0 +1,92 @@ +import { randomUUID } from "crypto"; + +import type { MaybePromise } from "@homarr/common/types"; +import { logger } from "@homarr/log"; + +import { createRedisConnection } from "./connection"; + +type SubscriptionCallback = (message: string) => MaybePromise; + +/** + * This class is used to deduplicate redis subscriptions. + * It keeps track of all subscriptions and only subscribes to a channel if there are any subscriptions to it. + * It also provides a way to remove the callback from the channel. + * It fixes a potential memory leak where the redis client would keep creating new subscriptions to the same channel. + * @see https://github.com/homarr-labs/homarr/issues/744 + */ +export class ChannelSubscriptionTracker { + private static subscriptions = new Map>(); + private static redis = createRedisConnection(); + private static listenerActive = false; + + /** + * Subscribes to a channel. + * @param channelName name of the channel + * @param callback callback function to be called when a message is received + * @returns a function to unsubscribe from the channel + */ + public static subscribe(channelName: string, callback: SubscriptionCallback) { + logger.debug(`Adding redis channel callback channel='${channelName}'`); + + // We only want to activate the listener once + if (!this.listenerActive) { + this.activateListener(); + this.listenerActive = true; + } + + const channelSubscriptions = this.subscriptions.get(channelName) ?? new Map(); + const id = randomUUID(); + + // If there are no subscriptions to the channel, subscribe to it + if (channelSubscriptions.size === 0) { + logger.debug(`Subscribing to redis channel channel='${channelName}'`); + void this.redis.subscribe(channelName); + } + + logger.debug(`Adding redis channel callback channel='${channelName}' id='${id}'`); + channelSubscriptions.set(id, callback); + + this.subscriptions.set(channelName, channelSubscriptions); + + // Return a function to unsubscribe + return () => { + logger.debug(`Removing redis channel callback channel='${channelName}' id='${id}'`); + + const channelSubscriptions = this.subscriptions.get(channelName); + if (!channelSubscriptions) return; + + channelSubscriptions.delete(id); + + // If there are no subscriptions to the channel, unsubscribe from it + if (channelSubscriptions.size >= 1) { + return; + } + + logger.debug(`Unsubscribing from redis channel channel='${channelName}'`); + void this.redis.unsubscribe(channelName); + this.subscriptions.delete(channelName); + }; + } + + /** + * Activates the listener for the redis client. + */ + private static activateListener() { + logger.debug("Activating listener"); + this.redis.on("message", (channel, message) => { + const channelSubscriptions = this.subscriptions.get(channel); + if (!channelSubscriptions) { + logger.warn(`Received message on unknown channel channel='${channel}'`); + return; + } + + for (const [id, callback] of channelSubscriptions.entries()) { + // Don't log messages from the logging channel as it would create an infinite loop + if (channel !== "pubSub:logging") { + logger.debug(`Calling subscription callback channel='${channel}' id='${id}'`); + } + void callback(message); + } + }); + } +} diff --git a/packages/redis/src/lib/channel.ts b/packages/redis/src/lib/channel.ts index 0dcbac883..6279be9f1 100644 --- a/packages/redis/src/lib/channel.ts +++ b/packages/redis/src/lib/channel.ts @@ -4,9 +4,9 @@ import { createId } from "@homarr/db"; import type { WidgetKind } from "@homarr/definitions"; import { logger } from "@homarr/log"; +import { ChannelSubscriptionTracker } from "./channel-subscription-tracker"; import { createRedisConnection } from "./connection"; -const subscriber = createRedisConnection(); // Used for subscribing to channels - after subscribing it can only be used for subscribing const publisher = createRedisConnection(); const lastDataClient = createRedisConnection(); @@ -31,15 +31,7 @@ export const createSubPubChannel = (name: string, { persist }: { persist: } }); } - void subscriber.subscribe(channelName, (err) => { - if (!err) { - return; - } - logger.error(`Error with channel '${channelName}': ${err.name} (${err.message})`); - }); - subscriber.on("message", (channel, message) => { - if (channel !== channelName) return; // TODO: check if this is necessary - it should be handled by the redis client - + return ChannelSubscriptionTracker.subscribe(channelName, (message) => { callback(superjson.parse(message)); }); }, @@ -172,15 +164,9 @@ export const createCacheChannel = (name: string, cacheDurationMs: number export const createItemAndIntegrationChannel = (kind: WidgetKind, integrationId: string) => { const channelName = `item:${kind}:integration:${integrationId}`; return { - subscribeAsync: async (callback: (data: TData) => void) => { - await subscriber.subscribe(channelName); - subscriber.on("message", (channel, message) => { - if (channel !== channelName) { - logger.warn(`received message on ${channel} channel but was looking for ${channelName}`); - return; - } + subscribe: (callback: (data: TData) => void) => { + return ChannelSubscriptionTracker.subscribe(channelName, (message) => { callback(superjson.parse(message)); - logger.debug(`sent message on ${channelName}`); }); }, publishAndUpdateLastStateAsync: async (data: TData) => {