feat(animal-chat): LanceDB 벡터 검색 RAG 통합
- LanceDB로 MD 문서 252개 청크 인덱싱 - /api/animal-chat에 벡터 검색 컨텍스트 주입 - 마지막 사용자 메시지로 관련 문서 검색 (top 3) - ChromaDB Windows crash로 LanceDB 채택
This commit is contained in:
339
backend/utils/animal_rag.py
Normal file
339
backend/utils/animal_rag.py
Normal file
@@ -0,0 +1,339 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
동물약 벡터 DB RAG 모듈
|
||||
- LanceDB + OpenAI text-embedding-3-small
|
||||
- MD 파일 청킹 및 임베딩
|
||||
- 유사도 검색
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
# .env 로드
|
||||
from dotenv import load_dotenv
|
||||
env_path = Path(__file__).parent.parent / ".env"
|
||||
load_dotenv(env_path)
|
||||
|
||||
# LanceDB
|
||||
import lancedb
|
||||
from openai import OpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 설정
|
||||
LANCE_DB_PATH = Path(__file__).parent.parent / "db" / "lance_animal_drugs"
|
||||
MD_DOCS_PATH = Path("C:/Users/청춘약국/source/new_anipharm")
|
||||
TABLE_NAME = "animal_drugs"
|
||||
CHUNK_SIZE = 1500 # 약 500 토큰
|
||||
CHUNK_OVERLAP = 300 # 약 100 토큰
|
||||
EMBEDDING_DIM = 1536 # text-embedding-3-small
|
||||
|
||||
|
||||
class AnimalDrugRAG:
|
||||
"""동물약 RAG 클래스 (LanceDB 버전)"""
|
||||
|
||||
def __init__(self, openai_api_key: str = None):
|
||||
"""
|
||||
Args:
|
||||
openai_api_key: OpenAI API 키 (없으면 환경변수에서 가져옴)
|
||||
"""
|
||||
self.api_key = openai_api_key or os.getenv('OPENAI_API_KEY')
|
||||
self.db = None
|
||||
self.table = None
|
||||
self.openai_client = None
|
||||
self._initialized = False
|
||||
|
||||
def _init_db(self):
|
||||
"""DB 초기화 (lazy loading)"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
# LanceDB 연결
|
||||
LANCE_DB_PATH.mkdir(parents=True, exist_ok=True)
|
||||
self.db = lancedb.connect(str(LANCE_DB_PATH))
|
||||
|
||||
# OpenAI 클라이언트
|
||||
if self.api_key:
|
||||
self.openai_client = OpenAI(api_key=self.api_key)
|
||||
else:
|
||||
logger.warning("OpenAI API 키 없음")
|
||||
|
||||
# 기존 테이블 열기
|
||||
if TABLE_NAME in self.db.table_names():
|
||||
self.table = self.db.open_table(TABLE_NAME)
|
||||
logger.info(f"기존 테이블 열림 (행 수: {len(self.table)})")
|
||||
else:
|
||||
logger.info("테이블 없음 - index_md_files() 호출 필요")
|
||||
|
||||
self._initialized = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AnimalDrugRAG 초기화 실패: {e}")
|
||||
raise
|
||||
|
||||
def _get_embedding(self, text: str) -> List[float]:
|
||||
"""OpenAI 임베딩 생성"""
|
||||
if not self.openai_client:
|
||||
raise ValueError("OpenAI 클라이언트 없음")
|
||||
|
||||
response = self.openai_client.embeddings.create(
|
||||
model="text-embedding-3-small",
|
||||
input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
def _get_embeddings_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""배치 임베딩 생성"""
|
||||
if not self.openai_client:
|
||||
raise ValueError("OpenAI 클라이언트 없음")
|
||||
|
||||
# OpenAI는 한 번에 최대 2048개 텍스트 처리
|
||||
embeddings = []
|
||||
batch_size = 100
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i+batch_size]
|
||||
response = self.openai_client.embeddings.create(
|
||||
model="text-embedding-3-small",
|
||||
input=batch
|
||||
)
|
||||
embeddings.extend([d.embedding for d in response.data])
|
||||
logger.info(f"임베딩 생성: {i+len(batch)}/{len(texts)}")
|
||||
|
||||
return embeddings
|
||||
|
||||
def chunk_markdown(self, content: str, source_file: str) -> List[Dict]:
|
||||
"""
|
||||
마크다운 청킹 (섹션 기반)
|
||||
"""
|
||||
chunks = []
|
||||
|
||||
# ## 헤더 기준 분리
|
||||
sections = re.split(r'\n(?=## )', content)
|
||||
|
||||
for i, section in enumerate(sections):
|
||||
if not section.strip():
|
||||
continue
|
||||
|
||||
# 섹션 제목 추출
|
||||
title_match = re.match(r'^## (.+?)(?:\n|$)', section)
|
||||
section_title = title_match.group(1).strip() if title_match else f"섹션{i+1}"
|
||||
|
||||
# 큰 섹션은 추가 분할
|
||||
if len(section) > CHUNK_SIZE:
|
||||
sub_chunks = self._split_by_size(section, CHUNK_SIZE, CHUNK_OVERLAP)
|
||||
for j, sub_chunk in enumerate(sub_chunks):
|
||||
chunk_id = f"{source_file}#{section_title}#{j}"
|
||||
chunks.append({
|
||||
"id": chunk_id,
|
||||
"text": sub_chunk,
|
||||
"source": source_file,
|
||||
"section": section_title,
|
||||
"chunk_index": j
|
||||
})
|
||||
else:
|
||||
chunk_id = f"{source_file}#{section_title}"
|
||||
chunks.append({
|
||||
"id": chunk_id,
|
||||
"text": section,
|
||||
"source": source_file,
|
||||
"section": section_title,
|
||||
"chunk_index": 0
|
||||
})
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_by_size(self, text: str, size: int, overlap: int) -> List[str]:
|
||||
"""텍스트를 크기 기준으로 분할"""
|
||||
chunks = []
|
||||
start = 0
|
||||
|
||||
while start < len(text):
|
||||
end = start + size
|
||||
|
||||
# 문장 경계에서 자르기
|
||||
if end < len(text):
|
||||
last_break = text.rfind('\n', start, end)
|
||||
if last_break == -1:
|
||||
last_break = text.rfind('. ', start, end)
|
||||
if last_break > start + size // 2:
|
||||
end = last_break + 1
|
||||
|
||||
chunks.append(text[start:end])
|
||||
start = end - overlap
|
||||
|
||||
return chunks
|
||||
|
||||
def index_md_files(self, md_path: Path = None) -> int:
|
||||
"""
|
||||
MD 파일들을 인덱싱
|
||||
"""
|
||||
self._init_db()
|
||||
|
||||
md_path = md_path or MD_DOCS_PATH
|
||||
if not md_path.exists():
|
||||
logger.error(f"MD 파일 경로 없음: {md_path}")
|
||||
return 0
|
||||
|
||||
# 기존 테이블 삭제
|
||||
if TABLE_NAME in self.db.table_names():
|
||||
self.db.drop_table(TABLE_NAME)
|
||||
logger.info("기존 테이블 삭제")
|
||||
|
||||
# 모든 청크 수집
|
||||
all_chunks = []
|
||||
md_files = list(md_path.glob("*.md"))
|
||||
|
||||
for md_file in md_files:
|
||||
try:
|
||||
content = md_file.read_text(encoding='utf-8')
|
||||
chunks = self.chunk_markdown(content, md_file.name)
|
||||
all_chunks.extend(chunks)
|
||||
logger.info(f"청킹: {md_file.name} ({len(chunks)}개)")
|
||||
except Exception as e:
|
||||
logger.error(f"청킹 실패 ({md_file.name}): {e}")
|
||||
|
||||
if not all_chunks:
|
||||
logger.warning("청크 없음")
|
||||
return 0
|
||||
|
||||
# 임베딩 생성
|
||||
texts = [c["text"] for c in all_chunks]
|
||||
logger.info(f"총 {len(texts)}개 청크 임베딩 시작...")
|
||||
embeddings = self._get_embeddings_batch(texts)
|
||||
|
||||
# 데이터 준비
|
||||
data = []
|
||||
for chunk, emb in zip(all_chunks, embeddings):
|
||||
data.append({
|
||||
"id": chunk["id"],
|
||||
"text": chunk["text"],
|
||||
"source": chunk["source"],
|
||||
"section": chunk["section"],
|
||||
"chunk_index": chunk["chunk_index"],
|
||||
"vector": emb
|
||||
})
|
||||
|
||||
# 테이블 생성
|
||||
self.table = self.db.create_table(TABLE_NAME, data)
|
||||
logger.info(f"인덱싱 완료: {len(data)}개 청크")
|
||||
|
||||
return len(data)
|
||||
|
||||
def search(self, query: str, n_results: int = 3) -> List[Dict]:
|
||||
"""
|
||||
유사도 검색
|
||||
"""
|
||||
self._init_db()
|
||||
|
||||
if self.table is None:
|
||||
logger.warning("테이블 없음 - index_md_files() 필요")
|
||||
return []
|
||||
|
||||
try:
|
||||
# 쿼리 임베딩
|
||||
query_emb = self._get_embedding(query)
|
||||
|
||||
# 검색
|
||||
results = self.table.search(query_emb).limit(n_results).to_list()
|
||||
|
||||
output = []
|
||||
for r in results:
|
||||
output.append({
|
||||
"text": r["text"],
|
||||
"source": r["source"],
|
||||
"section": r["section"],
|
||||
"score": 1 - r.get("_distance", 0) # 거리 → 유사도
|
||||
})
|
||||
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"검색 실패: {e}")
|
||||
return []
|
||||
|
||||
def get_context_for_chat(self, query: str, n_results: int = 3) -> str:
|
||||
"""
|
||||
챗봇용 컨텍스트 생성
|
||||
"""
|
||||
results = self.search(query, n_results)
|
||||
|
||||
if not results:
|
||||
return ""
|
||||
|
||||
context_parts = ["## 📚 관련 문서 (RAG 검색 결과)"]
|
||||
|
||||
for i, r in enumerate(results, 1):
|
||||
source = r["source"].replace(".md", "")
|
||||
section = r["section"]
|
||||
score = r["score"]
|
||||
text = r["text"][:1500]
|
||||
|
||||
context_parts.append(f"\n### [{i}] {source} - {section} (관련도: {score:.0%})")
|
||||
context_parts.append(text)
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""통계 정보 반환"""
|
||||
self._init_db()
|
||||
|
||||
count = len(self.table) if self.table else 0
|
||||
return {
|
||||
"table_name": TABLE_NAME,
|
||||
"document_count": count,
|
||||
"db_path": str(LANCE_DB_PATH)
|
||||
}
|
||||
|
||||
|
||||
# 싱글톤 인스턴스
|
||||
_rag_instance: Optional[AnimalDrugRAG] = None
|
||||
|
||||
|
||||
def get_animal_rag(api_key: str = None) -> AnimalDrugRAG:
|
||||
"""싱글톤 RAG 인스턴스 반환"""
|
||||
global _rag_instance
|
||||
if _rag_instance is None:
|
||||
_rag_instance = AnimalDrugRAG(api_key)
|
||||
return _rag_instance
|
||||
|
||||
|
||||
# CLI 테스트
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
rag = AnimalDrugRAG()
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
cmd = sys.argv[1]
|
||||
|
||||
if cmd == "index":
|
||||
count = rag.index_md_files()
|
||||
print(f"\n✅ {count}개 청크 인덱싱 완료")
|
||||
|
||||
elif cmd == "search" and len(sys.argv) > 2:
|
||||
query = " ".join(sys.argv[2:])
|
||||
results = rag.search(query)
|
||||
print(f"\n🔍 검색: {query}")
|
||||
for r in results:
|
||||
print(f"\n[{r['score']:.0%}] {r['source']} - {r['section']}")
|
||||
print(r['text'][:300] + "...")
|
||||
|
||||
elif cmd == "stats":
|
||||
stats = rag.get_stats()
|
||||
print(f"\n📊 통계:")
|
||||
print(f" - 테이블: {stats['table_name']}")
|
||||
print(f" - 문서 수: {stats['document_count']}")
|
||||
print(f" - DB 경로: {stats['db_path']}")
|
||||
|
||||
else:
|
||||
print("사용법:")
|
||||
print(" python animal_rag.py index # MD 파일 인덱싱")
|
||||
print(" python animal_rag.py search 질문 # 검색")
|
||||
print(" python animal_rag.py stats # 통계")
|
||||
Reference in New Issue
Block a user