80 lines
2.1 KiB
TypeScript
80 lines
2.1 KiB
TypeScript
|
import { Worker } from "worker_threads";
|
|||
|
|
|||
|
function createWorker(data: Float32Array): Promise<Float32Array> {
|
|||
|
return new Promise((resolve, reject) => {
|
|||
|
const worker = new Worker("./vino.worker.js");
|
|||
|
worker.on("message", (result) => resolve(result as Float32Array));
|
|||
|
worker.on("error", reject);
|
|||
|
worker.on("exit", (code) => {
|
|||
|
if (code !== 0) {
|
|||
|
reject(new Error(`Worker stopped with exit code ${code}`));
|
|||
|
}
|
|||
|
});
|
|||
|
worker.postMessage(data);
|
|||
|
});
|
|||
|
}
|
|||
|
|
|||
|
class TaskQueue {
|
|||
|
private concurrency: number;
|
|||
|
private queue: (() => void)[];
|
|||
|
private running: number;
|
|||
|
|
|||
|
constructor(concurrency: number) {
|
|||
|
this.concurrency = concurrency;
|
|||
|
this.queue = [];
|
|||
|
this.running = 0;
|
|||
|
}
|
|||
|
|
|||
|
private async runTask(task: () => Promise<any>): Promise<any> {
|
|||
|
this.running++;
|
|||
|
try {
|
|||
|
const result = await task();
|
|||
|
this.running--;
|
|||
|
this.next();
|
|||
|
return result;
|
|||
|
} catch (error) {
|
|||
|
this.running--;
|
|||
|
this.next();
|
|||
|
throw error;
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
public addTask(task: () => Promise<any>): Promise<any> {
|
|||
|
return new Promise((resolve, reject) => {
|
|||
|
const run = () => this.runTask(task).then(resolve).catch(reject);
|
|||
|
this.queue.push(run);
|
|||
|
this.next();
|
|||
|
});
|
|||
|
}
|
|||
|
|
|||
|
private next(): void {
|
|||
|
if (this.running < this.concurrency && this.queue.length > 0) {
|
|||
|
const task = this.queue.shift();
|
|||
|
task?.();
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
export async function cpuStartInfer() {
|
|||
|
const numTasks = 128;
|
|||
|
const patchSize = [1, 1, 80, 160, 160];
|
|||
|
const taskQueue = new TaskQueue(2); // 设置并发任务数,例如2
|
|||
|
|
|||
|
console.time("总推理用时");
|
|||
|
|
|||
|
try {
|
|||
|
const tasks = Array.from({ length: numTasks }, () => {
|
|||
|
const inferSample = new Float32Array(
|
|||
|
patchSize.reduce((a, b) => a * b)
|
|||
|
).map(() => Math.random() * 2.0 - 1.0);
|
|||
|
return taskQueue.addTask(() => createWorker(inferSample));
|
|||
|
});
|
|||
|
|
|||
|
const results = await Promise.all(tasks);
|
|||
|
console.log("所有任务完成,结果数量:", results.length);
|
|||
|
} catch (e) {
|
|||
|
console.error("推理失败:", e);
|
|||
|
}
|
|||
|
console.timeEnd("总推理用时");
|
|||
|
}
|