Source code

Revision control

Copy as Markdown

Other Tools

/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
// This file tests support for the sqlite-vec extension.
function tensorToBlob(tensor) {
return new Uint8ClampedArray(new Float32Array(tensor).buffer);
}
add_setup(async function () {
cleanup();
});
add_task(async function test_synchronous() {
info("Testing synchronous connection");
let conn = getOpenedUnsharedDatabase();
Assert.throws(
() =>
conn.executeSimpleSQL(
`CREATE VIRTUAL TABLE test USING vec0(
embedding FLOAT[4]
);`
),
/NS_ERROR_FAILURE/,
"Should not be able to use vec without loading the extension"
);
await loadExtension(conn);
conn.executeSimpleSQL(
`
CREATE VIRTUAL TABLE test USING vec0(
embedding FLOAT[4]
)
`
);
let stmt = conn.createStatement(
`
INSERT INTO test(rowid, embedding)
VALUES (1, :vector)
`
);
stmt.bindBlobByName("vector", tensorToBlob([0.3, 0.3, 0.3, 0.3]));
stmt.executeStep();
stmt.reset();
stmt.finalize();
stmt = conn.createStatement(
`
SELECT
rowid,
distance
FROM test
WHERE embedding MATCH :vector
ORDER BY distance
LIMIT 1
`
);
stmt.bindBlobByName("vector", tensorToBlob([0.3, 0.3, 0.3, 0.3]));
Assert.ok(stmt.executeStep());
Assert.equal(stmt.getInt32(0), 1);
Assert.equal(stmt.getDouble(1), 0.0);
stmt.reset();
stmt.finalize();
cleanup();
});
add_task(async function test_asynchronous() {
info("Testing asynchronous connection");
let conn = await openAsyncDatabase(getTestDB());
await Assert.rejects(
executeSimpleSQLAsync(
conn,
`
CREATE VIRTUAL TABLE test USING vec0(
embedding float[4]
)
`
),
err => err.message.startsWith("no such module"),
"Should not be able to use vec without loading the extension"
);
await loadExtension(conn);
await executeSimpleSQLAsync(
conn,
`
CREATE VIRTUAL TABLE test USING vec0(
embedding float[4]
)
`
);
await asyncClose(conn);
await IOUtils.remove(getTestDB().path, { ignoreAbsent: true });
});
add_task(async function test_clone() {
info("Testing cloning synchronous connection loads extensions in clone");
let conn1 = getOpenedUnsharedDatabase();
await loadExtension(conn1);
let conn2 = conn1.clone(false);
conn2.executeSimpleSQL(
`
CREATE VIRTUAL TABLE test USING vec0(
embedding float[4]
)
`
);
conn2.close();
cleanup();
});
add_task(async function test_asyncClone() {
info("Testing asynchronously cloning connection loads extensions in clone");
let conn1 = getOpenedUnsharedDatabase();
await loadExtension(conn1);
let conn2 = await asyncClone(conn1, false);
await executeSimpleSQLAsync(
conn2,
`
CREATE VIRTUAL TABLE test USING vec0(
embedding float[4]
)
`
);
await asyncClose(conn2);
await asyncCleanup();
});
async function loadExtension(conn, ext = "vec") {
await new Promise((resolve, reject) => {
conn.loadExtension(ext, status => {
if (Components.isSuccessCode(status)) {
resolve();
} else {
reject(status);
}
});
});
}
add_task(async function test_invariants() {
// Test some invariants of the vec extension that we rely upon, so that if
// the behavior changes we can catch it.
let conn = getOpenedUnsharedDatabase();
await loadExtension(conn);
conn.executeSimpleSQL(
`
CREATE VIRTUAL TABLE vectors USING vec0(
embedding FLOAT[4]
)
`
);
conn.executeSimpleSQL(
`
CREATE TABLE relations (
rowid INTEGER PRIMARY KEY,
content TEXT
)
`
);
let rowids = [];
let insertRelStmt = conn.createStatement(
`
INSERT INTO relations (rowid, content)
VALUES (NULL, "test")
RETURNING rowid
`
);
Assert.ok(insertRelStmt.executeStep());
rowids.push(insertRelStmt.getInt32(0));
insertRelStmt.reset();
Assert.ok(insertRelStmt.executeStep());
rowids.push(insertRelStmt.getInt32(0));
insertRelStmt.reset();
// Try to insert the same rowid twice in the vec table.
let insertVecStmt = conn.createStatement(
`
INSERT INTO vectors (rowid, embedding)
VALUES (:rowid, :vector)
`
);
insertVecStmt.bindByName("rowid", rowids[0]);
insertVecStmt.bindBlobByName("vector", tensorToBlob([0.1, 0.1, 0.1, 0.1]));
insertVecStmt.executeStep();
insertVecStmt.reset();
let deleteStmt = conn.createStatement(
`
DELETE FROM vectors WHERE rowid = :rowid
`
);
deleteStmt.bindByName("rowid", rowids[0]);
deleteStmt.executeStep();
deleteStmt.finalize();
insertVecStmt.bindByName("rowid", rowids[0]);
insertVecStmt.bindBlobByName("vector", tensorToBlob([0.2, 0.2, 0.2, 0.2]));
insertVecStmt.executeStep();
insertVecStmt.reset();
let selectStmt = conn.createStatement(
`
SELECT
rowid,
vec_to_json(embedding)
FROM vectors
`
);
let count = 0;
while (selectStmt.executeStep()) {
count++;
Assert.equal(selectStmt.getInt32(0), rowids[0]);
Assert.equal(
selectStmt.getUTF8String(1).replace(/(?<=[0-9])0+/g, ""),
"[0.2,0.2,0.2,0.2]"
);
}
Assert.equal(count, 1, "Should have one row in the vec table");
selectStmt.reset();
Assert.ok(insertRelStmt.executeStep());
rowids.push(insertRelStmt.getInt32(0));
insertRelStmt.finalize();
insertVecStmt.bindByName("rowid", rowids[2]);
insertVecStmt.bindBlobByName("vector", tensorToBlob([0.3, 0.3, 0.3, 0.3]));
insertVecStmt.executeStep();
insertVecStmt.finalize();
let expected = [
{ rowid: rowids[0], vector: "[0.2,0.2,0.2,0.2]" },
{ rowid: rowids[2], vector: "[0.3,0.3,0.3,0.3]" },
];
count = 0;
for (let i = 0; selectStmt.executeStep(); i++) {
count++;
Assert.equal(selectStmt.getInt32(0), expected[i].rowid);
Assert.equal(
selectStmt.getUTF8String(1).replace(/(?<=[0-9])0+/g, ""),
expected[i].vector
);
}
Assert.equal(count, 2, "Should have two rows in the vec table");
selectStmt.finalize();
// TODO: In the future add testing for RETURNING and UPSERT as those are
// currently broken. See:
cleanup();
});