feat: implement QuestionBank CRUD with pagination and template query
- Add pagination support to findAll (page, limit query params) - Add findByTemplateId method to service - Add GET /by-template/:templateId endpoint to controller - Service already includes CRUD for QuestionBank and QuestionBankItem
This commit is contained in:
@@ -0,0 +1,286 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { ConfigService } from '@nestjs/config';
|
||||
import { ModelConfigService } from '../model-config/model-config.service';
|
||||
import { I18nService } from '../i18n/i18n.service';
|
||||
|
||||
export interface EmbeddingResponse {
|
||||
data: Array<{
|
||||
embedding: number[];
|
||||
index: number;
|
||||
}>;
|
||||
model: string;
|
||||
usage: {
|
||||
prompt_tokens: number;
|
||||
total_tokens: number;
|
||||
};
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class EmbeddingService {
|
||||
private readonly logger = new Logger(EmbeddingService.name);
|
||||
private readonly defaultDimensions: number;
|
||||
|
||||
constructor(
|
||||
private modelConfigService: ModelConfigService,
|
||||
private configService: ConfigService,
|
||||
private i18nService: I18nService,
|
||||
) {
|
||||
this.defaultDimensions = parseInt(
|
||||
this.configService.get<string>('DEFAULT_VECTOR_DIMENSIONS', '2560'),
|
||||
);
|
||||
this.logger.log(
|
||||
`Default vector dimensions set to ${this.defaultDimensions}`,
|
||||
);
|
||||
}
|
||||
|
||||
async getEmbeddings(
|
||||
texts: string[],
|
||||
embeddingModelConfigId: string,
|
||||
): Promise<number[][]> {
|
||||
this.logger.log(`Generating embeddings for ${texts.length} texts`);
|
||||
|
||||
const modelConfig = await this.modelConfigService.findOne(
|
||||
embeddingModelConfigId,
|
||||
);
|
||||
if (!modelConfig || modelConfig.type !== 'embedding') {
|
||||
throw new Error(
|
||||
this.i18nService.formatMessage('embeddingModelNotFound', {
|
||||
id: embeddingModelConfigId,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
if (modelConfig.isEnabled === false) {
|
||||
throw new Error(
|
||||
`Model ${modelConfig.name} is disabled and cannot generate embeddings`,
|
||||
);
|
||||
}
|
||||
|
||||
if (!modelConfig.baseUrl) {
|
||||
throw new Error(
|
||||
`Model ${modelConfig.name} does not have baseUrl configured`,
|
||||
);
|
||||
}
|
||||
|
||||
// Determine max batch size based on model name
|
||||
const maxBatchSize = this.getMaxBatchSizeForModel(
|
||||
modelConfig.modelId,
|
||||
modelConfig.maxBatchSize,
|
||||
);
|
||||
|
||||
// Split processing if batch size exceeds limit
|
||||
if (texts.length > maxBatchSize) {
|
||||
this.logger.log(
|
||||
`Splitting ${texts.length} texts into batches (model batch limit: ${maxBatchSize})`,
|
||||
);
|
||||
|
||||
const allEmbeddings: number[][] = [];
|
||||
|
||||
for (let i = 0; i < texts.length; i += maxBatchSize) {
|
||||
const batch = texts.slice(i, i + maxBatchSize);
|
||||
const batchEmbeddings = await this.getEmbeddingsForBatch(
|
||||
batch,
|
||||
modelConfig,
|
||||
maxBatchSize,
|
||||
);
|
||||
|
||||
allEmbeddings.push(...batchEmbeddings);
|
||||
|
||||
// Wait briefly to avoid API rate limiting
|
||||
if (i + maxBatchSize < texts.length) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 100)); // Wait 100ms
|
||||
}
|
||||
}
|
||||
|
||||
return allEmbeddings;
|
||||
} else {
|
||||
// Normal processing (within batch size)
|
||||
return await this.getEmbeddingsForBatch(
|
||||
texts,
|
||||
modelConfig,
|
||||
maxBatchSize,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Determine max batch size based on model ID
|
||||
*/
|
||||
private getMaxBatchSizeForModel(
|
||||
modelId: string,
|
||||
configuredMaxBatchSize?: number,
|
||||
): number {
|
||||
// Model-specific batch size limits
|
||||
if (
|
||||
modelId.includes('text-embedding-004') ||
|
||||
modelId.includes('text-embedding-v4') ||
|
||||
modelId.includes('text-embedding-ada-002')
|
||||
) {
|
||||
return Math.min(10, configuredMaxBatchSize || 100); // Google limit: 10
|
||||
} else if (
|
||||
modelId.includes('text-embedding-3') ||
|
||||
modelId.includes('text-embedding-003')
|
||||
) {
|
||||
return Math.min(2048, configuredMaxBatchSize || 2048); // OpenAI v3 limit: 2048
|
||||
} else {
|
||||
// Default: smaller of configured max or 100
|
||||
return Math.min(configuredMaxBatchSize || 100, 100);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process single batch embedding
|
||||
*/
|
||||
private async getEmbeddingsForBatch(
|
||||
texts: string[],
|
||||
modelConfig: any,
|
||||
maxBatchSize: number,
|
||||
): Promise<number[][]> {
|
||||
const apiUrl = modelConfig.baseUrl.endsWith('/embeddings')
|
||||
? modelConfig.baseUrl
|
||||
: `${modelConfig.baseUrl}/embeddings`;
|
||||
|
||||
let lastError;
|
||||
const MAX_RETRIES = 3;
|
||||
|
||||
for (let attempt = 1; attempt <= MAX_RETRIES; attempt++) {
|
||||
try {
|
||||
const controller = new AbortController();
|
||||
const timeoutId = setTimeout(() => {
|
||||
controller.abort();
|
||||
this.logger.error(`Embedding API timeout after 60s: ${apiUrl}`);
|
||||
}, 60000); // 60s timeout
|
||||
|
||||
this.logger.log(
|
||||
`[Model call] Type: Embedding, Model: ${modelConfig.name} (${modelConfig.modelId}), Text count: ${texts.length}`,
|
||||
);
|
||||
this.logger.log(
|
||||
`Calling embedding API (attempt ${attempt}/${MAX_RETRIES}): ${apiUrl}`,
|
||||
);
|
||||
|
||||
let response;
|
||||
try {
|
||||
response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${modelConfig.apiKey}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
encoding_format: 'float',
|
||||
input: texts,
|
||||
model: modelConfig.modelId,
|
||||
}),
|
||||
signal: controller.signal,
|
||||
});
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
|
||||
// Detect batch size limit error
|
||||
if (
|
||||
errorText.includes('batch size is invalid') ||
|
||||
errorText.includes('batch_size') ||
|
||||
errorText.includes('invalid') ||
|
||||
errorText.includes('larger than')
|
||||
) {
|
||||
this.logger.warn(
|
||||
`Batch size limit error detected. Splitting batch in half and retrying: ${maxBatchSize} -> ${Math.floor(maxBatchSize / 2)}`,
|
||||
);
|
||||
|
||||
// Split batch into smaller units and retry
|
||||
if (texts.length > 1) {
|
||||
const midPoint = Math.floor(texts.length / 2);
|
||||
const firstHalf = texts.slice(0, midPoint);
|
||||
const secondHalf = texts.slice(midPoint);
|
||||
|
||||
const firstResult = await this.getEmbeddingsForBatch(
|
||||
firstHalf,
|
||||
modelConfig,
|
||||
Math.floor(maxBatchSize / 2),
|
||||
);
|
||||
const secondResult = await this.getEmbeddingsForBatch(
|
||||
secondHalf,
|
||||
modelConfig,
|
||||
Math.floor(maxBatchSize / 2),
|
||||
);
|
||||
|
||||
return [...firstResult, ...secondResult];
|
||||
}
|
||||
}
|
||||
|
||||
// Detect context length excess error
|
||||
if (
|
||||
errorText.includes('context length') ||
|
||||
errorText.includes('exceeds')
|
||||
) {
|
||||
const avgLength =
|
||||
texts.reduce((s, t) => s + t.length, 0) / texts.length;
|
||||
const totalLength = texts.reduce((s, t) => s + t.length, 0);
|
||||
this.logger.error(
|
||||
`Text length exceeds limit: ${texts.length} texts, ` +
|
||||
`total ${totalLength} characters, average ${Math.round(avgLength)} characters, ` +
|
||||
`model limit: ${modelConfig.maxInputTokens || 8192} tokens`,
|
||||
);
|
||||
throw new Error(
|
||||
`Text length exceeds model limit. ` +
|
||||
`Current: ${texts.length} texts with total ${totalLength} characters, ` +
|
||||
`model limit: ${modelConfig.maxInputTokens || 8192} tokens. ` +
|
||||
`Advice: Reduce chunk size or batch size`,
|
||||
);
|
||||
}
|
||||
|
||||
// Retry on 429 (Too Many Requests) or 5xx (Server Error)
|
||||
if (response.status === 429 || response.status >= 500) {
|
||||
this.logger.warn(
|
||||
`Temporary error from embedding API (${response.status}): ${errorText}`,
|
||||
);
|
||||
throw new Error(`API Error ${response.status}: ${errorText}`);
|
||||
}
|
||||
|
||||
this.logger.error(`Embedding API error details: ${errorText}`);
|
||||
this.logger.error(
|
||||
`Request parameters: model=${modelConfig.modelId}, inputLength=${texts[0]?.length}`,
|
||||
);
|
||||
throw new Error(
|
||||
`Embedding API call failed: ${response.statusText} - ${errorText}`,
|
||||
);
|
||||
}
|
||||
|
||||
const data: EmbeddingResponse = await response.json();
|
||||
const embeddings = data.data.map((item) => item.embedding);
|
||||
|
||||
// Get dimensions from actual response
|
||||
const actualDimensions =
|
||||
embeddings[0]?.length || this.defaultDimensions;
|
||||
this.logger.log(
|
||||
`Got ${embeddings.length} embedding vectors from ${modelConfig.name}. Dimensions: ${actualDimensions}`,
|
||||
);
|
||||
|
||||
return embeddings;
|
||||
} catch (error) {
|
||||
lastError = error;
|
||||
|
||||
// If not the last attempt and error appears temporary (or for robustness on all), retry after waiting
|
||||
if (attempt < MAX_RETRIES) {
|
||||
const delay = Math.pow(2, attempt - 1) * 1000; // 1s, 2s, 4s
|
||||
this.logger.warn(
|
||||
`Embedding request failed. Retrying after ${delay}ms: ${error.message}`,
|
||||
);
|
||||
await new Promise((resolve) => setTimeout(resolve, delay));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw lastError;
|
||||
}
|
||||
|
||||
private getEstimatedDimensions(modelId: string): number {
|
||||
// Use default dimensions from environment variable
|
||||
return this.defaultDimensions;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user