ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • #4 LangGraph 심화: 스레드 및 병렬 처리
    LangChain & LangGraph 2025. 3. 18. 20:00

     

    이전 포스트에서는 LangGraph의 분기 및 조건부 실행 패턴과 오류 처리 전략에 대해 살펴보았습니다. 이번 포스트에서는 LangGraph의 성능을 극대화하기 위한 스레드 및 병렬 처리 패턴에 대해 알아보겠습니다.

    병렬 처리의 중요성

    LLM 애플리케이션은 API 호출, 데이터베이스 쿼리, 파일 I/O 등 다양한 I/O 바운드 작업을 포함합니다. 이러한 작업을 순차적으로 실행하면 다음과 같은 문제가 발생합니다:

    1. 지연 시간 증가: 각 작업이 순차적으로 실행되어 총 실행 시간이 길어집니다.
    2. 리소스 비효율성: CPU나 메모리가 유휴 상태로 낭비됩니다.
    3. 확장성 제한: 작업량이 증가할 때 성능이 선형적으로 저하됩니다.

    병렬 처리를 통해 이러한 문제를 해결하고 LangGraph 애플리케이션의 성능을 크게 향상시킬 수 있습니다.

    스레드 및 병렬 처리 패턴

    1. 비동기 노드 구현

    LangGraph는 비동기 노드를 지원하여 I/O 작업을 효율적으로 처리할 수 있습니다:

    from langgraph.graph import StateGraph
    from typing import TypedDict, List, Dict, Any
    import asyncio
    
    class AsyncState(TypedDict):
        query: str
        search_results: List[Dict]
        database_results: Dict
        combined_results: Dict
    
    async def async_web_search(state: AsyncState) -> Dict:
        """비동기 웹 검색 함수"""
        query = state["query"]
        # 비동기 API 호출
        results = await async_search_api(query)
        return {"search_results": results}
    
    async def async_database_query(state: AsyncState) -> Dict:
        """비동기 데이터베이스 쿼리 함수"""
        query = state["query"]
        # 비동기 데이터베이스 쿼리
        results = await async_db_query(query)
        return {"database_results": results}
    
    async def combine_results_async(state: AsyncState) -> Dict:
        """결과 병합 함수"""
        # 검색 결과와 데이터베이스 결과 병합
        combined_results = {
            "search": state["search_results"],
            "database": state["database_results"]
        }
        return {"combined_results": combined_results}
    
    # 비동기 그래프 구성
    async_graph = StateGraph(AsyncState)
    async_graph.add_node("web_search", async_web_search)
    async_graph.add_node("db_query", async_database_query)
    async_graph.add_node("combine_results", combine_results_async)
    
    # 병렬 엣지 설정
    async_graph.add_edge("root", "web_search")
    async_graph.add_edge("root", "db_query")
    async_graph.add_edge("web_search", "combine_results")
    async_graph.add_edge("db_query", "combine_results")
    
    # 비동기 실행
    async def run_async_workflow(query: str):
        state = {"query": query}
        async_executor = async_graph.acompile()
        final_state = await async_executor.ainvoke(state)
        return final_state
    

    이 패턴의 장점:

    • 웹 검색과 데이터베이스 쿼리가 병렬로 실행되어 전체 실행 시간이 크게 단축됩니다.
    • I/O 대기 시간 동안 다른 작업이 진행되어 리소스 활용도가 향상됩니다.
    • 비동기 코드는 동기 코드와 유사한 구조를 유지하므로 이해하기 쉽습니다.

    2. 병렬 브랜치 패턴

    여러 독립적인 작업을 병렬로 실행하고 결과를 병합하는 패턴입니다:

    from langgraph.graph import StateGraph, END
    
    class ParallelState(TypedDict):
        input_text: str
        summarization: str
        translation: str
        sentiment: str
        final_output: Dict
    
    # 그래프 생성
    parallel_graph = StateGraph(ParallelState)
    
    # 노드 추가
    parallel_graph.add_node("parse_input", parse_input_node)
    parallel_graph.add_node("summarize", summarize_node)
    parallel_graph.add_node("translate", translate_node)
    parallel_graph.add_node("analyze_sentiment", sentiment_analysis_node)
    parallel_graph.add_node("combine_results", combine_results_node)
    
    # 병렬 분기 설정
    parallel_graph.add_edge("parse_input", "summarize")
    parallel_graph.add_edge("parse_input", "translate")
    parallel_graph.add_edge("parse_input", "analyze_sentiment")
    
    # 결과 병합
    parallel_graph.add_edge("summarize", "combine_results")
    parallel_graph.add_edge("translate", "combine_results")
    parallel_graph.add_edge("analyze_sentiment", "combine_results")
    parallel_graph.add_edge("combine_results", END)
    
    # 상태 조인 함수
    def join_results(states: List[Dict]) -> Dict:
        """여러 상태를 병합"""
        result = states[0].copy()
        for state in states[1:]:
            result.update(state)
        return result
    
    parallel_graph.set_join(join_results)
    

    이 패턴의 핵심은 set_join 메서드로 설정한 상태 병합 함수입니다. 이 함수는 여러 병렬 브랜치에서 반환된 상태를 하나의 상태로 병합합니다.

    상태 병합 전략

    병렬 브랜치의 결과를 병합할 때 다양한 전략을 사용할 수 있습니다:

    # 1. 단순 병합 (충돌 발생 시 마지막 값 사용)
    def simple_join(states):
        result = {}
        for state in states:
            result.update(state)
        return result
    
    # 2. 키별 병합 (특정 키만 각 상태에서 추출)
    def selective_join(states):
        result = states[0].copy()  # 기본 상태
        
        # 각 상태에서 특정 키만 추출
        result["summarization"] = states[1].get("summarization")
        result["translation"] = states[2].get("translation")
        result["sentiment"] = states[3].get("sentiment")
        
        return result
    
    # 3. 충돌 해결 병합 (충돌 시 커스텀 로직 적용)
    def conflict_resolving_join(states):
        result = states[0].copy()
        
        for state in states[1:]:
            for key, value in state.items():
                if key in result:
                    # 특정 키는 리스트로 병합
                    if key in ["results", "outputs"]:
                        if isinstance(result[key], list):
                            result[key].extend(value if isinstance(value, list) else [value])
                        else:
                            result[key] = [result[key], value]
                    # 특정 키는 딕셔너리 병합
                    elif key in ["metadata", "config"] and isinstance(value, dict):
                        result[key].update(value)
                    # 그 외에는 마지막 값 사용
                    else:
                        result[key] = value
                else:
                    result[key] = value
        
        return result
    

    3. 작업 분할 및 병합 패턴 (Map-Reduce)

    대규모 작업을 작은 단위로 분할하여 병렬 처리한 후 결과를 병합하는 패턴입니다:

    class MapReduceState(TypedDict):
        documents: List[str]
        chunks: List[List[str]]
        processed_chunks: List[Dict]
        final_result: Dict
    
    def split_documents(state: MapReduceState) -> Dict:
        """문서를 청크로 분할 (Map 준비)"""
        docs = state["documents"]
        # 각 문서를 더 작은 청크로 분할
        chunks = [split_into_chunks(doc) for doc in docs]
        return {"chunks": chunks}
    
    async def process_chunks(state: MapReduceState) -> Dict:
        """각 청크를 병렬로 처리 (Map)"""
        chunks = state["chunks"]
        flat_chunks = [chunk for sublist in chunks for chunk in sublist]
        
        # 병렬 처리를 위한 코루틴 생성
        coroutines = [process_chunk(chunk) for chunk in flat_chunks]
        
        # 모든 코루틴 병렬 실행
        processed_chunks = await asyncio.gather(*coroutines)
        
        return {"processed_chunks": processed_chunks}
    
    def reduce_results(state: MapReduceState) -> Dict:
        """처리된 청크 결과 병합 (Reduce)"""
        processed_chunks = state["processed_chunks"]
        # 결과 병합 (예: 통계 계산, 요약 등)
        final_result = combine_chunk_results(processed_chunks)
        return {"final_result": final_result}
    
    # 그래프 구성
    map_reduce_graph = StateGraph(MapReduceState)
    map_reduce_graph.add_node("split", split_documents)
    map_reduce_graph.add_node("process", process_chunks)
    map_reduce_graph.add_node("reduce", reduce_results)
    
    # 엣지 설정
    map_reduce_graph.add_edge("split", "process")
    map_reduce_graph.add_edge("process", "reduce")
    

    이 패턴은 대량의 문서 처리, 대규모 데이터셋 분석 등에 유용합니다. 주요 단계는 다음과 같습니다:

    1. 분할(Split): 큰 작업을 작은 청크로 나눕니다.
    2. 매핑(Map): 각 청크를 병렬로 처리합니다.
    3. 축소(Reduce): 처리된 결과를 병합합니다.

    4. 스레드 풀 실행기

    CPU 바운드 작업을 효율적으로 처리하기 위한 스레드 풀 패턴입니다:

    from concurrent.futures import ThreadPoolExecutor
    import functools
    
    class ThreadPoolState(TypedDict):
        items: List
        processed_items: List
        max_workers: int
    
    def process_with_thread_pool(state: ThreadPoolState) -> Dict:
        """스레드 풀을 사용한 병렬 처리"""
        items = state["items"]
        max_workers = state.get("max_workers", 10)
        
        def process_item(item):
            # 각 항목 처리 로직
            return compute_intensive_task(item)
        
        # 스레드 풀 생성 및 실행
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            processed_items = list(executor.map(process_item, items))
        
        return {"processed_items": processed_items}
    
    # 그래프에 노드 추가
    thread_pool_graph = StateGraph(ThreadPoolState)
    thread_pool_graph.add_node("thread_pool_processor", process_with_thread_pool)
    

    이 패턴은 CPU 바운드 작업에 적합하며, 스레드 풀 크기를 조정하여 시스템 리소스에 맞게 최적화할 수 있습니다.

    5. 분산 작업 큐 통합

    대규모 시스템을 위한 분산 작업 큐 통합 패턴입니다:

    import celery
    from celery import group
    
    # Celery 작업 정의
    @celery.task
    def process_task(item):
        return process_logic(item)
    
    class DistributedState(TypedDict):
        items: List
        task_ids: List
        results: List
    
    def submit_to_queue(state: DistributedState) -> Dict:
        """Celery 작업 큐에 작업 제출"""
        items = state["items"]
        
        # 병렬 작업 그룹 생성 및 실행
        job = group(process_task.s(item) for item in items)
        result = job.apply_async()
        
        # 작업 ID 저장
        return {"task_ids": result.id}
    
    def check_results(state: DistributedState) -> Dict:
        """작업 결과 확인"""
        task_id = state["task_ids"]
        
        # 결과 조회
        result = celery.GroupResult.restore(task_id)
        
        if result.ready():
            return {"results": result.get()}
        else:
            # 아직 완료되지 않음, 상태 유지
            return {}
    
    def is_complete(state: DistributedState) -> bool:
        """작업 완료 여부 확인"""
        return "results" in state and state["results"] is not None
    
    # 그래프 구성
    distributed_graph = StateGraph(DistributedState)
    distributed_graph.add_node("submit", submit_to_queue)
    distributed_graph.add_node("check", check_results)
    distributed_graph.add_node("process_results", process_final_results)
    
    # 엣지 설정
    distributed_graph.add_edge("submit", "check")
    distributed_graph.add_conditional_edges(
        "check",
        is_complete,
        {
            True: "process_results",
            False: "check"  # 완료될 때까지 반복 확인
        }
    )
    

    이 패턴은 대규모 작업을 여러 워커 노드에 분산하여 처리하므로 단일 서버의 한계를 넘어 확장할 수 있습니다.

    실제 적용 사례

    1. 대규모 문서 처리 시스템

    대량의 문서를 처리하는 시스템에서 병렬 처리를 적용한 예시입니다:

    class DocumentProcessingState(TypedDict):
        documents: List[Dict]
        processed_documents: List[Dict]
        analysis_results: Dict
    
    async def process_documents_parallel(state: DocumentProcessingState) -> Dict:
        """여러 문서를 병렬로 처리"""
        documents = state["documents"]
        
        # 병렬 처리 함수
        async def process_single_document(doc):
            # 문서 텍스트 추출
            text = extract_text(doc)
            
            # 병렬로 여러 분석 수행
            tasks = [
                analyze_sentiment(text),
                extract_entities(text),
                classify_document(text),
                summarize_text(text)
            ]
            
            sentiment, entities, category, summary = await asyncio.gather(*tasks)
            
            return {
                "id": doc["id"],
                "sentiment": sentiment,
                "entities": entities,
                "category": category,
                "summary": summary
            }
        
        # 모든 문서 병렬 처리
        tasks = [process_single_document(doc) for doc in documents]
        processed_documents = await asyncio.gather(*tasks)
        
        return {"processed_documents": processed_documents}
    
    def aggregate_results(state: DocumentProcessingState) -> Dict:
        """처리된 문서 결과 집계"""
        documents = state["processed_documents"]
        
        # 결과 집계
        sentiment_counts = {}
        entity_counts = {}
        category_counts = {}
        
        for doc in documents:
            # 감성 집계
            sentiment = doc["sentiment"]
            sentiment_counts[sentiment] = sentiment_counts.get(sentiment, 0) + 1
            
            # 엔티티 집계
            for entity in doc["entities"]:
                entity_counts[entity] = entity_counts.get(entity, 0) + 1
            
            # 카테고리 집계
            category = doc["category"]
            category_counts[category] = category_counts.get(category, 0) + 1
        
        return {
            "analysis_results": {
                "sentiment_distribution": sentiment_counts,
                "top_entities": sorted(entity_counts.items(), key=lambda x: x[1], reverse=True)[:10],
                "category_distribution": category_counts
            }
        }
    
    # 그래프 구성
    document_graph = StateGraph(DocumentProcessingState)
    document_graph.add_node("process", process_documents_parallel)
    document_graph.add_node("aggregate", aggregate_results)
    
    document_graph.add_edge("process", "aggregate")
    

    이 시스템은 두 가지 수준의 병렬화를 적용합니다:

    1. 여러 문서를 동시에 처리합니다.
    2. 각 문서에 대해 여러 분석 작업을 병렬로 수행합니다.

    2. 대화형 에이전트 시스템

    사용자 쿼리를 처리하는 대화형 에이전트에 병렬 처리를 적용한 예시입니다:

    class ConversationalAgentState(TypedDict):
        user_query: str
        context: Dict
        search_results: List
        knowledge_base_results: List
        llm_response: str
        final_response: str
    
    async def retrieve_information(state: ConversationalAgentState) -> Dict:
        """정보 검색을 병렬로 수행"""
        query = state["user_query"]
        
        # 여러 소스에서 정보 검색 병렬화
        search_task = search_web(query)
        knowledge_base_task = query_knowledge_base(query)
        
        # 병렬 실행
        search_results, kb_results = await asyncio.gather(
            search_task, 
            knowledge_base_task
        )
        
        return {
            "search_results": search_results,
            "knowledge_base_results": kb_results
        }
    
    def generate_response(state: ConversationalAgentState) -> Dict:
        """검색 결과를 바탕으로 응답 생성"""
        query = state["user_query"]
        context = state["context"]
        search_results = state["search_results"]
        kb_results = state["knowledge_base_results"]
        
        # 컨텍스트 구성
        prompt = f"""
        User Query: {query}
        
        Web Search Results:
        {format_search_results(search_results)}
        
        Knowledge Base Results:
        {format_kb_results(kb_results)}
        
        Previous Conversation Context:
        {format_context(context)}
        
        Generate a helpful response based on the above information:
        """
        
        # LLM 호출
        llm_response = call_llm(prompt)
        
        return {"llm_response": llm_response}
    
    def post_process_response(state: ConversationalAgentState) -> Dict:
        """응답 후처리"""
        response = state["llm_response"]
        
        # 포맷팅, 참조 추가 등
        final_response = format_response(response)
        
        return {"final_response": final_response}
    
    # 그래프 구성
    agent_graph = StateGraph(ConversationalAgentState)
    agent_graph.add_node("retrieve", retrieve_information)
    agent_graph.add_node("generate", generate_response)
    agent_graph.add_node("post_process", post_process_response)
    
    # 엣지 설정
    agent_graph.add_edge("retrieve", "generate")
    agent_graph.add_edge("generate", "post_process")
    

    이 에이전트는 웹 검색과 지식 베이스 쿼리를 병렬로 수행하여 응답 생성 시간을 단축합니다.

    3. 실시간 데이터 처리 파이프라인

    스트리밍 데이터를 실시간으로 처리하는 파이프라인에 병렬 처리를 적용한 예시입니다:

    class StreamProcessingState(TypedDict):
        data_batch: List[Dict]
        filtered_data: List[Dict]
        enriched_data: List[Dict]
        transformed_data: List[Dict]
        final_output: List[Dict]
    
    async def filter_data(state: StreamProcessingState) -> Dict:
        """데이터 필터링"""
        data = state["data_batch"]
        
        # 병렬 필터링
        async def filter_item(item):
            if meets_criteria(item):
                return item
            return None
        
        # 병렬 실행
        filtered_items = await asyncio.gather(*[filter_item(item) for item in data])
        
        # None 제거
        filtered_data = [item for item in filtered_items if item is not None]
        
        return {"filtered_data": filtered_data}
    
    async def enrich_data(state: StreamProcessingState) -> Dict:
        """데이터 보강"""
        data = state["filtered_data"]
        
        # 병렬 보강
        async def enrich_item(item):
            # 외부 API로 데이터 보강
            additional_info = await fetch_additional_info(item["id"])
            return {**item, "additional_info": additional_info}
        
        # 병렬 실행
        enriched_data = await asyncio.gather(*[enrich_item(item) for item in data])
        
        return {"enriched_data": enriched_data}
    
    def transform_data(state: StreamProcessingState) -> Dict:
        """데이터 변환"""
        data = state["enriched_data"]
        
        # 변환 로직
        transformed_data = list(map(transform_item, data))
        
        return {"transformed_data": transformed_data}
    
    def output_results(state: StreamProcessingState) -> Dict:
        """결과 출력"""
        data = state["transformed_data"]
        
        # 최종 처리
        final_output = prepare_for_output(data)
        
        return {"final_output": final_output}
    
    # 그래프 구성
    stream_graph = StateGraph(StreamProcessingState)
    stream_graph.add_node("filter", filter_data)
    stream_graph.add_node("enrich", enrich_data)
    stream_graph.add_node("transform", transform_data)
    stream_graph.add_node("output", output_results)
    
    # 엣지 설정
    stream_graph.add_edge("filter", "enrich")
    stream_graph.add_edge("enrich", "transform")
    stream_graph.add_edge("transform", "output")
    

    이 파이프라인은 데이터 배치 내의 각 항목을 병렬로 처리하여 실시간 처리 성능을 향상시킵니다.

    병렬 처리의 주의사항

    병렬 처리를 구현할 때 고려해야 할 몇 가지 주의사항이 있습니다:

    1. 상태 관리 복잡성

    병렬 처리는 상태 관리를 복잡하게 만들 수 있습니다. 다음 사항에 유의하세요:

    • 불변성 유지: 상태를 직접 수정하지 말고 새 상태를 반환하세요.
    • 상태 병합 전략: 병렬 브랜치의 결과를 병합하는 명확한 전략을 정의하세요.
    • 경쟁 조건 방지: 여러 노드가 동일한 상태 키를 업데이트할 때 충돌이 발생할 수 있습니다.

    2. 리소스 관리

    병렬 처리는 리소스 사용량을 증가시킬 수 있습니다:

    • 병렬 수준 제한: 시스템 리소스에 맞게 동시 실행 수를 제한하세요.
    • 메모리 사용량 모니터링: 병렬 작업이 메모리를 과도하게 사용하지 않도록 하세요.
    • 타임아웃 설정: 병렬 작업에 적절한 타임아웃을 설정하여 무한정 대기하지 않도록 하세요.

    3. 오류 처리

    병렬 실행 중 발생하는 오류를 적절히 처리해야 합니다:

    • 부분 실패 처리: 일부 병렬 작업이 실패해도 전체 실행이 중단되지 않도록 하세요.
    • 재시도 메커니즘: 실패한 작업에 대한 재시도 전략을 구현하세요.
    • 우아한 성능 저하: 일부 기능이 실패해도 핵심 기능은 계속 작동하도록 설계하세요.

    결론

    LangGraph의 스레드 및 병렬 처리 패턴을 활용하면 LLM 애플리케이션의 성능을 크게 향상시킬 수 있습니다. 비동기 노드, 병렬 브랜치, Map-Reduce 패턴, 스레드 풀, 분산 작업 큐 등 다양한 방법을 통해 독립적인 작업을 병렬로 실행하고 I/O 대기 시간을 최소화할 수 있습니다.

    병렬 처리는 특히 외부 API 호출, 데이터베이스 작업, 파일 처리 등 I/O 바운드 작업이 많은 LLM 애플리케이션에서 큰 성능 향상을 가져옵니다. 그러나 상태 관리, 리소스 사용, 오류 처리와 같은 복잡성도 증가하므로 신중하게 설계해야 합니다.

    다음 포스트에서는 LangGraph 애플리케이션의 성능을 더욱 극대화하기 위한 다양한 최적화 기법에 대해 알아보겠습니다.


     

Designed by Tistory.