Files
aurak/server/src/knowledge-base/embedding.service.ts
T
Developer 8686d101cd Initial commit: AuraK人才测评系统基础框架
## 已实现功能
- 题库管理后端API完整实现
- 模板管理页面(Settings-测评模板)
- 评估统计页面
- 人才测评页面(AssessmentView)
- QuestionBank前端服务层

## 技术栈
- 后端: Node.js + NestJS + TypeORM
- 前端: React + TypeScript
- 容器化: Docker Compose

## 已知待完善
- 题库列表页缺少删除按钮
- 题库详情页未实现(题目管理/AI生成/审核)
2026-05-13 21:32:41 +08:00

351 lines
11 KiB
TypeScript

import { Injectable, Logger } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { ModelConfigService } from '../model-config/model-config.service';
import { I18nService } from '../i18n/i18n.service';
export interface EmbeddingResponse {
data: Array<{
embedding: number[];
index: number;
}>;
model: string;
usage: {
prompt_tokens: number;
total_tokens: number;
};
}
@Injectable()
export class EmbeddingService {
private readonly logger = new Logger(EmbeddingService.name);
private readonly defaultDimensions: number;
constructor(
private modelConfigService: ModelConfigService,
private configService: ConfigService,
private i18nService: I18nService,
) {
this.defaultDimensions = parseInt(
this.configService.get<string>('DEFAULT_VECTOR_DIMENSIONS', '2560'),
);
this.logger.log(
`Default vector dimensions set to ${this.defaultDimensions}`,
);
}
async getEmbeddings(
texts: string[],
embeddingModelConfigId: string,
): Promise<number[][]> {
this.logger.log(`Generating embeddings for ${texts.length} texts`);
const modelConfig = await this.modelConfigService.findOne(
embeddingModelConfigId,
);
if (!modelConfig || modelConfig.type !== 'embedding') {
throw new Error(
this.i18nService.formatMessage('embeddingModelNotFound', {
id: embeddingModelConfigId,
}),
);
}
if (modelConfig.isEnabled === false) {
throw new Error(
`Model ${modelConfig.name} is disabled and cannot generate embeddings`,
);
}
if (!modelConfig.baseUrl) {
throw new Error(
`Model ${modelConfig.name} does not have baseUrl configured`,
);
}
// Determine max batch size based on model name
const maxBatchSize = this.getMaxBatchSizeForModel(
modelConfig.modelId,
modelConfig.maxBatchSize,
);
// Split processing if batch size exceeds limit
if (texts.length > maxBatchSize) {
this.logger.log(
`Splitting ${texts.length} texts into batches (model batch limit: ${maxBatchSize})`,
);
const allEmbeddings: number[][] = [];
for (let i = 0; i < texts.length; i += maxBatchSize) {
const batch = texts.slice(i, i + maxBatchSize);
const batchEmbeddings = await this.getEmbeddingsForBatch(
batch,
modelConfig,
maxBatchSize,
);
allEmbeddings.push(...batchEmbeddings);
// Wait briefly to avoid API rate limiting
if (i + maxBatchSize < texts.length) {
await new Promise((resolve) => setTimeout(resolve, 100)); // Wait 100ms
}
}
return allEmbeddings;
} else {
// Normal processing (within batch size)
return await this.getEmbeddingsForBatch(
texts,
modelConfig,
maxBatchSize,
);
}
}
/**
* Determine max batch size based on model ID
*/
private getMaxBatchSizeForModel(
modelId: string,
configuredMaxBatchSize?: number,
): number {
// Model-specific batch size limits
if (
modelId.includes('text-embedding-004') ||
modelId.includes('text-embedding-v4') ||
modelId.includes('text-embedding-ada-002')
) {
return Math.min(10, configuredMaxBatchSize || 100); // Google limit: 10
} else if (
modelId.includes('text-embedding-3') ||
modelId.includes('text-embedding-003')
) {
return Math.min(2048, configuredMaxBatchSize || 2048); // OpenAI v3 limit: 2048
} else {
// Default: smaller of configured max or 100
return Math.min(configuredMaxBatchSize || 100, 100);
}
}
/**
* Process single batch embedding
*/
private async getEmbeddingsForBatch(
texts: string[],
modelConfig: any,
maxBatchSize: number,
): Promise<number[][]> {
// Detect Ollama by port 11434 or /api/embeddings path
const isOllama =
modelConfig.baseUrl.includes(':11434') ||
modelConfig.baseUrl.includes('/api/embeddings');
if (isOllama) {
return await this.getOllamaEmbeddings(texts, modelConfig);
}
const apiUrl = modelConfig.baseUrl.endsWith('/embeddings')
? modelConfig.baseUrl
: `${modelConfig.baseUrl}/embeddings`;
let lastError;
const MAX_RETRIES = 3;
for (let attempt = 1; attempt <= MAX_RETRIES; attempt++) {
try {
const controller = new AbortController();
const timeoutId = setTimeout(() => {
controller.abort();
this.logger.error(`Embedding API timeout after 60s: ${apiUrl}`);
}, 60000); // 60s timeout
this.logger.log(
`[Model call] Type: Embedding, Model: ${modelConfig.name} (${modelConfig.modelId}), Text count: ${texts.length}`,
);
this.logger.log(
`Calling embedding API (attempt ${attempt}/${MAX_RETRIES}): ${apiUrl}`,
);
let response;
try {
response = await fetch(apiUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${modelConfig.apiKey}`,
},
body: JSON.stringify({
encoding_format: 'float',
input: texts,
model: modelConfig.modelId,
}),
signal: controller.signal,
});
} finally {
clearTimeout(timeoutId);
}
if (!response.ok) {
const errorText = await response.text();
// Detect batch size limit error
if (
errorText.includes('batch size is invalid') ||
errorText.includes('batch_size') ||
errorText.includes('invalid') ||
errorText.includes('larger than')
) {
this.logger.warn(
`Batch size limit error detected. Splitting batch in half and retrying: ${maxBatchSize} -> ${Math.floor(maxBatchSize / 2)}`,
);
// Split batch into smaller units and retry
if (texts.length > 1) {
const midPoint = Math.floor(texts.length / 2);
const firstHalf = texts.slice(0, midPoint);
const secondHalf = texts.slice(midPoint);
const firstResult = await this.getEmbeddingsForBatch(
firstHalf,
modelConfig,
Math.floor(maxBatchSize / 2),
);
const secondResult = await this.getEmbeddingsForBatch(
secondHalf,
modelConfig,
Math.floor(maxBatchSize / 2),
);
return [...firstResult, ...secondResult];
}
}
// Detect context length excess error
if (
errorText.includes('context length') ||
errorText.includes('exceeds')
) {
const avgLength =
texts.reduce((s, t) => s + t.length, 0) / texts.length;
const totalLength = texts.reduce((s, t) => s + t.length, 0);
this.logger.error(
`Text length exceeds limit: ${texts.length} texts, ` +
`total ${totalLength} characters, average ${Math.round(avgLength)} characters, ` +
`model limit: ${modelConfig.maxInputTokens || 8192} tokens`,
);
throw new Error(
`Text length exceeds model limit. ` +
`Current: ${texts.length} texts with total ${totalLength} characters, ` +
`model limit: ${modelConfig.maxInputTokens || 8192} tokens. ` +
`Advice: Reduce chunk size or batch size`,
);
}
// Retry on 429 (Too Many Requests) or 5xx (Server Error)
if (response.status === 429 || response.status >= 500) {
this.logger.warn(
`Temporary error from embedding API (${response.status}): ${errorText}`,
);
throw new Error(`API Error ${response.status}: ${errorText}`);
}
this.logger.error(`Embedding API error details: ${errorText}`);
this.logger.error(
`Request parameters: model=${modelConfig.modelId}, inputLength=${texts[0]?.length}`,
);
throw new Error(
`Embedding API call failed: ${response.statusText} - ${errorText}`,
);
}
const data: EmbeddingResponse = await response.json();
const embeddings = data.data.map((item) => item.embedding);
// Get dimensions from actual response
const actualDimensions =
embeddings[0]?.length || this.defaultDimensions;
this.logger.log(
`Got ${embeddings.length} embedding vectors from ${modelConfig.name}. Dimensions: ${actualDimensions}`,
);
return embeddings;
} catch (error) {
lastError = error;
// If not the last attempt and error appears temporary (or for robustness on all), retry after waiting
if (attempt < MAX_RETRIES) {
const delay = Math.pow(2, attempt - 1) * 1000; // 1s, 2s, 4s
this.logger.warn(
`Embedding request failed. Retrying after ${delay}ms: ${error.message}`,
);
await new Promise((resolve) => setTimeout(resolve, delay));
continue;
}
}
}
throw lastError;
}
private getEstimatedDimensions(modelId: string): number {
// Use default dimensions from environment variable
return this.defaultDimensions;
}
/**
* Get embeddings from local Ollama
*/
private async getOllamaEmbeddings(
texts: string[],
modelConfig: any,
): Promise<number[][]> {
const baseUrl = modelConfig.baseUrl || 'http://localhost:11434';
const modelName = modelConfig.modelId || 'nomic-embed-text';
this.logger.log(
`[Ollama] Generating embeddings for ${texts.length} texts using ${modelName}`,
);
const embeddings: number[][] = [];
for (let i = 0; i < texts.length; i++) {
try {
const url = baseUrl.endsWith('/api/embeddings')
? baseUrl
: `${baseUrl}/api/embeddings`;
const response = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: modelName,
prompt: texts[i],
}),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`Ollama API error: ${response.status} - ${errorText}`);
}
const data = await response.json();
embeddings.push(data.embedding);
} catch (error) {
this.logger.error(
`Ollama embedding error for text ${i}: ${error.message}`,
);
throw error;
}
}
this.logger.log(
`[Ollama] Got ${embeddings.length} embeddings, dimensions: ${embeddings[0]?.length || 0}`,
);
return embeddings;
}
}