- LanceDB로 MD 문서 252개 청크 인덱싱 - /api/animal-chat에 벡터 검색 컨텍스트 주입 - 마지막 사용자 메시지로 관련 문서 검색 (top 3) - ChromaDB Windows crash로 LanceDB 채택
340 lines
11 KiB
Python
340 lines
11 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
동물약 벡터 DB RAG 모듈
|
|
- LanceDB + OpenAI text-embedding-3-small
|
|
- MD 파일 청킹 및 임베딩
|
|
- 유사도 검색
|
|
"""
|
|
|
|
import os
|
|
import re
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import List, Dict, Optional
|
|
|
|
# .env 로드
|
|
from dotenv import load_dotenv
|
|
env_path = Path(__file__).parent.parent / ".env"
|
|
load_dotenv(env_path)
|
|
|
|
# LanceDB
|
|
import lancedb
|
|
from openai import OpenAI
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 설정
|
|
LANCE_DB_PATH = Path(__file__).parent.parent / "db" / "lance_animal_drugs"
|
|
MD_DOCS_PATH = Path("C:/Users/청춘약국/source/new_anipharm")
|
|
TABLE_NAME = "animal_drugs"
|
|
CHUNK_SIZE = 1500 # 약 500 토큰
|
|
CHUNK_OVERLAP = 300 # 약 100 토큰
|
|
EMBEDDING_DIM = 1536 # text-embedding-3-small
|
|
|
|
|
|
class AnimalDrugRAG:
|
|
"""동물약 RAG 클래스 (LanceDB 버전)"""
|
|
|
|
def __init__(self, openai_api_key: str = None):
|
|
"""
|
|
Args:
|
|
openai_api_key: OpenAI API 키 (없으면 환경변수에서 가져옴)
|
|
"""
|
|
self.api_key = openai_api_key or os.getenv('OPENAI_API_KEY')
|
|
self.db = None
|
|
self.table = None
|
|
self.openai_client = None
|
|
self._initialized = False
|
|
|
|
def _init_db(self):
|
|
"""DB 초기화 (lazy loading)"""
|
|
if self._initialized:
|
|
return
|
|
|
|
try:
|
|
# LanceDB 연결
|
|
LANCE_DB_PATH.mkdir(parents=True, exist_ok=True)
|
|
self.db = lancedb.connect(str(LANCE_DB_PATH))
|
|
|
|
# OpenAI 클라이언트
|
|
if self.api_key:
|
|
self.openai_client = OpenAI(api_key=self.api_key)
|
|
else:
|
|
logger.warning("OpenAI API 키 없음")
|
|
|
|
# 기존 테이블 열기
|
|
if TABLE_NAME in self.db.table_names():
|
|
self.table = self.db.open_table(TABLE_NAME)
|
|
logger.info(f"기존 테이블 열림 (행 수: {len(self.table)})")
|
|
else:
|
|
logger.info("테이블 없음 - index_md_files() 호출 필요")
|
|
|
|
self._initialized = True
|
|
|
|
except Exception as e:
|
|
logger.error(f"AnimalDrugRAG 초기화 실패: {e}")
|
|
raise
|
|
|
|
def _get_embedding(self, text: str) -> List[float]:
|
|
"""OpenAI 임베딩 생성"""
|
|
if not self.openai_client:
|
|
raise ValueError("OpenAI 클라이언트 없음")
|
|
|
|
response = self.openai_client.embeddings.create(
|
|
model="text-embedding-3-small",
|
|
input=text
|
|
)
|
|
return response.data[0].embedding
|
|
|
|
def _get_embeddings_batch(self, texts: List[str]) -> List[List[float]]:
|
|
"""배치 임베딩 생성"""
|
|
if not self.openai_client:
|
|
raise ValueError("OpenAI 클라이언트 없음")
|
|
|
|
# OpenAI는 한 번에 최대 2048개 텍스트 처리
|
|
embeddings = []
|
|
batch_size = 100
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
batch = texts[i:i+batch_size]
|
|
response = self.openai_client.embeddings.create(
|
|
model="text-embedding-3-small",
|
|
input=batch
|
|
)
|
|
embeddings.extend([d.embedding for d in response.data])
|
|
logger.info(f"임베딩 생성: {i+len(batch)}/{len(texts)}")
|
|
|
|
return embeddings
|
|
|
|
def chunk_markdown(self, content: str, source_file: str) -> List[Dict]:
|
|
"""
|
|
마크다운 청킹 (섹션 기반)
|
|
"""
|
|
chunks = []
|
|
|
|
# ## 헤더 기준 분리
|
|
sections = re.split(r'\n(?=## )', content)
|
|
|
|
for i, section in enumerate(sections):
|
|
if not section.strip():
|
|
continue
|
|
|
|
# 섹션 제목 추출
|
|
title_match = re.match(r'^## (.+?)(?:\n|$)', section)
|
|
section_title = title_match.group(1).strip() if title_match else f"섹션{i+1}"
|
|
|
|
# 큰 섹션은 추가 분할
|
|
if len(section) > CHUNK_SIZE:
|
|
sub_chunks = self._split_by_size(section, CHUNK_SIZE, CHUNK_OVERLAP)
|
|
for j, sub_chunk in enumerate(sub_chunks):
|
|
chunk_id = f"{source_file}#{section_title}#{j}"
|
|
chunks.append({
|
|
"id": chunk_id,
|
|
"text": sub_chunk,
|
|
"source": source_file,
|
|
"section": section_title,
|
|
"chunk_index": j
|
|
})
|
|
else:
|
|
chunk_id = f"{source_file}#{section_title}"
|
|
chunks.append({
|
|
"id": chunk_id,
|
|
"text": section,
|
|
"source": source_file,
|
|
"section": section_title,
|
|
"chunk_index": 0
|
|
})
|
|
|
|
return chunks
|
|
|
|
def _split_by_size(self, text: str, size: int, overlap: int) -> List[str]:
|
|
"""텍스트를 크기 기준으로 분할"""
|
|
chunks = []
|
|
start = 0
|
|
|
|
while start < len(text):
|
|
end = start + size
|
|
|
|
# 문장 경계에서 자르기
|
|
if end < len(text):
|
|
last_break = text.rfind('\n', start, end)
|
|
if last_break == -1:
|
|
last_break = text.rfind('. ', start, end)
|
|
if last_break > start + size // 2:
|
|
end = last_break + 1
|
|
|
|
chunks.append(text[start:end])
|
|
start = end - overlap
|
|
|
|
return chunks
|
|
|
|
def index_md_files(self, md_path: Path = None) -> int:
|
|
"""
|
|
MD 파일들을 인덱싱
|
|
"""
|
|
self._init_db()
|
|
|
|
md_path = md_path or MD_DOCS_PATH
|
|
if not md_path.exists():
|
|
logger.error(f"MD 파일 경로 없음: {md_path}")
|
|
return 0
|
|
|
|
# 기존 테이블 삭제
|
|
if TABLE_NAME in self.db.table_names():
|
|
self.db.drop_table(TABLE_NAME)
|
|
logger.info("기존 테이블 삭제")
|
|
|
|
# 모든 청크 수집
|
|
all_chunks = []
|
|
md_files = list(md_path.glob("*.md"))
|
|
|
|
for md_file in md_files:
|
|
try:
|
|
content = md_file.read_text(encoding='utf-8')
|
|
chunks = self.chunk_markdown(content, md_file.name)
|
|
all_chunks.extend(chunks)
|
|
logger.info(f"청킹: {md_file.name} ({len(chunks)}개)")
|
|
except Exception as e:
|
|
logger.error(f"청킹 실패 ({md_file.name}): {e}")
|
|
|
|
if not all_chunks:
|
|
logger.warning("청크 없음")
|
|
return 0
|
|
|
|
# 임베딩 생성
|
|
texts = [c["text"] for c in all_chunks]
|
|
logger.info(f"총 {len(texts)}개 청크 임베딩 시작...")
|
|
embeddings = self._get_embeddings_batch(texts)
|
|
|
|
# 데이터 준비
|
|
data = []
|
|
for chunk, emb in zip(all_chunks, embeddings):
|
|
data.append({
|
|
"id": chunk["id"],
|
|
"text": chunk["text"],
|
|
"source": chunk["source"],
|
|
"section": chunk["section"],
|
|
"chunk_index": chunk["chunk_index"],
|
|
"vector": emb
|
|
})
|
|
|
|
# 테이블 생성
|
|
self.table = self.db.create_table(TABLE_NAME, data)
|
|
logger.info(f"인덱싱 완료: {len(data)}개 청크")
|
|
|
|
return len(data)
|
|
|
|
def search(self, query: str, n_results: int = 3) -> List[Dict]:
|
|
"""
|
|
유사도 검색
|
|
"""
|
|
self._init_db()
|
|
|
|
if self.table is None:
|
|
logger.warning("테이블 없음 - index_md_files() 필요")
|
|
return []
|
|
|
|
try:
|
|
# 쿼리 임베딩
|
|
query_emb = self._get_embedding(query)
|
|
|
|
# 검색
|
|
results = self.table.search(query_emb).limit(n_results).to_list()
|
|
|
|
output = []
|
|
for r in results:
|
|
output.append({
|
|
"text": r["text"],
|
|
"source": r["source"],
|
|
"section": r["section"],
|
|
"score": 1 - r.get("_distance", 0) # 거리 → 유사도
|
|
})
|
|
|
|
return output
|
|
|
|
except Exception as e:
|
|
logger.error(f"검색 실패: {e}")
|
|
return []
|
|
|
|
def get_context_for_chat(self, query: str, n_results: int = 3) -> str:
|
|
"""
|
|
챗봇용 컨텍스트 생성
|
|
"""
|
|
results = self.search(query, n_results)
|
|
|
|
if not results:
|
|
return ""
|
|
|
|
context_parts = ["## 📚 관련 문서 (RAG 검색 결과)"]
|
|
|
|
for i, r in enumerate(results, 1):
|
|
source = r["source"].replace(".md", "")
|
|
section = r["section"]
|
|
score = r["score"]
|
|
text = r["text"][:1500]
|
|
|
|
context_parts.append(f"\n### [{i}] {source} - {section} (관련도: {score:.0%})")
|
|
context_parts.append(text)
|
|
|
|
return "\n".join(context_parts)
|
|
|
|
def get_stats(self) -> Dict:
|
|
"""통계 정보 반환"""
|
|
self._init_db()
|
|
|
|
count = len(self.table) if self.table else 0
|
|
return {
|
|
"table_name": TABLE_NAME,
|
|
"document_count": count,
|
|
"db_path": str(LANCE_DB_PATH)
|
|
}
|
|
|
|
|
|
# 싱글톤 인스턴스
|
|
_rag_instance: Optional[AnimalDrugRAG] = None
|
|
|
|
|
|
def get_animal_rag(api_key: str = None) -> AnimalDrugRAG:
|
|
"""싱글톤 RAG 인스턴스 반환"""
|
|
global _rag_instance
|
|
if _rag_instance is None:
|
|
_rag_instance = AnimalDrugRAG(api_key)
|
|
return _rag_instance
|
|
|
|
|
|
# CLI 테스트
|
|
if __name__ == "__main__":
|
|
import sys
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
rag = AnimalDrugRAG()
|
|
|
|
if len(sys.argv) > 1:
|
|
cmd = sys.argv[1]
|
|
|
|
if cmd == "index":
|
|
count = rag.index_md_files()
|
|
print(f"\n✅ {count}개 청크 인덱싱 완료")
|
|
|
|
elif cmd == "search" and len(sys.argv) > 2:
|
|
query = " ".join(sys.argv[2:])
|
|
results = rag.search(query)
|
|
print(f"\n🔍 검색: {query}")
|
|
for r in results:
|
|
print(f"\n[{r['score']:.0%}] {r['source']} - {r['section']}")
|
|
print(r['text'][:300] + "...")
|
|
|
|
elif cmd == "stats":
|
|
stats = rag.get_stats()
|
|
print(f"\n📊 통계:")
|
|
print(f" - 테이블: {stats['table_name']}")
|
|
print(f" - 문서 수: {stats['document_count']}")
|
|
print(f" - DB 경로: {stats['db_path']}")
|
|
|
|
else:
|
|
print("사용법:")
|
|
print(" python animal_rag.py index # MD 파일 인덱싱")
|
|
print(" python animal_rag.py search 질문 # 검색")
|
|
print(" python animal_rag.py stats # 통계")
|