db.rs•24.2 kB
use crate::error::{AppError, AppResult};
use crate::void_schema::{SchemasMap, VoidSchema};
use arrow::array::{Array, FixedSizeListArray, Float32Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::{RecordBatch, RecordBatchIterator};
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use futures_util::StreamExt;
use lancedb::index::Index;
use lancedb::index::vector::IvfPqIndexBuilder;
use lancedb::query::{ExecutableQuery, QueryBase as _};
use lancedb::{Connection, DistanceType, connect};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
/// Structure representing a SPARQL query example to be stored in LanceDB
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchDocument {
pub question: String,
pub answer: String,
pub endpoint_url: String,
pub doc_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub vector: Option<Vec<f32>>, // Vector embedding for semantic search
}
/// LanceDB db manager for SPARQL query indexing
pub struct SearchDB {
connection: Arc<Connection>,
table_name: String,
schema_table: String,
embedding_model: Arc<Mutex<TextEmbedding>>,
embedding_dimension: usize,
}
impl SearchDB {
/// Create a new QueryDatabase instance with a shared FastEmbed model
pub async fn new(db_path: &str, table_name: &str) -> AppResult<Self> {
let connection = connect(db_path)
.execute()
.await
.map_err(|e| AppError::DatabaseError(format!("Failed to connect to LanceDB: {e}")))?;
// Create a single FastEmbed model wrapped in a Mutex for thread safety
let embedding_model =
TextEmbedding::try_new(InitOptions::new(EmbeddingModel::BGEBaseENV15)).map_err(
|e| AppError::DatabaseError(format!("Failed to initialize FastEmbed model: {e}")),
)?;
let embedding_dimension = 768; // BAAI/bge-base-en-v1.5 embedding dimension, 384 for small
Ok(Self {
connection: Arc::new(connection),
table_name: table_name.to_string(),
schema_table: format!("{table_name}_schema"),
embedding_model: Arc::new(Mutex::new(embedding_model)),
embedding_dimension,
})
}
/// Generate embedding for a given text using the shared FastEmbed model
async fn generate_embedding(&self, text: &str) -> AppResult<Vec<f32>> {
// Lock the model for the duration of embedding generation
let mut model = self.embedding_model.lock().await;
let embeddings = model
.embed(vec![format!("passage: {}", text)], None)
.map_err(|e| AppError::DatabaseError(format!("Failed to generate embedding: {e}")))?;
if embeddings.is_empty() {
return Err(AppError::DatabaseError(
"No embeddings generated".to_string(),
));
}
Ok(embeddings[0].clone())
}
/// Initialize the database table with the required schema
pub async fn init_table(&self) -> AppResult<()> {
// Check if table exists, if not create it with empty data
let table_names = self
.connection
.table_names()
.execute()
.await
.map_err(|e| AppError::DatabaseError(format!("Failed to list tables: {e}")))?;
if !table_names.contains(&self.table_name) {
// Create the table with the Arrow schema
let schema = Schema::new(vec![
Field::new("question", DataType::Utf8, false),
Field::new("answer", DataType::Utf8, false),
Field::new("endpoint_url", DataType::Utf8, false),
Field::new("doc_type", DataType::Utf8, false),
Field::new(
"vector",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
self.embedding_dimension as i32,
),
true,
),
]);
self.connection
.create_empty_table(&self.table_name, Arc::new(schema))
.execute()
.await
.map_err(|e| AppError::DatabaseError(format!("Failed to create table: {e}")))?;
}
if !table_names.contains(&self.schema_table) {
let schema = Schema::new(vec![
Field::new("endpoint_url", DataType::Utf8, false),
Field::new("schema_map", DataType::LargeBinary, false),
Field::new("label_map", DataType::LargeBinary, false),
Field::new("predicates_list", DataType::LargeBinary, false),
Field::new("classes_list", DataType::LargeBinary, false),
Field::new("prefix_map", DataType::LargeBinary, false),
]);
self.connection
.create_empty_table(&self.schema_table, Arc::new(schema))
.execute()
.await
.map_err(|e| {
AppError::DatabaseError(format!("Failed to create schema table: {e}"))
})?;
}
Ok(())
}
pub async fn get_schemas(
&self,
// endpoints_config: &EndpointsConfig,
) -> SchemasMap {
use crate::void_schema::VoidSchema;
use arrow::array::LargeBinaryArray;
use arrow::array::StringArray;
use std::collections::HashMap;
let table = self
.connection
.open_table(&self.schema_table)
.execute()
.await;
let mut schemas_map: SchemasMap = HashMap::new();
if let Ok(table) = table {
let results = table.query().execute().await;
if let Ok(mut results) = results {
while let Some(batch) = results.next().await {
let batch = match batch {
Ok(b) => b,
Err(_) => continue,
};
let endpoint_url_col = batch
.column_by_name("endpoint_url")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let schema_map_col = batch
.column_by_name("schema_map")
.unwrap()
.as_any()
.downcast_ref::<LargeBinaryArray>()
.unwrap();
let label_map_col = batch
.column_by_name("label_map")
.unwrap()
.as_any()
.downcast_ref::<LargeBinaryArray>()
.unwrap();
let predicates_list_col = batch
.column_by_name("predicates_list")
.unwrap()
.as_any()
.downcast_ref::<LargeBinaryArray>()
.unwrap();
let classes_list_col = batch
.column_by_name("classes_list")
.unwrap()
.as_any()
.downcast_ref::<LargeBinaryArray>()
.unwrap();
let prefix_map_col = batch
.column_by_name("prefix_map")
.unwrap()
.as_any()
.downcast_ref::<LargeBinaryArray>()
.unwrap();
for i in 0..batch.num_rows() {
let endpoint_url = endpoint_url_col.value(i).to_string();
let void_schema = VoidSchema {
endpoint_url: endpoint_url.clone(),
schema_map: serde_json::from_slice(schema_map_col.value(i))
.unwrap_or_default(),
label_map: serde_json::from_slice(label_map_col.value(i))
.unwrap_or_default(),
predicates_list: serde_json::from_slice(predicates_list_col.value(i))
.unwrap_or_default(),
classes_list: serde_json::from_slice(classes_list_col.value(i))
.unwrap_or_default(),
prefix_map: serde_json::from_slice(prefix_map_col.value(i))
.unwrap_or_default(),
};
schemas_map.insert(endpoint_url, void_schema);
}
}
}
}
if !schemas_map.is_empty() {
tracing::info!(
"✨ Loaded {} SPARQL endpoints from database\n{}",
schemas_map.len(),
schemas_map
.values()
.map(|void_schema| {
format!(
"{} · {} classes | {} predicates | {} prefixes",
void_schema.endpoint_url,
void_schema.classes_list.len(),
void_schema.predicates_list.len(),
void_schema.prefix_map.len()
)
},)
.collect::<Vec<String>>()
.join("\n")
);
}
schemas_map
}
/// Insert documents into the database
pub async fn insert_docs(&self, mut docs: Vec<SearchDocument>) -> AppResult<()> {
if docs.is_empty() {
return Ok(());
}
// Generate embeddings for questions if not already present
for record in &mut docs {
if record.vector.is_none() {
let embedding = self.generate_embedding(&record.question).await?;
record.vector = Some(embedding);
}
}
// Open the existing table to check its schema
let table = self
.connection
.open_table(&self.table_name)
.execute()
.await
.map_err(|e| AppError::DatabaseError(format!("Failed to open table: {e}")))?;
// Get the existing table schema
let existing_schema = table
.schema()
.await
.map_err(|e| AppError::DatabaseError(format!("Failed to get table schema: {e}")))?;
// Check if the table has the embedding column
let has_embedding_column = existing_schema
.fields()
.iter()
.any(|field| field.name() == "vector");
let questions: Vec<String> = docs.iter().map(|r| r.question.clone()).collect();
let answers: Vec<String> = docs.iter().map(|r| r.answer.clone()).collect();
let endpoint_urls: Vec<String> = docs.iter().map(|r| r.endpoint_url.clone()).collect();
let doc_types: Vec<String> = docs.iter().map(|r| r.doc_type.clone()).collect();
let question_array = StringArray::from(questions);
let answer_array = StringArray::from(answers);
let endpoint_url_array = StringArray::from(endpoint_urls);
let doc_type_array = StringArray::from(doc_types);
let mut columns: Vec<Arc<dyn Array>> = vec![
Arc::new(question_array),
Arc::new(answer_array),
Arc::new(endpoint_url_array),
Arc::new(doc_type_array),
];
let mut schema_fields = vec![
Field::new("question", DataType::Utf8, false),
Field::new("answer", DataType::Utf8, false),
Field::new("endpoint_url", DataType::Utf8, false),
Field::new("doc_type", DataType::Utf8, false),
];
// Only add embedding column if it exists in the table
if has_embedding_column {
let mut embedding_values = Vec::new();
for record in &docs {
if let Some(ref embedding) = record.vector {
embedding_values.extend_from_slice(embedding);
} else {
// This shouldn't happen as we generate embeddings above, but handle it gracefully
embedding_values.extend(vec![0.0f32; self.embedding_dimension]);
}
}
let embedding_data = Float32Array::from(embedding_values);
let embedding_array = FixedSizeListArray::new(
Arc::new(Field::new("item", DataType::Float32, true)),
self.embedding_dimension as i32,
Arc::new(embedding_data),
None,
);
columns.push(Arc::new(embedding_array));
schema_fields.push(Field::new(
"vector",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
self.embedding_dimension as i32,
),
true,
));
}
let schema = Arc::new(Schema::new(schema_fields));
let record_batch = RecordBatch::try_new(schema.clone(), columns)
.map_err(|e| AppError::DatabaseError(format!("Failed to create record batch: {e}")))?;
// Create a `RecordBatchIterator` from the record batch
let batch_iter = RecordBatchIterator::new(vec![Ok(record_batch)].into_iter(), schema);
table
.add(Box::new(batch_iter))
.execute()
.await
.map_err(|e| AppError::DatabaseError(format!("Failed to insert documents: {e}")))?;
tracing::debug!(
"✅ Inserted {} documents into LanceDB table `{}`",
docs.len(),
self.table_name
);
Ok(())
}
/// Get the total count of documents in the table
pub async fn count_docs(&self) -> AppResult<usize> {
let table = self
.connection
.open_table(&self.table_name)
.execute()
.await
.map_err(|e| AppError::DatabaseError(format!("Failed to open table: {e}")))?;
let count = table
.count_rows(None)
.await
.map_err(|e| AppError::DatabaseError(format!("Failed to count documents: {e}")))?;
Ok(count)
}
/// Search for similar queries using semantic vector similarity
pub async fn search_similar(
&self,
query: &str,
limit: usize,
filters: Option<&HashMap<&str, &str>>,
) -> AppResult<Vec<SearchDocument>> {
let table = self
.connection
.open_table(&self.table_name)
.execute()
.await
.map_err(|e| AppError::DatabaseError(format!("Failed to open table: {e}")))?;
let query_embedding = self.generate_embedding(query).await?;
// Build the query with flexible filters
let mut query_builder = table.query();
if let Some(filters) = filters {
for (field, value) in filters.iter() {
query_builder = query_builder.only_if(format!("{field} = '{value}'"));
}
}
// Use vector similarity search with `nearest_to`
let mut results = query_builder
.limit(limit)
.nearest_to(&*query_embedding)
.map_err(|e| {
AppError::DatabaseError(format!("Failed to create nearest_to query: {e}"))
})?
.execute()
.await
.map_err(|e| {
AppError::DatabaseError(format!("Failed to execute vector search: {e}"))
})?;
// Convert results to `SearchDocument`
let mut docs = Vec::new();
while let Some(batch) = results.next().await {
let batch =
batch.map_err(|e| AppError::DatabaseError(format!("Failed to read batch: {e}")))?;
let question_col = batch
.column_by_name("question")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let answer_col = batch
.column_by_name("answer")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let endpoint_url_col = batch
.column_by_name("endpoint_url")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let doc_type_col = batch
.column_by_name("doc_type")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
for i in 0..batch.num_rows() {
let question = question_col.value(i).to_string();
let answer = answer_col.value(i).to_string();
let endpoint_url = endpoint_url_col.value(i).to_string();
let doc_type = doc_type_col.value(i).to_string();
docs.push(SearchDocument {
question,
answer,
endpoint_url,
doc_type,
vector: None, // Don't include embeddings in search results to save memory
});
}
}
Ok(docs)
}
/// Create a vector index on the question embeddings for faster similarity search
/// This should be called after inserting data into the table
pub async fn create_vector_index(&self) -> AppResult<()> {
let table = self
.connection
.open_table(&self.table_name)
.execute()
.await
.map_err(|e| AppError::DatabaseError(format!("Failed to open table: {e}")))?;
// Check if table has data before creating index
let count = table
.count_rows(None)
.await
.map_err(|e| AppError::DatabaseError(format!("Failed to count rows: {e}")))?;
if count == 0 {
tracing::warn!("Table is empty, skipping vector index creation");
return Ok(());
}
// Create index for efficient similarity search - use Auto for best performance
table
// .create_index(&["vector"], lancedb::index::Index::Auto)
.create_index(
&["vector"],
Index::IvfPq(
IvfPqIndexBuilder::default()
.distance_type(DistanceType::Cosine)
.num_partitions(50)
.num_sub_vectors(16),
),
)
.execute()
.await
.map_err(|e| AppError::DatabaseError(format!("Failed to create vector index: {e}")))?;
Ok(())
}
/// Store the current schemas_map in the LanceDB schema_table
pub async fn store_void_schemas(&self, void_schema: &VoidSchema) -> AppResult<()> {
use arrow::array::{LargeBinaryArray, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use std::sync::Arc;
let mut endpoint_urls = Vec::new();
let mut schema_maps = Vec::new();
let mut label_maps = Vec::new();
let mut predicates_lists = Vec::new();
let mut classes_lists = Vec::new();
let mut prefix_maps = Vec::new();
endpoint_urls.push(void_schema.endpoint_url.clone());
schema_maps.push(serde_json::to_vec(&void_schema.schema_map).map_err(|e| {
AppError::DatabaseError(format!("Failed to serialize schema_map: {e}"))
})?);
label_maps.push(
serde_json::to_vec(&void_schema.label_map).map_err(|e| {
AppError::DatabaseError(format!("Failed to serialize label_map: {e}"))
})?,
);
predicates_lists.push(
serde_json::to_vec(&void_schema.predicates_list).map_err(|e| {
AppError::DatabaseError(format!("Failed to serialize predicates_list: {e}"))
})?,
);
classes_lists.push(serde_json::to_vec(&void_schema.classes_list).map_err(|e| {
AppError::DatabaseError(format!("Failed to serialize classes_list: {e}"))
})?);
prefix_maps.push(serde_json::to_vec(&void_schema.prefix_map).map_err(|e| {
AppError::DatabaseError(format!("Failed to serialize prefix_map: {e}"))
})?);
let schema = Arc::new(Schema::new(vec![
Field::new("endpoint_url", DataType::Utf8, false),
Field::new("schema_map", DataType::LargeBinary, false),
Field::new("label_map", DataType::LargeBinary, false),
Field::new("predicates_list", DataType::LargeBinary, false),
Field::new("classes_list", DataType::LargeBinary, false),
Field::new("prefix_map", DataType::LargeBinary, false),
]));
let record_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(endpoint_urls)),
Arc::new(LargeBinaryArray::from(
schema_maps
.iter()
.map(|v| Some(v.as_slice()))
.collect::<Vec<_>>(),
)),
Arc::new(LargeBinaryArray::from(
label_maps
.iter()
.map(|v| Some(v.as_slice()))
.collect::<Vec<_>>(),
)),
Arc::new(LargeBinaryArray::from(
predicates_lists
.iter()
.map(|v| Some(v.as_slice()))
.collect::<Vec<_>>(),
)),
Arc::new(LargeBinaryArray::from(
classes_lists
.iter()
.map(|v| Some(v.as_slice()))
.collect::<Vec<_>>(),
)),
Arc::new(LargeBinaryArray::from(
prefix_maps
.iter()
.map(|v| Some(v.as_slice()))
.collect::<Vec<_>>(),
)),
],
)
.map_err(|e| {
AppError::DatabaseError(format!("Failed to create schema record batch: {e}"))
})?;
let table = self
.connection
.open_table(&self.schema_table)
.execute()
.await
.map_err(|e| AppError::DatabaseError(format!("Failed to open schema table: {e}")))?;
let batch_iter = arrow::record_batch::RecordBatchIterator::new(
vec![Ok(record_batch)].into_iter(),
schema,
);
table
.add(Box::new(batch_iter))
.execute()
.await
.map_err(|e| AppError::DatabaseError(format!("Failed to insert schemas_map: {e}")))?;
Ok(())
}
/// Clear all data from the table by dropping and recreating it
pub async fn clear_table(&self) -> AppResult<()> {
// Check if table exists
let table_names = self
.connection
.table_names()
.execute()
.await
.map_err(|e| AppError::DatabaseError(format!("Failed to list tables: {e}")))?;
if table_names.contains(&self.table_name) {
self.connection
.drop_table(&self.table_name)
.await
.map_err(|e| AppError::DatabaseError(format!("Failed to drop table: {e}")))?;
tracing::debug!("🗑️ Dropped table {}", self.table_name);
}
if table_names.contains(&self.schema_table) {
self.connection
.drop_table(&self.schema_table)
.await
.map_err(|e| AppError::DatabaseError(format!("Failed to drop table: {e}")))?;
tracing::debug!("🗑️ Dropped table {}", self.schema_table);
}
Ok(())
}
}