DBZ-8121 Support for pgvector datatypes

This commit is contained in:
Jiri Pechanec 2024-08-02 05:34:01 +02:00
parent 3e1bb6cbef
commit 7cf7af5765
11 changed files with 619 additions and 0 deletions

View File

@ -10,6 +10,9 @@
import io.debezium.config.CommonConnectorConfig;
import io.debezium.connector.postgresql.data.Ltree;
import io.debezium.connector.postgresql.data.vector.HalfVector;
import io.debezium.connector.postgresql.data.vector.SparseVector;
import io.debezium.connector.postgresql.data.vector.Vector;
import io.debezium.data.Envelope;
import io.debezium.schema.SchemaFactory;
import io.debezium.schema.SchemaNameAdjuster;
@ -72,4 +75,26 @@ public SchemaBuilder datatypeLtreeSchema() {
.name(Ltree.LOGICAL_NAME)
.version(Ltree.SCHEMA_VERSION);
}
public SchemaBuilder datatypeVectorSchema() {
return SchemaBuilder.array(Schema.FLOAT64_SCHEMA)
.name(Vector.LOGICAL_NAME)
.version(Vector.SCHEMA_VERSION);
}
public SchemaBuilder datatypeHalfVectorSchema() {
return SchemaBuilder.array(Schema.FLOAT32_SCHEMA)
.name(HalfVector.LOGICAL_NAME)
.version(HalfVector.SCHEMA_VERSION);
}
public SchemaBuilder datatypeSparseVectorSchema() {
return SchemaBuilder.struct()
.name(SparseVector.LOGICAL_NAME)
.name(SparseVector.LOGICAL_NAME)
.version(SparseVector.SCHEMA_VERSION)
.doc("Sparse vector")
.field(SparseVector.DIMENSIONS_FIELD, Schema.INT16_SCHEMA)
.field(SparseVector.VECTOR_FIELD, SchemaBuilder.map(Schema.INT16_SCHEMA, Schema.FLOAT64_SCHEMA).build());
}
}

View File

