8686d101cd
## 已实现功能 - 题库管理后端API完整实现 - 模板管理页面(Settings-测评模板) - 评估统计页面 - 人才测评页面(AssessmentView) - QuestionBank前端服务层 ## 技术栈 - 后端: Node.js + NestJS + TypeORM - 前端: React + TypeScript - 容器化: Docker Compose ## 已知待完善 - 题库列表页缺少删除按钮 - 题库详情页未实现(题目管理/AI生成/审核)
351 lines
11 KiB
TypeScript
351 lines
11 KiB
TypeScript
import { Injectable, Logger } from '@nestjs/common';
|
|
import { ConfigService } from '@nestjs/config';
|
|
import { ModelConfigService } from '../model-config/model-config.service';
|
|
import { I18nService } from '../i18n/i18n.service';
|
|
|
|
export interface EmbeddingResponse {
|
|
data: Array<{
|
|
embedding: number[];
|
|
index: number;
|
|
}>;
|
|
model: string;
|
|
usage: {
|
|
prompt_tokens: number;
|
|
total_tokens: number;
|
|
};
|
|
}
|
|
|
|
@Injectable()
|
|
export class EmbeddingService {
|
|
private readonly logger = new Logger(EmbeddingService.name);
|
|
private readonly defaultDimensions: number;
|
|
|
|
constructor(
|
|
private modelConfigService: ModelConfigService,
|
|
private configService: ConfigService,
|
|
private i18nService: I18nService,
|
|
) {
|
|
this.defaultDimensions = parseInt(
|
|
this.configService.get<string>('DEFAULT_VECTOR_DIMENSIONS', '2560'),
|
|
);
|
|
this.logger.log(
|
|
`Default vector dimensions set to ${this.defaultDimensions}`,
|
|
);
|
|
}
|
|
|
|
async getEmbeddings(
|
|
texts: string[],
|
|
embeddingModelConfigId: string,
|
|
): Promise<number[][]> {
|
|
this.logger.log(`Generating embeddings for ${texts.length} texts`);
|
|
|
|
const modelConfig = await this.modelConfigService.findOne(
|
|
embeddingModelConfigId,
|
|
);
|
|
if (!modelConfig || modelConfig.type !== 'embedding') {
|
|
throw new Error(
|
|
this.i18nService.formatMessage('embeddingModelNotFound', {
|
|
id: embeddingModelConfigId,
|
|
}),
|
|
);
|
|
}
|
|
|
|
if (modelConfig.isEnabled === false) {
|
|
throw new Error(
|
|
`Model ${modelConfig.name} is disabled and cannot generate embeddings`,
|
|
);
|
|
}
|
|
|
|
if (!modelConfig.baseUrl) {
|
|
throw new Error(
|
|
`Model ${modelConfig.name} does not have baseUrl configured`,
|
|
);
|
|
}
|
|
|
|
// Determine max batch size based on model name
|
|
const maxBatchSize = this.getMaxBatchSizeForModel(
|
|
modelConfig.modelId,
|
|
modelConfig.maxBatchSize,
|
|
);
|
|
|
|
// Split processing if batch size exceeds limit
|
|
if (texts.length > maxBatchSize) {
|
|
this.logger.log(
|
|
`Splitting ${texts.length} texts into batches (model batch limit: ${maxBatchSize})`,
|
|
);
|
|
|
|
const allEmbeddings: number[][] = [];
|
|
|
|
for (let i = 0; i < texts.length; i += maxBatchSize) {
|
|
const batch = texts.slice(i, i + maxBatchSize);
|
|
const batchEmbeddings = await this.getEmbeddingsForBatch(
|
|
batch,
|
|
modelConfig,
|
|
maxBatchSize,
|
|
);
|
|
|
|
allEmbeddings.push(...batchEmbeddings);
|
|
|
|
// Wait briefly to avoid API rate limiting
|
|
if (i + maxBatchSize < texts.length) {
|
|
await new Promise((resolve) => setTimeout(resolve, 100)); // Wait 100ms
|
|
}
|
|
}
|
|
|
|
return allEmbeddings;
|
|
} else {
|
|
// Normal processing (within batch size)
|
|
return await this.getEmbeddingsForBatch(
|
|
texts,
|
|
modelConfig,
|
|
maxBatchSize,
|
|
);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Determine max batch size based on model ID
|
|
*/
|
|
private getMaxBatchSizeForModel(
|
|
modelId: string,
|
|
configuredMaxBatchSize?: number,
|
|
): number {
|
|
// Model-specific batch size limits
|
|
if (
|
|
modelId.includes('text-embedding-004') ||
|
|
modelId.includes('text-embedding-v4') ||
|
|
modelId.includes('text-embedding-ada-002')
|
|
) {
|
|
return Math.min(10, configuredMaxBatchSize || 100); // Google limit: 10
|
|
} else if (
|
|
modelId.includes('text-embedding-3') ||
|
|
modelId.includes('text-embedding-003')
|
|
) {
|
|
return Math.min(2048, configuredMaxBatchSize || 2048); // OpenAI v3 limit: 2048
|
|
} else {
|
|
// Default: smaller of configured max or 100
|
|
return Math.min(configuredMaxBatchSize || 100, 100);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Process single batch embedding
|
|
*/
|
|
private async getEmbeddingsForBatch(
|
|
texts: string[],
|
|
modelConfig: any,
|
|
maxBatchSize: number,
|
|
): Promise<number[][]> {
|
|
// Detect Ollama by port 11434 or /api/embeddings path
|
|
const isOllama =
|
|
modelConfig.baseUrl.includes(':11434') ||
|
|
modelConfig.baseUrl.includes('/api/embeddings');
|
|
|
|
if (isOllama) {
|
|
return await this.getOllamaEmbeddings(texts, modelConfig);
|
|
}
|
|
|
|
const apiUrl = modelConfig.baseUrl.endsWith('/embeddings')
|
|
? modelConfig.baseUrl
|
|
: `${modelConfig.baseUrl}/embeddings`;
|
|
|
|
let lastError;
|
|
const MAX_RETRIES = 3;
|
|
|
|
for (let attempt = 1; attempt <= MAX_RETRIES; attempt++) {
|
|
try {
|
|
const controller = new AbortController();
|
|
const timeoutId = setTimeout(() => {
|
|
controller.abort();
|
|
this.logger.error(`Embedding API timeout after 60s: ${apiUrl}`);
|
|
}, 60000); // 60s timeout
|
|
|
|
this.logger.log(
|
|
`[Model call] Type: Embedding, Model: ${modelConfig.name} (${modelConfig.modelId}), Text count: ${texts.length}`,
|
|
);
|
|
this.logger.log(
|
|
`Calling embedding API (attempt ${attempt}/${MAX_RETRIES}): ${apiUrl}`,
|
|
);
|
|
|
|
let response;
|
|
try {
|
|
response = await fetch(apiUrl, {
|
|
method: 'POST',
|
|
headers: {
|
|
'Content-Type': 'application/json',
|
|
Authorization: `Bearer ${modelConfig.apiKey}`,
|
|
},
|
|
body: JSON.stringify({
|
|
encoding_format: 'float',
|
|
input: texts,
|
|
model: modelConfig.modelId,
|
|
}),
|
|
signal: controller.signal,
|
|
});
|
|
} finally {
|
|
clearTimeout(timeoutId);
|
|
}
|
|
|
|
if (!response.ok) {
|
|
const errorText = await response.text();
|
|
|
|
// Detect batch size limit error
|
|
if (
|
|
errorText.includes('batch size is invalid') ||
|
|
errorText.includes('batch_size') ||
|
|
errorText.includes('invalid') ||
|
|
errorText.includes('larger than')
|
|
) {
|
|
this.logger.warn(
|
|
`Batch size limit error detected. Splitting batch in half and retrying: ${maxBatchSize} -> ${Math.floor(maxBatchSize / 2)}`,
|
|
);
|
|
|
|
// Split batch into smaller units and retry
|
|
if (texts.length > 1) {
|
|
const midPoint = Math.floor(texts.length / 2);
|
|
const firstHalf = texts.slice(0, midPoint);
|
|
const secondHalf = texts.slice(midPoint);
|
|
|
|
const firstResult = await this.getEmbeddingsForBatch(
|
|
firstHalf,
|
|
modelConfig,
|
|
Math.floor(maxBatchSize / 2),
|
|
);
|
|
const secondResult = await this.getEmbeddingsForBatch(
|
|
secondHalf,
|
|
modelConfig,
|
|
Math.floor(maxBatchSize / 2),
|
|
);
|
|
|
|
return [...firstResult, ...secondResult];
|
|
}
|
|
}
|
|
|
|
// Detect context length excess error
|
|
if (
|
|
errorText.includes('context length') ||
|
|
errorText.includes('exceeds')
|
|
) {
|
|
const avgLength =
|
|
texts.reduce((s, t) => s + t.length, 0) / texts.length;
|
|
const totalLength = texts.reduce((s, t) => s + t.length, 0);
|
|
this.logger.error(
|
|
`Text length exceeds limit: ${texts.length} texts, ` +
|
|
`total ${totalLength} characters, average ${Math.round(avgLength)} characters, ` +
|
|
`model limit: ${modelConfig.maxInputTokens || 8192} tokens`,
|
|
);
|
|
throw new Error(
|
|
`Text length exceeds model limit. ` +
|
|
`Current: ${texts.length} texts with total ${totalLength} characters, ` +
|
|
`model limit: ${modelConfig.maxInputTokens || 8192} tokens. ` +
|
|
`Advice: Reduce chunk size or batch size`,
|
|
);
|
|
}
|
|
|
|
// Retry on 429 (Too Many Requests) or 5xx (Server Error)
|
|
if (response.status === 429 || response.status >= 500) {
|
|
this.logger.warn(
|
|
`Temporary error from embedding API (${response.status}): ${errorText}`,
|
|
);
|
|
throw new Error(`API Error ${response.status}: ${errorText}`);
|
|
}
|
|
|
|
this.logger.error(`Embedding API error details: ${errorText}`);
|
|
this.logger.error(
|
|
`Request parameters: model=${modelConfig.modelId}, inputLength=${texts[0]?.length}`,
|
|
);
|
|
throw new Error(
|
|
`Embedding API call failed: ${response.statusText} - ${errorText}`,
|
|
);
|
|
}
|
|
|
|
const data: EmbeddingResponse = await response.json();
|
|
const embeddings = data.data.map((item) => item.embedding);
|
|
|
|
// Get dimensions from actual response
|
|
const actualDimensions =
|
|
embeddings[0]?.length || this.defaultDimensions;
|
|
this.logger.log(
|
|
`Got ${embeddings.length} embedding vectors from ${modelConfig.name}. Dimensions: ${actualDimensions}`,
|
|
);
|
|
|
|
return embeddings;
|
|
} catch (error) {
|
|
lastError = error;
|
|
|
|
// If not the last attempt and error appears temporary (or for robustness on all), retry after waiting
|
|
if (attempt < MAX_RETRIES) {
|
|
const delay = Math.pow(2, attempt - 1) * 1000; // 1s, 2s, 4s
|
|
this.logger.warn(
|
|
`Embedding request failed. Retrying after ${delay}ms: ${error.message}`,
|
|
);
|
|
await new Promise((resolve) => setTimeout(resolve, delay));
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
|
|
throw lastError;
|
|
}
|
|
|
|
private getEstimatedDimensions(modelId: string): number {
|
|
// Use default dimensions from environment variable
|
|
return this.defaultDimensions;
|
|
}
|
|
|
|
/**
|
|
* Get embeddings from local Ollama
|
|
*/
|
|
private async getOllamaEmbeddings(
|
|
texts: string[],
|
|
modelConfig: any,
|
|
): Promise<number[][]> {
|
|
const baseUrl = modelConfig.baseUrl || 'http://localhost:11434';
|
|
const modelName = modelConfig.modelId || 'nomic-embed-text';
|
|
|
|
this.logger.log(
|
|
`[Ollama] Generating embeddings for ${texts.length} texts using ${modelName}`,
|
|
);
|
|
|
|
const embeddings: number[][] = [];
|
|
|
|
for (let i = 0; i < texts.length; i++) {
|
|
try {
|
|
const url = baseUrl.endsWith('/api/embeddings')
|
|
? baseUrl
|
|
: `${baseUrl}/api/embeddings`;
|
|
|
|
const response = await fetch(url, {
|
|
method: 'POST',
|
|
headers: {
|
|
'Content-Type': 'application/json',
|
|
},
|
|
body: JSON.stringify({
|
|
model: modelName,
|
|
prompt: texts[i],
|
|
}),
|
|
});
|
|
|
|
if (!response.ok) {
|
|
const errorText = await response.text();
|
|
throw new Error(`Ollama API error: ${response.status} - ${errorText}`);
|
|
}
|
|
|
|
const data = await response.json();
|
|
embeddings.push(data.embedding);
|
|
} catch (error) {
|
|
this.logger.error(
|
|
`Ollama embedding error for text ${i}: ${error.message}`,
|
|
);
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
this.logger.log(
|
|
`[Ollama] Got ${embeddings.length} embeddings, dimensions: ${embeddings[0]?.length || 0}`,
|
|
);
|
|
|
|
return embeddings;
|
|
}
|
|
}
|