Skip to content

WebSocket Adapter - Arquitetura Detalhada

Plano de Implementação Técnico

Status: 📋 Planejamento
Complexidade: Alta
Impacto na Performance: Muito Alto (+90% melhor UX)


🎯 Objetivo

Criar um WebSocketAdapter robusto que forneça: 1. Streaming em tempo real de respostas do agente 2. Feedback visual de ferramentas sendo executadas 3. Reconexão automática para maior confiabilidade 4. Escalabilidade horizontal via Redis Pub/Sub


🏗️ Estrutura de Arquivos Proposta

messaging/
├── adapters/
│   ├── websocket_adapter.py          # ✨ NOVO: Adapter WebSocket
│   ├── sse_adapter.py                 # ✨ NOVO: Adapter SSE (alternativa)
│   └── webchat_adapter.py             # ✅ Existente: HTTP síncrono
├── websocket/                         # ✨ NOVO: Módulo WebSocket
│   ├── __init__.py
│   ├── connection_manager.py          # Gerencia conexões WS
│   ├── event_streamer.py              # Stream eventos ADK→WS
│   ├── protocol.py                    # Definições de protocolo
│   └── redis_pubsub.py                # Pub/Sub para escalar
└── README.md                          # Atualizar com WebSocket

unified_bot.py                         # Adicionar endpoints WS

examples/
├── websocket-chat.html                # ✨ NOVO: Cliente WebSocket
└── websocket-react-example.tsx       # ✨ NOVO: Exemplo React

tests/
└── test_websocket_adapter.py          # ✨ NOVO: Testes WS

📋 Protocolo WebSocket

Formato de Mensagens

Cliente → Servidor

type ClientMessage = 
  // Enviar mensagem de chat
  | {
      type: 'chat_message';
      message: string;
      session_id?: string;
      metadata?: {
        page_url?: string;
        user_agent?: string;
        [key: string]: any;
      };
    }

  // Ping (keepalive)
  | {
      type: 'ping';
      timestamp: number;
    }

  // Cancelar processamento atual
  | {
      type: 'cancel';
      message_id: string;
    }

  // Disconnect graceful
  | {
      type: 'disconnect';
      reason?: string;
    };

Servidor → Cliente

type ServerEvent =
  // Conexão estabelecida
  | {
      type: 'connected';
      user_id: string;
      session_id: string;
      server_time: number;
    }

  // Início de processamento
  | {
      type: 'message_received';
      message_id: string;
      timestamp: number;
    }

  // Início de resposta
  | {
      type: 'response_start';
      message_id: string;
    }

  // Ferramenta sendo chamada
  | {
      type: 'tool_call_start';
      message_id: string;
      tool_name: string;
      tool_display_name: string;
      tool_description: string;
      timestamp: number;
    }

  // Ferramenta concluída
  | {
      type: 'tool_call_complete';
      message_id: string;
      tool_name: string;
      success: boolean;
      duration_ms: number;
      error?: string;
    }

  // Chunk de texto (streaming)
  | {
      type: 'text_chunk';
      message_id: string;
      text: string;
      chunk_index: number;
    }

  // Resposta completa
  | {
      type: 'response_complete';
      message_id: string;
      full_text: string;
      total_chunks: number;
      duration_ms: number;
      metadata?: {
        tokens_used?: number;
        model?: string;
      };
    }

  // Erro
  | {
      type: 'error';
      message_id?: string;
      error_code: string;
      error_message: string;
      recoverable: boolean;
    }

  // Pong (resposta ao ping)
  | {
      type: 'pong';
      timestamp: number;
      server_time: number;
    }

  // Servidor vai desconectar
  | {
      type: 'disconnecting';
      reason: string;
      reconnect_allowed: boolean;
      retry_after_ms?: number;
    };

🔧 Implementação Detalhada

1. WebSocketAdapter

# messaging/adapters/websocket_adapter.py

from typing import Dict, Optional
from fastapi import WebSocket, WebSocketDisconnect
import json
import asyncio
from datetime import datetime

from ..base import MessagingAdapter, IncomingMessage
from ..websocket.connection_manager import ConnectionManager
from ..websocket.event_streamer import EventStreamer
from ..websocket.protocol import ClientMessage, ServerEvent


