feat: add userId handling and enhance payment actions & optimize the whole payment process

- Introduced userId parameter in createCheckoutAction and createPortalAction for improved user-specific session management.
- Updated components to pass userId instead of customerId, aligning with the new payment structure.
- Added new actions to retrieve active subscription and lifetime status, ensuring users can check their subscription details securely.
- Created SQL migration for new payment-related tables, establishing a robust foundation for payment management.
- Enhanced error handling and logging in payment actions for better debugging and user feedback.
This commit is contained in:
javayhu 2025-04-09 14:31:08 +08:00
parent e45d2504e6
commit 3c05657fe0
17 changed files with 169 additions and 118 deletions

View File

@ -21,7 +21,7 @@ CREATE TABLE "payment" (
"type" text NOT NULL,
"interval" text,
"user_id" text NOT NULL,
"customer_id" text,
"customer_id" text NOT NULL,
"subscription_id" text,
"status" text NOT NULL,
"period_start" timestamp,

View File

@ -1,5 +1,5 @@
{
"id": "de48f1a8-7fe4-4ef7-a765-98ac690aa491",
"id": "4e932ffc-faf7-4223-b04c-382bf773e626",
"prevId": "00000000-0000-0000-0000-000000000000",
"version": "7",
"dialect": "postgresql",
@ -153,7 +153,7 @@
"name": "customer_id",
"type": "text",
"primaryKey": false,
"notNull": false
"notNull": true
},
"subscription_id": {
"name": "subscription_id",

View File

@ -5,8 +5,8 @@
{
"idx": 0,
"version": "7",
"when": 1744046260338,
"tag": "0000_low_ben_grimm",
"when": 1744167230793,
"tag": "0000_public_mongoose",
"breakpoints": true
}
]

View File

@ -9,6 +9,7 @@
"lint": "next lint",
"db:generate": "drizzle-kit generate",
"db:migrate": "drizzle-kit migrate",
"db:push": "drizzle-kit push",
"db:studio": "drizzle-kit studio",
"docs": "content-collections build",
"email": "email dev --dir src/mail/emails --port 3333"

View File

@ -14,6 +14,7 @@ const actionClient = createSafeActionClient();
// Checkout schema for validation
// metadata is optional, and may contain referral information if you need
const checkoutSchema = z.object({
userId: z.string().min(1, { message: 'User ID is required' }),
planId: z.string().min(1, { message: 'Plan ID is required' }),
priceId: z.string().min(1, { message: 'Price ID is required' }),
metadata: z.record(z.string()).optional(),
@ -25,18 +26,28 @@ const checkoutSchema = z.object({
export const createCheckoutAction = actionClient
.schema(checkoutSchema)
.action(async ({ parsedInput }) => {
// request the user to login before checkout
const { userId, planId, priceId, metadata } = parsedInput;
// Get the current user session for authorization
const session = await getSession();
if (!session) {
console.warn(`unauthorized request to create checkout session for user ${userId}`);
return {
success: false,
error: 'Unauthorized',
};
}
try {
const { planId, priceId, metadata } = parsedInput;
// Only allow users to create their own checkout session
if (session.user.id !== userId) {
console.warn(`current user ${session.user.id} is not authorized to create checkout session for user ${userId}`);
return {
success: false,
error: 'Not authorized to do this action',
};
}
try {
// Get the current locale from the request
const locale = await getLocale();

View File

@ -1,9 +1,12 @@
'use server';
import db from "@/db";
import { user } from "@/db/schema";
import { getSession } from "@/lib/server";
import { getBaseUrlWithLocale } from "@/lib/urls/urls";
import { createCustomerPortal } from "@/payment";
import { CreatePortalParams } from "@/payment/types";
import { eq } from "drizzle-orm";
import { getLocale } from "next-intl/server";
import { createSafeActionClient } from 'next-safe-action';
import { z } from 'zod';
@ -13,7 +16,7 @@ const actionClient = createSafeActionClient();
// Portal schema for validation
const portalSchema = z.object({
customerId: z.string().min(1, { message: 'Customer ID is required' }),
userId: z.string().min(1, { message: 'User ID is required' }),
returnUrl: z.string().url({ message: 'Return URL must be a valid URL' }).optional(),
});
@ -23,16 +26,42 @@ const portalSchema = z.object({
export const createPortalAction = actionClient
.schema(portalSchema)
.action(async ({ parsedInput }) => {
const { userId, returnUrl } = parsedInput;
// Get the current user session for authorization
const session = await getSession();
if (!session) {
console.warn(`unauthorized request to create portal session for user ${userId}`);
return {
success: false,
error: 'Unauthorized',
};
}
// Only allow users to create their own portal session
if (session.user.id !== userId) {
console.warn(`current user ${session.user.id} is not authorized to create portal session for user ${userId}`);
return {
success: false,
error: 'Not authorized to do this action',
};
}
try {
const { customerId, returnUrl } = parsedInput;
// Get the user's customer ID from the database
const customerResult = await db
.select({ customerId: user.customerId })
.from(user)
.where(eq(user.id, session.user.id))
.limit(1);
if (customerResult.length <= 0 || !customerResult[0].customerId) {
console.error(`No customer found for user ${session.user.id}`);
return {
success: false,
error: 'No customer found for user',
};
}
// Get the current locale from the request
const locale = await getLocale();
@ -41,7 +70,7 @@ export const createPortalAction = actionClient
const baseUrlWithLocale = getBaseUrlWithLocale(locale);
const returnUrlWithLocale = returnUrl || `${baseUrlWithLocale}/settings/billing`;
const params: CreatePortalParams = {
customerId,
customerId: customerResult[0].customerId,
returnUrl: returnUrlWithLocale,
locale
};

View File

@ -1,42 +1,52 @@
'use server';
import { getSession } from "@/lib/server";
import { listCustomerSubscriptions } from "@/payment";
import { getSubscriptions } from "@/payment";
import { createSafeActionClient } from 'next-safe-action';
import { z } from "zod";
// Create a safe action client
const actionClient = createSafeActionClient();
// Input schema
const schema = z.object({
userId: z.string().min(1, { message: 'User ID is required' }),
});
/**
* Get customer subscription data
* Get active subscription data
*
* If the user has multiple subscriptions,
* it returns the most recent active or trialing one
*/
export const getCustomerSubscriptionAction = actionClient
.action(async () => {
export const getActiveSubscriptionAction = actionClient
.schema(schema)
.action(async ({ parsedInput }) => {
const { userId } = parsedInput;
// Get the current user session for authorization
const session = await getSession();
if (!session) {
console.warn(`unauthorized request to get active subscription for user ${userId}`);
return {
success: false,
error: 'Unauthorized',
};
}
try {
// Get the effective customer ID from session
const customerId = session.user.customerId;
// const subscriptionId = session.user.subscriptionId;
if (!customerId) {
console.warn('get user subscription, no customerId');
return {
success: true,
data: null,
};
}
// Only allow users to check their own status unless they're admins
if (session.user.id !== userId && session.user.role !== 'admin') {
console.warn(`current user ${session.user.id} is not authorized to get active subscription for user ${userId}`);
return {
success: false,
error: 'Not authorized to do this action',
};
}
// Find the customer's most recent active subscription
const subscriptions = await listCustomerSubscriptions({
customerId: customerId
try {
// Find the user's most recent active subscription
const subscriptions = await getSubscriptions({
userId: session.user.id
});
// console.log('get user subscriptions:', subscriptions);
@ -56,9 +66,9 @@ export const getCustomerSubscriptionAction = actionClient
// first in the list, as they have been sorted by date
subscriptionData = subscriptions[0];
}
console.log('find subscription for customerId:', customerId);
console.log('find subscription for userId:', session.user.id);
} else {
console.log('no subscriptions found for customerId:', customerId);
console.log('no subscriptions found for userId:', session.user.id);
}
return {

View File

@ -14,7 +14,7 @@ const actionClient = createSafeActionClient();
// Input schema
const schema = z.object({
userId: z.string(),
userId: z.string().min(1, { message: 'User ID is required' }),
});
/**
@ -25,7 +25,7 @@ const schema = z.object({
* in order to do this, you have to update the logic to check the lifetime status,
* for example, just check the planId is `lifetime` or not.
*/
export const getUserLifetimeStatusAction = actionClient
export const getLifetimeStatusAction = actionClient
.schema(schema)
.action(async ({ parsedInput }) => {
const { userId } = parsedInput;
@ -33,6 +33,7 @@ export const getUserLifetimeStatusAction = actionClient
// Get the current user session for authorization
const session = await getSession();
if (!session) {
console.warn(`unauthorized request to get lifetime status for user ${userId}`);
return {
success: false,
error: 'Unauthorized',
@ -41,9 +42,10 @@ export const getUserLifetimeStatusAction = actionClient
// Only allow users to check their own status unless they're admins
if (session.user.id !== userId && session.user.role !== 'admin') {
console.warn(`current user ${session.user.id} is not authorized to get lifetime status for user ${userId}`);
return {
success: false,
error: 'Not authorized to view this user data',
error: 'Not authorized to do this action',
};
}

View File

@ -2,13 +2,13 @@
import { createCheckoutAction } from '@/actions/create-checkout-session';
import { Button } from '@/components/ui/button';
import { authClient } from '@/lib/auth-client';
import { Loader2Icon } from 'lucide-react';
import { useTranslations } from 'next-intl';
import { useState } from 'react';
import { toast } from 'sonner';
interface CheckoutButtonProps {
userId: string;
planId: string;
priceId: string;
metadata?: Record<string, string>;
@ -27,6 +27,7 @@ interface CheckoutButtonProps {
* NOTICE: Login is required when using this button.
*/
export function CheckoutButton({
userId,
planId,
priceId,
metadata,
@ -37,7 +38,6 @@ export function CheckoutButton({
}: CheckoutButtonProps) {
const t = useTranslations('PricingPage.CheckoutButton');
const [isLoading, setIsLoading] = useState(false);
const { refetch } = authClient.useSession();
const handleClick = async () => {
try {
@ -45,6 +45,7 @@ export function CheckoutButton({
// Create checkout session using server action
const result = await createCheckoutAction({
userId,
planId,
priceId,
metadata,
@ -52,11 +53,6 @@ export function CheckoutButton({
// Redirect to checkout
if (result && result.data?.success && result.data.data?.url) {
// TODO: Always refetch session to ensure we have the latest user data
if (refetch) {
await refetch();
}
// redirect to checkout page
window.location.href = result.data.data?.url;
} else {

View File

@ -9,7 +9,7 @@ import { useState } from 'react';
import { toast } from 'sonner';
interface CustomerPortalButtonProps {
customerId: string;
userId: string;
returnUrl?: string;
variant?: 'default' | 'outline' | 'destructive' | 'secondary' | 'ghost' | 'link' | null;
size?: 'default' | 'sm' | 'lg' | 'icon' | null;
@ -22,9 +22,11 @@ interface CustomerPortalButtonProps {
*
* This client component opens the Stripe customer portal
* It's used to let customers manage their billing, subscriptions, and payment methods
*
* NOTICE: Login is required when using this button.
*/
export function CustomerPortalButton({
customerId,
userId,
returnUrl,
variant = 'default',
size = 'default',
@ -40,7 +42,7 @@ export function CustomerPortalButton({
// Create customer portal session using server action
const result = await createPortalAction({
customerId,
userId,
returnUrl,
});

View File

@ -145,6 +145,7 @@ export function PricingCard({
) : isPaidPlan ? (
currentUser ? (
<CheckoutButton
userId={currentUser.id}
planId={plan.id}
priceId={price.priceId}
metadata={metadata}

View File

@ -19,7 +19,7 @@ export default function BillingCard() {
const {
isLoading: isLoadingPayment,
error: paymentError,
error: loadPaymentError,
subscription,
currentPlan,
refetch
@ -72,7 +72,7 @@ export default function BillingCard() {
}
// Render error state
if (paymentError) {
if (loadPaymentError) {
return (
<div className="grid gap-8 md:grid-cols-2">
<Card>
@ -81,7 +81,7 @@ export default function BillingCard() {
<CardDescription>{t('currentPlan.description')}</CardDescription>
</CardHeader>
<CardContent className="space-y-4">
<div className="text-destructive text-sm">{paymentError}</div>
<div className="text-destructive text-sm">{loadPaymentError}</div>
</CardContent>
<CardFooter>
<Button
@ -115,7 +115,7 @@ export default function BillingCard() {
<div className="text-3xl font-medium">
{currentPlan?.name}
</div>
<Badge variant={isFreePlan || isLifetimeMember ? 'outline' : 'default'}>
<Badge variant='outline'>
{isLifetimeMember
? t('status.lifetime')
: subscription?.status === 'active'
@ -178,9 +178,9 @@ export default function BillingCard() {
)}
{/* user is lifetime member, show manage billing button */}
{isLifetimeMember && currentUser?.customerId && (
{isLifetimeMember && currentUser && (
<CustomerPortalButton
customerId={currentUser.customerId}
userId={currentUser.id}
className=""
>
{t('manageBilling')}
@ -188,9 +188,9 @@ export default function BillingCard() {
)}
{/* user has subscription, show manage subscription button */}
{subscription && currentUser?.customerId && (
{subscription && currentUser && (
<CustomerPortalButton
customerId={currentUser.customerId}
userId={currentUser.id}
className=""
>
{t('manageSubscription')}

View File

@ -103,14 +103,8 @@ export const auth = betterAuth({
},
user: {
// https://www.better-auth.com/docs/concepts/database#extending-core-schema
additionalFields: {
customerId: {
type: "string",
required: false,
defaultValue: "",
input: false, // don't allow user to set customerId
},
},
// additionalFields: {
// },
// https://www.better-auth.com/docs/concepts/users-accounts#delete-user
deleteUser: {
enabled: true,

View File

@ -1,4 +1,4 @@
import { PaymentProvider, PricePlan, PaymentConfig, Customer, Subscription, Payment, PaymentStatus, PlanInterval, PaymentType, Price, CreateCheckoutParams, CheckoutResult, CreatePortalParams, PortalResult, getCustomerSubscriptionsParams } from "./types";
import { PaymentProvider, PricePlan, PaymentConfig, Customer, Subscription, Payment, PaymentStatus, PlanInterval, PaymentType, Price, CreateCheckoutParams, CheckoutResult, CreatePortalParams, PortalResult, getSubscriptionsParams } from "./types";
import { StripeProvider } from "./provider/stripe";
import { paymentConfig } from "./config/payment-config";
@ -77,11 +77,11 @@ export const handleWebhookEvent = async (
* @param params Parameters for listing customer subscriptions
* @returns Array of subscriptions
*/
export const listCustomerSubscriptions = async (
params: getCustomerSubscriptionsParams
export const getSubscriptions = async (
params: getSubscriptionsParams
): Promise<Subscription[]> => {
const provider = getPaymentProvider();
return provider.getCustomerSubscriptions(params);
return provider.getSubscriptions(params);
};
/**
@ -148,5 +148,5 @@ export type {
CheckoutResult,
CreatePortalParams,
PortalResult,
getCustomerSubscriptionsParams as ListCustomerSubscriptionsParams,
getSubscriptionsParams as ListCustomerSubscriptionsParams,
};

View File

@ -8,7 +8,7 @@ import {
CheckoutResult,
CreateCheckoutParams,
CreatePortalParams,
getCustomerSubscriptionsParams,
getSubscriptionsParams,
PaymentProvider,
PaymentStatus,
PaymentTypes,
@ -47,12 +47,6 @@ export class StripeProvider implements PaymentProvider {
/**
* Create a customer in Stripe if not exists
*
* NOTICE: if you want to delete user in database,
* please delete customer in Stripe as well,
* otherwise, the user wont have a customer id in database,
* and will not be able to make payments.
*
* @param email Customer email
* @returns Stripe customer ID
*/
@ -66,7 +60,17 @@ export class StripeProvider implements PaymentProvider {
// Find existing customer
if (customers.data && customers.data.length > 0) {
return customers.data[0].id;
const customerId = customers.data[0].id;
// Find user id by customer id
const userId = await this.findUserIdByCustomerId(customerId);
// user does not exist, update user with customer id
// in case you deleted user in database, but forgot to delete customer in Stripe
if (!userId) {
console.log(`User ${email} does not exist, update with customer id ${customerId}`);
await this.updateUserWithCustomerId(customerId, email);
}
return customerId;
}
// Create new customer
@ -202,6 +206,10 @@ export class StripeProvider implements PaymentProvider {
checkoutParams.payment_intent_data = {
metadata: customMetadata,
};
// Automatically create an invoice for the one-time payment
checkoutParams.invoice_creation = {
enabled: true,
};
}
// Add subscription data for recurring payments
@ -255,19 +263,19 @@ export class StripeProvider implements PaymentProvider {
}
/**
* List customer subscriptions
* @param params Parameters for listing customer subscriptions
* Get subscriptions
* @param params Parameters for getting subscriptions
* @returns Array of subscription objects
*/
public async getCustomerSubscriptions(params: getCustomerSubscriptionsParams): Promise<Subscription[]> {
const { customerId } = params;
public async getSubscriptions(params: getSubscriptionsParams): Promise<Subscription[]> {
const { userId } = params;
try {
// Build query to fetch subscriptions from database
const subscriptions = await db
.select()
.from(payment)
.where(eq(payment.customerId, customerId))
.where(eq(payment.userId, userId))
.orderBy(desc(payment.createdAt)); // Sort by creation date, newest first
// Map database records to our subscription model

View File

@ -149,10 +149,10 @@ export interface PortalResult {
}
/**
* Parameters for listing customer subscriptions
* Parameters for getting customer subscriptions
*/
export interface getCustomerSubscriptionsParams {
customerId: string;
export interface getSubscriptionsParams {
userId: string;
}
/**
@ -172,7 +172,7 @@ export interface PaymentProvider {
/**
* Get customer subscriptions
*/
getCustomerSubscriptions(params: getCustomerSubscriptionsParams): Promise<Subscription[]>;
getSubscriptions(params: getSubscriptionsParams): Promise<Subscription[]>;
/**
* Handle webhook events

View File

@ -1,5 +1,5 @@
import { getCustomerSubscriptionAction } from '@/actions/get-customer-subscription';
import { getUserLifetimeStatusAction } from '@/actions/get-user-lifetime-status';
import { getActiveSubscriptionAction } from '@/actions/get-active-subscription';
import { getLifetimeStatusAction } from '@/actions/get-lifetime-status';
import { Session } from '@/lib/auth';
import { getAllPlans } from '@/payment';
import { PricePlan, Subscription } from '@/payment/types';
@ -55,28 +55,36 @@ export const usePaymentStore = create<PaymentState>((set, get) => ({
// Fetch subscription data
set({ isLoading: true, error: null });
// Check if user is a lifetime member directly from the database
let isLifetimeMember = false;
try {
const result = await getUserLifetimeStatusAction({ userId: user.id });
if (result?.data?.success) {
isLifetimeMember = result.data.isLifetimeMember || false;
console.log('check user lifetime status result', result);
} else {
console.warn('check user lifetime status failed');
}
} catch (error) {
console.error('check user lifetime status error:', error);
}
// Get all available plans
const plans = getAllPlans();
const freePlan = plans.find(plan => plan.isFree);
const lifetimePlan = plans.find(plan => plan.isLifetime);
// Check if user is a lifetime member directly from the database
let isLifetimeMember = false;
try {
const result = await getLifetimeStatusAction({ userId: user.id });
if (result?.data?.success) {
isLifetimeMember = result.data.isLifetimeMember || false;
console.log('get lifetime status result', result);
} else {
console.warn('get lifetime status failed', result?.data?.error);
// set({
// error: result?.data?.error || 'Failed to fetch payment data',
// isLoading: false
// });
}
} catch (error) {
console.error('get lifetime status error:', error);
// set({
// error: 'Failed to fetch payment data',
// isLoading: false
// });
}
// If lifetime member, set the lifetime plan
if (isLifetimeMember) {
console.log('setting lifetime plan for user', user.id);
console.log('set lifetime plan for user', user.id);
set({
currentPlan: lifetimePlan || null,
subscription: null,
@ -86,30 +94,19 @@ export const usePaymentStore = create<PaymentState>((set, get) => ({
return;
}
// Skip fetching if user doesn't have a customer ID (except for lifetime members)
if (!user.customerId) {
console.log('setting free plan for user', user.id);
set({
currentPlan: freePlan || null,
subscription: null,
isLoading: false,
error: null
});
return;
}
try {
const result = await getCustomerSubscriptionAction();
// Check if user has an active subscription
const result = await getActiveSubscriptionAction({ userId: user.id });
if (result?.data?.success) {
const subscriptionData = result.data.data;
const activeSubscription = result.data.data;
// Set subscription state
if (subscriptionData) {
const plan = plans.find(p => p.id === subscriptionData.planId) || null;
if (activeSubscription) {
const plan = plans.find(p => p.id === activeSubscription.planId) || null;
console.log('subscription found, setting plan for user', user.id, plan?.id);
set({
currentPlan: plan,
subscription: subscriptionData,
subscription: activeSubscription,
isLoading: false,
error: null
});
@ -123,14 +120,14 @@ export const usePaymentStore = create<PaymentState>((set, get) => ({
});
}
} else { // Failed to fetch subscription
console.warn('Failed to fetch subscription for user', user.id, result?.data?.error);
console.error('fetch subscription for user failed', result?.data?.error);
set({
error: result?.data?.error || 'Failed to fetch payment data',
isLoading: false
});
}
} catch (error) {
console.error('Fetch payment data error:', error);
console.error('fetch payment data error:', error);
set({
error: 'Failed to fetch payment data',
isLoading: false