Source code
Revision control
Copy as Markdown
Other Tools
Test Info:
/* Any copyright is dedicated to the Public Domain.
"use strict";
/**
* @import { Request as EngineRequest, MLEngine as MLEngineClass } from "../../actors/MLEngineParent.sys.mjs"
* @import { StaticEmbeddingsOptions } from "../../content/backends/StaticEmbeddingsPipeline.d.ts"
*/
const { parseNpy } = ChromeUtils.importESModule(
"chrome://global/content/ml/Utils.sys.mjs"
);
const vocabSize = 9;
const dimensions = 8;
/**
* Mock out the URL requests with a small bad embeddings model.
*/
function getMockedValues() {
const { encoding } = generateFloat16Numpy(vocabSize, dimensions);
const tokenizer =
// prettier-ignore
{
version: "1.0",
truncation: null,
padding: null,
added_tokens: [{ id: 0, content: "[UNK]", single_word: false, lstrip: false, rstrip: false, normalized: false, special: true }],
normalizer: { type: "BertNormalizer", clean_text: true, handle_chinese_chars: true, strip_accents: null, lowercase: true },
pre_tokenizer: { type: "BertPreTokenizer" },
post_processor: {
type: "TemplateProcessing",
single: [
{ SpecialToken: { id: "[CLS]", type_id: 0 } },
{ Sequence: { id: "A", type_id: 0 } },
{ SpecialToken: { id: "[SEP]", type_id: 0 } },
],
pair: [],
special_tokens: {},
},
decoder: { type: "WordPiece", prefix: "##", cleanup: true },
model: {
type: "WordPiece", unk_token: "[UNK]", continuing_subword_prefix: "##", max_input_chars_per_word: 100,
vocab: { "[UNK]": 0, the: 1, quick: 2, brown: 3, dog: 4, jumped: 5, over: 6, lazy: 7, fox: 8 },
},
};
return {
tokenizer,
[`https://model-hub.mozilla.org/mozilla/static-embeddings/v1.0.0/models/minishlab/potion-retrieval-32M/fp16.d${dimensions}.npy`]:
encoding,
};
}
add_task(async function test_static_embeddings() {
/** @type {StaticEmbeddingsOptions} */
const staticEmbeddingsOptions = {
dtype: "fp16",
subfolder: "models/minishlab/potion-retrieval-32M",
dimensions,
mockedValues: getMockedValues(),
compression: false,
};
/** @type {MLEngineClass} */
const engine = await createEngine(
new PipelineOptions({
featureId: "simple-text-embedder",
engineId: "test-static-embeddings",
modelId: "mozilla/static-embeddings",
modelRevision: "v1.0.0",
taskName: "static-embeddings",
modelHub: "mozilla",
backend: "static-embeddings",
staticEmbeddingsOptions,
})
);
const { output } = await engine.run({
args: ["The quick brown fox jumped over the lazy fox"],
options: {
pooling: "mean",
normalize: true,
},
});
is(output.length, 1, "One embedding was returned");
const [embedding] = output;
is(embedding.length, dimensions, "The dimensions match");
is(
embedding.constructor.name,
"Float32Array",
"The embedding was returned as a Float32Array"
);
assertFloatArraysMatch(
embedding,
[
0.3156551122, 0.3262447714, 0.3368626534, 0.3474076688, 0.3580137789,
0.3685869872, 0.3791790008, 0.3898085951,
],
"The embeddings were computed as expected.",
0.00001 // epsilon
);
});