class WebSocketAdapter(MessagingAdapter):
    """
    Adapter WebSocket para chat em tempo real.

    Features:
    - Streaming de respostas em tempo real
    - Feedback de tool calls
    - Reconexão automática
    - Múltiplas conexões por usuário (multi-device)
    """

    def __init__(self, config: Dict):
        super().__init__(config)
        self.connection_manager = ConnectionManager(config)
        self.event_streamer = EventStreamer()

    def _validate_config(self) -> None:
        """Valida configurações"""
        # Configuração mínima, tudo é opcional
        pass

    async def setup(self) -> None:
        """Inicializa adapter"""
        await self.connection_manager.start_heartbeat()
        logger.info("✅ WebSocket adapter inicializado")

    async def cleanup(self) -> None:
        """Limpa recursos"""
        await self.connection_manager.disconnect_all()
        await self.connection_manager.stop_heartbeat()

    async def handle_websocket(
        self,
        websocket: WebSocket,
        user_id: str,
        initial_session_id: Optional[str] = None
    ):
        """
        Handle completo de uma conexão WebSocket.

        Args:
            websocket: Conexão WebSocket do FastAPI
            user_id: ID do usuário
            initial_session_id: Session ID inicial (opcional)
        """
        # Accept connection
        await websocket.accept()

        # Get or create session
        session_id = initial_session_id or self._generate_session_id(user_id)

        # Register connection
        connection_id = await self.connection_manager.connect(
            user_id=user_id,
            session_id=session_id,
            websocket=websocket
        )

        # Send connected event
        await self._send_event(websocket, {
            "type": "connected",
            "user_id": user_id,
            "session_id": session_id,
            "connection_id": connection_id,
            "server_time": datetime.utcnow().timestamp()
        })

        try:
            # Message loop
            while True:
                # Receive message
                raw_data = await websocket.receive_text()
                data = json.loads(raw_data)

                # Handle message
                await self._handle_client_message(
                    websocket=websocket,
                    user_id=user_id,
                    session_id=session_id,
                    message=data
                )

        except WebSocketDisconnect:
            logger.info(f"🔌 WebSocket disconnected: {user_id}")
        except Exception as e:
            logger.error(f"❌ WebSocket error: {e}", exc_info=True)
            await self._send_event(websocket, {
                "type": "error",
                "error_code": "internal_error",
                "error_message": str(e),
                "recoverable": False
            })
        finally:
            # Cleanup
            await self.connection_manager.disconnect(connection_id)

    async def _handle_client_message(
        self,
        websocket: WebSocket,
        user_id: str,
        session_id: str,
        message: Dict
    ):
        """Processa mensagem do cliente"""

        msg_type = message.get("type")

        if msg_type == "chat_message":
            await self._handle_chat_message(
                websocket=websocket,
                user_id=user_id,
                session_id=session_id,
                message=message
            )

        elif msg_type == "ping":
            await self._send_event(websocket, {
                "type": "pong",
                "timestamp": message.get("timestamp"),
                "server_time": datetime.utcnow().timestamp()
            })

        elif msg_type == "disconnect":
            await websocket.close(code=1000, reason="Client requested disconnect")

        else:
            logger.warning(f"⚠️ Unknown message type: {msg_type}")

    async def _handle_chat_message(
        self,
        websocket: WebSocket,
        user_id: str,
        session_id: str,
        message: Dict
    ):
        """Processa mensagem de chat e stream resposta"""

        message_id = self._generate_message_id()
        message_text = message.get("message", "")

        if not message_text:
            await self._send_event(websocket, {
                "type": "error",
                "error_code": "empty_message",
                "error_message": "Message cannot be empty",
                "recoverable": True
            })
            return

        # Acknowledge receipt
        await self._send_event(websocket, {
            "type": "message_received",
            "message_id": message_id,
            "timestamp": datetime.utcnow().timestamp()
        })

        # Create IncomingMessage
        incoming_message = IncomingMessage(
            platform=self.platform_name,
            user_id=user_id,
            channel_id=session_id,
            thread_id=session_id,
            text=message_text,
            metadata=message.get("metadata", {})
        )

        # Stream response usando EventStreamer
        await self.event_streamer.stream_to_websocket(
            websocket=websocket,
            message_id=message_id,
            incoming_message=incoming_message,
            runner=self._get_runner()  # Injetado via setup
        )

    async def _send_event(self, websocket: WebSocket, event: Dict):
        """Envia evento para o cliente"""
        try:
            await websocket.send_json(event)
        except Exception as e:
            logger.error(f"❌ Error sending event: {e}")

    def _generate_session_id(self, user_id: str) -> str:
        """Gera session ID único"""
        import uuid
        return f"ws_{user_id}_{uuid.uuid4().hex[:8]}"

    def _generate_message_id(self) -> str:
        """Gera message ID único"""
        import uuid
        return f"msg_{uuid.uuid4().hex[:12]}"

    @property
    def platform_name(self) -> str:
        return "websocket"

    @property
    def webhook_path(self) -> str:
        return "/ws/chat"  # Base path