@ -55,6 +55,9 @@
import io.debezium.connector.postgresql.PostgresConnectorConfig.HStoreHandlingMode;
import io.debezium.connector.postgresql.PostgresConnectorConfig.IntervalHandlingMode;
import io.debezium.connector.postgresql.data.Ltree;
import io.debezium.connector.postgresql.data.vector.HalfVector;
import io.debezium.connector.postgresql.data.vector.SparseVector;
import io.debezium.connector.postgresql.data.vector.Vector;
import io.debezium.connector.postgresql.proto.PgProto;
import io.debezium.data.Bits;
import io.debezium.data.Json;
@ -321,6 +324,15 @@ else if (oidValue == typeRegistry.hstoreOid()) {
else if (oidValue == typeRegistry.ltreeOid()) {
return Ltree.builder();
}
else if (oidValue == typeRegistry.vectorOid()) {
return Vector.builder();
}
else if (oidValue == typeRegistry.halfVectorOid()) {
return HalfVector.builder();
}
else if (oidValue == typeRegistry.sparseVectorOid()) {
return SparseVector.builder();
}
else if (oidValue == typeRegistry.hstoreArrayOid()) {
return SchemaBuilder.array(hstoreSchema().optional().build());
}
@ -525,6 +537,15 @@ else if (oidValue == typeRegistry.hstoreOid()) {
else if (oidValue == typeRegistry.ltreeOid()) {
return data -> convertLtree(column, fieldDefn, data);
}
else if (oidValue == typeRegistry.vectorOid()) {
return data -> convertPgVector(column, fieldDefn, data);
}
else if (oidValue == typeRegistry.halfVectorOid()) {
return data -> convertPgHalfVector(column, fieldDefn, data);
}
else if (oidValue == typeRegistry.sparseVectorOid()) {
return data -> convertPgSparseVector(column, fieldDefn, data);
}
else if (oidValue == typeRegistry.ltreeArrayOid()) {
return data -> convertLtreeArray(column, fieldDefn, data);
}
@ -659,6 +680,48 @@ else if (data instanceof PGobject) {
});
}
private Object convertPgVector(Column column, Field fieldDefn, Object data) {
return convertValue(column, fieldDefn, data, Collections.emptyList(), r -> {
if (data instanceof byte[] typedData) {
r.deliver(Vector.fromLogical(fieldDefn.schema(), new String(typedData, databaseCharset)));
}
if (data instanceof String typedData) {
r.deliver(Vector.fromLogical(fieldDefn.schema(), typedData));
}
else if (data instanceof PGobject typedData) {
r.deliver(Vector.fromLogical(fieldDefn.schema(), typedData.getValue()));
}
});
}
private Object convertPgHalfVector(Column column, Field fieldDefn, Object data) {
return convertValue(column, fieldDefn, data, Collections.emptyList(), r -> {
if (data instanceof byte[] typedData) {
r.deliver(HalfVector.fromLogical(fieldDefn.schema(), new String(typedData, databaseCharset)));
}
if (data instanceof String typedData) {
r.deliver(HalfVector.fromLogical(fieldDefn.schema(), typedData));
}
else if (data instanceof PGobject typedData) {
r.deliver(HalfVector.fromLogical(fieldDefn.schema(), typedData.getValue()));
}
});
}
private Object convertPgSparseVector(Column column, Field fieldDefn, Object data) {
return convertValue(column, fieldDefn, data, Collections.emptyList(), r -> {
if (data instanceof byte[] typedData) {
r.deliver(SparseVector.fromLogical(fieldDefn.schema(), new String(typedData, databaseCharset)));
}
if (data instanceof String typedData) {
r.deliver(SparseVector.fromLogical(fieldDefn.schema(), typedData));
}
else if (data instanceof PGobject typedData) {
r.deliver(SparseVector.fromLogical(fieldDefn.schema(), typedData.getValue()));
}
});
}
private Object convertLtreeArray(Column column, Field fieldDefn, Object data) {
return convertValue(column, fieldDefn, data, Collections.emptyList(), r -> {
if (data instanceof byte[]) {

View File

@ -47,6 +47,9 @@ public class TypeRegistry {
public static final String TYPE_NAME_HSTORE = "hstore";
public static final String TYPE_NAME_LTREE = "ltree";
public static final String TYPE_NAME_ISBN = "isbn";
public static final String TYPE_NAME_VECTOR = "vector";
public static final String TYPE_NAME_HALF_VECTOR = "halfvec";
public static final String TYPE_NAME_SPARSE_VECTOR = "sparsevec";
public static final String TYPE_NAME_HSTORE_ARRAY = "_hstore";
public static final String TYPE_NAME_GEOGRAPHY_ARRAY = "_geography";
@ -111,6 +114,10 @@ private static Map<String, String> getLongTypeNames() {
private int ltreeOid = Integer.MIN_VALUE;
private int isbnOid = Integer.MIN_VALUE;
private int vectorOid = Integer.MIN_VALUE;
private int halfVectorOid = Integer.MIN_VALUE;
private int sparseVectorOid = Integer.MIN_VALUE;
private int hstoreArrayOid = Integer.MIN_VALUE;
private int geometryArrayOid = Integer.MIN_VALUE;
private int geographyArrayOid = Integer.MIN_VALUE;
@ -171,6 +178,15 @@ else if (TYPE_NAME_LTREE_ARRAY.equals(type.getName())) {
else if (TYPE_NAME_ISBN.equals(type.getName())) {
isbnOid = type.getOid();
}
else if (TYPE_NAME_VECTOR.equals(type.getName())) {
vectorOid = type.getOid();
}
else if (TYPE_NAME_HALF_VECTOR.equals(type.getName())) {
halfVectorOid = type.getOid();
}
else if (TYPE_NAME_SPARSE_VECTOR.equals(type.getName())) {
sparseVectorOid = type.getOid();
}
}
/**
@ -317,6 +333,30 @@ public int ltreeArrayOid() {
return ltreeArrayOid;
}
/**
*
* @return OID for PgVector's {@code VECTOR} type of this PostgreSQL instance
*/
public int vectorOid() {
return vectorOid;
}
/**
*
* @return OID for PgVector's {@code VECTOR} type of this PostgreSQL instance
*/
public int halfVectorOid() {
return halfVectorOid;
}
/**
*
* @return OID for PgVector's {@code VECTOR} type of this PostgreSQL instance
*/
public int sparseVectorOid() {
return sparseVectorOid;
}
/**
* Converts a type name in long (readable) format like <code>boolean</code> to s standard
* data type name like <code>bool</code>.

View File

@ -172,6 +172,13 @@ public static Object resolveValue(String columnName, PostgresType type, String f
case "isbn":
return value.asString();
// PgVector types are string encoded values
// ValueConverter turns them into the correct types
case "vector":
case "halfvec":
case "sparsevec":
return value.asString();
// catch-all for other known/builtin PG types
// TODO: improve with more specific/useful classes here?
case "pg_lsn":

View File

@ -0,0 +1,57 @@
/*
* Copyright Debezium Authors.
*
* Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0
*/
package io.debezium.connector.postgresql.data.vector;
import java.util.List;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.SchemaBuilder;
import io.debezium.connector.postgresql.PostgresSchemaFactory;
/**
* A semantic type for a PgVector halfvec type.
*
* @author Jiri Pechanec
*/
public class HalfVector {
public static final String LOGICAL_NAME = "io.debezium.data.HalfVector";
public static int SCHEMA_VERSION = 1;
/**
* Returns a {@link SchemaBuilder} for a halfvec field. You can use the resulting SchemaBuilder
* to set additional schema settings such as required/optional, default value, and documentation.
*
* @return the schema builder
*/
public static SchemaBuilder builder() {
return PostgresSchemaFactory.get().datatypeHalfVectorSchema();
}
/**
* Returns a {@link SchemaBuilder} for a halfvec field, with all other default Schema settings.
*
* @return the schema
* @see #builder()
*/
public static Schema schema() {
return builder().build();
}
/**
* Converts a value from its logical format - {@link String} of {@code [x,y,z,...]}
* to its encoded format - a Connect array represented by list of numbers.
*
* @param schema of the encoded value
* @param value the value of the vector
*
* @return the encoded value
*/
public static List<Float> fromLogical(Schema schema, String value) {
return Vectors.fromVectorString(schema, value, Float::parseFloat);
}
}

View File

@ -0,0 +1,58 @@
/*
* Copyright Debezium Authors.
*
* Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0
*/
package io.debezium.connector.postgresql.data.vector;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.SchemaBuilder;
import org.apache.kafka.connect.data.Struct;
import io.debezium.connector.postgresql.PostgresSchemaFactory;
/**
* A semantic type for a PgVector sparsevec type.
*
* @author Mincong Huang
*/
public class SparseVector {
public static final String LOGICAL_NAME = "io.debezium.data.SparseVector";
public static final String DIMENSIONS_FIELD = "dimensions";
public static final String VECTOR_FIELD = "vector";
public static int SCHEMA_VERSION = 1;
/**
* Returns a {@link SchemaBuilder} for a {@code sparsevec} field. You can use the resulting SchemaBuilder
* to set additional schema settings such as required/optional, default value, and documentation.
*
* @return the schema builder
*/
public static SchemaBuilder builder() {
return PostgresSchemaFactory.get().datatypeSparseVectorSchema();
}
/**
* Returns a {@link SchemaBuilder} for a {@code sparsevec} field, with all other default Schema settings.
*
* @return the schema
* @see #builder()
*/
public static Schema schema() {
return builder().build();
}
/**
* Converts a value from its logical format - {@link String} of {@code {i1: v1, i2: v2, ...}/dimensions}
* to its encoded format - a {@link Struct} with a number of dimensions and a map of index to value
*
* @param schema of the encoded value
* @param value the value of the vector
*
* @return the encoded value
*/
public static Struct fromLogical(Schema schema, String value) {
return Vectors.fromSparseVectorString(schema, value, Double::parseDouble);
}
}

View File

@ -0,0 +1,57 @@
/*
* Copyright Debezium Authors.
*
* Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0
*/
package io.debezium.connector.postgresql.data.vector;
import java.util.List;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.SchemaBuilder;
import io.debezium.connector.postgresql.PostgresSchemaFactory;
/**
* A semantic type for a PgVector vector type.
*
* @author Jiri Pechanec
*/
public class Vector {
public static final String LOGICAL_NAME = "io.debezium.data.Vector";
public static int SCHEMA_VERSION = 1;
/**
* Returns a {@link SchemaBuilder} for a vector field. You can use the resulting SchemaBuilder
* to set additional schema settings such as required/optional, default value, and documentation.
*
* @return the schema builder
*/
public static SchemaBuilder builder() {
return PostgresSchemaFactory.get().datatypeVectorSchema();
}
/**
* Returns a {@link SchemaBuilder} for a vector field, with all other default Schema settings.
*
* @return the schema
* @see #builder()
*/
public static Schema schema() {
return builder().build();
}
/**
* Converts a value from its logical format - {@link String} of {@code [x,y,z,...]}
* to its encoded format - a Connect array represented by list of numbers.
*
* @param schema of the encoded value
* @param value the value of the vector
*
* @return the encoded value
*/
public static List<Double> fromLogical(Schema schema, String value) {
return Vectors.fromVectorString(schema, value, Double::parseDouble);
}
}

View File

@ -0,0 +1,78 @@
/*
* Copyright Debezium Authors.
*
* Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0
*/
package io.debezium.connector.postgresql.data.vector;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.Struct;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public final class Vectors {
private static final String SPARSE_VECTOR_ERROR = "Cannot convert sparse vector {}, expected format is {i1:x,i2:y,i3:z,...}/dimensions";
private static final Logger LOGGER = LoggerFactory.getLogger(Vectors.class);
static <T> List<T> fromVectorString(Schema schema, String value, Function<String, T> elementMapper) {
Objects.requireNonNull(value, "value may not be null");
value = value.trim();
if (!value.startsWith("[") || !value.endsWith("]")) {
LOGGER.warn("Cannot convert vector {}, expected format is [x,y,z,...]", value);
return null;
}
value = value.substring(1, value.length() - 1);
final var strValues = value.split(",");
final List<T> result = new ArrayList<>(strValues.length);
for (String element : strValues) {
result.add(elementMapper.apply(element.trim()));
}
return result;
}
static <T> Struct fromSparseVectorString(Schema schema, String value, Function<String, T> elementMapper) {
Objects.requireNonNull(value, "value may not be null");
value = value.trim();
var parts = value.split("/");
if (parts.length != 2) {
LOGGER.warn(SPARSE_VECTOR_ERROR, value);
return null;
}
var strVector = parts[0].trim();
final var dimensions = Short.parseShort(parts[1].trim());
if (!strVector.startsWith("{") || !strVector.endsWith("}")) {
LOGGER.warn(SPARSE_VECTOR_ERROR, value);
return null;
}
strVector = strVector.substring(1, strVector.length() - 1);
final var strValues = strVector.split(",");
final Map<Short, T> vector = new HashMap<>(strValues.length);
for (String element : strValues) {
parts = element.split(":");
if (parts.length != 2) {
LOGGER.warn(SPARSE_VECTOR_ERROR, value);
return null;
}
vector.put(Short.parseShort(parts[0].trim()), elementMapper.apply(parts[1].trim()));
}
final var result = new Struct(schema);
result.put(SparseVector.DIMENSIONS_FIELD, dimensions);
result.put(SparseVector.VECTOR_FIELD, vector);
return result;
}
}

View File

@ -0,0 +1,110 @@
/*
* Copyright Debezium Authors.
*
* Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0
*/
package io.debezium.connector.postgresql;
import static io.debezium.junit.EqualityCheck.LESS_THAN;
import java.util.List;
import java.util.Map;
import org.apache.kafka.connect.data.Struct;
import org.assertj.core.api.Assertions;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import io.debezium.connector.postgresql.PostgresConnectorConfig.SnapshotMode;
import io.debezium.connector.postgresql.connection.PostgresConnection;
import io.debezium.connector.postgresql.connection.ReplicationConnection;
import io.debezium.junit.SkipTestRule;
import io.debezium.junit.SkipWhenDatabaseVersion;
import io.debezium.util.Testing;
/**
* Integration test to verify PgVector types.
*
* @author Jiri Pechanec
*/
@SkipWhenDatabaseVersion(check = LESS_THAN, major = 15, reason = "PgVector is tested only with PostgreSQL 15+")
public class VectorDatabaseIT extends AbstractRecordsProducerTest {
@Rule
public final SkipTestRule skipTest = new SkipTestRule();
@Before
public void before() throws Exception {
// ensure the slot is deleted for each test
try (PostgresConnection conn = TestHelper.create()) {
conn.dropReplicationSlot(ReplicationConnection.Builder.DEFAULT_SLOT_NAME);
}
TestHelper.dropAllSchemas();
TestHelper.executeDDL("init_pgvector.ddl");
TestHelper.execute(
"CREATE TABLE pgvector.table_vector (pk SERIAL, f_vector pgvector.vector(3), f_halfvec pgvector.halfvec(3), f_sparsevec pgvector.sparsevec(3000), PRIMARY KEY(pk));",
"INSERT INTO pgvector.table_vector (f_vector, f_halfvec, f_sparsevec) VALUES ('[1,2,3]', '[101,102,103]', '{1: 201, 9: 209}/3000');");
initializeConnectorTestFramework();
}
@Test
public void shouldSnapshotAndStreamData() throws Exception {
Testing.Print.enable();
start(PostgresConnector.class, TestHelper.defaultConfig()
.with(PostgresConnectorConfig.SNAPSHOT_MODE, SnapshotMode.INITIAL)
.build());
assertConnectorIsRunning();
waitForStreamingRunning("postgres", TestHelper.TEST_SERVER);
TestHelper.execute("INSERT INTO pgvector.table_vector (f_vector, f_halfvec, f_sparsevec) VALUES ('[10,20,30]', '[110,120,130]', '{1: 301, 9: 309}/3000');");
var actualRecords = consumeRecordsByTopic(2);
var recs = actualRecords.recordsForTopic("test_server.pgvector.table_vector");
Assertions.assertThat(recs).hasSize(2);
var rec = ((Struct) recs.get(0).value());
Assertions.assertThat(rec.schema().field("after").schema().field("f_vector").schema().name()).isEqualTo("io.debezium.data.Vector");
Assertions.assertThat(rec.schema().field("after").schema().field("f_halfvec").schema().name()).isEqualTo("io.debezium.data.HalfVector");
Assertions.assertThat(rec.schema().field("after").schema().field("f_sparsevec").schema().name()).isEqualTo("io.debezium.data.SparseVector");
Assertions.assertThat(rec.getStruct("after").getArray("f_vector")).isEqualTo(List.of(1.0, 2.0, 3.0));
Assertions.assertThat(rec.getStruct("after").getArray("f_halfvec")).isEqualTo(List.of(101.0f, 102.0f, 103.0f));
Assertions.assertThat(rec.getStruct("after").getStruct("f_sparsevec").getInt16("dimensions")).isEqualTo((short) 3000);
Assertions.assertThat(rec.getStruct("after").getStruct("f_sparsevec").getMap("vector")).isEqualTo(Map.of((short) 1, 201.0, (short) 9, 209.0));
rec = ((Struct) recs.get(1).value());
Assertions.assertThat(rec.schema().field("after").schema().field("f_vector").schema().name()).isEqualTo("io.debezium.data.Vector");
Assertions.assertThat(rec.schema().field("after").schema().field("f_halfvec").schema().name()).isEqualTo("io.debezium.data.HalfVector");
Assertions.assertThat(rec.schema().field("after").schema().field("f_sparsevec").schema().name()).isEqualTo("io.debezium.data.SparseVector");
Assertions.assertThat(rec.getStruct("after").getArray("f_vector")).isEqualTo(List.of(10.0, 20.0, 30.0));
Assertions.assertThat(rec.getStruct("after").getArray("f_halfvec")).isEqualTo(List.of(110.0f, 120.0f, 130.0f));
Assertions.assertThat(rec.getStruct("after").getStruct("f_sparsevec").getInt16("dimensions")).isEqualTo((short) 3000);
Assertions.assertThat(rec.getStruct("after").getStruct("f_sparsevec").getMap("vector")).isEqualTo(Map.of((short) 1, 301.0, (short) 9, 309.0));
}
@Test
public void shouldStreamData() throws Exception {
Testing.Print.enable();
start(PostgresConnector.class, TestHelper.defaultConfig()
.with(PostgresConnectorConfig.SNAPSHOT_MODE, SnapshotMode.NO_DATA)
.build());
assertConnectorIsRunning();
waitForStreamingRunning("postgres", TestHelper.TEST_SERVER);
TestHelper.execute(
"DROP TABLE IF EXISTS pgvector.table_vector_str;",
"CREATE TABLE pgvector.table_vector_str (pk SERIAL, f_vector pgvector.vector(3), f_halfvec pgvector.halfvec(3), f_sparsevec pgvector.sparsevec(3000), PRIMARY KEY(pk));",
"INSERT INTO pgvector.table_vector_str (f_vector, f_halfvec, f_sparsevec) VALUES ('[1,2,3]', '[101,102,103]', '{1: 201, 9: 209}/3000');");
var actualRecords = consumeRecordsByTopic(1);
var recs = actualRecords.recordsForTopic("test_server.pgvector.table_vector_str");
Assertions.assertThat(recs).hasSize(1);
var rec = ((Struct) recs.get(0).value());
Assertions.assertThat(rec.schema().field("after").schema().field("f_vector").schema().name()).isEqualTo("io.debezium.data.Vector");
Assertions.assertThat(rec.getStruct("after").getArray("f_vector")).isEqualTo(List.of(1.0, 2.0, 3.0));
}
}

View File

@ -0,0 +1,117 @@
/*
* Copyright Debezium Authors.
*
* Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0
*/
package io.debezium.connector.postgresql;
import java.util.List;
import java.util.Map;
import org.assertj.core.api.Assertions;
import org.junit.Test;
import io.debezium.connector.postgresql.data.vector.HalfVector;
import io.debezium.connector.postgresql.data.vector.SparseVector;
import io.debezium.connector.postgresql.data.vector.Vector;
public class VectorDatabaseTest {
@Test
public void shouldParseVector() {
final var expectedVector = List.of(10.0, 20.0, 30.0);
Assertions.assertThat(Vector.fromLogical(Vector.schema(), "[10,20,30]")).isEqualTo(expectedVector);
Assertions.assertThat(Vector.fromLogical(Vector.schema(), "[ 10,20,30] ")).isEqualTo(expectedVector);
Assertions.assertThat(Vector.fromLogical(Vector.schema(), " [ 10,20,30 ]")).isEqualTo(expectedVector);
Assertions.assertThat(Vector.fromLogical(Vector.schema(), "[10 ,20 ,30]")).isEqualTo(expectedVector);
Assertions.assertThat(Vector.fromLogical(Vector.schema(), "[10.2 , 20, 30]")).isEqualTo(List.of(10.2, 20.0, 30.0));
Assertions.assertThat(Vector.fromLogical(Vector.schema(), "[10.2e-1 , 20, 30]")).isEqualTo(List.of(1.02, 20.0, 30.0));
}
@Test
public void shouldIgnoreErrorInVectorFormat() {
Assertions.assertThat(Vector.fromLogical(Vector.schema(), "10,20,30]")).isNull();
Assertions.assertThat(Vector.fromLogical(Vector.schema(), "[10,20,30")).isNull();
Assertions.assertThat(Vector.fromLogical(Vector.schema(), "{10,20,30}")).isNull();
}
@Test(expected = NumberFormatException.class)
public void shouldFailOnNumberInVectorFormat() {
Vector.fromLogical(Vector.schema(), "[a10,20,30]");
}
@Test
public void shouldParseHalfVector() {
final var expectedVector = List.of(10.0f, 20.0f, 30.0f);
Assertions.assertThat(HalfVector.fromLogical(HalfVector.schema(), "[10,20,30]")).isEqualTo(expectedVector);
Assertions.assertThat(HalfVector.fromLogical(HalfVector.schema(), "[ 10,20,30] ")).isEqualTo(expectedVector);
Assertions.assertThat(HalfVector.fromLogical(HalfVector.schema(), " [ 10,20,30 ]")).isEqualTo(expectedVector);
Assertions.assertThat(HalfVector.fromLogical(HalfVector.schema(), "[10 ,20 ,30]")).isEqualTo(expectedVector);
Assertions.assertThat(HalfVector.fromLogical(HalfVector.schema(), "[10.2 , 20, 30]")).isEqualTo(List.of(10.2f, 20.0f, 30.0f));
Assertions.assertThat(HalfVector.fromLogical(HalfVector.schema(), "[10.2e-1 , 20, 30]")).isEqualTo(List.of(1.02f, 20.0f, 30.0f));
}
@Test
public void shouldIgnoreErrorInHalfVectorFormat() {
Assertions.assertThat(HalfVector.fromLogical(HalfVector.schema(), "10,20,30]")).isNull();
Assertions.assertThat(HalfVector.fromLogical(HalfVector.schema(), "[10,20,30")).isNull();
Assertions.assertThat(HalfVector.fromLogical(HalfVector.schema(), "{10,20,30}")).isNull();
}
@Test(expected = NumberFormatException.class)
public void shouldFailOnNumberInHalfVectorFormat() {
HalfVector.fromLogical(HalfVector.schema(), "[a10,20,30]");
}
@Test
public void shouldParseSparseVector() {
final var expectedVector = Map.of((short) 1, 10.0, (short) 11, 20.0, (short) 111, 30.0);
final var expectedDimensions = (short) 1000;
var vector = SparseVector.fromLogical(SparseVector.schema(), "{1:10,11:20,111:30}/1000");
Assertions.assertThat(vector.getInt16("dimensions")).isEqualTo(expectedDimensions);
Assertions.assertThat(vector.getMap("vector")).isEqualTo(expectedVector);
vector = SparseVector.fromLogical(SparseVector.schema(), "{1:10, 11:20, 111:30}/1000");
Assertions.assertThat(vector.getInt16("dimensions")).isEqualTo(expectedDimensions);
Assertions.assertThat(vector.getMap("vector")).isEqualTo(expectedVector);
vector = SparseVector.fromLogical(SparseVector.schema(), " {1:10,11:20,111:30}/1000");
Assertions.assertThat(vector.getInt16("dimensions")).isEqualTo(expectedDimensions);
Assertions.assertThat(vector.getMap("vector")).isEqualTo(expectedVector);
vector = SparseVector.fromLogical(SparseVector.schema(), "{1:10,11:20,111:30} /1000");
Assertions.assertThat(vector.getInt16("dimensions")).isEqualTo(expectedDimensions);
Assertions.assertThat(vector.getMap("vector")).isEqualTo(expectedVector);
vector = SparseVector.fromLogical(SparseVector.schema(), "{1:10,11:20,111:30}/ 1000");
Assertions.assertThat(vector.getInt16("dimensions")).isEqualTo(expectedDimensions);
Assertions.assertThat(vector.getMap("vector")).isEqualTo(expectedVector);
vector = SparseVector.fromLogical(SparseVector.schema(), "{1:10,11:20,111:30}/1000 ");
Assertions.assertThat(vector.getInt16("dimensions")).isEqualTo(expectedDimensions);
Assertions.assertThat(vector.getMap("vector")).isEqualTo(expectedVector);
vector = SparseVector.fromLogical(SparseVector.schema(), "{1:10,11:20,111:30 }/1000");
Assertions.assertThat(vector.getInt16("dimensions")).isEqualTo(expectedDimensions);
Assertions.assertThat(vector.getMap("vector")).isEqualTo(expectedVector);
}
@Test
public void shouldIgnoreErrorInSparseVectorFormat() {
Assertions.assertThat(SparseVector.fromLogical(SparseVector.schema(), "{1:10,11:20,111:30}")).isNull();
Assertions.assertThat(SparseVector.fromLogical(SparseVector.schema(), "{1:10,11:20,111:30/1000")).isNull();
Assertions.assertThat(SparseVector.fromLogical(SparseVector.schema(), "1:10,11:20,111:30}/1000")).isNull();
Assertions.assertThat(SparseVector.fromLogical(SparseVector.schema(), "{1:10,11:20,111:30}1000")).isNull();
Assertions.assertThat(SparseVector.fromLogical(SparseVector.schema(), "/1000")).isNull();
Assertions.assertThat(SparseVector.fromLogical(SparseVector.schema(), "{10,11:20,111:30}/1000")).isNull();
Assertions.assertThat(SparseVector.fromLogical(SparseVector.schema(), "{1:10,11#20,111:30}/1000")).isNull();
}
@Test(expected = NumberFormatException.class)
public void shouldFailOnNumberInSparseVectorFormat() {
SparseVector.fromLogical(SparseVector.schema(), "{1:10,11:20,111:x}/1000");
}
}

View File

@ -0,0 +1,7 @@
-- noinspection SqlNoDataSourceInspectionForFile
-- Separate file because pgvector is tested since PostgreSQL 15
CREATE SCHEMA IF NOT EXISTS public;
DROP SCHEMA IF EXISTS pgvector CASCADE;
CREATE SCHEMA pgvector;
CREATE EXTENSION IF NOT EXISTS vector SCHEMA pgvector;