From dca3987e18efcf477b40e4db0665f5bb489d8759 Mon Sep 17 00:00:00 2001 From: Jordan Date: Tue, 11 Mar 2025 07:26:49 -0700 Subject: [PATCH] work on downloading some more. --- __mocks__/db.ts | 3 +- app/i18n/api.ts | 216 ++++++++++++--------- app/lib/__tests__/settings.spec.tsx | 3 +- app/lib/__tests__/whisper.spec.tsx | 217 ++++++++++++++-------- app/lib/migrations.ts | 11 +- app/lib/settings.ts | 17 +- app/lib/util.ts | 5 + app/lib/whisper.ts | 201 +++++++------------- components/LanguageSelection.tsx | 22 ++- components/Settings.tsx | 97 +++++----- components/__tests__/index.spec.tsx | 21 ++- components/ui/ISpeakButton.tsx | 12 +- components/ui/__tests__/Settings.spec.tsx | 1 + jestSetup.ts | 68 ++++--- 14 files changed, 480 insertions(+), 414 deletions(-) create mode 100644 app/lib/util.ts diff --git a/__mocks__/db.ts b/__mocks__/db.ts index 54d8c12..26f1ef1 100644 --- a/__mocks__/db.ts +++ b/__mocks__/db.ts @@ -1,7 +1,8 @@ export default { getDb: jest.fn(() => { return { - runAsync: jest.fn((statement: string, value: string) => {}), + runAsync: jest.fn((statement: string, ... values: string []) => {}), + runSync: jest.fn((statement: string, ... values : string []) => {}), getFirstAsync: jest.fn((statement: string, value: string) => { return []; }), diff --git a/app/i18n/api.ts b/app/i18n/api.ts index e2ab4ec..76a7582 100644 --- a/app/i18n/api.ts +++ b/app/i18n/api.ts @@ -1,123 +1,155 @@ import { Cache } from "react-native-cache"; import { LIBRETRANSLATE_BASE_URL } from "@/constants/api"; -import AsyncStorage from '@react-native-async-storage/async-storage'; +import AsyncStorage from "@react-native-async-storage/async-storage"; import { Settings } from "../lib/settings"; type language_t = string; const cache = new Cache({ - namespace: "translation_terrace", - policy: { - maxEntries: 50000, // if unspecified, it can have unlimited entries - stdTTL: 0 // the standard ttl as number in seconds, default: 0 (unlimited) - }, - backend: AsyncStorage + namespace: "translation_terrace", + policy: { + maxEntries: 50000, // if unspecified, it can have unlimited entries + stdTTL: 0, // the standard ttl as number in seconds, default: 0 (unlimited) + }, + backend: AsyncStorage, }); export type language_matrix_entry = { - code: string, - name: string, - targets: string [] -} + code: string; + name: string; + targets: string[]; +}; export type language_matrix = { - [key:string] : language_matrix_entry -} + [key: string]: language_matrix_entry; +}; -export async function fetchWithTimeout(url : string, options : RequestInit, timeout = 5000) : Promise { - return Promise.race([ - fetch(url, options), - new Promise((_, reject) => setTimeout(() => reject(new Error('timeout')), timeout)) - ]); +export async function fetchWithTimeout( + url: string, + options: RequestInit, + timeout = 5000 +): Promise { + return Promise.race([ + fetch(url, options), + new Promise((_, reject) => + setTimeout(() => reject(new Error("timeout")), timeout) + ), + ]); } export class LanguageServer { - constructor(public baseUrl : string) {} + constructor(public baseUrl: string) {} - async fetchLanguages(timeout = 500) : Promise { - let data = {}; - const res = await fetchWithTimeout(this.baseUrl + "/languages", { - headers: { - "Content-Type": "application/json" - } - }, timeout); - try { - data = await res.json(); - } catch (e) { - throw new Error(`Parsing data from ${await res.text()}: ${e}`) - } - try { - return Object.fromEntries( - Object.values(data as language_matrix_entry []).map((obj : language_matrix_entry) => { - return [ - obj["code"], - obj, - ] - }) - ) - } catch(e) { - throw new Error(`Can't extract values from data: ${JSON.stringify(data)}`) - } + async fetchLanguages(timeout = 500): Promise { + let data = {}; + const res = await fetchWithTimeout( + this.baseUrl + "/languages", + { + headers: { + "Content-Type": "application/json", + }, + }, + timeout + ); + try { + data = await res.json(); + } catch (e) { + throw new Error(`Parsing data from ${await res.text()}: ${e}`); } + try { + return Object.fromEntries( + Object.values(data as language_matrix_entry[]).map( + (obj: language_matrix_entry) => { + return [obj["code"], obj]; + } + ) + ); + } catch (e) { + throw new Error( + `Can't extract values from data: ${JSON.stringify(data)}` + ); + } + } - static async getDefault() { - const settings = await Settings.getDefault(); - return new LanguageServer(await settings.getLibretranslateBaseUrl() || LIBRETRANSLATE_BASE_URL); - } + static async getDefault() { + const settings = await Settings.getDefault(); + return new LanguageServer( + (await settings.getLibretranslateBaseUrl()) || LIBRETRANSLATE_BASE_URL + ); + } } export class Translator { - constructor(public source : language_t, public defaultTarget : string = "en", private _languageServer : LanguageServer) { - } + constructor( + public source: language_t, + public defaultTarget: string = "en", + private _languageServer: LanguageServer + ) {} - get languageServer() { - return this._languageServer; - } + get languageServer() { + return this._languageServer; + } - async translate(text : string, target : string|undefined = undefined) { - const url = this._languageServer.baseUrl + `/translate`; - const res = await fetch(url, { - method: "POST", - body: JSON.stringify({ - q: text, - source: this.source, - target: target || this.defaultTarget, - format: "text", - alternatives: 3, - api_key: "" - }), - headers: { "Content-Type": "application/json" } - }); + async translate(text: string, target: string | undefined = undefined) { + const url = this._languageServer.baseUrl + `/translate`; + console.log(url); + const postData = { + method: "POST", + body: JSON.stringify({ + q: text, + source: this.source, + target: target || this.defaultTarget, + format: "text", + alternatives: 3, + api_key: "", + }), + headers: { "Content-Type": "application/json" }, + }; - - const data = await res.json(); - console.log(data) - return data.translatedText - } + console.debug("Requesting %s with %o", url, postData); - static async getDefault(defaultTarget: string | undefined = undefined) { - const settings = await Settings.getDefault(); - const source = await settings.getHostLanguage(); - return new Translator(source, defaultTarget, await LanguageServer.getDefault()) + const res = await fetch(url, postData); + + const data = await res.json(); + if (res.status === 200) { + console.log(data); + return data.translatedText; + } else { + console.error(data); } + } + + static async getDefault(defaultTarget: string | undefined = undefined) { + const settings = await Settings.getDefault(); + const source = await settings.getHostLanguage(); + return new Translator( + source, + defaultTarget, + await LanguageServer.getDefault() + ); + } } export class CachedTranslator extends Translator { - async translate (text : string, target : string|undefined = undefined) { - const targetKey = target || this.defaultTarget; - // console.debug(`Translating from ${this.source} -> ${targetKey}`) - const key1 = `${this.source}::${targetKey}::${text}` - const tr1 = await cache.get(key1); - if (tr1) return tr1; - const tr2 = await super.translate(text, target); - const key2 = `${this.source}::${targetKey}::${text}` - await cache.set(key2, tr2); - return tr2; - } + async translate(text: string, target: string | undefined = undefined) { + const targetKey = target || this.defaultTarget; + // console.debug(`Translating from ${this.source} -> ${targetKey}`) + const key1 = `${this.source}::${targetKey}::${text}`; + const tr1 = await cache.get(key1); + if (tr1) return tr1; + const tr2 = await super.translate(text, target); + const key2 = `${this.source}::${targetKey}::${text}`; + await cache.set(key2, tr2); + return tr2; + } - static async getDefault(defaultTarget: string | undefined = undefined) { - const settings = await Settings.getDefault(); - const source = await settings.getHostLanguage(); - return new CachedTranslator(source, defaultTarget, await LanguageServer.getDefault()) - } -} \ No newline at end of file + static async getDefault(defaultTarget: string | undefined = undefined) { + const settings = await Settings.getDefault(); + const source = await settings.getHostLanguage() || "en"; + return new CachedTranslator( + source, + defaultTarget, + await LanguageServer.getDefault() + ); + } +} diff --git a/app/lib/__tests__/settings.spec.tsx b/app/lib/__tests__/settings.spec.tsx index c24286b..b51b84e 100644 --- a/app/lib/__tests__/settings.spec.tsx +++ b/app/lib/__tests__/settings.spec.tsx @@ -8,11 +8,12 @@ describe('Settings', () => { beforeEach(async () => { db = await getDb("development"); + await migrateDb("development"); settings = new Settings(db); }); afterEach(async () => { - await migrateDb("development"); + await migrateDb("development", "down"); }); it('should set the host language in the database', async () => { diff --git a/app/lib/__tests__/whisper.spec.tsx b/app/lib/__tests__/whisper.spec.tsx index ab9dd35..4125147 100644 --- a/app/lib/__tests__/whisper.spec.tsx +++ b/app/lib/__tests__/whisper.spec.tsx @@ -1,101 +1,170 @@ -// components/ui/__tests__/WhisperFile.spec.tsx +// app/lib/__tests__/whisper.spec.tsx import React from "react"; -import { render, act } from "@testing-library/react-native"; -import { WhisperFile } from "@/app/lib/whisper"; // Adjust the import path as necessary +import { getDb } from "@/app/lib/db"; +import { WhisperFile, WhisperModelTag } from "@/app/lib/whisper"; // Corrected to use WhisperFile and WhisperModelTag instead of WhisperDownloader +import { Settings } from "@/app/lib/settings"; +import { File } from "expo-file-system/next"; +jest.mock('expo-file-system'); +import * as FileSystem from 'expo-file-system'; + +jest.mock("@/app/lib/db", () => ({ + getDb: jest.fn().mockResolvedValue({ + runAsync: jest.fn(), + upsert: jest.fn(), // Mock the upsert method used in addToDatabase + }), +})); + +jest.mock("@/app/lib/settings", () => ({ + Settings: { + getDefault: jest.fn(() => ({ + getValue: jest.fn((key) => { + switch (key) { + case "whisper_model": + return "base"; + default: + throw new Error(`Invalid setting: '${key}'`); + } + }), + })), + }, +})); + + +jest.mock("expo-file-system/next", () => { + const _next = jest.requireActual("expo-file-system/next"); + return { + ..._next, + File: jest.fn().mockImplementation(() => ({ + ..._next.File, + text: jest.fn(() => { + return new String("text"); + }), + })), + }; +}); describe("WhisperFile", () => { + // Corrected to use WhisperFile instead of WhisperDownloader let whisperFile: WhisperFile; - beforeEach(() => { + beforeEach(async () => { whisperFile = new WhisperFile("small"); }); - it("should initialize correctly", () => { - expect(whisperFile).toBeInstanceOf(WhisperFile); + it("should create a download resumable with existing data if available", async () => { + const mockExistingData = "mockExistingData"; + jest.spyOn(whisperFile, "doesTargetExist").mockResolvedValue(true); + + await whisperFile.createDownloadResumable(); + // expect(whisperFile.targetFileName).toEqual("small.bin"); + expect(whisperFile.targetPath).toContain("small.bin"); + + expect(FileSystem.createDownloadResumable).toHaveBeenCalledWith( + "https://huggingface.co/openai/whisper-small/resolve/main/pytorch_model.bin", + "file:///whisper/small.bin", + {}, + expect.any(Function), + expect.anything(), + ); }); - describe("getModelFileSize", () => { - it("should return the correct model file size", async () => { - expect(whisperFile.size).toBeUndefined(); - await whisperFile.updateMetadata(); - expect(whisperFile.size).toBeGreaterThan(1000); - }); - }); + // it("should create a download resumable without existing data if not available", async () => { + // jest.spyOn(whisperFile, "doesTargetExist").mockResolvedValue(false); - describe("getWhisperDownloadStatus", () => { - it("should return the correct download status", async () => { - const mockStatus = { - doesTargetExist: true, - isDownloadComplete: false, - hasDownloadStarted: true, - progress: { - current: 100, - total: 200, - remaining: 100, - percentRemaining: 50.0, - }, - }; - jest - .spyOn(whisperFile, "getDownloadStatus") - .mockResolvedValue(mockStatus); + // await whisperFile.createDownloadResumable(); // Updated to use createDownloadResumable instead of download - const result = await whisperFile.getDownloadStatus(); + // expect(FileSystem.createDownloadResumable).toHaveBeenCalledWith( + // "http://mock.model.com/model", + // "mockTargetPath", + // {}, + // expect.any(Function), + // undefined + // ); + // }); - expect(result).toEqual(mockStatus); - }); - }); + // it("should update the download status in the database", async () => { + // const mockRunAsync = jest.fn(); + // (getDb as jest.Mock).mockResolvedValue({ runAsync: mockRunAsync }); - describe("initiateWhisperDownload", () => { - it("should initiate the download with default options", async () => { - const mockModelLabel = "small"; - jest - .spyOn(whisperFile, "createDownloadResumable") - .mockResolvedValue(true); + // const downloadable = await whisperFile.createDownloadResumable(); // Updated to use createDownloadResumable instead of download + // await downloadable.resumeAsync(); - await whisperFile.initiateWhisperDownload(mockModelLabel); + // jest.advanceTimersByTime(1000); - expect(whisperFile.createDownloadResumable).toHaveBeenCalledWith( - mockModelLabel - ); - }); + // expect(mockRunAsync).toHaveBeenCalled(); + // }); - it("should initiate the download with custom options", async () => { - const mockModelLabel = "small"; - const mockOptions = { force_redownload: true }; - jest - .spyOn(whisperFile, "createDownloadResumable") - .mockResolvedValue(true); + // it("should record the latest target hash after downloading", async () => { + // const mockRecordLatestTargetHash = jest.spyOn( + // whisperFile, + // "recordLatestTargetHash" + // ); - await whisperFile.initiateWhisperDownload(mockModelLabel, mockOptions); + // await whisperFile.createDownloadResumable(); // Updated to use createDownloadResumable instead of download - expect(whisperFile.createDownloadResumable).toHaveBeenCalledWith( - mockModelLabel, - mockOptions - ); - }); + // expect(mockRecordLatestTargetHash).toHaveBeenCalled(); + // }); - it("should return the correct download status when target exists and is complete", async () => { - jest.spyOn(whisperFile, "doesTargetExist").mockResolvedValue(true); - jest.spyOn(whisperFile, "isDownloadComplete").mockResolvedValue(true); + // it("should call the onData callback if provided", async () => { + // const mockOnData = jest.fn(); + // const options = { onData: mockOnData }; - expect(await whisperFile.doesTargetExist()).toEqual(true); - expect(await whisperFile.isDownloadComplete()).toEqual(true); - }); + // await whisperFile.createDownloadResumable(options); // Updated to use createDownloadResumable instead of download - it("should return the correct download status when target does not exist", async () => { - jest.spyOn(whisperFile, "doesTargetExist").mockResolvedValue(false); + // expect(mockOnData).toHaveBeenCalledWith(expect.any(Object)); + // }); - const result = await whisperFile.getDownloadStatus(); + // describe("getDownloadStatus", () => { + // it("should return the correct download status when model size is known and download has started", async () => { + // whisperFile.size = 1024; + // jest.spyOn(whisperFile, "doesTargetExist").mockResolvedValue(true); + // jest.spyOn(whisperFile, "isDownloadComplete").mockResolvedValue(false); + // jest.spyOn(whisperFile, "targetFile").mockReturnValue({ + // size: 512, + // }); - expect(result).toEqual({ - doesTargetExist: false, - isDownloadComplete: false, - hasDownloadStarted: false, - progress: undefined, - }); - }); - }); + // const status = await whisperFile.getDownloadStatus(); - // Add more tests as needed for other methods in WhisperFile + // expect(status).toEqual({ + // doesTargetExist: true, + // isDownloadComplete: false, + // hasDownloadStarted: true, + // progress: { + // current: 512, + // total: 1024, + // remaining: 512, + // percentRemaining: 50.0, + // }, + // }); + // }); + + // it("should return the correct download status when model size is known and download is complete", async () => { + // whisperFile.size = 1024; + // jest.spyOn(whisperFile, "doesTargetExist").mockResolvedValue(true); + // jest.spyOn(whisperFile, "isDownloadComplete").mockResolvedValue(true); + + // const status = await whisperFile.getDownloadStatus(); + + // expect(status).toEqual({ + // doesTargetExist: true, + // isDownloadComplete: true, + // hasDownloadStarted: false, + // progress: undefined, + // }); + // }); + + // it("should return the correct download status when model size is unknown", async () => { + // jest.spyOn(whisperFile, "doesTargetExist").mockResolvedValue(false); + + // const status = await whisperFile.getDownloadStatus(); + + // expect(status).toEqual({ + // doesTargetExist: false, + // isDownloadComplete: false, + // hasDownloadStarted: false, + // progress: undefined, + // }); + // }); + // }); }); diff --git a/app/lib/migrations.ts b/app/lib/migrations.ts index 9da55d8..98c1fab 100644 --- a/app/lib/migrations.ts +++ b/app/lib/migrations.ts @@ -2,15 +2,16 @@ export const MIGRATE_UP = { 1: [ `CREATE TABLE IF NOT EXISTS settings ( - host_language TEXT, - libretranslate_base_url TEXT, - ui_direction INTEGER, - whisper_model TEXT - )`, + key TEXT PRIMARY KEY, + value TEXT + )`, ], 2: [ `CREATE TABLE IF NOT EXISTS whisper_models ( model TEXT PRIMARY KEY, + download_status STRING(255), + expected_size INTEGER, + last_hash STRING(1024), bytes_done INTEGER, bytes_total INTEGER )`, diff --git a/app/lib/settings.ts b/app/lib/settings.ts index 99b8cc3..ba3eba2 100644 --- a/app/lib/settings.ts +++ b/app/lib/settings.ts @@ -1,5 +1,6 @@ import { SQLiteDatabase } from "expo-sqlite"; import { getDb } from "./db"; +import { WhisperFile, whisper_model_tag_t } from "./whisper"; export class Settings { @@ -20,10 +21,9 @@ export class Settings { throw new Error(`Invalid setting: '${key}'`) } - const row: { [key: string]: string } | null = this.db.getFirstSync(`SELECT ${key} from settings LIMIT 1`) + const row: { value: string } | null = this.db.getFirstSync(`SELECT value FROM settings WHERE key = ?`, key) - if (!(row && row[key])) return undefined; - return row[key]; + return row?.value; } @@ -33,13 +33,11 @@ export class Settings { } // Check if the key already exists - this.db.runSync(`INSERT INTO OR UPDATE + this.db.runSync(`INSERT OR REPLACE INTO settings - (${key}) + (key, value) VALUES - (?) - WHERE - ${key} IS NOT NULL`, value); + (?, ?)`, key, value); } async setHostLanguage(value: string) { @@ -63,11 +61,10 @@ export class Settings { } async getWhisperModel() { - return await this.getValue("whisper_model"); + return await this.getValue("whisper_model") as whisper_model_tag_t; } static async getDefault() { return new Settings(await getDb()) } - } \ No newline at end of file diff --git a/app/lib/util.ts b/app/lib/util.ts new file mode 100644 index 0000000..b81db9f --- /dev/null +++ b/app/lib/util.ts @@ -0,0 +1,5 @@ +import { TextDecoder } from "util"; + +export async function arrbufToStr(arrayBuffer : ArrayBuffer) { + return new TextDecoder().decode(new Uint8Array(arrayBuffer)); +} \ No newline at end of file diff --git a/app/lib/whisper.ts b/app/lib/whisper.ts index fef323c..503dc30 100644 --- a/app/lib/whisper.ts +++ b/app/lib/whisper.ts @@ -3,6 +3,7 @@ import * as FileSystem from "expo-file-system"; import { File, Paths } from "expo-file-system/next"; import { getDb } from "./db"; import * as Crypto from "expo-crypto"; +import { arrbufToStr } from "./util"; export const WHISPER_MODEL_PATH = Paths.join( FileSystem.documentDirectory || "file:///", @@ -114,6 +115,12 @@ export type download_status_t = { }; export class WhisperFile { + hf_metadata: hf_metadata_t | undefined; + + target_hash: string | undefined; + does_target_exist: boolean = false; + download_data: FileSystem.DownloadProgressData | undefined; + constructor( public tag: whisper_model_tag_t, private targetFileName?: string, @@ -122,11 +129,11 @@ export class WhisperFile { ) { this.targetFileName = this.targetFileName || `${tag}.bin`; this.label = - this.label || `${tag[0].toUpperCase}${tag.substring(1).toLowerCase()}`; + this.label || `${tag[0].toUpperCase()}${tag.substring(1).toLowerCase()}`; } get targetPath() { - return Paths.join(WHISPER_MODEL_DIR, this.targetFileName as string); + return Paths.join(WHISPER_MODEL_PATH, this.targetFileName as string); } get targetFile() { @@ -137,79 +144,30 @@ export class WhisperFile { return await FileSystem.getInfoAsync(this.targetPath); } - async doesTargetExist() { - return (await this.getTargetInfo()).exists; + async updateTargetExistence() { + this.does_target_exist = (await this.getTargetInfo()).exists; } - public async recordLatestTargetHash() { - if (!(await this.doesTargetExist())) { - console.debug("%s does not exist", this.targetPath); - } - if (!this.label) { - throw new Error("No label"); - } - const digest1Str = await this.getActualTargetHash(); - if (!digest1Str) { - return; - } - const db = await getDb(); - db.runSync(`INSERT OR UPDATE - INTO whisper_models - (model, last_hash) - VALUES (?, ?) - WHERE - model = ?`, this.label, digest1Str, this.label); - } - - public async getRecordedTargetHash(): Promise { - const db = await getDb(); - const row = db.getFirstSync("SELECT last_hash FROM whisper_models WHERE model = ?", this.tag); - return (row as {last_hash: string}).last_hash - } - - public async getActualTargetHash(): Promise { - if (!(await this.doesTargetExist())) { + public async getTargetSha() { + await this.updateTargetExistence(); + if (!this.does_target_exist) { console.debug("%s does not exist", this.targetPath); return undefined; } - const digest1 = await Crypto.digest( + return await Crypto.digest( Crypto.CryptoDigestAlgorithm.SHA256, this.targetFile.bytes() ); - const digest1Str = new TextDecoder().decode(new Uint8Array(digest1)); - return digest1Str; } - async isTargetCorrupted() { - const recordedTargetHash = await this.getRecordedTargetHash(); - const actualTargetHash = await this.getActualTargetHash(); - if (!(actualTargetHash || recordedTargetHash)) return false; - return actualTargetHash !== recordedTargetHash; + public async updateTargetHash() { + const targetSha = await this.getTargetSha(); + if (!targetSha) return; + this.target_hash = await arrbufToStr(targetSha); } - async isDownloadComplete() { - if (!(await this.doesTargetExist())) { - console.debug("%s does not exist", this.targetPath); - return false; - } - const data = this.targetFile.bytes(); - const meta = await this.fetchMetadata(); - const expectedHash = meta.oid; - const digest1: ArrayBuffer = await Crypto.digest( - Crypto.CryptoDigestAlgorithm.SHA256, - data - ); - const digest1Str = new TextDecoder().decode(new Uint8Array(digest1)); - const doesMatch = digest1Str === expectedHash; - if (!doesMatch) { - console.debug( - "sha256 of '%s' does not match expected '%s'", - digest1Str, - expectedHash - ); - return false; - } - return true; + get isHashValid() { + return this.target_hash === this.hf_metadata?.oid; } delete(ignoreErrors = true) { @@ -232,7 +190,21 @@ export class WhisperFile { return create_hf_url(this.tag, "raw"); } - private async fetchMetadata(): Promise { + get percentDone() { + if (!this.download_data) return 0; + return ( + (this.download_data.totalBytesWritten / + this.download_data.totalBytesExpectedToWrite) * + 100 + ); + } + + get percentLeft() { + if (!this.download_data) return 0; + return 100 - this.percentDone; + } + + public async syncHfMetadata() { try { const resp = await fetch(this.metadataUrl, { credentials: "include", @@ -254,7 +226,7 @@ export class WhisperFile { mode: "cors", }); const text = await resp.text(); - return Object.fromEntries( + this.hf_metadata = Object.fromEntries( text.split("\n").map((line) => line.split(" ")) ) as hf_metadata_t; } catch (err) { @@ -263,95 +235,50 @@ export class WhisperFile { } } - async updateMetadata() { - const metadata = await this.fetchMetadata(); - this.size = Number.parseInt(metadata.size); - } - - async addToDatabase() { - const db = await getDb(); - if (!(this.size && this.tag)) { - throw new Error(); - } - db.runSync(`INSERT OR UPDATE - INTO whisper_models - (model, expected_size) - VALUES - (?, ?) - WHERE - model = ?`, this.tag, this.size.valueOf(), this.tag); - } - async createDownloadResumable( options: { onData?: DownloadCallback | undefined; } = { - onData: undefined, - } + onData: undefined, + } ) { - const existingData = (await this.doesTargetExist()) + await this.syncHfMetadata(); + + // If the whisper model dir doesn't exist, create it. + if (!WHISPER_MODEL_DIR.exists) { + FileSystem.makeDirectoryAsync(WHISPER_MODEL_PATH, { + intermediates: true, + }); + } + + // Check for the existence of the target file + // If it exists, load the existing data. + await this.updateTargetExistence(); + const existingData = this.does_target_exist ? this.targetFile.text() : undefined; - if (await this.doesTargetExist()) { - } - + // Create the resumable. return FileSystem.createDownloadResumable( this.modelUrl, this.targetPath, {}, async (data: FileSystem.DownloadProgressData) => { - const db = await getDb(); - db.runAsync(`INSERT INTO OR UPDATE - whisper_models - (model, download_status) - VALUES - (?, ?) - WHERE - model = ? - `, this.tag, "active", this.tag); - await this.recordLatestTargetHash(); - if (options.onData) await options.onData(data); + this.download_data = data; + await this.syncHfMetadata(); + await this.updateTargetHash(); + await this.updateTargetExistence(); + if (options.onData) await options.onData(this); }, existingData ? existingData : undefined ); } - - async getDownloadStatus(): Promise { - const doesTargetExist = await this.doesTargetExist(); - const isDownloadComplete = await this.isDownloadComplete(); - const hasDownloadStarted = doesTargetExist && !isDownloadComplete; - - if (!this.size) { - return { - doesTargetExist: false, - isDownloadComplete: false, - hasDownloadStarted: false, - progress: undefined, - } - } - - const remaining = hasDownloadStarted - ? this.size - (this.targetFile.size as number) - : 0; - - const progress = hasDownloadStarted - ? { - current: this.targetFile.size || 0, - total: this.size, - remaining: this.size - (this.targetFile.size as number), - percentRemaining: (remaining / this.size) * 100.0, - } - : undefined; - - return { - doesTargetExist, - isDownloadComplete, - hasDownloadStarted, - progress, - }; - } } +export type DownloadCallback = (arg0: WhisperFile) => any; -export type DownloadCallback = (arg0: FileSystem.DownloadProgressData) => any; +export const WHISPER_FILES = { + small: new WhisperFile("small"), + medium: new WhisperFile("medium"), + large: new WhisperFile("large"), +}; diff --git a/components/LanguageSelection.tsx b/components/LanguageSelection.tsx index 0771cd0..8126be5 100644 --- a/components/LanguageSelection.tsx +++ b/components/LanguageSelection.tsx @@ -8,6 +8,7 @@ import { SafeAreaProvider, SafeAreaView } from "react-native-safe-area-context"; import { Conversation, Speaker } from "@/app/lib/conversation"; import { NavigationProp, ParamListBase } from "@react-navigation/native"; import { Link, useNavigation } from "expo-router"; +import { migrateDb } from "@/app/lib/db"; export function LanguageSelection(props: { @@ -30,6 +31,7 @@ export function LanguageSelection(props: { useEffect(() => { (async () => { + await migrateDb(); try { // Replace with your actual async data fetching logic setTranslator(await CachedTranslator.getDefault()); @@ -49,12 +51,12 @@ export function LanguageSelection(props: { Settings - - + + {(languages && languagesLoaded) ? Object.entries(languages).filter((l) => (LANG_FLAGS as any)[l[0]] !== undefined).map( ([lang, lang_entry]) => { return ( - + ); } ) : Waiting... @@ -66,11 +68,15 @@ export function LanguageSelection(props: { ) } +const DEBUG_BORDER = { + borderWidth: 3, + borderStyle: "dotted", + borderColor: "blue", +} + const styles = StyleSheet.create({ - column: { - flex: 1, - flexDirection: 'row', - flexWrap: 'wrap', - padding: 8, + table: { + flexDirection: "row", + flexWrap: "wrap", }, }) \ No newline at end of file diff --git a/components/Settings.tsx b/components/Settings.tsx index 9490655..5c8d10d 100644 --- a/components/Settings.tsx +++ b/components/Settings.tsx @@ -1,6 +1,7 @@ import React, { useState, useEffect } from "react"; import { View, Text, TextInput, Pressable, StyleSheet } from "react-native"; import { + WHISPER_FILES, WhisperFile, download_status_t, whisper_tag_t, @@ -34,33 +35,36 @@ const SettingsComponent = () => { } | null>(null); const [whisperModel, setWhisperModel] = useState("small"); - const [downloader, setDownloader] = useState(null); const [whisperFile, setWhisperFile] = useState(); - const [downloadStatus, setDownloadStatus] = useState< - undefined | download_status_t - >(); + const [whisperFileExists, setWhisperFileExists] = useState(false); + const [isWhisperHashValid, setIsWhisperHashValid] = useState(false); + const [downloader, setDownloader] = useState(null); + const [bytesDone, setBytesDone] = useState(); + const [bytesRemaining, setBytesRemaining] = useState(); const [statusTimeout, setStatusTimeout] = useState< NodeJS.Timeout | undefined >(); useEffect(() => { - loadSettings(); - }, []); + (async function () { + const settings = await Settings.getDefault(); + setHostLanguage((await settings.getHostLanguage()) || "en"); + setLibretranslateBaseUrl( + (await settings.getLibretranslateBaseUrl()) || LIBRETRANSLATE_BASE_URL + ); + setWhisperModel((await settings.getWhisperModel()) || "small"); + setWhisperFile(WHISPER_FILES[whisperModel]); + await whisperFile?.syncHfMetadata(); + await whisperFile?.updateTargetExistence(); + await whisperFile?.updateTargetHash(); + })(); + }, [whisperFile]); const getLanguageOptions = async () => { const languageServer = await LanguageServer.getDefault(); setLanguageOptions(await languageServer.fetchLanguages()); }; - const loadSettings = async () => { - const settings = await Settings.getDefault(); - setHostLanguage((await settings.getHostLanguage()) || "en"); - setLibretranslateBaseUrl( - (await settings.getLibretranslateBaseUrl()) || LIBRETRANSLATE_BASE_URL - ); - setWhisperModel(await settings.getWhisperModel()); - }; - const handleHostLanguageChange = async (lang: string) => { const settings = await Settings.getDefault(); setHostLanguage(lang); @@ -83,17 +87,24 @@ const SettingsComponent = () => { } }; - const intervalUpdateDownloadStatus = async () => { - if (!whisperFile) return; - const status = await whisperFile.getDownloadStatus(); - setDownloadStatus(status); - }; - const handleWhisperModelChange = async (model: whisper_tag_t) => { const settings = await Settings.getDefault(); await settings.setWhisperModel(model); setWhisperModel(model); - setWhisperFile(new WhisperFile(model)); + const wFile = WHISPER_FILES[whisperModel]; + await wFile.syncHfMetadata(); + await wFile.updateTargetExistence(); + await wFile.updateTargetHash(); + setIsWhisperHashValid(wFile.isHashValid); + setWhisperFile(wFile); + setWhisperFileExists(wFile.does_target_exist); + }; + + const doSetDownloadStatus = (arg0: WhisperFile) => { + console.log("Downloading ....") + setIsWhisperHashValid(arg0.isHashValid); + setBytesDone(arg0.download_data?.totalBytesWritten); + setBytesRemaining(arg0.download_data?.totalBytesExpectedToWrite); }; const doDownload = async () => { @@ -101,16 +112,16 @@ const SettingsComponent = () => { throw new Error("Could not start download because whisperModel not set."); } - console.log("Starging download of %s", whisperModel) + console.log("Starting download of %s", whisperModel); - const whisperFile = new WhisperFile(whisperModel); + if (!whisperFile) throw new Error("No whisper file"); - const resumable = await whisperFile.createDownloadResumable(); - setDownloader(resumable); try { - await resumable.downloadAsync(); - const statusTimeout = setInterval(intervalUpdateDownloadStatus, 200); - setStatusTimeout(statusTimeout); + const resumable = await whisperFile.createDownloadResumable({ + onData: doSetDownloadStatus, + }); + setDownloader(resumable); + await resumable.resumeAsync(); } catch (error) { console.error("Failed to download whisper model:", error); } @@ -174,28 +185,22 @@ const SettingsComponent = () => { ))} - {whisperModel && - (downloadStatus?.isDownloadComplete ? ( - downloadStatus?.doesTargetExist ? ( - - DELETE {whisperModel.toUpperCase()} - - ) : ( - - PAUSE - - ) - ) : ( - + {/* whisper file: { whisperFile?.tag } */} + {whisperFile && + ( whisperFileExists && ( + DELETE {whisperModel.toUpperCase()} + )) +} + DOWNLOAD {whisperModel.toUpperCase()} ))} - {downloadStatus?.progress && ( + {bytesDone && bytesRemaining && ( - {downloadStatus.progress.current} of{" "} - {downloadStatus.progress.total} ( - {downloadStatus.progress.percentRemaining} %){" "} + {bytesDone} of{" "} + {bytesRemaining} ( + {bytesDone / bytesRemaining * 100} %){" "} )} diff --git a/components/__tests__/index.spec.tsx b/components/__tests__/index.spec.tsx index b4ea7a0..f7a21b1 100644 --- a/components/__tests__/index.spec.tsx +++ b/components/__tests__/index.spec.tsx @@ -1,5 +1,5 @@ jest.mock("@/app/i18n/api", () => require("../../__mocks__/api.ts")); -import { renderRouter} from 'expo-router/testing-library'; +import { renderRouter } from "expo-router/testing-library"; import React from "react"; import { act, @@ -13,14 +13,21 @@ import { createNavigationContainerRef, } from "@react-navigation/native"; import TTNavStack from "../TTNavStack"; +import { migrateDb } from "@/app/lib/db"; describe("Navigation", () => { - beforeEach(() => { + beforeEach(async () => { + await migrateDb("development", "up"); // Reset the navigation state before each test - jest.clearAllMocks(); jest.useFakeTimers(); }); + afterEach(async () => { + await migrateDb("development", "down"); + jest.clearAllMocks(); + jest.useRealTimers(); + }); + it("Navigates to ConversationThread on language selection", async () => { const MockComponent = jest.fn(() => ); renderRouter( @@ -28,7 +35,7 @@ describe("Navigation", () => { index: MockComponent, }, { - initialUrl: '/', + initialUrl: "/", } ); const languageSelectionText = await waitFor(() => @@ -47,14 +54,16 @@ describe("Navigation", () => { index: MockComponent, }, { - initialUrl: '/', + initialUrl: "/", } ); const settingsButton = await waitFor(() => screen.getByText(/.*Settings.*/i) ); fireEvent.press(settingsButton); - expect(await waitFor(() => screen.getByText(/Settings/i))).toBeOnTheScreen(); + expect( + await waitFor(() => screen.getByText(/Settings/i)) + ).toBeOnTheScreen(); // expect(waitFor(() => screen.getByText(/Settings/i))).toBeTruthy() expect(screen.getByText("Settings")).toBeOnTheScreen(); }); diff --git a/components/ui/ISpeakButton.tsx b/components/ui/ISpeakButton.tsx index ec6a754..2072d9f 100644 --- a/components/ui/ISpeakButton.tsx +++ b/components/ui/ISpeakButton.tsx @@ -106,7 +106,12 @@ const ISpeakButton = (props: ISpeakButtonProps) => { {countries && countries.map((c) => { - return ; + return ( + + {c} + + + ); })} @@ -121,14 +126,13 @@ const ISpeakButton = (props: ISpeakButtonProps) => { const styles = StyleSheet.create({ button: { - width: "20%", borderRadius: 10, borderColor: "white", borderWidth: 1, borderStyle: "solid", height: 110, - alignSelf: "flex-start", - margin: 8, + width: 170, + margin: 10, }, flag: {}, iSpeak: { diff --git a/components/ui/__tests__/Settings.spec.tsx b/components/ui/__tests__/Settings.spec.tsx index 50cdfd0..33ae1cb 100644 --- a/components/ui/__tests__/Settings.spec.tsx +++ b/components/ui/__tests__/Settings.spec.tsx @@ -83,6 +83,7 @@ describe("SettingsComponent", () => { beforeEach(async () => { db = await getDb("development"); + await migrateDb("development"); settings = new Settings(db); jest.spyOn(Settings, 'getDefault').mockResolvedValue(settings); await settings.setHostLanguage("en"); diff --git a/jestSetup.ts b/jestSetup.ts index c244553..e199f00 100644 --- a/jestSetup.ts +++ b/jestSetup.ts @@ -9,43 +9,37 @@ jest.mock("expo-sqlite", () => { const { MIGRATE_UP } = jest.requireActual("./app/lib/migrations"); + const genericRun = (sql: string, ... params : string []) => { + // console.log("Running %s with %s", sql, params); + try { + const stmt = db.prepare(sql); + stmt.run(...params); + } catch (e) { + throw new Error( + `running ${sql} with params ${JSON.stringify(params)}: ${e}` + ); + } + } + + const genericGetFirst = (sql: string, params = []) => { + const stmt = db.prepare(sql); + // const result = stmt.run(...params); + return stmt.get(params); + }; + const openDatabaseAsync = async (name: string) => { return { closeAsync: jest.fn(() => db.close()), executeSql: jest.fn((sql: string) => db.exec(sql)), - runAsync: jest.fn(async (sql: string, params = []) => { - for (let m of Object.values(MIGRATE_UP)) { - for (let stmt of m) { - const s = db.prepare(stmt); - s.run(); - } - } - const stmt = db.prepare(sql); - // console.log("Running %s with %s", sql, params); - try { - stmt.run(params); - } catch (e) { - throw new Error( - `running ${sql} with params ${JSON.stringify(params)}: ${e}` - ); - } - }), - getFirstAsync: jest.fn(async (sql: string, params = []) => { - for (let m of Object.values(MIGRATE_UP)) { - for (let stmt of m) { - const s = db.prepare(stmt); - s.run(); - } - } - const stmt = db.prepare(sql); - // const result = stmt.run(...params); - return stmt.get(params); - }), + runAsync: jest.fn(genericRun), + runSync: jest.fn(genericRun), + getFirstAsync: jest.fn(genericGetFirst), + getFirstSync: jest.fn(genericGetFirst), }; }; return { migrateDb: async (direction: "up" | "down" = "up") => { - const db = await openDatabaseAsync("translation_terrace"); + const db = await openDatabaseAsync("translation_terrace_development"); for (let m of Object.values(MIGRATE_UP)) { for (let stmt of m) { await db.executeSql(stmt); @@ -119,4 +113,18 @@ jest.mock('@/app/lib/settings', () => { ...originalModule, default: MockSettings }; -}); \ No newline at end of file +}); + +jest.mock('expo-file-system', () => ({ + // ... other methods ... + createDownloadResumable: jest.fn(() => ({ + downloadAsync: jest.fn(() => Promise.resolve({ uri: 'mocked-uri' })), + pauseAsync: jest.fn(() => Promise.resolve()), + resumeAsync: jest.fn(() => Promise.resolve()), + cancelAsync: jest.fn(() => Promise.resolve()), + })), + getInfoAsync: jest.fn(() => ({ + exists: () => false, + })) + // ... other methods ... +}));