2. ConnectionManager

# messaging/websocket/connection_manager.py

import asyncio
from typing import Dict, Optional, Set
from dataclasses import dataclass
from datetime import datetime, timedelta
from fastapi import WebSocket
import logging

logger = logging.getLogger(__name__)


@dataclass
class Connection:
    """Representa uma conexão WebSocket ativa"""
    connection_id: str
    user_id: str
    session_id: str
    websocket: WebSocket
    connected_at: datetime
    last_ping: Optional[datetime] = None


class ConnectionManager:
    """
    Gerencia conexões WebSocket ativas.

    Features:
    - Múltiplas conexões por usuário (multi-device)
    - Heartbeat (ping/pong)
    - Cleanup de conexões órfãs
    - Broadcast para usuário específico
    """

    def __init__(self, config: Dict):
        self.config = config

        # connection_id → Connection
        self.connections: Dict[str, Connection] = {}

        # user_id → Set[connection_id]
        self.user_connections: Dict[str, Set[str]] = {}

        # Heartbeat
        self.heartbeat_task: Optional[asyncio.Task] = None
        self.heartbeat_interval = config.get("heartbeat_interval", 30)  # 30s
        self.connection_timeout = config.get("connection_timeout", 300)  # 5min

    async def connect(
        self,
        user_id: str,
        session_id: str,
        websocket: WebSocket
    ) -> str:
        """Registra nova conexão"""

        import uuid
        connection_id = f"conn_{uuid.uuid4().hex[:12]}"

        connection = Connection(
            connection_id=connection_id,
            user_id=user_id,
            session_id=session_id,
            websocket=websocket,
            connected_at=datetime.utcnow()
        )

        self.connections[connection_id] = connection

        if user_id not in self.user_connections:
            self.user_connections[user_id] = set()
        self.user_connections[user_id].add(connection_id)

        logger.info(
            f"✅ WebSocket connected: {connection_id} "
            f"(user={user_id}, total_connections={len(self.connections)})"
        )

        return connection_id

    async def disconnect(self, connection_id: str):
        """Remove conexão"""

        connection = self.connections.pop(connection_id, None)
        if not connection:
            return

        # Remove from user_connections
        if connection.user_id in self.user_connections:
            self.user_connections[connection.user_id].discard(connection_id)

            # Cleanup user entry if no connections
            if not self.user_connections[connection.user_id]:
                del self.user_connections[connection.user_id]

        logger.info(
            f"🔌 WebSocket disconnected: {connection_id} "
            f"(user={connection.user_id}, remaining={len(self.connections)})"
        )

    async def disconnect_all(self):
        """Desconecta todas as conexões (shutdown)"""

        logger.info(f"🛑 Disconnecting all WebSocket connections ({len(self.connections)})")

        for connection in list(self.connections.values()):
            try:
                await connection.websocket.close(
                    code=1001,  # Going Away
                    reason="Server shutting down"
                )
            except Exception as e:
                logger.error(f"Error disconnecting {connection.connection_id}: {e}")

        self.connections.clear()
        self.user_connections.clear()

    async def send_to_connection(self, connection_id: str, event: Dict) -> bool:
        """Envia evento para conexão específica"""

        connection = self.connections.get(connection_id)
        if not connection:
            return False

        try:
            await connection.websocket.send_json(event)
            return True
        except Exception as e:
            logger.error(f"Error sending to {connection_id}: {e}")
            # Cleanup bad connection
            await self.disconnect(connection_id)
            return False

    async def broadcast_to_user(self, user_id: str, event: Dict):
        """Broadcast evento para todas conexões do usuário"""

        connection_ids = self.user_connections.get(user_id, set())

        for conn_id in list(connection_ids):
            await self.send_to_connection(conn_id, event)

    async def start_heartbeat(self):
        """Inicia loop de heartbeat"""

        if self.heartbeat_task:
            return

        self.heartbeat_task = asyncio.create_task(self._heartbeat_loop())
        logger.info(f"💓 Heartbeat started (interval={self.heartbeat_interval}s)")

    async def stop_heartbeat(self):
        """Para loop de heartbeat"""

        if self.heartbeat_task:
            self.heartbeat_task.cancel()
            try:
                await self.heartbeat_task
            except asyncio.CancelledError:
                pass
            self.heartbeat_task = None

    async def _heartbeat_loop(self):
        """Loop de heartbeat (ping/pong + cleanup)"""

        while True:
            try:
                await asyncio.sleep(self.heartbeat_interval)

                now = datetime.utcnow()
                timeout_threshold = now - timedelta(seconds=self.connection_timeout)

                # Check connections
                for conn_id, connection in list(self.connections.items()):
                    # Ping
                    try:
                        await connection.websocket.send_json({
                            "type": "ping",
                            "timestamp": now.timestamp()
                        })
                        connection.last_ping = now
                    except Exception:
                        # Connection dead
                        logger.warning(f"⚠️ Dead connection detected: {conn_id}")
                        await self.disconnect(conn_id)
                        continue

                    # Timeout check
                    last_activity = connection.last_ping or connection.connected_at
                    if last_activity < timeout_threshold:
                        logger.warning(
                            f"⏱️ Connection timeout: {conn_id} "
                            f"(inactive for {(now - last_activity).seconds}s)"
                        )
                        try:
                            await connection.websocket.close(
                                code=1000,
                                reason="Connection timeout"
                            )
                        except Exception:
                            pass
                        await self.disconnect(conn_id)

            except asyncio.CancelledError:
                break
            except Exception as e:
                logger.error(f"❌ Heartbeat error: {e}", exc_info=True)

