feat(animal-chat): LanceDB 벡터 검색 RAG 통합
- LanceDB로 MD 문서 252개 청크 인덱싱 - /api/animal-chat에 벡터 검색 컨텍스트 주입 - 마지막 사용자 메시지로 관련 문서 검색 (top 3) - ChromaDB Windows crash로 LanceDB 채택
This commit is contained in:
parent
3631da2953
commit
be1e6c2bb7
@ -3192,10 +3192,22 @@ def api_animal_chat():
|
||||
{chr(10).join(product_lines)}
|
||||
"""
|
||||
|
||||
# 벡터 DB 검색 (LanceDB RAG)
|
||||
vector_context = ""
|
||||
try:
|
||||
from utils.animal_rag import get_animal_rag
|
||||
# 마지막 사용자 메시지로 검색
|
||||
last_user_msg = next((m['content'] for m in reversed(messages) if m.get('role') == 'user'), '')
|
||||
if last_user_msg:
|
||||
rag = get_animal_rag()
|
||||
vector_context = rag.get_context_for_chat(last_user_msg, n_results=3)
|
||||
except Exception as e:
|
||||
logging.warning(f"벡터 검색 실패 (무시): {e}")
|
||||
|
||||
# System Prompt 구성
|
||||
system_prompt = ANIMAL_CHAT_SYSTEM_PROMPT.format(
|
||||
available_products=available_products_text,
|
||||
knowledge_base=ANIMAL_DRUG_KNOWLEDGE
|
||||
knowledge_base=ANIMAL_DRUG_KNOWLEDGE + "\n\n" + vector_context if vector_context else ANIMAL_DRUG_KNOWLEDGE
|
||||
)
|
||||
|
||||
# OpenAI API 호출
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
21
backend/test_chroma.py
Normal file
21
backend/test_chroma.py
Normal file
@ -0,0 +1,21 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
import chromadb
|
||||
|
||||
print('1. creating client...', flush=True)
|
||||
client = chromadb.PersistentClient(path='./db/chroma_test3')
|
||||
print('2. client created', flush=True)
|
||||
|
||||
# 임베딩 없이 컬렉션 생성
|
||||
col = client.get_or_create_collection('test3')
|
||||
print('3. collection created (no ef)', flush=True)
|
||||
|
||||
col.add(ids=['1'], documents=['hello world'], embeddings=[[0.1]*384])
|
||||
print('4. document added with manual embedding', flush=True)
|
||||
|
||||
result = col.query(query_embeddings=[[0.1]*384], n_results=1)
|
||||
print(f'5. query result: {len(result["documents"][0])} docs', flush=True)
|
||||
print('Done!')
|
||||
27
backend/test_rag.py
Normal file
27
backend/test_rag.py
Normal file
@ -0,0 +1,27 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
print("1. Starting...")
|
||||
print(f" CWD: {os.getcwd()}")
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
print(f"2. API Key: {os.getenv('OPENAI_API_KEY', 'NOT SET')[:20]}...")
|
||||
|
||||
from utils.animal_rag import AnimalDrugRAG
|
||||
print("3. Module imported")
|
||||
|
||||
rag = AnimalDrugRAG()
|
||||
print("4. RAG created")
|
||||
|
||||
try:
|
||||
count = rag.index_md_files()
|
||||
print(f"5. Indexed: {count} chunks")
|
||||
except Exception as e:
|
||||
print(f"5. Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
print("6. Done")
|
||||
0
backend/test_rag_output.txt
Normal file
0
backend/test_rag_output.txt
Normal file
339
backend/utils/animal_rag.py
Normal file
339
backend/utils/animal_rag.py
Normal file
@ -0,0 +1,339 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
동물약 벡터 DB RAG 모듈
|
||||
- LanceDB + OpenAI text-embedding-3-small
|
||||
- MD 파일 청킹 및 임베딩
|
||||
- 유사도 검색
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
# .env 로드
|
||||
from dotenv import load_dotenv
|
||||
env_path = Path(__file__).parent.parent / ".env"
|
||||
load_dotenv(env_path)
|
||||
|
||||
# LanceDB
|
||||
import lancedb
|
||||
from openai import OpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 설정
|
||||
LANCE_DB_PATH = Path(__file__).parent.parent / "db" / "lance_animal_drugs"
|
||||
MD_DOCS_PATH = Path("C:/Users/청춘약국/source/new_anipharm")
|
||||
TABLE_NAME = "animal_drugs"
|
||||
CHUNK_SIZE = 1500 # 약 500 토큰
|
||||
CHUNK_OVERLAP = 300 # 약 100 토큰
|
||||
EMBEDDING_DIM = 1536 # text-embedding-3-small
|
||||
|
||||
|
||||
class AnimalDrugRAG:
|
||||
"""동물약 RAG 클래스 (LanceDB 버전)"""
|
||||
|
||||
def __init__(self, openai_api_key: str = None):
|
||||
"""
|
||||
Args:
|
||||
openai_api_key: OpenAI API 키 (없으면 환경변수에서 가져옴)
|
||||
"""
|
||||
self.api_key = openai_api_key or os.getenv('OPENAI_API_KEY')
|
||||
self.db = None
|
||||
self.table = None
|
||||
self.openai_client = None
|
||||
self._initialized = False
|
||||
|
||||
def _init_db(self):
|
||||
"""DB 초기화 (lazy loading)"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
# LanceDB 연결
|
||||
LANCE_DB_PATH.mkdir(parents=True, exist_ok=True)
|
||||
self.db = lancedb.connect(str(LANCE_DB_PATH))
|
||||
|
||||
# OpenAI 클라이언트
|
||||
if self.api_key:
|
||||
self.openai_client = OpenAI(api_key=self.api_key)
|
||||
else:
|
||||
logger.warning("OpenAI API 키 없음")
|
||||
|
||||
# 기존 테이블 열기
|
||||
if TABLE_NAME in self.db.table_names():
|
||||
self.table = self.db.open_table(TABLE_NAME)
|
||||
logger.info(f"기존 테이블 열림 (행 수: {len(self.table)})")
|
||||
else:
|
||||
logger.info("테이블 없음 - index_md_files() 호출 필요")
|
||||
|
||||
self._initialized = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AnimalDrugRAG 초기화 실패: {e}")
|
||||
raise
|
||||
|
||||
def _get_embedding(self, text: str) -> List[float]:
|
||||
"""OpenAI 임베딩 생성"""
|
||||
if not self.openai_client:
|
||||
raise ValueError("OpenAI 클라이언트 없음")
|
||||
|
||||
response = self.openai_client.embeddings.create(
|
||||
model="text-embedding-3-small",
|
||||
input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
def _get_embeddings_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""배치 임베딩 생성"""
|
||||
if not self.openai_client:
|
||||
raise ValueError("OpenAI 클라이언트 없음")
|
||||
|
||||
# OpenAI는 한 번에 최대 2048개 텍스트 처리
|
||||
embeddings = []
|
||||
batch_size = 100
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i+batch_size]
|
||||
response = self.openai_client.embeddings.create(
|
||||
model="text-embedding-3-small",
|
||||
input=batch
|
||||
)
|
||||
embeddings.extend([d.embedding for d in response.data])
|
||||
logger.info(f"임베딩 생성: {i+len(batch)}/{len(texts)}")
|
||||
|
||||
return embeddings
|
||||
|
||||
def chunk_markdown(self, content: str, source_file: str) -> List[Dict]:
|
||||
"""
|
||||
마크다운 청킹 (섹션 기반)
|
||||
"""
|
||||
chunks = []
|
||||
|
||||
# ## 헤더 기준 분리
|
||||
sections = re.split(r'\n(?=## )', content)
|
||||
|
||||
for i, section in enumerate(sections):
|
||||
if not section.strip():
|
||||
continue
|
||||
|
||||
# 섹션 제목 추출
|
||||
title_match = re.match(r'^## (.+?)(?:\n|$)', section)
|
||||
section_title = title_match.group(1).strip() if title_match else f"섹션{i+1}"
|
||||
|
||||
# 큰 섹션은 추가 분할
|
||||
if len(section) > CHUNK_SIZE:
|
||||
sub_chunks = self._split_by_size(section, CHUNK_SIZE, CHUNK_OVERLAP)
|
||||
for j, sub_chunk in enumerate(sub_chunks):
|
||||
chunk_id = f"{source_file}#{section_title}#{j}"
|
||||
chunks.append({
|
||||
"id": chunk_id,
|
||||
"text": sub_chunk,
|
||||
"source": source_file,
|
||||
"section": section_title,
|
||||
"chunk_index": j
|
||||
})
|
||||
else:
|
||||
chunk_id = f"{source_file}#{section_title}"
|
||||
chunks.append({
|
||||
"id": chunk_id,
|
||||
"text": section,
|
||||
"source": source_file,
|
||||
"section": section_title,
|
||||
"chunk_index": 0
|
||||
})
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_by_size(self, text: str, size: int, overlap: int) -> List[str]:
|
||||
"""텍스트를 크기 기준으로 분할"""
|
||||
chunks = []
|
||||
start = 0
|
||||
|
||||
while start < len(text):
|
||||
end = start + size
|
||||
|
||||
# 문장 경계에서 자르기
|
||||
if end < len(text):
|
||||
last_break = text.rfind('\n', start, end)
|
||||
if last_break == -1:
|
||||
last_break = text.rfind('. ', start, end)
|
||||
if last_break > start + size // 2:
|
||||
end = last_break + 1
|
||||
|
||||
chunks.append(text[start:end])
|
||||
start = end - overlap
|
||||
|
||||
return chunks
|
||||
|
||||
def index_md_files(self, md_path: Path = None) -> int:
|
||||
"""
|
||||
MD 파일들을 인덱싱
|
||||
"""
|
||||
self._init_db()
|
||||
|
||||
md_path = md_path or MD_DOCS_PATH
|
||||
if not md_path.exists():
|
||||
logger.error(f"MD 파일 경로 없음: {md_path}")
|
||||
return 0
|
||||
|
||||
# 기존 테이블 삭제
|
||||
if TABLE_NAME in self.db.table_names():
|
||||
self.db.drop_table(TABLE_NAME)
|
||||
logger.info("기존 테이블 삭제")
|
||||
|
||||
# 모든 청크 수집
|
||||
all_chunks = []
|
||||
md_files = list(md_path.glob("*.md"))
|
||||
|
||||
for md_file in md_files:
|
||||
try:
|
||||
content = md_file.read_text(encoding='utf-8')
|
||||
chunks = self.chunk_markdown(content, md_file.name)
|
||||
all_chunks.extend(chunks)
|
||||
logger.info(f"청킹: {md_file.name} ({len(chunks)}개)")
|
||||
except Exception as e:
|
||||
logger.error(f"청킹 실패 ({md_file.name}): {e}")
|
||||
|
||||
if not all_chunks:
|
||||
logger.warning("청크 없음")
|
||||
return 0
|
||||
|
||||
# 임베딩 생성
|
||||
texts = [c["text"] for c in all_chunks]
|
||||
logger.info(f"총 {len(texts)}개 청크 임베딩 시작...")
|
||||
embeddings = self._get_embeddings_batch(texts)
|
||||
|
||||
# 데이터 준비
|
||||
data = []
|
||||
for chunk, emb in zip(all_chunks, embeddings):
|
||||
data.append({
|
||||
"id": chunk["id"],
|
||||
"text": chunk["text"],
|
||||
"source": chunk["source"],
|
||||
"section": chunk["section"],
|
||||
"chunk_index": chunk["chunk_index"],
|
||||
"vector": emb
|
||||
})
|
||||
|
||||
# 테이블 생성
|
||||
self.table = self.db.create_table(TABLE_NAME, data)
|
||||
logger.info(f"인덱싱 완료: {len(data)}개 청크")
|
||||
|
||||
return len(data)
|
||||
|
||||
def search(self, query: str, n_results: int = 3) -> List[Dict]:
|
||||
"""
|
||||
유사도 검색
|
||||
"""
|
||||
self._init_db()
|
||||
|
||||
if self.table is None:
|
||||
logger.warning("테이블 없음 - index_md_files() 필요")
|
||||
return []
|
||||
|
||||
try:
|
||||
# 쿼리 임베딩
|
||||
query_emb = self._get_embedding(query)
|
||||
|
||||
# 검색
|
||||
results = self.table.search(query_emb).limit(n_results).to_list()
|
||||
|
||||
output = []
|
||||
for r in results:
|
||||
output.append({
|
||||
"text": r["text"],
|
||||
"source": r["source"],
|
||||
"section": r["section"],
|
||||
"score": 1 - r.get("_distance", 0) # 거리 → 유사도
|
||||
})
|
||||
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"검색 실패: {e}")
|
||||
return []
|
||||
|
||||
def get_context_for_chat(self, query: str, n_results: int = 3) -> str:
|
||||
"""
|
||||
챗봇용 컨텍스트 생성
|
||||
"""
|
||||
results = self.search(query, n_results)
|
||||
|
||||
if not results:
|
||||
return ""
|
||||
|
||||
context_parts = ["## 📚 관련 문서 (RAG 검색 결과)"]
|
||||
|
||||
for i, r in enumerate(results, 1):
|
||||
source = r["source"].replace(".md", "")
|
||||
section = r["section"]
|
||||
score = r["score"]
|
||||
text = r["text"][:1500]
|
||||
|
||||
context_parts.append(f"\n### [{i}] {source} - {section} (관련도: {score:.0%})")
|
||||
context_parts.append(text)
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""통계 정보 반환"""
|
||||
self._init_db()
|
||||
|
||||
count = len(self.table) if self.table else 0
|
||||
return {
|
||||
"table_name": TABLE_NAME,
|
||||
"document_count": count,
|
||||
"db_path": str(LANCE_DB_PATH)
|
||||
}
|
||||
|
||||
|
||||
# 싱글톤 인스턴스
|
||||
_rag_instance: Optional[AnimalDrugRAG] = None
|
||||
|
||||
|
||||
def get_animal_rag(api_key: str = None) -> AnimalDrugRAG:
|
||||
"""싱글톤 RAG 인스턴스 반환"""
|
||||
global _rag_instance
|
||||
if _rag_instance is None:
|
||||
_rag_instance = AnimalDrugRAG(api_key)
|
||||
return _rag_instance
|
||||
|
||||
|
||||
# CLI 테스트
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
rag = AnimalDrugRAG()
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
cmd = sys.argv[1]
|
||||
|
||||
if cmd == "index":
|
||||
count = rag.index_md_files()
|
||||
print(f"\n✅ {count}개 청크 인덱싱 완료")
|
||||
|
||||
elif cmd == "search" and len(sys.argv) > 2:
|
||||
query = " ".join(sys.argv[2:])
|
||||
results = rag.search(query)
|
||||
print(f"\n🔍 검색: {query}")
|
||||
for r in results:
|
||||
print(f"\n[{r['score']:.0%}] {r['source']} - {r['section']}")
|
||||
print(r['text'][:300] + "...")
|
||||
|
||||
elif cmd == "stats":
|
||||
stats = rag.get_stats()
|
||||
print(f"\n📊 통계:")
|
||||
print(f" - 테이블: {stats['table_name']}")
|
||||
print(f" - 문서 수: {stats['document_count']}")
|
||||
print(f" - DB 경로: {stats['db_path']}")
|
||||
|
||||
else:
|
||||
print("사용법:")
|
||||
print(" python animal_rag.py index # MD 파일 인덱싱")
|
||||
print(" python animal_rag.py search 질문 # 검색")
|
||||
print(" python animal_rag.py stats # 통계")
|
||||
Loading…
Reference in New Issue
Block a user