347 lines
9.0 KiB
TypeScript
347 lines
9.0 KiB
TypeScript
import { Platform } from "react-native";
|
|
import * as FileSystem from "expo-file-system";
|
|
import { File, Paths } from "expo-file-system/next";
|
|
import { getDb } from "./db";
|
|
import * as Crypto from "expo-crypto";
|
|
|
|
export const WHISPER_MODEL_PATH = Paths.join(
|
|
FileSystem.documentDirectory || "file:///",
|
|
"whisper"
|
|
);
|
|
export const WHISPER_MODEL_DIR = new File(WHISPER_MODEL_PATH);
|
|
|
|
// Thanks to https://medium.com/@fabi.mofar/downloading-and-saving-files-in-react-native-expo-5b3499adda84
|
|
|
|
export async function saveFile(
|
|
uri: string,
|
|
filename: string,
|
|
mimetype: string
|
|
) {
|
|
if (Platform.OS === "android") {
|
|
const permissions =
|
|
await FileSystem.StorageAccessFramework.requestDirectoryPermissionsAsync();
|
|
|
|
if (permissions.granted) {
|
|
const base64 = await FileSystem.readAsStringAsync(uri, {
|
|
encoding: FileSystem.EncodingType.Base64,
|
|
});
|
|
|
|
await FileSystem.StorageAccessFramework.createFileAsync(
|
|
permissions.directoryUri,
|
|
filename,
|
|
mimetype
|
|
)
|
|
.then(async (uri) => {
|
|
await FileSystem.writeAsStringAsync(uri, base64, {
|
|
encoding: FileSystem.EncodingType.Base64,
|
|
});
|
|
})
|
|
.catch((e) => console.log(e));
|
|
} else {
|
|
shareAsync(uri);
|
|
}
|
|
} else {
|
|
shareAsync(uri);
|
|
}
|
|
}
|
|
|
|
function shareAsync(uri: string) {
|
|
throw new Error("Function not implemented.");
|
|
}
|
|
|
|
export const WHISPER_MODEL_TAGS = ["small", "medium", "large"];
|
|
export type whisper_model_tag_t = "small" | "medium" | "large";
|
|
|
|
export const WHISPER_MODELS = {
|
|
small: {
|
|
source:
|
|
"https://huggingface.co/openai/whisper-small/blob/resolve/pytorch_model.bin",
|
|
target: "small.bin",
|
|
label: "Small",
|
|
size: 967092419,
|
|
},
|
|
medium: {
|
|
source:
|
|
"https://huggingface.co/openai/whisper-medium/resolve/main/pytorch_model.bin",
|
|
target: "medium.bin",
|
|
label: "Medium",
|
|
size: 3055735323,
|
|
},
|
|
large: {
|
|
source:
|
|
"https://huggingface.co/openai/whisper-large/resolve/main/pytorch_model.bin",
|
|
target: "large.bin",
|
|
label: "Large",
|
|
size: 6173629930,
|
|
},
|
|
} as {
|
|
[key:whisper_model_tag_t]: {
|
|
source: string;
|
|
target: string;
|
|
label: string;
|
|
size: number;
|
|
};
|
|
};
|
|
|
|
export type whisper_tag_t = "small" | "medium" | "large";
|
|
export type hf_channel_t = "raw" | "resolve";
|
|
|
|
export const HF_URL_BASE = "https://huggingface.co/openai/whisper-";
|
|
export const HF_URL_RAW = "raw";
|
|
export const HF_URL_RESOLVE = "resolve";
|
|
export const HF_URL_END = "/main/pytorch_model.bin";
|
|
|
|
export function create_hf_url(tag: whisper_tag_t, channel: hf_channel_t) {
|
|
return `${HF_URL_BASE}${tag}/${channel}${HF_URL_END}`;
|
|
}
|
|
|
|
export type hf_metadata_t = {
|
|
version: string;
|
|
oid: string;
|
|
size: string;
|
|
};
|
|
|
|
export type download_status_t = {
|
|
doesTargetExist: boolean;
|
|
isDownloadComplete: boolean;
|
|
hasDownloadStarted: boolean;
|
|
progress?: {
|
|
current: number;
|
|
total: number;
|
|
remaining: number;
|
|
percentRemaining: number;
|
|
};
|
|
};
|
|
|
|
export class WhisperFile {
|
|
constructor(
|
|
public tag: whisper_model_tag_t,
|
|
private targetFileName?: string,
|
|
public label?: string,
|
|
public size?: number
|
|
) {
|
|
this.targetFileName = this.targetFileName || `${tag}.bin`;
|
|
this.label =
|
|
this.label || `${tag[0].toUpperCase}${tag.substring(1).toLowerCase()}`;
|
|
}
|
|
|
|
get targetPath() {
|
|
return Paths.join(WHISPER_MODEL_DIR, this.targetFileName as string);
|
|
}
|
|
|
|
get targetFile() {
|
|
return new File(this.targetPath);
|
|
}
|
|
|
|
async getTargetInfo() {
|
|
return await FileSystem.getInfoAsync(this.targetPath);
|
|
}
|
|
|
|
async doesTargetExist() {
|
|
return (await this.getTargetInfo()).exists;
|
|
}
|
|
|
|
public async recordLatestTargetHash() {
|
|
if (!(await this.doesTargetExist())) {
|
|
console.debug("%s does not exist", this.targetPath);
|
|
}
|
|
const digest1Str = await this.getActualTargetHash();
|
|
if (!digest1Str) {
|
|
return;
|
|
}
|
|
const db = await getDb();
|
|
await db("whisper_models")
|
|
.upsert({
|
|
model: this.tag,
|
|
last_hash: digest1Str,
|
|
})
|
|
.where({ model: this.tag });
|
|
}
|
|
|
|
public async getRecordedTargetHash(): Promise<string> {
|
|
const db = await getDb();
|
|
const row = await db("whisper_models").select("last_hash").where({
|
|
model: this.tag,
|
|
}).first();
|
|
return row["last_hash"]
|
|
}
|
|
|
|
public async getActualTargetHash(): Promise<string | undefined> {
|
|
if (!(await this.doesTargetExist())) {
|
|
console.debug("%s does not exist", this.targetPath);
|
|
return undefined;
|
|
}
|
|
const digest1 = await Crypto.digest(
|
|
Crypto.CryptoDigestAlgorithm.SHA256,
|
|
this.targetFile.bytes()
|
|
);
|
|
const digest1Str = new TextDecoder().decode(new Uint8Array(digest1));
|
|
return digest1Str;
|
|
}
|
|
|
|
async isTargetCorrupted() {
|
|
const recordedTargetHash = await this.getRecordedTargetHash();
|
|
const actualTargetHash = await this.getActualTargetHash();
|
|
if (!(actualTargetHash || recordedTargetHash)) return false;
|
|
return actualTargetHash !== recordedTargetHash;
|
|
}
|
|
|
|
async isDownloadComplete() {
|
|
if (!(await this.doesTargetExist())) {
|
|
console.debug("%s does not exist", this.targetPath);
|
|
return false;
|
|
}
|
|
const data = this.targetFile.bytes();
|
|
const meta = await this.fetchMetadata();
|
|
const expectedHash = meta.oid;
|
|
const digest1: ArrayBuffer = await Crypto.digest(
|
|
Crypto.CryptoDigestAlgorithm.SHA256,
|
|
data
|
|
);
|
|
const digest1Str = new TextDecoder().decode(new Uint8Array(digest1));
|
|
const doesMatch = digest1Str === expectedHash;
|
|
if (!doesMatch) {
|
|
console.debug(
|
|
"sha256 of '%s' does not match expected '%s'",
|
|
digest1Str,
|
|
expectedHash
|
|
);
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
delete(ignoreErrors = true) {
|
|
try {
|
|
this.targetFile.delete();
|
|
} catch (err) {
|
|
console.error(err);
|
|
if (!ignoreErrors) {
|
|
throw err;
|
|
}
|
|
}
|
|
}
|
|
|
|
get modelUrl() {
|
|
return create_hf_url(this.tag, "resolve");
|
|
}
|
|
|
|
get metadataUrl() {
|
|
return create_hf_url(this.tag, "raw");
|
|
}
|
|
|
|
private async fetchMetadata(): Promise<hf_metadata_t> {
|
|
try {
|
|
const resp = await fetch(this.metadataUrl, {
|
|
credentials: "include",
|
|
headers: {
|
|
"User-Agent":
|
|
"Mozilla/5.0 (X11; Linux x86_64; rv:135.0) Gecko/20100101 Firefox/135.0",
|
|
Accept:
|
|
"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
|
|
"Accept-Language": "en-US,en;q=0.5",
|
|
"Sec-GPC": "1",
|
|
"Upgrade-Insecure-Requests": "1",
|
|
"Sec-Fetch-Dest": "document",
|
|
"Sec-Fetch-Mode": "navigate",
|
|
"Sec-Fetch-Site": "cross-site",
|
|
"If-None-Match": '"8fa71cbce85078986b46fb97caec22039e73351a"',
|
|
Priority: "u=0, i",
|
|
},
|
|
method: "GET",
|
|
mode: "cors",
|
|
});
|
|
const text = await resp.text();
|
|
return Object.fromEntries(
|
|
text.split("\n").map((line) => line.split(" "))
|
|
) as hf_metadata_t;
|
|
} catch (err) {
|
|
console.error("Failed to fetch %s: %s", this.metadataUrl, err);
|
|
throw err;
|
|
}
|
|
}
|
|
|
|
async updateMetadata() {
|
|
const metadata = await this.fetchMetadata();
|
|
this.size = Number.parseInt(metadata.size);
|
|
}
|
|
|
|
async addToDatabase() {
|
|
const db = await getDb();
|
|
await db("whisper_models").upsert({
|
|
model: this.tag,
|
|
expected_size: this.size,
|
|
}).where({
|
|
model: this.tag,
|
|
});
|
|
}
|
|
|
|
async createDownloadResumable(
|
|
options: {
|
|
onData?: DownloadCallback | undefined;
|
|
} = {
|
|
onData: undefined,
|
|
}
|
|
) {
|
|
const existingData = (await this.doesTargetExist())
|
|
? this.targetFile.text()
|
|
: undefined;
|
|
|
|
if (await this.doesTargetExist()) {
|
|
}
|
|
|
|
return FileSystem.createDownloadResumable(
|
|
this.modelUrl,
|
|
this.targetPath,
|
|
{},
|
|
async (data: FileSystem.DownloadProgressData) => {
|
|
const db = await getDb();
|
|
await db.upsert({
|
|
model: this.tag,
|
|
download_status: "active",
|
|
})
|
|
await this.recordLatestTargetHash();
|
|
if (options.onData) await options.onData(data);
|
|
},
|
|
existingData ? existingData : undefined
|
|
);
|
|
}
|
|
|
|
async getDownloadStatus() : Promise<download_status_t> {
|
|
const doesTargetExist = await this.doesTargetExist();
|
|
const isDownloadComplete = await this.isDownloadComplete();
|
|
const hasDownloadStarted = doesTargetExist && !isDownloadComplete;
|
|
|
|
if (!this.size) {
|
|
return {
|
|
doesTargetExist: false,
|
|
isDownloadComplete: false,
|
|
hasDownloadStarted: false,
|
|
progress: undefined,
|
|
}
|
|
}
|
|
|
|
const remaining = hasDownloadStarted
|
|
? this.size - (this.targetFile.size as number)
|
|
: 0;
|
|
|
|
const progress = hasDownloadStarted
|
|
? {
|
|
current: this.targetFile.size || 0,
|
|
total: this.size,
|
|
remaining: this.size - (this.targetFile.size as number),
|
|
percentRemaining: (remaining / this.size) * 100.0,
|
|
}
|
|
: undefined;
|
|
|
|
return {
|
|
doesTargetExist,
|
|
isDownloadComplete,
|
|
hasDownloadStarted,
|
|
progress,
|
|
};
|
|
}
|
|
}
|
|
|
|
|
|
export type DownloadCallback = (arg0: FileSystem.DownloadProgressData) => any; |