3. EventStreamer

# messaging/websocket/event_streamer.py

import logging
from typing import Optional
from datetime import datetime
from fastapi import WebSocket

from google.adk.runners import Runner
from google.genai import types

from ..base import IncomingMessage

logger = logging.getLogger(__name__)


class EventStreamer:
    """
    Converte eventos do ADK Runner para eventos WebSocket.

    Mapeia:
    - Tool calls → tool_call_start/complete
    - Text parts → text_chunk (streaming)
    - Completion → response_complete
    """

    # Mapeamento de tools para display names
    TOOL_DISPLAY_NAMES = {
        "search_agent": ("🔍 Busca na Internet", "Realizando busca na web"),
        "busca_produtos_tool": ("🔍 Busca de Produtos", "Procurando experiências disponíveis"),
        "tem_variacao_tool": ("🔄 Variações", "Verificando variações de produto"),
        # ... resto do mapeamento
    }

    async def stream_to_websocket(
        self,
        websocket: WebSocket,
        message_id: str,
        incoming_message: IncomingMessage,
        runner: Runner
    ):
        """
        Stream eventos do Runner para WebSocket.

        Args:
            websocket: Conexão WebSocket
            message_id: ID da mensagem
            incoming_message: Mensagem do usuário
            runner: ADK Runner
        """

        start_time = datetime.utcnow()
        full_text = ""
        chunk_index = 0
        tool_start_times = {}

        try:
            # Response start
            await self._send_event(websocket, {
                "type": "response_start",
                "message_id": message_id
            })

            # Create Content
            user_content = types.Content(
                role="user",
                parts=[types.Part.from_text(text=incoming_message.text)]
            )

            # Stream from Runner
            async for event in runner.run_async(
                session_id=incoming_message.thread_id,
                user_id=incoming_message.user_id,
                new_message=user_content
            ):
                # Tool calls
                if event.content and hasattr(event.content, 'parts'):
                    for part in event.content.parts:
                        # Function call start
                        if hasattr(part, 'function_call') and part.function_call:
                            tool_name = part.function_call.name
                            display_info = self.TOOL_DISPLAY_NAMES.get(
                                tool_name,
                                (f"⚙️ {tool_name}", "Executando ferramenta")
                            )

                            tool_start_times[tool_name] = datetime.utcnow()

                            await self._send_event(websocket, {
                                "type": "tool_call_start",
                                "message_id": message_id,
                                "tool_name": tool_name,
                                "tool_display_name": display_info[0],
                                "tool_description": display_info[1],
                                "timestamp": datetime.utcnow().timestamp()
                            })

                        # Function response (tool result)
                        if hasattr(part, 'function_response') and part.function_response:
                            tool_name = part.function_response.name
                            start_time_tool = tool_start_times.get(tool_name)
                            duration_ms = 0
                            if start_time_tool:
                                duration_ms = int(
                                    (datetime.utcnow() - start_time_tool).total_seconds() * 1000
                                )

                            # Check if successful (simple heuristic)
                            success = not (
                                'error' in str(part.function_response.response).lower()
                            )

                            await self._send_event(websocket, {
                                "type": "tool_call_complete",
                                "message_id": message_id,
                                "tool_name": tool_name,
                                "success": success,
                                "duration_ms": duration_ms
                            })

                        # Text chunks
                        if hasattr(part, 'text') and part.text:
                            text = part.text
                            full_text += text

                            await self._send_event(websocket, {
                                "type": "text_chunk",
                                "message_id": message_id,
                                "text": text,
                                "chunk_index": chunk_index
                            })

                            chunk_index += 1

            # Response complete
            duration_ms = int((datetime.utcnow() - start_time).total_seconds() * 1000)

            await self._send_event(websocket, {
                "type": "response_complete",
                "message_id": message_id,
                "full_text": full_text,
                "total_chunks": chunk_index,
                "duration_ms": duration_ms
            })

        except Exception as e:
            logger.error(f"❌ Streaming error: {e}", exc_info=True)

            await self._send_event(websocket, {
                "type": "error",
                "message_id": message_id,
                "error_code": "streaming_error",
                "error_message": str(e),
                "recoverable": False
            })

    async def _send_event(self, websocket: WebSocket, event: dict):
        """Envia evento para WebSocket"""
        try:
            await websocket.send_json(event)
        except Exception as e:
            logger.error(f"❌ Error sending event: {e}")
            raise

