2025-02-28 07:13:45 -08:00

213 lines
5.9 KiB
TypeScript

import { Platform } from "react-native";
import * as FileSystem from "expo-file-system";
import { File, Paths } from 'expo-file-system/next';
import { getDb } from "./db";
export const WHISPER_MODEL_PATH = Paths.join(FileSystem.bundleDirectory || "file:///", "whisper");
export const WHISPER_MODEL_DIR = new File(WHISPER_MODEL_PATH);
// Thanks to https://medium.com/@fabi.mofar/downloading-and-saving-files-in-react-native-expo-5b3499adda84
export async function saveFile(
uri: string,
filename: string,
mimetype: string
) {
if (Platform.OS === "android") {
const permissions =
await FileSystem.StorageAccessFramework.requestDirectoryPermissionsAsync();
if (permissions.granted) {
const base64 = await FileSystem.readAsStringAsync(uri, {
encoding: FileSystem.EncodingType.Base64,
});
await FileSystem.StorageAccessFramework.createFileAsync(
permissions.directoryUri,
filename,
mimetype
)
.then(async (uri) => {
await FileSystem.writeAsStringAsync(uri, base64, {
encoding: FileSystem.EncodingType.Base64,
});
})
.catch((e) => console.log(e));
} else {
shareAsync(uri);
}
} else {
shareAsync(uri);
}
}
function shareAsync(uri: string) {
throw new Error("Function not implemented.");
}
export const WHISPER_MODEL_TAGS = ["small", "medium", "large"];
export type whisper_model_tag_t = (typeof WHISPER_MODEL_TAGS)[number];
export const WHISPER_MODELS = {
small: {
source:
"https://huggingface.co/openai/whisper-small/blob/main/pytorch_model.bin",
target: "small.bin",
label: "Small",
},
medium: {
source:
"https://huggingface.co/openai/whisper-medium/blob/main/pytorch_model.bin",
target: "medium.bin",
label: "Medium",
},
large: {
source:
"https://huggingface.co/openai/whisper-large/blob/main/pytorch_model.bin",
target: "large.bin",
label: "Large",
},
} as {
[key: whisper_model_tag_t]: { source: string; target: string; label: string };
};
export function getWhisperTarget(key : whisper_model_tag_t) {
const path = Paths.join(WHISPER_MODEL_DIR, WHISPER_MODELS[key].target);
return new File(path)
}
export type download_status =
| {
status: "not_started" | "complete";
}
| {
status: "in_progress";
bytes: {
total: number;
done: number;
};
};
export async function getModelFileSize(whisper_model: whisper_model_tag_t) {
const target = getWhisperTarget(whisper_model)
if (!target.exists) return undefined;
return target.size;
}
/**
*
* @param whisper_model The whisper model key to check (e.g. `"small"`)
* @returns
*/
export async function getWhisperDownloadStatus(
whisper_model: whisper_model_tag_t
): Promise<download_status> {
// const files = await FileSystem.readDirectoryAsync("file:///whisper");
const result = (await (
await getDb()
).getFirstSync(
`
SELECT (bytes_done, total) WHERE model = ?
`,
[whisper_model]
)) as { bytes_done: number; total: number } | undefined;
if (!result)
return {
status: "not_started",
};
if (result.bytes_done < result.total)
return {
status: "in_progress",
bytes: {
done: result.bytes_done,
total: result.total,
},
};
return {
status: "complete",
};
}
export function whisperFileExists(whisper_model : whisper_model_tag_t) {
const target = getWhisperTarget(whisper_model);
return target.exists
}
export type DownloadCallback = (arg0 : FileSystem.DownloadProgressData) => any;
async function updateModelSize(model_label : string, size : number) {
const db = await getDb();
const query = "INSERT OR REPLACE INTO whisper_models (model, bytes_total) VALUES (?, ?)"
const stmt = db.prepareSync(query);
stmt.executeSync(model_label, size);
}
async function getExpectedModelSize(model_label : string) : Promise<number | undefined> {
const db = await getDb();
const query = "SELECT bytes_total FROM whisper_models WHERE model = ?"
const stmt = db.prepareSync(query);
const curs = stmt.executeSync(model_label);
const row = curs.getFirstSync()
return row ? row.bytes_total : undefined;
}
export async function initiateWhisperDownload(
whisper_model: whisper_model_tag_t,
options: {
force_redownload?: boolean;
onDownload?: DownloadCallback | undefined;
} = {
force_redownload: false,
onDownload: undefined,
}
) {
console.debug("Starting download of %s", whisper_model);
if (!WHISPER_MODEL_DIR.exists) {
await FileSystem.makeDirectoryAsync(WHISPER_MODEL_PATH, {
intermediates: true,
});
console.debug("Created %s", WHISPER_MODEL_DIR);
}
const whisperTarget = getWhisperTarget(whisper_model);
// If the target file exists, delete it.
if (whisperTarget.exists) {
if (options.force_redownload) {
whisperTarget.delete()
} else {
const expected = await getExpectedModelSize(whisper_model);
if (whisperTarget.size === expected) {
console.warn("Whisper model for %s already exists", whisper_model);
return undefined;
}
}
}
// Initiate a new resumable download.
const spec = WHISPER_MODELS[whisper_model];
console.log("Downloading %s", spec.source);
const resumable = FileSystem.createDownloadResumable(
spec.source,
whisperTarget.uri,
{},
// On each data write, update the whisper model download status.
// Note that since createDownloadResumable callback only works in the foreground,
// a background process will also be updating the file size.
async (data) => {
console.log("%s: %d bytes of %d", whisperTarget.uri, data.totalBytesWritten, data.totalBytesExpectedToWrite);
await updateModelSize(whisper_model, data.totalBytesExpectedToWrite)
if (options.onDownload) await options.onDownload(data);
},
whisperTarget.exists ? whisperTarget.base64() : undefined,
);
return resumable;
}