From c40801b329f109bebe211d9257179cb55d5bcfee Mon Sep 17 00:00:00 2001 From: Steve Korshakov Date: Mon, 1 Sep 2025 14:49:56 -0700 Subject: [PATCH] ref: move socket handlers --- sources/app/api/socket.ts | 327 +------------------------ sources/app/api/socket/pingHandler.ts | 12 + sources/app/api/socket/rpcHandler.ts | 170 +++++++++++++ sources/app/api/socket/usageHandler.ts | 124 ++++++++++ 4 files changed, 314 insertions(+), 319 deletions(-) create mode 100644 sources/app/api/socket/pingHandler.ts create mode 100644 sources/app/api/socket/rpcHandler.ts create mode 100644 sources/app/api/socket/usageHandler.ts diff --git a/sources/app/api/socket.ts b/sources/app/api/socket.ts index bf35232..1d2fd3b 100644 --- a/sources/app/api/socket.ts +++ b/sources/app/api/socket.ts @@ -11,6 +11,9 @@ import { decrementWebSocketConnection, incrementWebSocketConnection, machineAliv import { AsyncLock } from "@/utils/lock"; import { activityCache } from "../presence/sessionCache"; import { randomKeyNaked } from "@/utils/randomKeyNaked"; +import { usageHandler } from "./socket/usageHandler"; +import { rpcHandler } from "./socket/rpcHandler"; +import { pingHandler } from "./socket/pingHandler"; export function startSocket(app: Fastify, eventRouter: EventRouter) { const io = new Server(app.server, { @@ -30,12 +33,6 @@ export function startSocket(app: Fastify, eventRouter: EventRouter) { serveClient: false // Don't serve the client files }); - // Connection tracking is now handled by EventRouter - - // Track RPC listeners: Map> - // Only session-scoped clients (CLI) register handlers, only user-scoped clients (mobile) call them - const rpcListeners = new Map>(); - io.on("connection", async (socket) => { log({ module: 'websocket' }, `New connection attempt from socket: ${socket.id}`); const token = socket.handshake.auth.token as string; @@ -117,7 +114,6 @@ export function startSocket(app: Fastify, eventRouter: EventRouter) { // Lock const receiveMessageLock = new AsyncLock(); - const receiveUsageLock = new AsyncLock(); socket.on('disconnect', () => { websocketEventsCounter.inc({ event_type: 'disconnect' }); @@ -126,28 +122,6 @@ export function startSocket(app: Fastify, eventRouter: EventRouter) { eventRouter.removeConnection(userId, connection); decrementWebSocketConnection(connection.connectionType); - // Clean up RPC listeners for this socket - const userRpcMap = rpcListeners.get(userId); - if (userRpcMap) { - // Remove all RPC methods registered by this socket - const methodsToRemove: string[] = []; - for (const [method, registeredSocket] of userRpcMap.entries()) { - if (registeredSocket === socket) { - methodsToRemove.push(method); - } - } - - if (methodsToRemove.length > 0) { - log({ module: 'websocket-rpc' }, `Cleaning up RPC methods on disconnect for socket ${socket.id}: ${methodsToRemove.join(', ')}`); - methodsToRemove.forEach(method => userRpcMap.delete(method)); - } - - if (userRpcMap.size === 0) { - rpcListeners.delete(userId); - log({ module: 'websocket-rpc' }, `All RPC listeners removed for user ${userId}`); - } - } - log({ module: 'websocket' }, `User disconnected: ${userId}`); // Broadcast daemon offline status @@ -670,297 +644,12 @@ export function startSocket(app: Fastify, eventRouter: EventRouter) { } }); - // RPC register - Register this socket as a listener for an RPC method - socket.on('rpc-register', async (data: any) => { - try { - const { method } = data; + // Handlers + rpcHandler(userId, socket, eventRouter); + usageHandler(userId, socket, eventRouter); + pingHandler(socket); - if (!method || typeof method !== 'string') { - socket.emit('rpc-error', { type: 'register', error: 'Invalid method name' }); - return; - } - - // Get or create user's RPC map - let userRpcMap = rpcListeners.get(userId); - if (!userRpcMap) { - userRpcMap = new Map(); - rpcListeners.set(userId, userRpcMap); - } - - // Check if method was already registered - const previousSocket = userRpcMap.get(method); - if (previousSocket && previousSocket !== socket) { - log({ module: 'websocket-rpc' }, `RPC method ${method} re-registered: ${previousSocket.id} -> ${socket.id}`); - } - - // Register this socket as the listener for this method - userRpcMap.set(method, socket); - - socket.emit('rpc-registered', { method }); - log({ module: 'websocket-rpc' }, `RPC method registered: ${method} on socket ${socket.id} (user: ${userId})`); - log({ module: 'websocket-rpc' }, `Active RPC methods for user ${userId}: ${Array.from(userRpcMap.keys()).join(', ')}`); - } catch (error) { - log({ module: 'websocket', level: 'error' }, `Error in rpc-register: ${error}`); - socket.emit('rpc-error', { type: 'register', error: 'Internal error' }); - } - }); - - // RPC unregister - Remove this socket as a listener for an RPC method - socket.on('rpc-unregister', async (data: any) => { - try { - const { method } = data; - - if (!method || typeof method !== 'string') { - socket.emit('rpc-error', { type: 'unregister', error: 'Invalid method name' }); - return; - } - - const userRpcMap = rpcListeners.get(userId); - if (userRpcMap && userRpcMap.get(method) === socket) { - userRpcMap.delete(method); - log({ module: 'websocket-rpc' }, `RPC method unregistered: ${method} from socket ${socket.id} (user: ${userId})`); - - if (userRpcMap.size === 0) { - rpcListeners.delete(userId); - log({ module: 'websocket-rpc' }, `All RPC methods unregistered for user ${userId}`); - } else { - log({ module: 'websocket-rpc' }, `Remaining RPC methods for user ${userId}: ${Array.from(userRpcMap.keys()).join(', ')}`); - } - } else { - log({ module: 'websocket-rpc' }, `RPC unregister ignored: ${method} not registered on socket ${socket.id}`); - } - - socket.emit('rpc-unregistered', { method }); - } catch (error) { - log({ module: 'websocket', level: 'error' }, `Error in rpc-unregister: ${error}`); - socket.emit('rpc-error', { type: 'unregister', error: 'Internal error' }); - } - }); - - // RPC call - Call an RPC method on another socket of the same user - socket.on('rpc-call', async (data: any, callback: (response: any) => void) => { - try { - const { method, params } = data; - - if (!method || typeof method !== 'string') { - if (callback) { - callback({ - ok: false, - error: 'Invalid parameters: method is required' - }); - } - return; - } - - // Find the RPC listener for this method within the same user - const userRpcMap = rpcListeners.get(userId); - if (!userRpcMap) { - log({ module: 'websocket-rpc' }, `RPC call failed: No RPC methods registered for user ${userId}`); - if (callback) { - callback({ - ok: false, - error: 'No RPC methods registered' - }); - } - return; - } - - const targetSocket = userRpcMap.get(method); - if (!targetSocket || !targetSocket.connected) { - log({ module: 'websocket-rpc' }, `RPC call failed: Method ${method} not available (disconnected or not registered)`); - if (callback) { - callback({ - ok: false, - error: 'RPC method not available' - }); - } - return; - } - - // Don't allow calling your own socket - if (targetSocket === socket) { - log({ module: 'websocket-rpc' }, `RPC call failed: Attempted self-call on method ${method}`); - if (callback) { - callback({ - ok: false, - error: 'Cannot call RPC on the same socket' - }); - } - return; - } - - // Log RPC call initiation - const startTime = Date.now(); - log({ module: 'websocket-rpc' }, `RPC call initiated: ${socket.id} -> ${method} (target: ${targetSocket.id})`); - - // Forward the RPC request to the target socket using emitWithAck - try { - const response = await targetSocket.timeout(30000).emitWithAck('rpc-request', { - method, - params - }); - - const duration = Date.now() - startTime; - log({ module: 'websocket-rpc' }, `RPC call succeeded: ${method} (${duration}ms)`); - - // Forward the response back to the caller via callback - if (callback) { - callback({ - ok: true, - result: response - }); - } - - } catch (error) { - const duration = Date.now() - startTime; - const errorMsg = error instanceof Error ? error.message : 'RPC call failed'; - log({ module: 'websocket-rpc' }, `RPC call failed: ${method} - ${errorMsg} (${duration}ms)`); - - // Timeout or error occurred - if (callback) { - callback({ - ok: false, - error: errorMsg - }); - } - } - } catch (error) { - log({ module: 'websocket', level: 'error' }, `Error in rpc-call: ${error}`); - if (callback) { - callback({ - ok: false, - error: 'Internal error' - }); - } - } - }); - - socket.on('ping', async (callback: (response: any) => void) => { - try { - callback({}); - } catch (error) { - log({ module: 'websocket', level: 'error' }, `Error in ping: ${error}`); - } - }); - - // Usage reporting - socket.on('usage-report', async (data: any, callback?: (response: any) => void) => { - await receiveUsageLock.inLock(async () => { - try { - const { key, sessionId, tokens, cost } = data; - - // Validate required fields - if (!key || typeof key !== 'string') { - if (callback) { - callback({ success: false, error: 'Invalid key' }); - } - return; - } - - // Validate tokens and cost objects - if (!tokens || typeof tokens !== 'object' || typeof tokens.total !== 'number') { - if (callback) { - callback({ success: false, error: 'Invalid tokens object - must include total' }); - } - return; - } - - if (!cost || typeof cost !== 'object' || typeof cost.total !== 'number') { - if (callback) { - callback({ success: false, error: 'Invalid cost object - must include total' }); - } - return; - } - - // Validate sessionId if provided - if (sessionId && typeof sessionId !== 'string') { - if (callback) { - callback({ success: false, error: 'Invalid sessionId' }); - } - return; - } - - try { - // If sessionId provided, verify it belongs to the user - if (sessionId) { - const session = await db.session.findFirst({ - where: { - id: sessionId, - accountId: userId - } - }); - - if (!session) { - if (callback) { - callback({ success: false, error: 'Session not found' }); - } - return; - } - } - - // Prepare usage data - const usageData: PrismaJson.UsageReportData = { - tokens, - cost - }; - - // Upsert the usage report - const report = await db.usageReport.upsert({ - where: { - accountId_sessionId_key: { - accountId: userId, - sessionId: sessionId || null, - key - } - }, - update: { - data: usageData, - updatedAt: new Date() - }, - create: { - accountId: userId, - sessionId: sessionId || null, - key, - data: usageData - } - }); - - log({ module: 'websocket' }, `Usage report saved: key=${key}, sessionId=${sessionId || 'none'}, userId=${userId}`); - - // Emit usage ephemeral update if sessionId is provided - if (sessionId) { - const usageEvent = buildUsageEphemeral(sessionId, key, usageData.tokens, usageData.cost); - eventRouter.emitEphemeral({ - userId, - payload: usageEvent, - recipientFilter: { type: 'user-scoped-only' } - }); - } - - if (callback) { - callback({ - success: true, - reportId: report.id, - createdAt: report.createdAt.getTime(), - updatedAt: report.updatedAt.getTime() - }); - } - } catch (error) { - log({ module: 'websocket', level: 'error' }, `Failed to save usage report: ${error}`); - if (callback) { - callback({ success: false, error: 'Failed to save usage report' }); - } - } - } catch (error) { - log({ module: 'websocket', level: 'error' }, `Error in usage-report handler: ${error}`); - if (callback) { - callback({ success: false, error: 'Internal error' }); - } - } - }); - }); - - socket.emit('auth', { success: true, user: userId }); + // Ready log({ module: 'websocket' }, `User connected: ${userId}`); }); diff --git a/sources/app/api/socket/pingHandler.ts b/sources/app/api/socket/pingHandler.ts new file mode 100644 index 0000000..d82766b --- /dev/null +++ b/sources/app/api/socket/pingHandler.ts @@ -0,0 +1,12 @@ +import { log } from "@/utils/log"; +import { Socket } from "socket.io"; + +export function pingHandler(socket: Socket) { + socket.on('ping', async (callback: (response: any) => void) => { + try { + callback({}); + } catch (error) { + log({ module: 'websocket', level: 'error' }, `Error in ping: ${error}`); + } + }); +} \ No newline at end of file diff --git a/sources/app/api/socket/rpcHandler.ts b/sources/app/api/socket/rpcHandler.ts new file mode 100644 index 0000000..d7a38e4 --- /dev/null +++ b/sources/app/api/socket/rpcHandler.ts @@ -0,0 +1,170 @@ +import { EventRouter } from "@/modules/eventRouter"; +import { log } from "@/utils/log"; +import { Socket } from "socket.io"; + +export function rpcHandler(userId: string, socket: Socket, eventRouter: EventRouter) { + const rpcListeners = new Map(); + // RPC register - Register this socket as a listener for an RPC method + socket.on('rpc-register', async (data: any) => { + try { + const { method } = data; + + if (!method || typeof method !== 'string') { + socket.emit('rpc-error', { type: 'register', error: 'Invalid method name' }); + return; + } + + // Check if method was already registered + const previousSocket = rpcListeners.get(method); + if (previousSocket && previousSocket !== socket) { + log({ module: 'websocket-rpc' }, `RPC method ${method} re-registered: ${previousSocket.id} -> ${socket.id}`); + } + + // Register this socket as the listener for this method + rpcListeners.set(method, socket); + + socket.emit('rpc-registered', { method }); + log({ module: 'websocket-rpc' }, `RPC method registered: ${method} on socket ${socket.id} (user: ${userId})`); + log({ module: 'websocket-rpc' }, `Active RPC methods for user ${userId}: ${Array.from(rpcListeners.keys()).join(', ')}`); + } catch (error) { + log({ module: 'websocket', level: 'error' }, `Error in rpc-register: ${error}`); + socket.emit('rpc-error', { type: 'register', error: 'Internal error' }); + } + }); + + // RPC unregister - Remove this socket as a listener for an RPC method + socket.on('rpc-unregister', async (data: any) => { + try { + const { method } = data; + + if (!method || typeof method !== 'string') { + socket.emit('rpc-error', { type: 'unregister', error: 'Invalid method name' }); + return; + } + + if (rpcListeners.get(method) === socket) { + rpcListeners.delete(method); + log({ module: 'websocket-rpc' }, `RPC method unregistered: ${method} from socket ${socket.id} (user: ${userId})`); + + if (rpcListeners.size === 0) { + rpcListeners.delete(userId); + log({ module: 'websocket-rpc' }, `All RPC methods unregistered for user ${userId}`); + } else { + log({ module: 'websocket-rpc' }, `Remaining RPC methods for user ${userId}: ${Array.from(rpcListeners.keys()).join(', ')}`); + } + } else { + log({ module: 'websocket-rpc' }, `RPC unregister ignored: ${method} not registered on socket ${socket.id}`); + } + + socket.emit('rpc-unregistered', { method }); + } catch (error) { + log({ module: 'websocket', level: 'error' }, `Error in rpc-unregister: ${error}`); + socket.emit('rpc-error', { type: 'unregister', error: 'Internal error' }); + } + }); + + // RPC call - Call an RPC method on another socket of the same user + socket.on('rpc-call', async (data: any, callback: (response: any) => void) => { + try { + const { method, params } = data; + + if (!method || typeof method !== 'string') { + if (callback) { + callback({ + ok: false, + error: 'Invalid parameters: method is required' + }); + } + return; + } + + const targetSocket = rpcListeners.get(method); + if (!targetSocket || !targetSocket.connected) { + log({ module: 'websocket-rpc' }, `RPC call failed: Method ${method} not available (disconnected or not registered)`); + if (callback) { + callback({ + ok: false, + error: 'RPC method not available' + }); + } + return; + } + + // Don't allow calling your own socket + if (targetSocket === socket) { + log({ module: 'websocket-rpc' }, `RPC call failed: Attempted self-call on method ${method}`); + if (callback) { + callback({ + ok: false, + error: 'Cannot call RPC on the same socket' + }); + } + return; + } + + // Log RPC call initiation + const startTime = Date.now(); + log({ module: 'websocket-rpc' }, `RPC call initiated: ${socket.id} -> ${method} (target: ${targetSocket.id})`); + + // Forward the RPC request to the target socket using emitWithAck + try { + const response = await targetSocket.timeout(30000).emitWithAck('rpc-request', { + method, + params + }); + + const duration = Date.now() - startTime; + log({ module: 'websocket-rpc' }, `RPC call succeeded: ${method} (${duration}ms)`); + + // Forward the response back to the caller via callback + if (callback) { + callback({ + ok: true, + result: response + }); + } + + } catch (error) { + const duration = Date.now() - startTime; + const errorMsg = error instanceof Error ? error.message : 'RPC call failed'; + log({ module: 'websocket-rpc' }, `RPC call failed: ${method} - ${errorMsg} (${duration}ms)`); + + // Timeout or error occurred + if (callback) { + callback({ + ok: false, + error: errorMsg + }); + } + } + } catch (error) { + log({ module: 'websocket', level: 'error' }, `Error in rpc-call: ${error}`); + if (callback) { + callback({ + ok: false, + error: 'Internal error' + }); + } + } + }); + + socket.on('disconnect', () => { + + const methodsToRemove: string[] = []; + for (const [method, registeredSocket] of rpcListeners.entries()) { + if (registeredSocket === socket) { + methodsToRemove.push(method); + } + } + + if (methodsToRemove.length > 0) { + log({ module: 'websocket-rpc' }, `Cleaning up RPC methods on disconnect for socket ${socket.id}: ${methodsToRemove.join(', ')}`); + methodsToRemove.forEach(method => rpcListeners.delete(method)); + } + + if (rpcListeners.size === 0) { + rpcListeners.delete(userId); + log({ module: 'websocket-rpc' }, `All RPC listeners removed for user ${userId}`); + } + }); +} \ No newline at end of file diff --git a/sources/app/api/socket/usageHandler.ts b/sources/app/api/socket/usageHandler.ts new file mode 100644 index 0000000..244d7c9 --- /dev/null +++ b/sources/app/api/socket/usageHandler.ts @@ -0,0 +1,124 @@ +import { Socket } from "socket.io"; +import { AsyncLock } from "@/utils/lock"; +import { db } from "@/storage/db"; +import { buildUsageEphemeral, EventRouter } from "@/modules/eventRouter"; +import { log } from "@/utils/log"; + +export function usageHandler(userId: string, socket: Socket, eventRouter: EventRouter) { + const receiveUsageLock = new AsyncLock(); + socket.on('usage-report', async (data: any, callback?: (response: any) => void) => { + await receiveUsageLock.inLock(async () => { + try { + const { key, sessionId, tokens, cost } = data; + + // Validate required fields + if (!key || typeof key !== 'string') { + if (callback) { + callback({ success: false, error: 'Invalid key' }); + } + return; + } + + // Validate tokens and cost objects + if (!tokens || typeof tokens !== 'object' || typeof tokens.total !== 'number') { + if (callback) { + callback({ success: false, error: 'Invalid tokens object - must include total' }); + } + return; + } + + if (!cost || typeof cost !== 'object' || typeof cost.total !== 'number') { + if (callback) { + callback({ success: false, error: 'Invalid cost object - must include total' }); + } + return; + } + + // Validate sessionId if provided + if (sessionId && typeof sessionId !== 'string') { + if (callback) { + callback({ success: false, error: 'Invalid sessionId' }); + } + return; + } + + try { + // If sessionId provided, verify it belongs to the user + if (sessionId) { + const session = await db.session.findFirst({ + where: { + id: sessionId, + accountId: userId + } + }); + + if (!session) { + if (callback) { + callback({ success: false, error: 'Session not found' }); + } + return; + } + } + + // Prepare usage data + const usageData: PrismaJson.UsageReportData = { + tokens, + cost + }; + + // Upsert the usage report + const report = await db.usageReport.upsert({ + where: { + accountId_sessionId_key: { + accountId: userId, + sessionId: sessionId || null, + key + } + }, + update: { + data: usageData, + updatedAt: new Date() + }, + create: { + accountId: userId, + sessionId: sessionId || null, + key, + data: usageData + } + }); + + log({ module: 'websocket' }, `Usage report saved: key=${key}, sessionId=${sessionId || 'none'}, userId=${userId}`); + + // Emit usage ephemeral update if sessionId is provided + if (sessionId) { + const usageEvent = buildUsageEphemeral(sessionId, key, usageData.tokens, usageData.cost); + eventRouter.emitEphemeral({ + userId, + payload: usageEvent, + recipientFilter: { type: 'user-scoped-only' } + }); + } + + if (callback) { + callback({ + success: true, + reportId: report.id, + createdAt: report.createdAt.getTime(), + updatedAt: report.updatedAt.getTime() + }); + } + } catch (error) { + log({ module: 'websocket', level: 'error' }, `Failed to save usage report: ${error}`); + if (callback) { + callback({ success: false, error: 'Failed to save usage report' }); + } + } + } catch (error) { + log({ module: 'websocket', level: 'error' }, `Error in usage-report handler: ${error}`); + if (callback) { + callback({ success: false, error: 'Internal error' }); + } + } + }); + }); +} \ No newline at end of file