🔌 Endpoint no Unified Bot

# unified_bot.py

# ... imports existentes ...
from messaging.adapters.websocket_adapter import WebSocketAdapter

# ... código existente ...

# Registrar WebSocket adapter
websocket_adapter = WebSocketAdapter({
    "heartbeat_interval": 30,
    "connection_timeout": 300
})
websocket_adapter._runner = runner  # Inject runner
adapter_factory.register_adapter("websocket", websocket_adapter)


@app.websocket("/ws/chat/{user_id}")
async def websocket_chat_endpoint(
    websocket: WebSocket,
    user_id: str,
    session_id: Optional[str] = Query(None)
):
    """
    Endpoint WebSocket para chat em tempo real.

    Args:
        user_id: ID do usuário
        session_id: Session ID (opcional, será gerado se ausente)

    Example:
        ws://localhost:8080/ws/chat/user123
        ws://localhost:8080/ws/chat/user123?session_id=existing_session
    """
    adapter = adapter_factory.get_adapter("websocket")

    if not adapter:
        await websocket.close(code=1008, reason="WebSocket not configured")
        return

    await adapter.handle_websocket(
        websocket=websocket,
        user_id=user_id,
        initial_session_id=session_id
    )

📊 Métricas e Monitoring

Métricas a Coletar

# Adicionar Prometheus/CloudWatch metrics

websocket_connections_total = Counter('websocket_connections_total')
websocket_connections_active = Gauge('websocket_connections_active')
websocket_messages_total = Counter('websocket_messages_total')
websocket_errors_total = Counter('websocket_errors_total')
websocket_message_duration = Histogram('websocket_message_duration_seconds')

Health Check

@app.get("/ws/health")
async def websocket_health():
    """Health check para WebSocket"""
    adapter = adapter_factory.get_adapter("websocket")

    if not adapter:
        return {"status": "disabled"}

    active_connections = len(adapter.connection_manager.connections)

    return {
        "status": "healthy",
        "active_connections": active_connections,
        "heartbeat_running": adapter.connection_manager.heartbeat_task is not None
    }

⚡ Escalabilidade com Redis Pub/Sub

Para múltiplas instâncias do servidor (horizontal scaling):

# messaging/websocket/redis_pubsub.py

import redis.asyncio as redis
import json

class RedisPubSubManager:
    """
    Gerencia pub/sub via Redis para múltiplas instâncias.

    Permite que usuários conectados em diferentes instâncias
    recebam mensagens broadcast.
    """

    def __init__(self, redis_url: str):
        self.redis = redis.from_url(redis_url)
        self.pubsub = self.redis.pubsub()

    async def subscribe_user(self, user_id: str, callback):
        """Inscreve em canal do usuário"""
        channel = f"ws:user:{user_id}"
        await self.pubsub.subscribe(**{channel: callback})

    async def publish_to_user(self, user_id: str, event: dict):
        """Publica evento para usuário (todas instâncias)"""
        channel = f"ws:user:{user_id}"
        await self.redis.publish(channel, json.dumps(event))

Próximo passo: Decidir se implementar ou fazer mais análise/prototipagem.