From 17c7d677433f60dff22b0920c3c8c613cbe2fdad Mon Sep 17 00:00:00 2001 From: javayhu Date: Thu, 12 Jun 2025 01:40:44 +0800 Subject: [PATCH] refactor: add getDb function & update all db calls --- src/actions/create-customer-portal-session.ts | 3 ++- src/actions/get-lifetime-status.ts | 3 ++- src/actions/get-users.ts | 3 ++- src/db/index.ts | 17 ++++++++--------- src/lib/auth.ts | 4 ++-- src/payment/provider/stripe.ts | 9 ++++++++- 6 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/actions/create-customer-portal-session.ts b/src/actions/create-customer-portal-session.ts index e4e46fa..a10962f 100644 --- a/src/actions/create-customer-portal-session.ts +++ b/src/actions/create-customer-portal-session.ts @@ -1,6 +1,6 @@ 'use server'; -import db from '@/db'; +import { getDb } from '@/db'; import { user } from '@/db/schema'; import { getSession } from '@/lib/server'; import { getUrlWithLocale } from '@/lib/urls/urls'; @@ -56,6 +56,7 @@ export const createPortalAction = actionClient try { // Get the user's customer ID from the database + const db = await getDb(); const customerResult = await db .select({ customerId: user.customerId }) .from(user) diff --git a/src/actions/get-lifetime-status.ts b/src/actions/get-lifetime-status.ts index ac98a56..1fe9937 100644 --- a/src/actions/get-lifetime-status.ts +++ b/src/actions/get-lifetime-status.ts @@ -1,6 +1,6 @@ 'use server'; -import db from '@/db'; +import { getDb } from '@/db'; import { payment } from '@/db/schema'; import { findPlanByPriceId, getAllPricePlans } from '@/lib/price-plan'; import { getSession } from '@/lib/server'; @@ -69,6 +69,7 @@ export const getLifetimeStatusAction = actionClient } // Query the database for one-time payments with lifetime plans + const db = await getDb(); const result = await db .select({ id: payment.id, diff --git a/src/actions/get-users.ts b/src/actions/get-users.ts index 3ed1f10..f01702a 100644 --- a/src/actions/get-users.ts +++ b/src/actions/get-users.ts @@ -1,6 +1,6 @@ 'use server'; -import db from '@/db'; +import { getDb } from '@/db'; import { user } from '@/db/schema'; import { asc, desc, ilike, or, sql } from 'drizzle-orm'; import { createSafeActionClient } from 'next-safe-action'; @@ -57,6 +57,7 @@ export const getUsersAction = actionClient : user.createdAt; const sortDirection = sortConfig?.desc ? desc : asc; + const db = await getDb(); let [items, [{ count }]] = await Promise.all([ db .select() diff --git a/src/db/index.ts b/src/db/index.ts index bac13fa..e7b4e3a 100644 --- a/src/db/index.ts +++ b/src/db/index.ts @@ -6,14 +6,15 @@ import { drizzle } from 'drizzle-orm/postgres-js'; import postgres from 'postgres'; import * as schema from './schema'; -const connectionString = process.env.DATABASE_URL; -if (!connectionString) { - throw new Error('DATABASE_URL is not set'); -} +let db: ReturnType | null = null; -// Disable prefetch as it is not supported for "Transaction" pool mode -const client = postgres(connectionString, { prepare: false }); -const db = drizzle(client, { schema }); +export async function getDb() { + if (db) return db; + const connectionString = process.env.DATABASE_URL!; + const client = postgres(connectionString, { prepare: false }); + db = drizzle(client, { schema }); + return db; +} /** * Connect to Neon Database @@ -41,5 +42,3 @@ const db = drizzle(client, { schema }); * Drizzle with Supabase Database * https://orm.drizzle.team/docs/tutorials/drizzle-with-supabase */ - -export default db; diff --git a/src/lib/auth.ts b/src/lib/auth.ts index 1bd2c34..bd4c3c7 100644 --- a/src/lib/auth.ts +++ b/src/lib/auth.ts @@ -1,5 +1,5 @@ import { websiteConfig } from '@/config/website'; -import db from '@/db/index'; +import { getDb } from '@/db/index'; import { defaultMessages } from '@/i18n/messages'; import { LOCALE_COOKIE_NAME, routing } from '@/i18n/routing'; import { sendEmail } from '@/mail'; @@ -21,7 +21,7 @@ import { getBaseUrl, getUrlWithLocaleInCallbackUrl } from './urls/urls'; export const auth = betterAuth({ baseURL: getBaseUrl(), appName: defaultMessages.Metadata.name, - database: drizzleAdapter(db, { + database: drizzleAdapter(await getDb(), { provider: 'pg', // or "mysql", "sqlite" }), session: { diff --git a/src/payment/provider/stripe.ts b/src/payment/provider/stripe.ts index 9a773c2..5800534 100644 --- a/src/payment/provider/stripe.ts +++ b/src/payment/provider/stripe.ts @@ -1,5 +1,5 @@ import { randomUUID } from 'crypto'; -import db from '@/db'; +import { getDb } from '@/db'; import { payment, session, user } from '@/db/schema'; import { sendMessageToDiscord } from '@/lib/discord'; import { @@ -114,6 +114,7 @@ export class StripeProvider implements PaymentProvider { ): Promise { try { // Update user record with customer ID if email matches + const db = await getDb(); const result = await db .update(user) .set({ @@ -144,6 +145,7 @@ export class StripeProvider implements PaymentProvider { ): Promise { try { // Query the user table for a matching customerId + const db = await getDb(); const result = await db .select({ id: user.id }) .from(user) @@ -318,6 +320,7 @@ export class StripeProvider implements PaymentProvider { try { // Build query to fetch subscriptions from database + const db = await getDb(); const subscriptions = await db .select() .from(payment) @@ -459,6 +462,7 @@ export class StripeProvider implements PaymentProvider { updatedAt: new Date(), }; + const db = await getDb(); const result = await db .insert(payment) .values(createFields) @@ -518,6 +522,7 @@ export class StripeProvider implements PaymentProvider { updatedAt: new Date(), }; + const db = await getDb(); const result = await db .update(payment) .set(updateFields) @@ -545,6 +550,7 @@ export class StripeProvider implements PaymentProvider { console.log( `>> Mark payment record for Stripe subscription ${stripeSubscription.id} as canceled` ); + const db = await getDb(); const result = await db .update(payment) .set({ @@ -594,6 +600,7 @@ export class StripeProvider implements PaymentProvider { // Create a one-time payment record const now = new Date(); + const db = await getDb(); const result = await db .insert(payment) .values({