Skip to content

Commit ce45617

Browse files
feat: support add and delete from MySQLVectorStore (#53)
1 parent a1c9411 commit ce45617

File tree

3 files changed

+83
-1
lines changed

3 files changed

+83
-1
lines changed

src/langchain_google_cloud_sql_mysql/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def init_vectorstore_table(
342342
table_name (str): The MySQL database table name.
343343
vector_size (int): Vector size for the embedding model to be used.
344344
content_column (str): Name of the column to store document content.
345-
Deafult: `page_content`.
345+
Default: `page_content`.
346346
embedding_column (str) : Name of the column to store vector embeddings.
347347
Default: `embedding`.
348348
metadata_columns (List[Column]): A list of Columns to create for custom

src/langchain_google_cloud_sql_mysql/vectorstore.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,32 @@ def add_texts(
184184
)
185185
return ids
186186

187+
def add_documents(
188+
self,
189+
documents: List[Document],
190+
ids: Optional[List[str]] = None,
191+
**kwargs: Any,
192+
) -> List[str]:
193+
texts = [doc.page_content for doc in documents]
194+
metadatas = [doc.metadata for doc in documents]
195+
ids = self.add_texts(texts, metadatas=metadatas, ids=ids, **kwargs)
196+
return ids
197+
198+
def delete(
199+
self,
200+
ids: Optional[List[str]] = None,
201+
**kwargs: Any,
202+
) -> bool:
203+
if not ids:
204+
return False
205+
206+
id_list = ", ".join([f"'{id}'" for id in ids])
207+
query = (
208+
f"DELETE FROM `{self.table_name}` WHERE `{self.id_column}` in ({id_list})"
209+
)
210+
self.engine._execute(query)
211+
return True
212+
187213
@classmethod
188214
def from_texts( # type: ignore[override]
189215
cls: Type[MySQLVectorStore],

tests/integration/test_mysql_vectorstore.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,30 @@ def test_add_texts_edge_cases(self, engine, vs):
146146
assert len(results) == 3
147147
engine._execute(f"TRUNCATE TABLE `{DEFAULT_TABLE}`")
148148

149+
def test_add_docs(self, engine, vs):
150+
ids = [str(uuid.uuid4()) for i in range(len(texts))]
151+
vs.add_documents(docs, ids=ids)
152+
results = engine._fetch(f"SELECT * FROM `{DEFAULT_TABLE}`")
153+
assert len(results) == 3
154+
engine._execute(f"TRUNCATE TABLE `{DEFAULT_TABLE}`")
155+
149156
def test_add_embedding(self, engine, vs):
150157
ids = [str(uuid.uuid4()) for _ in range(len(texts))]
151158
vs._add_embeddings(texts, embeddings, metadatas, ids)
152159
results = engine._fetch(f"SELECT * FROM `{DEFAULT_TABLE}`")
153160
assert len(results) == 3
154161
engine._execute(f"TRUNCATE TABLE `{DEFAULT_TABLE}`")
155162

163+
def test_delete(self, engine, vs):
164+
ids = [str(uuid.uuid4()) for _ in range(len(texts))]
165+
vs.add_texts(texts, ids=ids)
166+
results = engine._fetch(f"SELECT * FROM `{DEFAULT_TABLE}`")
167+
assert len(results) == 3
168+
# delete an ID
169+
vs.delete([ids[0]])
170+
results = engine._fetch(f"SELECT * FROM `{DEFAULT_TABLE}`")
171+
assert len(results) == 2
172+
156173
def test_add_texts_custom(self, engine, vs_custom):
157174
ids = [str(uuid.uuid4()) for _ in range(len(texts))]
158175
vs_custom.add_texts(texts, ids=ids)
@@ -172,11 +189,50 @@ def test_add_texts_custom(self, engine, vs_custom):
172189
assert len(results) == 6
173190
engine._execute(f"TRUNCATE TABLE `{CUSTOM_TABLE}`")
174191

192+
def test_add_docs_custom(self, engine, vs_custom):
193+
ids = [str(uuid.uuid4()) for i in range(len(texts))]
194+
docs = [
195+
Document(
196+
page_content=texts[i],
197+
metadata={"page": str(i), "source": "google.com"},
198+
)
199+
for i in range(len(texts))
200+
]
201+
vs_custom.add_documents(docs, ids=ids)
202+
203+
results = engine._fetch(f"SELECT * FROM `{CUSTOM_TABLE}`")
204+
content = [result["mycontent"] for result in results]
205+
assert len(results) == 3
206+
assert "foo" in content
207+
assert "bar" in content
208+
assert "baz" in content
209+
assert results[0]["myembedding"]
210+
pages = [result["page"] for result in results]
211+
assert "0" in pages
212+
assert "1" in pages
213+
assert "2" in pages
214+
assert results[0]["source"] == "google.com"
215+
engine._execute(f"TRUNCATE TABLE `{CUSTOM_TABLE}`")
216+
175217
def test_add_embedding_custom(self, engine, vs_custom):
176218
ids = [str(uuid.uuid4()) for _ in range(len(texts))]
177219
vs_custom._add_embeddings(texts, embeddings, metadatas, ids)
178220
results = engine._fetch(f"SELECT * FROM `{CUSTOM_TABLE}`")
179221
assert len(results) == 3
180222
engine._execute(f"TRUNCATE TABLE `{CUSTOM_TABLE}`")
181223

224+
def test_delete_custom(self, engine, vs_custom):
225+
ids = [str(uuid.uuid4()) for _ in range(len(texts))]
226+
vs_custom.add_texts(texts, ids=ids)
227+
results = engine._fetch(f"SELECT * FROM `{CUSTOM_TABLE}`")
228+
content = [result["mycontent"] for result in results]
229+
assert len(results) == 3
230+
assert "foo" in content
231+
# delete an ID
232+
vs_custom.delete([ids[0]])
233+
results = engine._fetch(f"SELECT * FROM `{CUSTOM_TABLE}`")
234+
content = [result["mycontent"] for result in results]
235+
assert len(results) == 2
236+
assert "foo" not in content
237+
182238
# Need tests for store metadata=False

0 commit comments

Comments
 (0)