refactor: add getDb function & update all db calls

This commit is contained in:
javayhu 2025-06-12 01:40:44 +08:00
parent 0684b16278
commit 17c7d67743
6 changed files with 24 additions and 15 deletions

View File

@ -1,6 +1,6 @@
'use server'; 'use server';
import db from '@/db'; import { getDb } from '@/db';
import { user } from '@/db/schema'; import { user } from '@/db/schema';
import { getSession } from '@/lib/server'; import { getSession } from '@/lib/server';
import { getUrlWithLocale } from '@/lib/urls/urls'; import { getUrlWithLocale } from '@/lib/urls/urls';
@ -56,6 +56,7 @@ export const createPortalAction = actionClient
try { try {
// Get the user's customer ID from the database // Get the user's customer ID from the database
const db = await getDb();
const customerResult = await db const customerResult = await db
.select({ customerId: user.customerId }) .select({ customerId: user.customerId })
.from(user) .from(user)

View File

@ -1,6 +1,6 @@
'use server'; 'use server';
import db from '@/db'; import { getDb } from '@/db';
import { payment } from '@/db/schema'; import { payment } from '@/db/schema';
import { findPlanByPriceId, getAllPricePlans } from '@/lib/price-plan'; import { findPlanByPriceId, getAllPricePlans } from '@/lib/price-plan';
import { getSession } from '@/lib/server'; import { getSession } from '@/lib/server';
@ -69,6 +69,7 @@ export const getLifetimeStatusAction = actionClient
} }
// Query the database for one-time payments with lifetime plans // Query the database for one-time payments with lifetime plans
const db = await getDb();
const result = await db const result = await db
.select({ .select({
id: payment.id, id: payment.id,

View File

@ -1,6 +1,6 @@
'use server'; 'use server';
import db from '@/db'; import { getDb } from '@/db';
import { user } from '@/db/schema'; import { user } from '@/db/schema';
import { asc, desc, ilike, or, sql } from 'drizzle-orm'; import { asc, desc, ilike, or, sql } from 'drizzle-orm';
import { createSafeActionClient } from 'next-safe-action'; import { createSafeActionClient } from 'next-safe-action';
@ -57,6 +57,7 @@ export const getUsersAction = actionClient
: user.createdAt; : user.createdAt;
const sortDirection = sortConfig?.desc ? desc : asc; const sortDirection = sortConfig?.desc ? desc : asc;
const db = await getDb();
let [items, [{ count }]] = await Promise.all([ let [items, [{ count }]] = await Promise.all([
db db
.select() .select()

View File

@ -6,14 +6,15 @@ import { drizzle } from 'drizzle-orm/postgres-js';
import postgres from 'postgres'; import postgres from 'postgres';
import * as schema from './schema'; import * as schema from './schema';
const connectionString = process.env.DATABASE_URL; let db: ReturnType<typeof drizzle> | null = null;
if (!connectionString) {
throw new Error('DATABASE_URL is not set');
}
// Disable prefetch as it is not supported for "Transaction" pool mode export async function getDb() {
const client = postgres(connectionString, { prepare: false }); if (db) return db;
const db = drizzle(client, { schema }); const connectionString = process.env.DATABASE_URL!;
const client = postgres(connectionString, { prepare: false });
db = drizzle(client, { schema });
return db;
}
/** /**
* Connect to Neon Database * Connect to Neon Database
@ -41,5 +42,3 @@ const db = drizzle(client, { schema });
* Drizzle with Supabase Database * Drizzle with Supabase Database
* https://orm.drizzle.team/docs/tutorials/drizzle-with-supabase * https://orm.drizzle.team/docs/tutorials/drizzle-with-supabase
*/ */
export default db;

View File

@ -1,5 +1,5 @@
import { websiteConfig } from '@/config/website'; import { websiteConfig } from '@/config/website';
import db from '@/db/index'; import { getDb } from '@/db/index';
import { defaultMessages } from '@/i18n/messages'; import { defaultMessages } from '@/i18n/messages';
import { LOCALE_COOKIE_NAME, routing } from '@/i18n/routing'; import { LOCALE_COOKIE_NAME, routing } from '@/i18n/routing';
import { sendEmail } from '@/mail'; import { sendEmail } from '@/mail';
@ -21,7 +21,7 @@ import { getBaseUrl, getUrlWithLocaleInCallbackUrl } from './urls/urls';
export const auth = betterAuth({ export const auth = betterAuth({
baseURL: getBaseUrl(), baseURL: getBaseUrl(),
appName: defaultMessages.Metadata.name, appName: defaultMessages.Metadata.name,
database: drizzleAdapter(db, { database: drizzleAdapter(await getDb(), {
provider: 'pg', // or "mysql", "sqlite" provider: 'pg', // or "mysql", "sqlite"
}), }),
session: { session: {

View File

@ -1,5 +1,5 @@
import { randomUUID } from 'crypto'; import { randomUUID } from 'crypto';
import db from '@/db'; import { getDb } from '@/db';
import { payment, session, user } from '@/db/schema'; import { payment, session, user } from '@/db/schema';
import { sendMessageToDiscord } from '@/lib/discord'; import { sendMessageToDiscord } from '@/lib/discord';
import { import {
@ -114,6 +114,7 @@ export class StripeProvider implements PaymentProvider {
): Promise<void> { ): Promise<void> {
try { try {
// Update user record with customer ID if email matches // Update user record with customer ID if email matches
const db = await getDb();
const result = await db const result = await db
.update(user) .update(user)
.set({ .set({
@ -144,6 +145,7 @@ export class StripeProvider implements PaymentProvider {
): Promise<string | undefined> { ): Promise<string | undefined> {
try { try {
// Query the user table for a matching customerId // Query the user table for a matching customerId
const db = await getDb();
const result = await db const result = await db
.select({ id: user.id }) .select({ id: user.id })
.from(user) .from(user)
@ -318,6 +320,7 @@ export class StripeProvider implements PaymentProvider {
try { try {
// Build query to fetch subscriptions from database // Build query to fetch subscriptions from database
const db = await getDb();
const subscriptions = await db const subscriptions = await db
.select() .select()
.from(payment) .from(payment)
@ -459,6 +462,7 @@ export class StripeProvider implements PaymentProvider {
updatedAt: new Date(), updatedAt: new Date(),
}; };
const db = await getDb();
const result = await db const result = await db
.insert(payment) .insert(payment)
.values(createFields) .values(createFields)
@ -518,6 +522,7 @@ export class StripeProvider implements PaymentProvider {
updatedAt: new Date(), updatedAt: new Date(),
}; };
const db = await getDb();
const result = await db const result = await db
.update(payment) .update(payment)
.set(updateFields) .set(updateFields)
@ -545,6 +550,7 @@ export class StripeProvider implements PaymentProvider {
console.log( console.log(
`>> Mark payment record for Stripe subscription ${stripeSubscription.id} as canceled` `>> Mark payment record for Stripe subscription ${stripeSubscription.id} as canceled`
); );
const db = await getDb();
const result = await db const result = await db
.update(payment) .update(payment)
.set({ .set({
@ -594,6 +600,7 @@ export class StripeProvider implements PaymentProvider {
// Create a one-time payment record // Create a one-time payment record
const now = new Date(); const now = new Date();
const db = await getDb();
const result = await db const result = await db
.insert(payment) .insert(payment)
.values({ .values({