feat(animal-chat): LanceDB 벡터 검색 RAG 통합

- LanceDB로 MD 문서 252개 청크 인덱싱
- /api/animal-chat에 벡터 검색 컨텍스트 주입
- 마지막 사용자 메시지로 관련 문서 검색 (top 3)
- ChromaDB Windows crash로 LanceDB 채택
This commit is contained in:
thug0bin 2026-03-08 15:00:39 +09:00
parent 3631da2953
commit be1e6c2bb7
16 changed files with 400 additions and 1 deletions

View File

@ -3192,10 +3192,22 @@ def api_animal_chat():
{chr(10).join(product_lines)}
"""
# 벡터 DB 검색 (LanceDB RAG)
vector_context = ""
try:
from utils.animal_rag import get_animal_rag
# 마지막 사용자 메시지로 검색
last_user_msg = next((m['content'] for m in reversed(messages) if m.get('role') == 'user'), '')
if last_user_msg:
rag = get_animal_rag()
vector_context = rag.get_context_for_chat(last_user_msg, n_results=3)
except Exception as e:
logging.warning(f"벡터 검색 실패 (무시): {e}")
# System Prompt 구성
system_prompt = ANIMAL_CHAT_SYSTEM_PROMPT.format(
available_products=available_products_text,
knowledge_base=ANIMAL_DRUG_KNOWLEDGE
knowledge_base=ANIMAL_DRUG_KNOWLEDGE + "\n\n" + vector_context if vector_context else ANIMAL_DRUG_KNOWLEDGE
)
# OpenAI API 호출

21
backend/test_chroma.py Normal file
View File

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
import os
from dotenv import load_dotenv
load_dotenv()
import chromadb
print('1. creating client...', flush=True)
client = chromadb.PersistentClient(path='./db/chroma_test3')
print('2. client created', flush=True)
# 임베딩 없이 컬렉션 생성
col = client.get_or_create_collection('test3')
print('3. collection created (no ef)', flush=True)
col.add(ids=['1'], documents=['hello world'], embeddings=[[0.1]*384])
print('4. document added with manual embedding', flush=True)
result = col.query(query_embeddings=[[0.1]*384], n_results=1)
print(f'5. query result: {len(result["documents"][0])} docs', flush=True)
print('Done!')

27
backend/test_rag.py Normal file
View File

@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
import sys
import os
sys.path.insert(0, os.path.dirname(__file__))
print("1. Starting...")
print(f" CWD: {os.getcwd()}")
from dotenv import load_dotenv
load_dotenv()
print(f"2. API Key: {os.getenv('OPENAI_API_KEY', 'NOT SET')[:20]}...")
from utils.animal_rag import AnimalDrugRAG
print("3. Module imported")
rag = AnimalDrugRAG()
print("4. RAG created")
try:
count = rag.index_md_files()
print(f"5. Indexed: {count} chunks")
except Exception as e:
print(f"5. Error: {e}")
import traceback
traceback.print_exc()
print("6. Done")

View File

339
backend/utils/animal_rag.py Normal file
View File

@ -0,0 +1,339 @@
# -*- coding: utf-8 -*-
"""
동물약 벡터 DB RAG 모듈
- LanceDB + OpenAI text-embedding-3-small
- MD 파일 청킹 임베딩
- 유사도 검색
"""
import os
import re
import logging
from pathlib import Path
from typing import List, Dict, Optional
# .env 로드
from dotenv import load_dotenv
env_path = Path(__file__).parent.parent / ".env"
load_dotenv(env_path)
# LanceDB
import lancedb
from openai import OpenAI
logger = logging.getLogger(__name__)
# 설정
LANCE_DB_PATH = Path(__file__).parent.parent / "db" / "lance_animal_drugs"
MD_DOCS_PATH = Path("C:/Users/청춘약국/source/new_anipharm")
TABLE_NAME = "animal_drugs"
CHUNK_SIZE = 1500 # 약 500 토큰
CHUNK_OVERLAP = 300 # 약 100 토큰
EMBEDDING_DIM = 1536 # text-embedding-3-small
class AnimalDrugRAG:
"""동물약 RAG 클래스 (LanceDB 버전)"""
def __init__(self, openai_api_key: str = None):
"""
Args:
openai_api_key: OpenAI API (없으면 환경변수에서 가져옴)
"""
self.api_key = openai_api_key or os.getenv('OPENAI_API_KEY')
self.db = None
self.table = None
self.openai_client = None
self._initialized = False
def _init_db(self):
"""DB 초기화 (lazy loading)"""
if self._initialized:
return
try:
# LanceDB 연결
LANCE_DB_PATH.mkdir(parents=True, exist_ok=True)
self.db = lancedb.connect(str(LANCE_DB_PATH))
# OpenAI 클라이언트
if self.api_key:
self.openai_client = OpenAI(api_key=self.api_key)
else:
logger.warning("OpenAI API 키 없음")
# 기존 테이블 열기
if TABLE_NAME in self.db.table_names():
self.table = self.db.open_table(TABLE_NAME)
logger.info(f"기존 테이블 열림 (행 수: {len(self.table)})")
else:
logger.info("테이블 없음 - index_md_files() 호출 필요")
self._initialized = True
except Exception as e:
logger.error(f"AnimalDrugRAG 초기화 실패: {e}")
raise
def _get_embedding(self, text: str) -> List[float]:
"""OpenAI 임베딩 생성"""
if not self.openai_client:
raise ValueError("OpenAI 클라이언트 없음")
response = self.openai_client.embeddings.create(
model="text-embedding-3-small",
input=text
)
return response.data[0].embedding
def _get_embeddings_batch(self, texts: List[str]) -> List[List[float]]:
"""배치 임베딩 생성"""
if not self.openai_client:
raise ValueError("OpenAI 클라이언트 없음")
# OpenAI는 한 번에 최대 2048개 텍스트 처리
embeddings = []
batch_size = 100
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
response = self.openai_client.embeddings.create(
model="text-embedding-3-small",
input=batch
)
embeddings.extend([d.embedding for d in response.data])
logger.info(f"임베딩 생성: {i+len(batch)}/{len(texts)}")
return embeddings
def chunk_markdown(self, content: str, source_file: str) -> List[Dict]:
"""
마크다운 청킹 (섹션 기반)
"""
chunks = []
# ## 헤더 기준 분리
sections = re.split(r'\n(?=## )', content)
for i, section in enumerate(sections):
if not section.strip():
continue
# 섹션 제목 추출
title_match = re.match(r'^## (.+?)(?:\n|$)', section)
section_title = title_match.group(1).strip() if title_match else f"섹션{i+1}"
# 큰 섹션은 추가 분할
if len(section) > CHUNK_SIZE:
sub_chunks = self._split_by_size(section, CHUNK_SIZE, CHUNK_OVERLAP)
for j, sub_chunk in enumerate(sub_chunks):
chunk_id = f"{source_file}#{section_title}#{j}"
chunks.append({
"id": chunk_id,
"text": sub_chunk,
"source": source_file,
"section": section_title,
"chunk_index": j
})
else:
chunk_id = f"{source_file}#{section_title}"
chunks.append({
"id": chunk_id,
"text": section,
"source": source_file,
"section": section_title,
"chunk_index": 0
})
return chunks
def _split_by_size(self, text: str, size: int, overlap: int) -> List[str]:
"""텍스트를 크기 기준으로 분할"""
chunks = []
start = 0
while start < len(text):
end = start + size
# 문장 경계에서 자르기
if end < len(text):
last_break = text.rfind('\n', start, end)
if last_break == -1:
last_break = text.rfind('. ', start, end)
if last_break > start + size // 2:
end = last_break + 1
chunks.append(text[start:end])
start = end - overlap
return chunks
def index_md_files(self, md_path: Path = None) -> int:
"""
MD 파일들을 인덱싱
"""
self._init_db()
md_path = md_path or MD_DOCS_PATH
if not md_path.exists():
logger.error(f"MD 파일 경로 없음: {md_path}")
return 0
# 기존 테이블 삭제
if TABLE_NAME in self.db.table_names():
self.db.drop_table(TABLE_NAME)
logger.info("기존 테이블 삭제")
# 모든 청크 수집
all_chunks = []
md_files = list(md_path.glob("*.md"))
for md_file in md_files:
try:
content = md_file.read_text(encoding='utf-8')
chunks = self.chunk_markdown(content, md_file.name)
all_chunks.extend(chunks)
logger.info(f"청킹: {md_file.name} ({len(chunks)}개)")
except Exception as e:
logger.error(f"청킹 실패 ({md_file.name}): {e}")
if not all_chunks:
logger.warning("청크 없음")
return 0
# 임베딩 생성
texts = [c["text"] for c in all_chunks]
logger.info(f"{len(texts)}개 청크 임베딩 시작...")
embeddings = self._get_embeddings_batch(texts)
# 데이터 준비
data = []
for chunk, emb in zip(all_chunks, embeddings):
data.append({
"id": chunk["id"],
"text": chunk["text"],
"source": chunk["source"],
"section": chunk["section"],
"chunk_index": chunk["chunk_index"],
"vector": emb
})
# 테이블 생성
self.table = self.db.create_table(TABLE_NAME, data)
logger.info(f"인덱싱 완료: {len(data)}개 청크")
return len(data)
def search(self, query: str, n_results: int = 3) -> List[Dict]:
"""
유사도 검색
"""
self._init_db()
if self.table is None:
logger.warning("테이블 없음 - index_md_files() 필요")
return []
try:
# 쿼리 임베딩
query_emb = self._get_embedding(query)
# 검색
results = self.table.search(query_emb).limit(n_results).to_list()
output = []
for r in results:
output.append({
"text": r["text"],
"source": r["source"],
"section": r["section"],
"score": 1 - r.get("_distance", 0) # 거리 → 유사도
})
return output
except Exception as e:
logger.error(f"검색 실패: {e}")
return []
def get_context_for_chat(self, query: str, n_results: int = 3) -> str:
"""
챗봇용 컨텍스트 생성
"""
results = self.search(query, n_results)
if not results:
return ""
context_parts = ["## 📚 관련 문서 (RAG 검색 결과)"]
for i, r in enumerate(results, 1):
source = r["source"].replace(".md", "")
section = r["section"]
score = r["score"]
text = r["text"][:1500]
context_parts.append(f"\n### [{i}] {source} - {section} (관련도: {score:.0%})")
context_parts.append(text)
return "\n".join(context_parts)
def get_stats(self) -> Dict:
"""통계 정보 반환"""
self._init_db()
count = len(self.table) if self.table else 0
return {
"table_name": TABLE_NAME,
"document_count": count,
"db_path": str(LANCE_DB_PATH)
}
# 싱글톤 인스턴스
_rag_instance: Optional[AnimalDrugRAG] = None
def get_animal_rag(api_key: str = None) -> AnimalDrugRAG:
"""싱글톤 RAG 인스턴스 반환"""
global _rag_instance
if _rag_instance is None:
_rag_instance = AnimalDrugRAG(api_key)
return _rag_instance
# CLI 테스트
if __name__ == "__main__":
import sys
logging.basicConfig(level=logging.INFO)
rag = AnimalDrugRAG()
if len(sys.argv) > 1:
cmd = sys.argv[1]
if cmd == "index":
count = rag.index_md_files()
print(f"\n{count}개 청크 인덱싱 완료")
elif cmd == "search" and len(sys.argv) > 2:
query = " ".join(sys.argv[2:])
results = rag.search(query)
print(f"\n🔍 검색: {query}")
for r in results:
print(f"\n[{r['score']:.0%}] {r['source']} - {r['section']}")
print(r['text'][:300] + "...")
elif cmd == "stats":
stats = rag.get_stats()
print(f"\n📊 통계:")
print(f" - 테이블: {stats['table_name']}")
print(f" - 문서 수: {stats['document_count']}")
print(f" - DB 경로: {stats['db_path']}")
else:
print("사용법:")
print(" python animal_rag.py index # MD 파일 인덱싱")
print(" python animal_rag.py search 질문 # 검색")
print(" python animal_rag.py stats # 통계")