80 lines
2.1 KiB
TypeScript
80 lines
2.1 KiB
TypeScript
import { Worker } from "worker_threads";
|
||
|
||
function createWorker(data: Float32Array): Promise<Float32Array> {
|
||
return new Promise((resolve, reject) => {
|
||
const worker = new Worker("./vino.worker.js");
|
||
worker.on("message", (result) => resolve(result as Float32Array));
|
||
worker.on("error", reject);
|
||
worker.on("exit", (code) => {
|
||
if (code !== 0) {
|
||
reject(new Error(`Worker stopped with exit code ${code}`));
|
||
}
|
||
});
|
||
worker.postMessage(data);
|
||
});
|
||
}
|
||
|
||
class TaskQueue {
|
||
private concurrency: number;
|
||
private queue: (() => void)[];
|
||
private running: number;
|
||
|
||
constructor(concurrency: number) {
|
||
this.concurrency = concurrency;
|
||
this.queue = [];
|
||
this.running = 0;
|
||
}
|
||
|
||
private async runTask(task: () => Promise<any>): Promise<any> {
|
||
this.running++;
|
||
try {
|
||
const result = await task();
|
||
this.running--;
|
||
this.next();
|
||
return result;
|
||
} catch (error) {
|
||
this.running--;
|
||
this.next();
|
||
throw error;
|
||
}
|
||
}
|
||
|
||
public addTask(task: () => Promise<any>): Promise<any> {
|
||
return new Promise((resolve, reject) => {
|
||
const run = () => this.runTask(task).then(resolve).catch(reject);
|
||
this.queue.push(run);
|
||
this.next();
|
||
});
|
||
}
|
||
|
||
private next(): void {
|
||
if (this.running < this.concurrency && this.queue.length > 0) {
|
||
const task = this.queue.shift();
|
||
task?.();
|
||
}
|
||
}
|
||
}
|
||
|
||
export async function cpuStartInfer() {
|
||
const numTasks = 128;
|
||
const patchSize = [1, 1, 80, 160, 160];
|
||
const taskQueue = new TaskQueue(2); // 设置并发任务数,例如2
|
||
|
||
console.time("总推理用时");
|
||
|
||
try {
|
||
const tasks = Array.from({ length: numTasks }, () => {
|
||
const inferSample = new Float32Array(
|
||
patchSize.reduce((a, b) => a * b)
|
||
).map(() => Math.random() * 2.0 - 1.0);
|
||
return taskQueue.addTask(() => createWorker(inferSample));
|
||
});
|
||
|
||
const results = await Promise.all(tasks);
|
||
console.log("所有任务完成,结果数量:", results.length);
|
||
} catch (e) {
|
||
console.error("推理失败:", e);
|
||
}
|
||
console.timeEnd("总推理用时");
|
||
}
|