work on downloading some more.

This commit is contained in:
Jordan 2025-03-11 07:26:49 -07:00
parent 8f67d0421b
commit dca3987e18
14 changed files with 480 additions and 414 deletions

View File

@ -1,7 +1,8 @@
export default { export default {
getDb: jest.fn(() => { getDb: jest.fn(() => {
return { return {
runAsync: jest.fn((statement: string, value: string) => {}), runAsync: jest.fn((statement: string, ... values: string []) => {}),
runSync: jest.fn((statement: string, ... values : string []) => {}),
getFirstAsync: jest.fn((statement: string, value: string) => { getFirstAsync: jest.fn((statement: string, value: string) => {
return []; return [];
}), }),

View File

@ -1,6 +1,6 @@
import { Cache } from "react-native-cache"; import { Cache } from "react-native-cache";
import { LIBRETRANSLATE_BASE_URL } from "@/constants/api"; import { LIBRETRANSLATE_BASE_URL } from "@/constants/api";
import AsyncStorage from '@react-native-async-storage/async-storage'; import AsyncStorage from "@react-native-async-storage/async-storage";
import { Settings } from "../lib/settings"; import { Settings } from "../lib/settings";
type language_t = string; type language_t = string;
@ -9,25 +9,31 @@ const cache = new Cache({
namespace: "translation_terrace", namespace: "translation_terrace",
policy: { policy: {
maxEntries: 50000, // if unspecified, it can have unlimited entries maxEntries: 50000, // if unspecified, it can have unlimited entries
stdTTL: 0 // the standard ttl as number in seconds, default: 0 (unlimited) stdTTL: 0, // the standard ttl as number in seconds, default: 0 (unlimited)
}, },
backend: AsyncStorage backend: AsyncStorage,
}); });
export type language_matrix_entry = { export type language_matrix_entry = {
code: string, code: string;
name: string, name: string;
targets: string [] targets: string[];
} };
export type language_matrix = { export type language_matrix = {
[key:string] : language_matrix_entry [key: string]: language_matrix_entry;
} };
export async function fetchWithTimeout(url : string, options : RequestInit, timeout = 5000) : Promise<Response> { export async function fetchWithTimeout(
url: string,
options: RequestInit,
timeout = 5000
): Promise<Response> {
return Promise.race([ return Promise.race([
fetch(url, options), fetch(url, options),
new Promise((_, reject) => setTimeout(() => reject(new Error('timeout')), timeout)) new Promise((_, reject) =>
setTimeout(() => reject(new Error("timeout")), timeout)
),
]); ]);
} }
@ -36,39 +42,49 @@ export class LanguageServer {
async fetchLanguages(timeout = 500): Promise<language_matrix> { async fetchLanguages(timeout = 500): Promise<language_matrix> {
let data = {}; let data = {};
const res = await fetchWithTimeout(this.baseUrl + "/languages", { const res = await fetchWithTimeout(
this.baseUrl + "/languages",
{
headers: { headers: {
"Content-Type": "application/json" "Content-Type": "application/json",
} },
}, timeout); },
timeout
);
try { try {
data = await res.json(); data = await res.json();
} catch (e) { } catch (e) {
throw new Error(`Parsing data from ${await res.text()}: ${e}`) throw new Error(`Parsing data from ${await res.text()}: ${e}`);
} }
try { try {
return Object.fromEntries( return Object.fromEntries(
Object.values(data as language_matrix_entry []).map((obj : language_matrix_entry) => { Object.values(data as language_matrix_entry[]).map(
return [ (obj: language_matrix_entry) => {
obj["code"], return [obj["code"], obj];
obj, }
]
})
) )
);
} catch (e) { } catch (e) {
throw new Error(`Can't extract values from data: ${JSON.stringify(data)}`) throw new Error(
`Can't extract values from data: ${JSON.stringify(data)}`
);
} }
} }
static async getDefault() { static async getDefault() {
const settings = await Settings.getDefault(); const settings = await Settings.getDefault();
return new LanguageServer(await settings.getLibretranslateBaseUrl() || LIBRETRANSLATE_BASE_URL); return new LanguageServer(
(await settings.getLibretranslateBaseUrl()) || LIBRETRANSLATE_BASE_URL
);
} }
} }
export class Translator { export class Translator {
constructor(public source : language_t, public defaultTarget : string = "en", private _languageServer : LanguageServer) { constructor(
} public source: language_t,
public defaultTarget: string = "en",
private _languageServer: LanguageServer
) {}
get languageServer() { get languageServer() {
return this._languageServer; return this._languageServer;
@ -76,7 +92,8 @@ export class Translator {
async translate(text: string, target: string | undefined = undefined) { async translate(text: string, target: string | undefined = undefined) {
const url = this._languageServer.baseUrl + `/translate`; const url = this._languageServer.baseUrl + `/translate`;
const res = await fetch(url, { console.log(url);
const postData = {
method: "POST", method: "POST",
body: JSON.stringify({ body: JSON.stringify({
q: text, q: text,
@ -84,21 +101,32 @@ export class Translator {
target: target || this.defaultTarget, target: target || this.defaultTarget,
format: "text", format: "text",
alternatives: 3, alternatives: 3,
api_key: "" api_key: "",
}), }),
headers: { "Content-Type": "application/json" } headers: { "Content-Type": "application/json" },
}); };
console.debug("Requesting %s with %o", url, postData);
const res = await fetch(url, postData);
const data = await res.json(); const data = await res.json();
console.log(data) if (res.status === 200) {
return data.translatedText console.log(data);
return data.translatedText;
} else {
console.error(data);
}
} }
static async getDefault(defaultTarget: string | undefined = undefined) { static async getDefault(defaultTarget: string | undefined = undefined) {
const settings = await Settings.getDefault(); const settings = await Settings.getDefault();
const source = await settings.getHostLanguage(); const source = await settings.getHostLanguage();
return new Translator(source, defaultTarget, await LanguageServer.getDefault()) return new Translator(
source,
defaultTarget,
await LanguageServer.getDefault()
);
} }
} }
@ -106,18 +134,22 @@ export class CachedTranslator extends Translator {
async translate(text: string, target: string | undefined = undefined) { async translate(text: string, target: string | undefined = undefined) {
const targetKey = target || this.defaultTarget; const targetKey = target || this.defaultTarget;
// console.debug(`Translating from ${this.source} -> ${targetKey}`) // console.debug(`Translating from ${this.source} -> ${targetKey}`)
const key1 = `${this.source}::${targetKey}::${text}` const key1 = `${this.source}::${targetKey}::${text}`;
const tr1 = await cache.get(key1); const tr1 = await cache.get(key1);
if (tr1) return tr1; if (tr1) return tr1;
const tr2 = await super.translate(text, target); const tr2 = await super.translate(text, target);
const key2 = `${this.source}::${targetKey}::${text}` const key2 = `${this.source}::${targetKey}::${text}`;
await cache.set(key2, tr2); await cache.set(key2, tr2);
return tr2; return tr2;
} }
static async getDefault(defaultTarget: string | undefined = undefined) { static async getDefault(defaultTarget: string | undefined = undefined) {
const settings = await Settings.getDefault(); const settings = await Settings.getDefault();
const source = await settings.getHostLanguage(); const source = await settings.getHostLanguage() || "en";
return new CachedTranslator(source, defaultTarget, await LanguageServer.getDefault()) return new CachedTranslator(
source,
defaultTarget,
await LanguageServer.getDefault()
);
} }
} }

View File

@ -8,11 +8,12 @@ describe('Settings', () => {
beforeEach(async () => { beforeEach(async () => {
db = await getDb("development"); db = await getDb("development");
await migrateDb("development");
settings = new Settings(db); settings = new Settings(db);
}); });
afterEach(async () => { afterEach(async () => {
await migrateDb("development"); await migrateDb("development", "down");
}); });
it('should set the host language in the database', async () => { it('should set the host language in the database', async () => {

View File

@ -1,101 +1,170 @@
// components/ui/__tests__/WhisperFile.spec.tsx // app/lib/__tests__/whisper.spec.tsx
import React from "react"; import React from "react";
import { render, act } from "@testing-library/react-native"; import { getDb } from "@/app/lib/db";
import { WhisperFile } from "@/app/lib/whisper"; // Adjust the import path as necessary import { WhisperFile, WhisperModelTag } from "@/app/lib/whisper"; // Corrected to use WhisperFile and WhisperModelTag instead of WhisperDownloader
import { Settings } from "@/app/lib/settings";
import { File } from "expo-file-system/next";
jest.mock('expo-file-system');
import * as FileSystem from 'expo-file-system';
jest.mock("@/app/lib/db", () => ({
getDb: jest.fn().mockResolvedValue({
runAsync: jest.fn(),
upsert: jest.fn(), // Mock the upsert method used in addToDatabase
}),
}));
jest.mock("@/app/lib/settings", () => ({
Settings: {
getDefault: jest.fn(() => ({
getValue: jest.fn((key) => {
switch (key) {
case "whisper_model":
return "base";
default:
throw new Error(`Invalid setting: '${key}'`);
}
}),
})),
},
}));
jest.mock("expo-file-system/next", () => {
const _next = jest.requireActual("expo-file-system/next");
return {
..._next,
File: jest.fn().mockImplementation(() => ({
..._next.File,
text: jest.fn(() => {
return new String("text");
}),
})),
};
});
describe("WhisperFile", () => { describe("WhisperFile", () => {
// Corrected to use WhisperFile instead of WhisperDownloader
let whisperFile: WhisperFile; let whisperFile: WhisperFile;
beforeEach(() => { beforeEach(async () => {
whisperFile = new WhisperFile("small"); whisperFile = new WhisperFile("small");
}); });
it("should initialize correctly", () => { it("should create a download resumable with existing data if available", async () => {
expect(whisperFile).toBeInstanceOf(WhisperFile); const mockExistingData = "mockExistingData";
});
describe("getModelFileSize", () => {
it("should return the correct model file size", async () => {
expect(whisperFile.size).toBeUndefined();
await whisperFile.updateMetadata();
expect(whisperFile.size).toBeGreaterThan(1000);
});
});
describe("getWhisperDownloadStatus", () => {
it("should return the correct download status", async () => {
const mockStatus = {
doesTargetExist: true,
isDownloadComplete: false,
hasDownloadStarted: true,
progress: {
current: 100,
total: 200,
remaining: 100,
percentRemaining: 50.0,
},
};
jest
.spyOn(whisperFile, "getDownloadStatus")
.mockResolvedValue(mockStatus);
const result = await whisperFile.getDownloadStatus();
expect(result).toEqual(mockStatus);
});
});
describe("initiateWhisperDownload", () => {
it("should initiate the download with default options", async () => {
const mockModelLabel = "small";
jest
.spyOn(whisperFile, "createDownloadResumable")
.mockResolvedValue(true);
await whisperFile.initiateWhisperDownload(mockModelLabel);
expect(whisperFile.createDownloadResumable).toHaveBeenCalledWith(
mockModelLabel
);
});
it("should initiate the download with custom options", async () => {
const mockModelLabel = "small";
const mockOptions = { force_redownload: true };
jest
.spyOn(whisperFile, "createDownloadResumable")
.mockResolvedValue(true);
await whisperFile.initiateWhisperDownload(mockModelLabel, mockOptions);
expect(whisperFile.createDownloadResumable).toHaveBeenCalledWith(
mockModelLabel,
mockOptions
);
});
it("should return the correct download status when target exists and is complete", async () => {
jest.spyOn(whisperFile, "doesTargetExist").mockResolvedValue(true); jest.spyOn(whisperFile, "doesTargetExist").mockResolvedValue(true);
jest.spyOn(whisperFile, "isDownloadComplete").mockResolvedValue(true);
expect(await whisperFile.doesTargetExist()).toEqual(true); await whisperFile.createDownloadResumable();
expect(await whisperFile.isDownloadComplete()).toEqual(true); // expect(whisperFile.targetFileName).toEqual("small.bin");
expect(whisperFile.targetPath).toContain("small.bin");
expect(FileSystem.createDownloadResumable).toHaveBeenCalledWith(
"https://huggingface.co/openai/whisper-small/resolve/main/pytorch_model.bin",
"file:///whisper/small.bin",
{},
expect.any(Function),
expect.anything(),
);
}); });
it("should return the correct download status when target does not exist", async () => { // it("should create a download resumable without existing data if not available", async () => {
jest.spyOn(whisperFile, "doesTargetExist").mockResolvedValue(false); // jest.spyOn(whisperFile, "doesTargetExist").mockResolvedValue(false);
const result = await whisperFile.getDownloadStatus(); // await whisperFile.createDownloadResumable(); // Updated to use createDownloadResumable instead of download
expect(result).toEqual({ // expect(FileSystem.createDownloadResumable).toHaveBeenCalledWith(
doesTargetExist: false, // "http://mock.model.com/model",
isDownloadComplete: false, // "mockTargetPath",
hasDownloadStarted: false, // {},
progress: undefined, // expect.any(Function),
}); // undefined
}); // );
}); // });
// Add more tests as needed for other methods in WhisperFile // it("should update the download status in the database", async () => {
// const mockRunAsync = jest.fn();
// (getDb as jest.Mock).mockResolvedValue({ runAsync: mockRunAsync });
// const downloadable = await whisperFile.createDownloadResumable(); // Updated to use createDownloadResumable instead of download
// await downloadable.resumeAsync();
// jest.advanceTimersByTime(1000);
// expect(mockRunAsync).toHaveBeenCalled();
// });
// it("should record the latest target hash after downloading", async () => {
// const mockRecordLatestTargetHash = jest.spyOn(
// whisperFile,
// "recordLatestTargetHash"
// );
// await whisperFile.createDownloadResumable(); // Updated to use createDownloadResumable instead of download
// expect(mockRecordLatestTargetHash).toHaveBeenCalled();
// });
// it("should call the onData callback if provided", async () => {
// const mockOnData = jest.fn();
// const options = { onData: mockOnData };
// await whisperFile.createDownloadResumable(options); // Updated to use createDownloadResumable instead of download
// expect(mockOnData).toHaveBeenCalledWith(expect.any(Object));
// });
// describe("getDownloadStatus", () => {
// it("should return the correct download status when model size is known and download has started", async () => {
// whisperFile.size = 1024;
// jest.spyOn(whisperFile, "doesTargetExist").mockResolvedValue(true);
// jest.spyOn(whisperFile, "isDownloadComplete").mockResolvedValue(false);
// jest.spyOn(whisperFile, "targetFile").mockReturnValue({
// size: 512,
// });
// const status = await whisperFile.getDownloadStatus();
// expect(status).toEqual({
// doesTargetExist: true,
// isDownloadComplete: false,
// hasDownloadStarted: true,
// progress: {
// current: 512,
// total: 1024,
// remaining: 512,
// percentRemaining: 50.0,
// },
// });
// });
// it("should return the correct download status when model size is known and download is complete", async () => {
// whisperFile.size = 1024;
// jest.spyOn(whisperFile, "doesTargetExist").mockResolvedValue(true);
// jest.spyOn(whisperFile, "isDownloadComplete").mockResolvedValue(true);
// const status = await whisperFile.getDownloadStatus();
// expect(status).toEqual({
// doesTargetExist: true,
// isDownloadComplete: true,
// hasDownloadStarted: false,
// progress: undefined,
// });
// });
// it("should return the correct download status when model size is unknown", async () => {
// jest.spyOn(whisperFile, "doesTargetExist").mockResolvedValue(false);
// const status = await whisperFile.getDownloadStatus();
// expect(status).toEqual({
// doesTargetExist: false,
// isDownloadComplete: false,
// hasDownloadStarted: false,
// progress: undefined,
// });
// });
// });
}); });

View File

@ -2,15 +2,16 @@
export const MIGRATE_UP = { export const MIGRATE_UP = {
1: [ 1: [
`CREATE TABLE IF NOT EXISTS settings ( `CREATE TABLE IF NOT EXISTS settings (
host_language TEXT, key TEXT PRIMARY KEY,
libretranslate_base_url TEXT, value TEXT
ui_direction INTEGER,
whisper_model TEXT
)`, )`,
], ],
2: [ 2: [
`CREATE TABLE IF NOT EXISTS whisper_models ( `CREATE TABLE IF NOT EXISTS whisper_models (
model TEXT PRIMARY KEY, model TEXT PRIMARY KEY,
download_status STRING(255),
expected_size INTEGER,
last_hash STRING(1024),
bytes_done INTEGER, bytes_done INTEGER,
bytes_total INTEGER bytes_total INTEGER
)`, )`,

View File

@ -1,5 +1,6 @@
import { SQLiteDatabase } from "expo-sqlite"; import { SQLiteDatabase } from "expo-sqlite";
import { getDb } from "./db"; import { getDb } from "./db";
import { WhisperFile, whisper_model_tag_t } from "./whisper";
export class Settings { export class Settings {
@ -20,10 +21,9 @@ export class Settings {
throw new Error(`Invalid setting: '${key}'`) throw new Error(`Invalid setting: '${key}'`)
} }
const row: { [key: string]: string } | null = this.db.getFirstSync(`SELECT ${key} from settings LIMIT 1`) const row: { value: string } | null = this.db.getFirstSync(`SELECT value FROM settings WHERE key = ?`, key)
if (!(row && row[key])) return undefined; return row?.value;
return row[key];
} }
@ -33,13 +33,11 @@ export class Settings {
} }
// Check if the key already exists // Check if the key already exists
this.db.runSync(`INSERT INTO OR UPDATE this.db.runSync(`INSERT OR REPLACE INTO
settings settings
(${key}) (key, value)
VALUES VALUES
(?) (?, ?)`, key, value);
WHERE
${key} IS NOT NULL`, value);
} }
async setHostLanguage(value: string) { async setHostLanguage(value: string) {
@ -63,11 +61,10 @@ export class Settings {
} }
async getWhisperModel() { async getWhisperModel() {
return await this.getValue("whisper_model"); return await this.getValue("whisper_model") as whisper_model_tag_t;
} }
static async getDefault() { static async getDefault() {
return new Settings(await getDb()) return new Settings(await getDb())
} }
} }

5
app/lib/util.ts Normal file
View File

@ -0,0 +1,5 @@
import { TextDecoder } from "util";
export async function arrbufToStr(arrayBuffer : ArrayBuffer) {
return new TextDecoder().decode(new Uint8Array(arrayBuffer));
}

View File

@ -3,6 +3,7 @@ import * as FileSystem from "expo-file-system";
import { File, Paths } from "expo-file-system/next"; import { File, Paths } from "expo-file-system/next";
import { getDb } from "./db"; import { getDb } from "./db";
import * as Crypto from "expo-crypto"; import * as Crypto from "expo-crypto";
import { arrbufToStr } from "./util";
export const WHISPER_MODEL_PATH = Paths.join( export const WHISPER_MODEL_PATH = Paths.join(
FileSystem.documentDirectory || "file:///", FileSystem.documentDirectory || "file:///",
@ -114,6 +115,12 @@ export type download_status_t = {
}; };
export class WhisperFile { export class WhisperFile {
hf_metadata: hf_metadata_t | undefined;
target_hash: string | undefined;
does_target_exist: boolean = false;
download_data: FileSystem.DownloadProgressData | undefined;
constructor( constructor(
public tag: whisper_model_tag_t, public tag: whisper_model_tag_t,
private targetFileName?: string, private targetFileName?: string,
@ -122,11 +129,11 @@ export class WhisperFile {
) { ) {
this.targetFileName = this.targetFileName || `${tag}.bin`; this.targetFileName = this.targetFileName || `${tag}.bin`;
this.label = this.label =
this.label || `${tag[0].toUpperCase}${tag.substring(1).toLowerCase()}`; this.label || `${tag[0].toUpperCase()}${tag.substring(1).toLowerCase()}`;
} }
get targetPath() { get targetPath() {
return Paths.join(WHISPER_MODEL_DIR, this.targetFileName as string); return Paths.join(WHISPER_MODEL_PATH, this.targetFileName as string);
} }
get targetFile() { get targetFile() {
@ -137,79 +144,30 @@ export class WhisperFile {
return await FileSystem.getInfoAsync(this.targetPath); return await FileSystem.getInfoAsync(this.targetPath);
} }
async doesTargetExist() { async updateTargetExistence() {
return (await this.getTargetInfo()).exists; this.does_target_exist = (await this.getTargetInfo()).exists;
} }
public async recordLatestTargetHash() { public async getTargetSha() {
if (!(await this.doesTargetExist())) { await this.updateTargetExistence();
console.debug("%s does not exist", this.targetPath); if (!this.does_target_exist) {
}
if (!this.label) {
throw new Error("No label");
}
const digest1Str = await this.getActualTargetHash();
if (!digest1Str) {
return;
}
const db = await getDb();
db.runSync(`INSERT OR UPDATE
INTO whisper_models
(model, last_hash)
VALUES (?, ?)
WHERE
model = ?`, this.label, digest1Str, this.label);
}
public async getRecordedTargetHash(): Promise<string> {
const db = await getDb();
const row = db.getFirstSync("SELECT last_hash FROM whisper_models WHERE model = ?", this.tag);
return (row as {last_hash: string}).last_hash
}
public async getActualTargetHash(): Promise<string | undefined> {
if (!(await this.doesTargetExist())) {
console.debug("%s does not exist", this.targetPath); console.debug("%s does not exist", this.targetPath);
return undefined; return undefined;
} }
const digest1 = await Crypto.digest( return await Crypto.digest(
Crypto.CryptoDigestAlgorithm.SHA256, Crypto.CryptoDigestAlgorithm.SHA256,
this.targetFile.bytes() this.targetFile.bytes()
); );
const digest1Str = new TextDecoder().decode(new Uint8Array(digest1));
return digest1Str;
} }
async isTargetCorrupted() { public async updateTargetHash() {
const recordedTargetHash = await this.getRecordedTargetHash(); const targetSha = await this.getTargetSha();
const actualTargetHash = await this.getActualTargetHash(); if (!targetSha) return;
if (!(actualTargetHash || recordedTargetHash)) return false; this.target_hash = await arrbufToStr(targetSha);
return actualTargetHash !== recordedTargetHash;
} }
async isDownloadComplete() { get isHashValid() {
if (!(await this.doesTargetExist())) { return this.target_hash === this.hf_metadata?.oid;
console.debug("%s does not exist", this.targetPath);
return false;
}
const data = this.targetFile.bytes();
const meta = await this.fetchMetadata();
const expectedHash = meta.oid;
const digest1: ArrayBuffer = await Crypto.digest(
Crypto.CryptoDigestAlgorithm.SHA256,
data
);
const digest1Str = new TextDecoder().decode(new Uint8Array(digest1));
const doesMatch = digest1Str === expectedHash;
if (!doesMatch) {
console.debug(
"sha256 of '%s' does not match expected '%s'",
digest1Str,
expectedHash
);
return false;
}
return true;
} }
delete(ignoreErrors = true) { delete(ignoreErrors = true) {
@ -232,7 +190,21 @@ export class WhisperFile {
return create_hf_url(this.tag, "raw"); return create_hf_url(this.tag, "raw");
} }
private async fetchMetadata(): Promise<hf_metadata_t> { get percentDone() {
if (!this.download_data) return 0;
return (
(this.download_data.totalBytesWritten /
this.download_data.totalBytesExpectedToWrite) *
100
);
}
get percentLeft() {
if (!this.download_data) return 0;
return 100 - this.percentDone;
}
public async syncHfMetadata() {
try { try {
const resp = await fetch(this.metadataUrl, { const resp = await fetch(this.metadataUrl, {
credentials: "include", credentials: "include",
@ -254,7 +226,7 @@ export class WhisperFile {
mode: "cors", mode: "cors",
}); });
const text = await resp.text(); const text = await resp.text();
return Object.fromEntries( this.hf_metadata = Object.fromEntries(
text.split("\n").map((line) => line.split(" ")) text.split("\n").map((line) => line.split(" "))
) as hf_metadata_t; ) as hf_metadata_t;
} catch (err) { } catch (err) {
@ -263,25 +235,6 @@ export class WhisperFile {
} }
} }
async updateMetadata() {
const metadata = await this.fetchMetadata();
this.size = Number.parseInt(metadata.size);
}
async addToDatabase() {
const db = await getDb();
if (!(this.size && this.tag)) {
throw new Error();
}
db.runSync(`INSERT OR UPDATE
INTO whisper_models
(model, expected_size)
VALUES
(?, ?)
WHERE
model = ?`, this.tag, this.size.valueOf(), this.tag);
}
async createDownloadResumable( async createDownloadResumable(
options: { options: {
onData?: DownloadCallback | undefined; onData?: DownloadCallback | undefined;
@ -289,69 +242,43 @@ export class WhisperFile {
onData: undefined, onData: undefined,
} }
) { ) {
const existingData = (await this.doesTargetExist()) await this.syncHfMetadata();
// If the whisper model dir doesn't exist, create it.
if (!WHISPER_MODEL_DIR.exists) {
FileSystem.makeDirectoryAsync(WHISPER_MODEL_PATH, {
intermediates: true,
});
}
// Check for the existence of the target file
// If it exists, load the existing data.
await this.updateTargetExistence();
const existingData = this.does_target_exist
? this.targetFile.text() ? this.targetFile.text()
: undefined; : undefined;
if (await this.doesTargetExist()) { // Create the resumable.
}
return FileSystem.createDownloadResumable( return FileSystem.createDownloadResumable(
this.modelUrl, this.modelUrl,
this.targetPath, this.targetPath,
{}, {},
async (data: FileSystem.DownloadProgressData) => { async (data: FileSystem.DownloadProgressData) => {
const db = await getDb(); this.download_data = data;
db.runAsync(`INSERT INTO OR UPDATE await this.syncHfMetadata();
whisper_models await this.updateTargetHash();
(model, download_status) await this.updateTargetExistence();
VALUES if (options.onData) await options.onData(this);
(?, ?)
WHERE
model = ?
`, this.tag, "active", this.tag);
await this.recordLatestTargetHash();
if (options.onData) await options.onData(data);
}, },
existingData ? existingData : undefined existingData ? existingData : undefined
); );
} }
async getDownloadStatus(): Promise<download_status_t> {
const doesTargetExist = await this.doesTargetExist();
const isDownloadComplete = await this.isDownloadComplete();
const hasDownloadStarted = doesTargetExist && !isDownloadComplete;
if (!this.size) {
return {
doesTargetExist: false,
isDownloadComplete: false,
hasDownloadStarted: false,
progress: undefined,
}
} }
const remaining = hasDownloadStarted export type DownloadCallback = (arg0: WhisperFile) => any;
? this.size - (this.targetFile.size as number)
: 0;
const progress = hasDownloadStarted export const WHISPER_FILES = {
? { small: new WhisperFile("small"),
current: this.targetFile.size || 0, medium: new WhisperFile("medium"),
total: this.size, large: new WhisperFile("large"),
remaining: this.size - (this.targetFile.size as number),
percentRemaining: (remaining / this.size) * 100.0,
}
: undefined;
return {
doesTargetExist,
isDownloadComplete,
hasDownloadStarted,
progress,
}; };
}
}
export type DownloadCallback = (arg0: FileSystem.DownloadProgressData) => any;

View File

@ -8,6 +8,7 @@ import { SafeAreaProvider, SafeAreaView } from "react-native-safe-area-context";
import { Conversation, Speaker } from "@/app/lib/conversation"; import { Conversation, Speaker } from "@/app/lib/conversation";
import { NavigationProp, ParamListBase } from "@react-navigation/native"; import { NavigationProp, ParamListBase } from "@react-navigation/native";
import { Link, useNavigation } from "expo-router"; import { Link, useNavigation } from "expo-router";
import { migrateDb } from "@/app/lib/db";
export function LanguageSelection(props: { export function LanguageSelection(props: {
@ -30,6 +31,7 @@ export function LanguageSelection(props: {
useEffect(() => { useEffect(() => {
(async () => { (async () => {
await migrateDb();
try { try {
// Replace with your actual async data fetching logic // Replace with your actual async data fetching logic
setTranslator(await CachedTranslator.getDefault()); setTranslator(await CachedTranslator.getDefault());
@ -50,7 +52,7 @@ export function LanguageSelection(props: {
</Pressable> </Pressable>
<ScrollView > <ScrollView >
<SafeAreaProvider> <SafeAreaProvider>
<SafeAreaView> <SafeAreaView style={styles.table}>
{(languages && languagesLoaded) ? Object.entries(languages).filter((l) => (LANG_FLAGS as any)[l[0]] !== undefined).map( {(languages && languagesLoaded) ? Object.entries(languages).filter((l) => (LANG_FLAGS as any)[l[0]] !== undefined).map(
([lang, lang_entry]) => { ([lang, lang_entry]) => {
return ( return (
@ -66,11 +68,15 @@ export function LanguageSelection(props: {
) )
} }
const DEBUG_BORDER = {
borderWidth: 3,
borderStyle: "dotted",
borderColor: "blue",
}
const styles = StyleSheet.create({ const styles = StyleSheet.create({
column: { table: {
flex: 1, flexDirection: "row",
flexDirection: 'row', flexWrap: "wrap",
flexWrap: 'wrap',
padding: 8,
}, },
}) })

View File

@ -1,6 +1,7 @@
import React, { useState, useEffect } from "react"; import React, { useState, useEffect } from "react";
import { View, Text, TextInput, Pressable, StyleSheet } from "react-native"; import { View, Text, TextInput, Pressable, StyleSheet } from "react-native";
import { import {
WHISPER_FILES,
WhisperFile, WhisperFile,
download_status_t, download_status_t,
whisper_tag_t, whisper_tag_t,
@ -34,31 +35,34 @@ const SettingsComponent = () => {
} | null>(null); } | null>(null);
const [whisperModel, setWhisperModel] = const [whisperModel, setWhisperModel] =
useState<keyof typeof WHISPER_MODELS>("small"); useState<keyof typeof WHISPER_MODELS>("small");
const [downloader, setDownloader] = useState<any>(null);
const [whisperFile, setWhisperFile] = useState<WhisperFile | undefined>(); const [whisperFile, setWhisperFile] = useState<WhisperFile | undefined>();
const [downloadStatus, setDownloadStatus] = useState< const [whisperFileExists, setWhisperFileExists] = useState<boolean>(false);
undefined | download_status_t const [isWhisperHashValid, setIsWhisperHashValid] = useState<boolean>(false);
>(); const [downloader, setDownloader] = useState<any>(null);
const [bytesDone, setBytesDone] = useState<number | undefined>();
const [bytesRemaining, setBytesRemaining] = useState<number | undefined>();
const [statusTimeout, setStatusTimeout] = useState< const [statusTimeout, setStatusTimeout] = useState<
NodeJS.Timeout | undefined NodeJS.Timeout | undefined
>(); >();
useEffect(() => { useEffect(() => {
loadSettings(); (async function () {
}, []);
const getLanguageOptions = async () => {
const languageServer = await LanguageServer.getDefault();
setLanguageOptions(await languageServer.fetchLanguages());
};
const loadSettings = async () => {
const settings = await Settings.getDefault(); const settings = await Settings.getDefault();
setHostLanguage((await settings.getHostLanguage()) || "en"); setHostLanguage((await settings.getHostLanguage()) || "en");
setLibretranslateBaseUrl( setLibretranslateBaseUrl(
(await settings.getLibretranslateBaseUrl()) || LIBRETRANSLATE_BASE_URL (await settings.getLibretranslateBaseUrl()) || LIBRETRANSLATE_BASE_URL
); );
setWhisperModel(await settings.getWhisperModel()); setWhisperModel((await settings.getWhisperModel()) || "small");
setWhisperFile(WHISPER_FILES[whisperModel]);
await whisperFile?.syncHfMetadata();
await whisperFile?.updateTargetExistence();
await whisperFile?.updateTargetHash();
})();
}, [whisperFile]);
const getLanguageOptions = async () => {
const languageServer = await LanguageServer.getDefault();
setLanguageOptions(await languageServer.fetchLanguages());
}; };
const handleHostLanguageChange = async (lang: string) => { const handleHostLanguageChange = async (lang: string) => {
@ -83,17 +87,24 @@ const SettingsComponent = () => {
} }
}; };
const intervalUpdateDownloadStatus = async () => {
if (!whisperFile) return;
const status = await whisperFile.getDownloadStatus();
setDownloadStatus(status);
};
const handleWhisperModelChange = async (model: whisper_tag_t) => { const handleWhisperModelChange = async (model: whisper_tag_t) => {
const settings = await Settings.getDefault(); const settings = await Settings.getDefault();
await settings.setWhisperModel(model); await settings.setWhisperModel(model);
setWhisperModel(model); setWhisperModel(model);
setWhisperFile(new WhisperFile(model)); const wFile = WHISPER_FILES[whisperModel];
await wFile.syncHfMetadata();
await wFile.updateTargetExistence();
await wFile.updateTargetHash();
setIsWhisperHashValid(wFile.isHashValid);
setWhisperFile(wFile);
setWhisperFileExists(wFile.does_target_exist);
};
const doSetDownloadStatus = (arg0: WhisperFile) => {
console.log("Downloading ....")
setIsWhisperHashValid(arg0.isHashValid);
setBytesDone(arg0.download_data?.totalBytesWritten);
setBytesRemaining(arg0.download_data?.totalBytesExpectedToWrite);
}; };
const doDownload = async () => { const doDownload = async () => {
@ -101,16 +112,16 @@ const SettingsComponent = () => {
throw new Error("Could not start download because whisperModel not set."); throw new Error("Could not start download because whisperModel not set.");
} }
console.log("Starging download of %s", whisperModel) console.log("Starting download of %s", whisperModel);
const whisperFile = new WhisperFile(whisperModel); if (!whisperFile) throw new Error("No whisper file");
const resumable = await whisperFile.createDownloadResumable();
setDownloader(resumable);
try { try {
await resumable.downloadAsync(); const resumable = await whisperFile.createDownloadResumable({
const statusTimeout = setInterval(intervalUpdateDownloadStatus, 200); onData: doSetDownloadStatus,
setStatusTimeout(statusTimeout); });
setDownloader(resumable);
await resumable.resumeAsync();
} catch (error) { } catch (error) {
console.error("Failed to download whisper model:", error); console.error("Failed to download whisper model:", error);
} }
@ -174,28 +185,22 @@ const SettingsComponent = () => {
))} ))}
</Picker> </Picker>
<View> <View>
{whisperModel && {/* <Text>whisper file: { whisperFile?.tag }</Text> */}
(downloadStatus?.isDownloadComplete ? ( {whisperFile &&
downloadStatus?.doesTargetExist ? ( ( whisperFileExists && (<Pressable onPress={doDelete} style={styles.deleteButton}>
<Pressable onPress={doDelete}>
<Text>DELETE {whisperModel.toUpperCase()}</Text> <Text>DELETE {whisperModel.toUpperCase()}</Text>
</Pressable> </Pressable>))
) : ( }
<Pressable onPress={doStopDownload}> <Pressable onPress={doDownload} style={styles.pauseDownloadButton}>
<Text>PAUSE</Text>
</Pressable>
)
) : (
<Pressable onPress={doDownload}>
<Text>DOWNLOAD {whisperModel.toUpperCase()}</Text> <Text>DOWNLOAD {whisperModel.toUpperCase()}</Text>
</Pressable> </Pressable>
))} ))}
{downloadStatus?.progress && ( {bytesDone && bytesRemaining && (
<View> <View>
<Text> <Text>
{downloadStatus.progress.current} of{" "} {bytesDone} of{" "}
{downloadStatus.progress.total} ( {bytesRemaining} (
{downloadStatus.progress.percentRemaining} %){" "} {bytesDone / bytesRemaining * 100} %){" "}
</Text> </Text>
</View> </View>
)} )}

View File

@ -1,5 +1,5 @@
jest.mock("@/app/i18n/api", () => require("../../__mocks__/api.ts")); jest.mock("@/app/i18n/api", () => require("../../__mocks__/api.ts"));
import { renderRouter} from 'expo-router/testing-library'; import { renderRouter } from "expo-router/testing-library";
import React from "react"; import React from "react";
import { import {
act, act,
@ -13,14 +13,21 @@ import {
createNavigationContainerRef, createNavigationContainerRef,
} from "@react-navigation/native"; } from "@react-navigation/native";
import TTNavStack from "../TTNavStack"; import TTNavStack from "../TTNavStack";
import { migrateDb } from "@/app/lib/db";
describe("Navigation", () => { describe("Navigation", () => {
beforeEach(() => { beforeEach(async () => {
await migrateDb("development", "up");
// Reset the navigation state before each test // Reset the navigation state before each test
jest.clearAllMocks();
jest.useFakeTimers(); jest.useFakeTimers();
}); });
afterEach(async () => {
await migrateDb("development", "down");
jest.clearAllMocks();
jest.useRealTimers();
});
it("Navigates to ConversationThread on language selection", async () => { it("Navigates to ConversationThread on language selection", async () => {
const MockComponent = jest.fn(() => <TTNavStack />); const MockComponent = jest.fn(() => <TTNavStack />);
renderRouter( renderRouter(
@ -28,7 +35,7 @@ describe("Navigation", () => {
index: MockComponent, index: MockComponent,
}, },
{ {
initialUrl: '/', initialUrl: "/",
} }
); );
const languageSelectionText = await waitFor(() => const languageSelectionText = await waitFor(() =>
@ -47,14 +54,16 @@ describe("Navigation", () => {
index: MockComponent, index: MockComponent,
}, },
{ {
initialUrl: '/', initialUrl: "/",
} }
); );
const settingsButton = await waitFor(() => const settingsButton = await waitFor(() =>
screen.getByText(/.*Settings.*/i) screen.getByText(/.*Settings.*/i)
); );
fireEvent.press(settingsButton); fireEvent.press(settingsButton);
expect(await waitFor(() => screen.getByText(/Settings/i))).toBeOnTheScreen(); expect(
await waitFor(() => screen.getByText(/Settings/i))
).toBeOnTheScreen();
// expect(waitFor(() => screen.getByText(/Settings/i))).toBeTruthy() // expect(waitFor(() => screen.getByText(/Settings/i))).toBeTruthy()
expect(screen.getByText("Settings")).toBeOnTheScreen(); expect(screen.getByText("Settings")).toBeOnTheScreen();
}); });

View File

@ -106,7 +106,12 @@ const ISpeakButton = (props: ISpeakButtonProps) => {
<View style={styles.flag}> <View style={styles.flag}>
{countries && {countries &&
countries.map((c) => { countries.map((c) => {
return <CountryFlag isoCode={c} size={25} key={c} />; return (
<View>
<Text>{c}</Text>
<CountryFlag isoCode={c} size={25} key={c} />
</View>
);
})} })}
</View> </View>
<View> <View>
@ -121,14 +126,13 @@ const ISpeakButton = (props: ISpeakButtonProps) => {
const styles = StyleSheet.create({ const styles = StyleSheet.create({
button: { button: {
width: "20%",
borderRadius: 10, borderRadius: 10,
borderColor: "white", borderColor: "white",
borderWidth: 1, borderWidth: 1,
borderStyle: "solid", borderStyle: "solid",
height: 110, height: 110,
alignSelf: "flex-start", width: 170,
margin: 8, margin: 10,
}, },
flag: {}, flag: {},
iSpeak: { iSpeak: {

View File

@ -83,6 +83,7 @@ describe("SettingsComponent", () => {
beforeEach(async () => { beforeEach(async () => {
db = await getDb("development"); db = await getDb("development");
await migrateDb("development");
settings = new Settings(db); settings = new Settings(db);
jest.spyOn(Settings, 'getDefault').mockResolvedValue(settings); jest.spyOn(Settings, 'getDefault').mockResolvedValue(settings);
await settings.setHostLanguage("en"); await settings.setHostLanguage("en");

View File

@ -9,43 +9,37 @@ jest.mock("expo-sqlite", () => {
const { MIGRATE_UP } = jest.requireActual("./app/lib/migrations"); const { MIGRATE_UP } = jest.requireActual("./app/lib/migrations");
const openDatabaseAsync = async (name: string) => { const genericRun = (sql: string, ... params : string []) => {
return {
closeAsync: jest.fn(() => db.close()),
executeSql: jest.fn((sql: string) => db.exec(sql)),
runAsync: jest.fn(async (sql: string, params = []) => {
for (let m of Object.values(MIGRATE_UP)) {
for (let stmt of m) {
const s = db.prepare(stmt);
s.run();
}
}
const stmt = db.prepare(sql);
// console.log("Running %s with %s", sql, params); // console.log("Running %s with %s", sql, params);
try { try {
stmt.run(params); const stmt = db.prepare(sql);
stmt.run(...params);
} catch (e) { } catch (e) {
throw new Error( throw new Error(
`running ${sql} with params ${JSON.stringify(params)}: ${e}` `running ${sql} with params ${JSON.stringify(params)}: ${e}`
); );
} }
}),
getFirstAsync: jest.fn(async (sql: string, params = []) => {
for (let m of Object.values(MIGRATE_UP)) {
for (let stmt of m) {
const s = db.prepare(stmt);
s.run();
}
} }
const genericGetFirst = (sql: string, params = []) => {
const stmt = db.prepare(sql); const stmt = db.prepare(sql);
// const result = stmt.run(...params); // const result = stmt.run(...params);
return stmt.get(params); return stmt.get(params);
}), };
const openDatabaseAsync = async (name: string) => {
return {
closeAsync: jest.fn(() => db.close()),
executeSql: jest.fn((sql: string) => db.exec(sql)),
runAsync: jest.fn(genericRun),
runSync: jest.fn(genericRun),
getFirstAsync: jest.fn(genericGetFirst),
getFirstSync: jest.fn(genericGetFirst),
}; };
}; };
return { return {
migrateDb: async (direction: "up" | "down" = "up") => { migrateDb: async (direction: "up" | "down" = "up") => {
const db = await openDatabaseAsync("translation_terrace"); const db = await openDatabaseAsync("translation_terrace_development");
for (let m of Object.values(MIGRATE_UP)) { for (let m of Object.values(MIGRATE_UP)) {
for (let stmt of m) { for (let stmt of m) {
await db.executeSql(stmt); await db.executeSql(stmt);
@ -120,3 +114,17 @@ jest.mock('@/app/lib/settings', () => {
default: MockSettings default: MockSettings
}; };
}); });
jest.mock('expo-file-system', () => ({
// ... other methods ...
createDownloadResumable: jest.fn(() => ({
downloadAsync: jest.fn(() => Promise.resolve({ uri: 'mocked-uri' })),
pauseAsync: jest.fn(() => Promise.resolve()),
resumeAsync: jest.fn(() => Promise.resolve()),
cancelAsync: jest.fn(() => Promise.resolve()),
})),
getInfoAsync: jest.fn(() => ({
exists: () => false,
}))
// ... other methods ...
}));