mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-08 03:37:13 +08:00
Compare commits
52 Commits
fix-leakag
...
macbook-te
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
265562f240 | ||
|
|
95c55fa2e9 | ||
|
|
c360da4f35 | ||
|
|
bc76a032ba | ||
|
|
8e986584f4 | ||
|
|
4b68d30b0e | ||
|
|
b292192467 | ||
|
|
f172f7d4aa | ||
|
|
8e8b6be690 | ||
|
|
e8c6135a91 | ||
|
|
771cf41fea | ||
|
|
7ea17bb957 | ||
|
|
f8846f85a1 | ||
|
|
4c05ef0ba8 | ||
|
|
5438b64e32 | ||
|
|
248acf715e | ||
|
|
54ca0997ee | ||
|
|
b78076cac7 | ||
|
|
ba19d530ad | ||
|
|
47555602d7 | ||
|
|
6eb76c7c1a | ||
|
|
b32cc4b09d | ||
|
|
6e3dbb8d8b | ||
|
|
b66c093316 | ||
|
|
13d360030f | ||
|
|
66daebe88f | ||
|
|
4071ba29da | ||
|
|
21f9e2df40 | ||
|
|
80d326310e | ||
|
|
53fc705b13 | ||
|
|
d5af53888a | ||
|
|
a7a37249f7 | ||
|
|
6af6ff2a0a | ||
|
|
30ca282594 | ||
|
|
ab7293bed6 | ||
|
|
1614c15bb1 | ||
|
|
f813959750 | ||
|
|
f957ec2267 | ||
|
|
92e3074c10 | ||
|
|
0c618482c4 | ||
|
|
2d8f6c46f1 | ||
|
|
0fbc0475f3 | ||
|
|
c27787f09f | ||
|
|
d90fcd4e2b | ||
|
|
69fd0ca9aa | ||
|
|
4135cf4682 | ||
|
|
c82741c3d8 | ||
|
|
9573b2ac2d | ||
|
|
ab5c9fc37b | ||
|
|
e5e77381f0 | ||
|
|
066514e2a9 | ||
|
|
045a1737f8 |
206
.cursorrules
206
.cursorrules
@@ -1,23 +1,201 @@
|
||||
Hermes-Agent is an agent harness for LLMs.
|
||||
Hermes-Agent is an agent harness for LLMs with an interactive CLI.
|
||||
|
||||
When building, the tool functionality is in the tools/ directory, where each specific tool (or in some cases, tools that are built for the same execution category or api) are placed in a script each their own.
|
||||
## Development Environment
|
||||
|
||||
Each tool is then consolidated in the model_tools.py file in the repo root.
|
||||
**IMPORTANT**: Always use the virtual environment if it exists:
|
||||
```bash
|
||||
source venv/bin/activate # Before running any Python commands
|
||||
```
|
||||
|
||||
There is also a way to consolidate sets of tools in toolsets.py for the agent to use.
|
||||
## Project Structure
|
||||
|
||||
The primary agent runner code is in run_agent, but other runners could be developed using the tools and framework.
|
||||
- `hermes` - CLI launcher script (run with `./hermes`)
|
||||
- `cli.py` - Interactive CLI with Rich UI, prompt_toolkit, animated spinners
|
||||
- `cli-config.yaml` - CLI configuration (model, terminal, toolsets, personalities)
|
||||
- `tools/` - Individual tool implementations (web, terminal, browser, vision, etc.)
|
||||
- `tools/__init__.py` - Exports all tools for importing
|
||||
- `model_tools.py` - Consolidates tool schemas and handlers for the agent
|
||||
- `toolsets.py` - Groups tools into logical toolsets (web, terminal, browser, etc.)
|
||||
- `toolset_distributions.py` - Probability-based tool selection for data generation
|
||||
- `run_agent.py` - Primary agent runner with AIAgent class and KawaiiSpinner
|
||||
- `batch_runner.py` - Parallel batch processing with checkpointing
|
||||
- `tests/` - Test scripts
|
||||
|
||||
Always ensure consistency between tools, the model_tools.py and toolsets.py when changing any of them, otherwise they could become desynced in a way that is detrimental to functionality.
|
||||
## File Dependency Chain
|
||||
|
||||
The expected pathway for using API keys is to setup and place them in a .env file in the repo root.
|
||||
```
|
||||
tools/*.py → tools/__init__.py → model_tools.py → toolsets.py → toolset_distributions.py
|
||||
↑
|
||||
run_agent.py ──────────────────────────┘
|
||||
cli.py → run_agent.py (uses AIAgent with quiet_mode=True)
|
||||
batch_runner.py → run_agent.py + toolset_distributions.py
|
||||
```
|
||||
|
||||
Test scripts will be placed in tests/
|
||||
Always ensure consistency between tools, model_tools.py, and toolsets.py when changing any of them.
|
||||
|
||||
The run_agent loop is setup to:
|
||||
- Process the enabled toolsets to provide to the model,
|
||||
- Pipe in a prompt or problem from the input to the agent,
|
||||
- Loop the LLM each time it calls a tool, until the model decides no more tools are needed and provides a natural language response,
|
||||
- Return that response.
|
||||
## CLI Architecture (cli.py)
|
||||
|
||||
There are additional caveats for logging, where we restructure the "tools" as a system prompt for storage later into a format that can be used and handled properly later.
|
||||
The interactive CLI uses:
|
||||
- **Rich** - For the welcome banner and styled panels
|
||||
- **prompt_toolkit** - For fixed input area with history and `patch_stdout`
|
||||
- **KawaiiSpinner** (in run_agent.py) - Animated feedback during API calls and tool execution
|
||||
|
||||
Key components:
|
||||
- `HermesCLI` class - Main CLI controller with commands and conversation loop
|
||||
- `load_cli_config()` - Loads `cli-config.yaml`, sets environment variables for terminal
|
||||
- `build_welcome_banner()` - Displays ASCII art logo, tools, and skills summary
|
||||
- `/commands` - Process user commands like `/help`, `/clear`, `/personality`, etc.
|
||||
|
||||
CLI uses `quiet_mode=True` when creating AIAgent to suppress verbose logging and enable kawaii-style feedback instead.
|
||||
|
||||
### Adding CLI Commands
|
||||
|
||||
1. Add to `COMMANDS` dict with description
|
||||
2. Add handler in `process_command()` method
|
||||
3. For persistent settings, use `save_config_value()` to update `cli-config.yaml`
|
||||
|
||||
## Adding a New Tool
|
||||
|
||||
Follow this strict order to maintain consistency:
|
||||
|
||||
1. Create `tools/your_tool.py` with:
|
||||
- Handler function (sync or async) returning a JSON string via `json.dumps()`
|
||||
- `check_*_requirements()` function to verify dependencies (e.g., API keys)
|
||||
- Schema definition following OpenAI function-calling format
|
||||
|
||||
2. Export in `tools/__init__.py`:
|
||||
- Import the handler and check function
|
||||
- Add to `__all__` list
|
||||
|
||||
3. Register in `model_tools.py`:
|
||||
- Create `get_*_tool_definitions()` function or add to existing
|
||||
- Add routing in `handle_function_call()` dispatcher
|
||||
- Update `get_all_tool_names()` with the tool name
|
||||
- Update `get_toolset_for_tool()` mapping
|
||||
- Update `get_available_toolsets()` and `check_toolset_requirements()`
|
||||
|
||||
4. Add to toolset in `toolsets.py`:
|
||||
- Add to existing toolset or create new one in TOOLSETS dict
|
||||
|
||||
5. Optionally add to `toolset_distributions.py` for batch processing
|
||||
|
||||
## Tool Implementation Pattern
|
||||
|
||||
```python
|
||||
# tools/example_tool.py
|
||||
import json
|
||||
import os
|
||||
|
||||
def check_example_requirements() -> bool:
|
||||
"""Check if required API keys/dependencies are available."""
|
||||
return bool(os.getenv("EXAMPLE_API_KEY"))
|
||||
|
||||
def example_tool(param: str, task_id: str = None) -> str:
|
||||
"""Execute the tool and return JSON string result."""
|
||||
try:
|
||||
result = {"success": True, "data": "..."}
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
```
|
||||
|
||||
All tool handlers MUST return a JSON string. Never return raw dicts.
|
||||
|
||||
## Stateful Tools
|
||||
|
||||
Tools that maintain state (terminal, browser) require:
|
||||
- `task_id` parameter for session isolation between concurrent tasks
|
||||
- `cleanup_*()` function to release resources
|
||||
- Cleanup is called automatically in run_agent.py after conversation completes
|
||||
|
||||
## Environment Variables
|
||||
|
||||
API keys are loaded from `.env` file in repo root:
|
||||
- `OPENROUTER_API_KEY` - Main LLM API access (primary provider)
|
||||
- `FIRECRAWL_API_KEY` - Web search/extract tools
|
||||
- `BROWSERBASE_API_KEY` / `BROWSERBASE_PROJECT_ID` - Browser automation
|
||||
- `FAL_KEY` - Image generation (FLUX model)
|
||||
- `NOUS_API_KEY` - Vision and Mixture-of-Agents tools
|
||||
|
||||
Terminal tool configuration (can also be set in `cli-config.yaml`):
|
||||
- `TERMINAL_ENV` - Backend: local, docker, singularity, modal, or ssh
|
||||
- `TERMINAL_CWD` - Working directory
|
||||
- `TERMINAL_SSH_HOST`, `TERMINAL_SSH_USER`, `TERMINAL_SSH_KEY` - For SSH backend
|
||||
|
||||
## Agent Loop (run_agent.py)
|
||||
|
||||
The AIAgent class handles:
|
||||
- Processing enabled toolsets to provide to the model
|
||||
- Piping prompts to the agent
|
||||
- Looping LLM calls when tools are invoked, until natural language response
|
||||
- Returning the final response
|
||||
|
||||
Uses OpenAI-compatible API (primarily OpenRouter) with the OpenAI Python SDK.
|
||||
|
||||
## Reasoning Model Support
|
||||
|
||||
For models that support chain-of-thought reasoning:
|
||||
- Extract `reasoning_content` from API responses
|
||||
- Store in `assistant_msg["reasoning"]` for trajectory export
|
||||
- Pass back via `reasoning_content` field on subsequent turns
|
||||
|
||||
## Trajectory Format
|
||||
|
||||
Conversations are saved in ShareGPT format for training:
|
||||
```json
|
||||
{"from": "system", "value": "System prompt with <tools>...</tools>"}
|
||||
{"from": "human", "value": "User message"}
|
||||
{"from": "gpt", "value": "<think>reasoning</think>\n<tool_call>{...}</tool_call>"}
|
||||
{"from": "tool", "value": "<tool_response>{...}</tool_response>"}
|
||||
{"from": "gpt", "value": "Final response"}
|
||||
```
|
||||
|
||||
Tool calls use `<tool_call>` XML tags, responses use `<tool_response>` tags, reasoning uses `<think>` tags.
|
||||
|
||||
## Batch Processing (batch_runner.py)
|
||||
|
||||
For processing multiple prompts:
|
||||
- Parallel execution with multiprocessing
|
||||
- Content-based resume for fault tolerance (matches on prompt text, not indices)
|
||||
- Toolset distributions control probabilistic tool availability per prompt
|
||||
- Output: `data/<run_name>/trajectories.jsonl` (combined) + individual batch files
|
||||
|
||||
## Logging
|
||||
|
||||
Trajectories restructure tools as a system prompt for storage in a format suitable for later training use.
|
||||
|
||||
## Skills System
|
||||
|
||||
Skills are on-demand knowledge documents the agent can load. Located in `skills/` directory:
|
||||
|
||||
```
|
||||
skills/
|
||||
├── mlops/ # Category folder
|
||||
│ ├── axolotl/ # Skill folder
|
||||
│ │ ├── SKILL.md # Main instructions (required)
|
||||
│ │ ├── references/ # Additional docs, API specs
|
||||
│ │ └── templates/ # Output formats, configs
|
||||
│ └── vllm/
|
||||
│ └── SKILL.md
|
||||
└── example-skill/
|
||||
└── SKILL.md
|
||||
```
|
||||
|
||||
**Progressive disclosure** (token-efficient):
|
||||
1. `skills_categories()` - List category names (~50 tokens)
|
||||
2. `skills_list(category)` - Name + description per skill (~3k tokens)
|
||||
3. `skill_view(name)` - Full content + tags + linked files
|
||||
|
||||
SKILL.md files use YAML frontmatter:
|
||||
```yaml
|
||||
---
|
||||
name: skill-name
|
||||
description: Brief description for listing
|
||||
tags: [tag1, tag2]
|
||||
related_skills: [other-skill]
|
||||
version: 1.0.0
|
||||
---
|
||||
# Skill Content...
|
||||
```
|
||||
|
||||
Tool files: `tools/skills_tool.py` → `model_tools.py` → `toolsets.py`
|
||||
117
.env.example
117
.env.example
@@ -1,14 +1,21 @@
|
||||
# Hermes Agent Environment Configuration
|
||||
# Copy this file to .env and fill in your API keys
|
||||
# Get API keys from the URLs listed below
|
||||
|
||||
# =============================================================================
|
||||
# REQUIRED API KEYS
|
||||
# LLM PROVIDER (OpenRouter)
|
||||
# =============================================================================
|
||||
# OpenRouter provides access to many models through one API
|
||||
# All LLM calls go through OpenRouter - no direct provider keys needed
|
||||
# Get your key at: https://openrouter.ai/keys
|
||||
OPENROUTER_API_KEY=
|
||||
|
||||
# Anthropic API Key - Main agent model
|
||||
# Get at: https://console.anthropic.com/
|
||||
ANTHROPIC_API_KEY=
|
||||
# Default model to use (OpenRouter format: provider/model)
|
||||
# Examples: anthropic/claude-sonnet-4, openai/gpt-4o, google/gemini-2.0-flash, zhipuai/glm-4-plus
|
||||
LLM_MODEL=anthropic/claude-sonnet-4
|
||||
|
||||
# =============================================================================
|
||||
# TOOL API KEYS
|
||||
# =============================================================================
|
||||
|
||||
# Firecrawl API Key - Web search, extract, and crawl
|
||||
# Get at: https://firecrawl.dev/
|
||||
@@ -18,31 +25,105 @@ FIRECRAWL_API_KEY=
|
||||
# Get at: https://inference-api.nousresearch.com/
|
||||
NOUS_API_KEY=
|
||||
|
||||
# Morph API Key - Terminal/command execution tools
|
||||
# Get at: https://morph.so/
|
||||
MORPH_API_KEY=
|
||||
|
||||
# FAL.ai API Key - Image generation
|
||||
# Get at: https://fal.ai/
|
||||
FAL_KEY=
|
||||
|
||||
# =============================================================================
|
||||
# OPTIONAL API KEYS
|
||||
# TERMINAL TOOL CONFIGURATION
|
||||
# =============================================================================
|
||||
# Backend type: "local", "singularity", "docker", or "modal"
|
||||
# Uncomment ONE configuration block below based on your preferred backend.
|
||||
|
||||
# OpenAI API Key - Optional, for enhanced Hecate features
|
||||
# Get at: https://platform.openai.com/
|
||||
OPENAI_API_KEY=
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 1: Singularity/Apptainer (RECOMMENDED for HPC clusters)
|
||||
# - No root required, common on shared systems
|
||||
# - Auto-builds and caches SIF images from docker:// URLs
|
||||
# - Uses /scratch if available, otherwise /tmp
|
||||
# -----------------------------------------------------------------------------
|
||||
TERMINAL_ENV=singularity
|
||||
TERMINAL_SINGULARITY_IMAGE=docker://nikolaik/python-nodejs:python3.11-nodejs20
|
||||
TERMINAL_CWD=/workspace
|
||||
TERMINAL_TIMEOUT=60
|
||||
# Optional: Override scratch directory (auto-detects /scratch or /tmp)
|
||||
# TERMINAL_SCRATCH_DIR=/scratch/myuser/hermes
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 2: Local execution (FASTEST, but no isolation)
|
||||
# - Runs directly on your machine
|
||||
# - No containers, no setup required
|
||||
# - WARNING: Commands run with your user permissions
|
||||
# -----------------------------------------------------------------------------
|
||||
# TERMINAL_ENV=local
|
||||
# TERMINAL_CWD=/tmp
|
||||
# TERMINAL_TIMEOUT=60
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 3: Docker (good isolation, requires Docker)
|
||||
# - Requires Docker installed and user in 'docker' group
|
||||
# - Each task gets an isolated container
|
||||
# -----------------------------------------------------------------------------
|
||||
# TERMINAL_ENV=docker
|
||||
# TERMINAL_DOCKER_IMAGE=nikolaik/python-nodejs:python3.11-nodejs20
|
||||
# TERMINAL_CWD=/workspace
|
||||
# TERMINAL_TIMEOUT=60
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 4: Modal (cloud execution, scalable)
|
||||
# - Requires Modal account: pip install modal && modal setup
|
||||
# - Runs in Modal's cloud sandboxes
|
||||
# - Good for scaling to many parallel workers
|
||||
# -----------------------------------------------------------------------------
|
||||
# TERMINAL_ENV=modal
|
||||
# TERMINAL_MODAL_IMAGE=nikolaik/python-nodejs:python3.11-nodejs20
|
||||
# TERMINAL_CWD=/workspace
|
||||
# TERMINAL_TIMEOUT=60
|
||||
|
||||
# Common settings for all backends
|
||||
TERMINAL_LIFETIME_SECONDS=300
|
||||
TERMINAL_DISK_WARNING_GB=500
|
||||
|
||||
# =============================================================================
|
||||
# OPTIONAL CONFIGURATION
|
||||
# BROWSER TOOL CONFIGURATION (agent-browser + Browserbase)
|
||||
# =============================================================================
|
||||
# Browser automation requires Browserbase cloud service for remote browser execution.
|
||||
# This allows the agent to navigate websites, fill forms, and extract information.
|
||||
#
|
||||
# STEALTH MODES:
|
||||
# - Basic Stealth: ALWAYS active (random fingerprints, auto CAPTCHA solving)
|
||||
# - Advanced Stealth: Requires BROWSERBASE_ADVANCED_STEALTH=true (Scale Plan only)
|
||||
|
||||
# Browserbase API Key - Cloud browser execution
|
||||
# Get at: https://browserbase.com/
|
||||
BROWSERBASE_API_KEY=
|
||||
|
||||
# Browserbase Project ID - From your Browserbase dashboard
|
||||
BROWSERBASE_PROJECT_ID=
|
||||
|
||||
# Enable residential proxies for better CAPTCHA solving (default: true)
|
||||
BROWSERBASE_PROXIES=true
|
||||
|
||||
# Enable advanced stealth mode (default: false, requires Scale Plan)
|
||||
BROWSERBASE_ADVANCED_STEALTH=false
|
||||
|
||||
# Browser session timeout in seconds (default: 300)
|
||||
BROWSER_SESSION_TIMEOUT=300
|
||||
|
||||
# =============================================================================
|
||||
# LEGACY/OPTIONAL
|
||||
# =============================================================================
|
||||
|
||||
# Terminal Tool Settings
|
||||
HECATE_VM_LIFETIME_SECONDS=300
|
||||
HECATE_DEFAULT_SNAPSHOT_ID=snapshot_p5294qxt
|
||||
# Morph API Key - For legacy Hecate terminal backend
|
||||
# Get at: https://morph.so/
|
||||
# MORPH_API_KEY=
|
||||
|
||||
# Debug Logging (set to "true" to enable, logs saved to ./logs/)
|
||||
# Hecate VM Settings (only if using terminal-hecate tool)
|
||||
# HECATE_VM_LIFETIME_SECONDS=300
|
||||
# HECATE_DEFAULT_SNAPSHOT_ID=snapshot_p5294qxt
|
||||
|
||||
# =============================================================================
|
||||
# DEBUG OPTIONS
|
||||
# =============================================================================
|
||||
WEB_TOOLS_DEBUG=false
|
||||
VISION_TOOLS_DEBUG=false
|
||||
MOA_TOOLS_DEBUG=false
|
||||
|
||||
22
.gitignore
vendored
22
.gitignore
vendored
@@ -20,4 +20,24 @@ logs/
|
||||
data/
|
||||
.pytest_cache/
|
||||
tmp/
|
||||
temp_vision_images/
|
||||
temp_vision_images/
|
||||
hermes-*/*
|
||||
examples/
|
||||
tests/quick_test_dataset.jsonl
|
||||
tests/sample_dataset.jsonl
|
||||
run_datagen_kimik2-thinking.sh
|
||||
run_datagen_megascience_glm4-6.sh
|
||||
run_datagen_sonnet.sh
|
||||
source-data/*
|
||||
run_datagen_megascience_glm4-6.sh
|
||||
data/*
|
||||
node_modules/
|
||||
browser-use/
|
||||
agent-browser/
|
||||
# Private keys
|
||||
*.ppk
|
||||
*.pem
|
||||
privvy*
|
||||
|
||||
# CLI config (may contain sensitive SSH paths)
|
||||
cli-config.yaml
|
||||
|
||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "mini-swe-agent"]
|
||||
path = mini-swe-agent
|
||||
url = https://github.com/SWE-agent/mini-swe-agent
|
||||
414
README.md
414
README.md
@@ -4,34 +4,64 @@ An AI agent with advanced tool-calling capabilities, featuring a flexible toolse
|
||||
|
||||
## Features
|
||||
|
||||
- **Interactive CLI**: Beautiful terminal interface with animated feedback, personalities, and session management
|
||||
- **Web Tools**: Search, extract content, and crawl websites
|
||||
- **Terminal Tools**: Execute commands with interactive session support
|
||||
- **Terminal Tools**: Execute commands via local, Docker, Singularity, Modal, or SSH backends
|
||||
- **Browser Tools**: Automate web browsers to navigate, click, type, and extract content
|
||||
- **Vision Tools**: Analyze images from URLs
|
||||
- **Reasoning Tools**: Advanced multi-model reasoning (Mixture of Agents)
|
||||
- **Creative Tools**: Generate images from text prompts
|
||||
- **Skills Tools**: On-demand knowledge documents with progressive disclosure
|
||||
- **Toolsets System**: Organize tools into logical groups for different scenarios
|
||||
- **Batch Processing**: Process datasets in parallel with checkpointing and statistics tracking
|
||||
- **Ephemeral System Prompts**: Guide model behavior without polluting training datasets
|
||||
|
||||
## Quick Start (CLI)
|
||||
|
||||
```bash
|
||||
# After setup (see below), just run:
|
||||
./hermes
|
||||
|
||||
# Or with options:
|
||||
./hermes --model "anthropic/claude-sonnet-4" --toolsets "web,terminal"
|
||||
```
|
||||
|
||||
The CLI provides:
|
||||
- Animated spinners during thinking and tool execution
|
||||
- Kawaii-style feedback messages
|
||||
- `/commands` for configuration, history, and session management
|
||||
- Customizable personalities (`/personality kawaii`, `/personality pirate`, etc.)
|
||||
- Persistent configuration via `cli-config.yaml`
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. Install Dependencies
|
||||
### 1. Clone the Repository
|
||||
```bash
|
||||
# Clone with submodules (recommended)
|
||||
git clone --recurse-submodules https://github.com/NousResearch/Hermes-Agent.git
|
||||
cd Hermes-Agent
|
||||
|
||||
# Or if already cloned without submodules:
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
### 2. Install Dependencies
|
||||
```bash
|
||||
# Create and activate virtual environment (recommended)
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
|
||||
# Install required packages
|
||||
# Install Python packages
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Install Hecate for terminal tools
|
||||
git clone git@github.com:NousResearch/hecate.git
|
||||
cd hecate
|
||||
pip install -e .
|
||||
cd ..
|
||||
# Install mini-swe-agent for terminal tools
|
||||
pip install -e ./mini-swe-agent
|
||||
|
||||
# Install Node.js dependencies for browser tools (requires Node.js)
|
||||
npm install
|
||||
```
|
||||
|
||||
### 2. Configure Environment Variables
|
||||
### 3. Configure Environment Variables
|
||||
```bash
|
||||
# Copy the example environment file
|
||||
cp .env.example .env
|
||||
@@ -41,14 +71,265 @@ nano .env # or use your preferred editor
|
||||
```
|
||||
|
||||
**Required API Keys:**
|
||||
- `ANTHROPIC_API_KEY` - Main agent model (get at: https://console.anthropic.com/)
|
||||
- `OPENROUTER_API_KEY` - LLM access via OpenRouter (get at: https://openrouter.ai/keys)
|
||||
- `FIRECRAWL_API_KEY` - Web tools (get at: https://firecrawl.dev/)
|
||||
- `NOUS_API_KEY` - Vision & reasoning tools (get at: https://inference-api.nousresearch.com/)
|
||||
- `MORPH_API_KEY` - Terminal tools (get at: https://morph.so/)
|
||||
- `FAL_KEY` - Image generation (get at: https://fal.ai/)
|
||||
- `OPENAI_API_KEY` - Optional, for some Hecate features
|
||||
|
||||
See `.env.example` for all available configuration options including debug settings and terminal tool configuration.
|
||||
**Optional API Keys (for specific features):**
|
||||
- `BROWSERBASE_API_KEY` - Browser automation (get at: https://browserbase.com/)
|
||||
- `BROWSERBASE_PROJECT_ID` - From Browserbase dashboard
|
||||
- `MORPH_API_KEY` - For legacy Hecate terminal backend (get at: https://morph.so/)
|
||||
|
||||
### 4. Configure Terminal Backend
|
||||
|
||||
The terminal tool uses **mini-swe-agent** environments. Configure in `.env` or `cli-config.yaml`:
|
||||
|
||||
```bash
|
||||
# Backend: "local", "docker", "singularity", "modal", or "ssh"
|
||||
TERMINAL_ENV=local # Default: runs on host machine (no isolation)
|
||||
TERMINAL_ENV=ssh # Remote execution via SSH (agent code stays local)
|
||||
TERMINAL_ENV=singularity # Recommended for HPC: Apptainer/Singularity containers
|
||||
TERMINAL_ENV=docker # Isolated Docker containers
|
||||
TERMINAL_ENV=modal # Cloud execution via Modal
|
||||
|
||||
# Container image (for docker/singularity/modal backends)
|
||||
TERMINAL_DOCKER_IMAGE=python:3.11-slim
|
||||
TERMINAL_SINGULARITY_IMAGE=docker://python:3.11-slim
|
||||
TERMINAL_TIMEOUT=60
|
||||
|
||||
# SSH backend (for ssh)
|
||||
TERMINAL_SSH_HOST=my-server.example.com
|
||||
TERMINAL_SSH_USER=myuser
|
||||
TERMINAL_SSH_KEY=~/.ssh/id_rsa # Optional, uses ssh-agent if not set
|
||||
```
|
||||
|
||||
**Backend Requirements:**
|
||||
- **local**: No extra setup (runs directly on your machine, no isolation)
|
||||
- **ssh**: SSH access to remote machine (great for sandboxing - agent can't touch its own code)
|
||||
- **singularity**: Requires Apptainer or Singularity installed (common on HPC clusters, no root needed)
|
||||
- **docker**: Requires Docker installed and user in `docker` group
|
||||
- **modal**: Requires Modal account (see setup below)
|
||||
|
||||
### Singularity/Apptainer Setup (Recommended for HPC)
|
||||
|
||||
Singularity/Apptainer provides rootless container execution, ideal for HPC clusters:
|
||||
|
||||
```bash
|
||||
# 1. Verify Apptainer is installed
|
||||
apptainer --version # or: singularity --version
|
||||
|
||||
# 2. Set up cache directories (important for parallel workers)
|
||||
# Use /scratch if available (HPC), otherwise /tmp
|
||||
export APPTAINER_CACHEDIR=/scratch/$USER/.apptainer
|
||||
export APPTAINER_TMPDIR=/scratch/$USER/.apptainer/tmp
|
||||
mkdir -p "$APPTAINER_CACHEDIR" "$APPTAINER_TMPDIR"
|
||||
|
||||
# 3. Pre-build SIF image (recommended for parallel batch processing)
|
||||
# This avoids race conditions when multiple workers start simultaneously
|
||||
apptainer build $APPTAINER_CACHEDIR/python-nodejs.sif docker://nikolaik/python-nodejs:python3.11-nodejs20
|
||||
|
||||
# 4. Configure .env to use the local SIF
|
||||
TERMINAL_ENV=singularity
|
||||
TERMINAL_SINGULARITY_IMAGE=/scratch/$USER/.apptainer/python-nodejs.sif
|
||||
```
|
||||
|
||||
**Tip:** The batch scripts in `configs/` automatically handle SIF pre-building if `/scratch` is available.
|
||||
|
||||
### Modal Cloud Backend Setup
|
||||
|
||||
[Modal](https://modal.com) provides serverless cloud compute for running sandboxed environments at scale.
|
||||
|
||||
```bash
|
||||
# 1. Install Modal and dependencies
|
||||
pip install modal boto3
|
||||
|
||||
# 2. Authenticate with Modal (opens browser)
|
||||
modal setup
|
||||
|
||||
# 3. Set terminal backend to modal in .env
|
||||
TERMINAL_ENV=modal
|
||||
```
|
||||
|
||||
Modal uses CLI-based authentication (stored in `~/.modal/`), so no API key is needed in `.env`. After running `modal setup`, commands will automatically execute in Modal's cloud sandboxes.
|
||||
|
||||
### Browser Tools Setup
|
||||
|
||||
Browser tools enable the agent to navigate websites, fill forms, click buttons, and extract content. They use [agent-browser](https://github.com/vercel-labs/agent-browser) CLI with [Browserbase](https://browserbase.com) cloud execution.
|
||||
|
||||
```bash
|
||||
# 1. Install Node.js (if not already installed)
|
||||
# Use nvm (recommended) or your package manager
|
||||
|
||||
# 2. Install agent-browser CLI (choose one option):
|
||||
npm install -g agent-browser # Option A: Global install (recommended)
|
||||
npm install # Option B: Local install (uses npx fallback)
|
||||
|
||||
# 3. Get Browserbase credentials
|
||||
# Sign up at https://browserbase.com/ and get your:
|
||||
# - API Key (from Settings → API Keys)
|
||||
# - Project ID (from your project dashboard)
|
||||
|
||||
# 4. Add to your .env file:
|
||||
BROWSERBASE_API_KEY=your_api_key_here
|
||||
BROWSERBASE_PROJECT_ID=your_project_id_here
|
||||
```
|
||||
|
||||
**Available Browser Tools:**
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `browser_navigate` | Navigate to a URL |
|
||||
| `browser_snapshot` | Get text-based page snapshot with element refs |
|
||||
| `browser_click` | Click an element by ref (e.g., `@e5`) |
|
||||
| `browser_type` | Type text into an input field |
|
||||
| `browser_scroll` | Scroll up or down |
|
||||
| `browser_back` | Go back in browser history |
|
||||
| `browser_press` | Press a keyboard key (Enter, Tab, etc.) |
|
||||
| `browser_close` | Close the browser session |
|
||||
| `browser_get_images` | Get list of images on the page |
|
||||
|
||||
**Example Usage:**
|
||||
```bash
|
||||
# Use browser tools with web search and vision
|
||||
python run_agent.py \
|
||||
--query "Go to amazon.com and find the price of the latest Kindle" \
|
||||
--enabled_toolsets=browser,web,vision
|
||||
|
||||
# Use browser-focused distribution
|
||||
python batch_runner.py \
|
||||
--dataset_file=browser_tasks.jsonl \
|
||||
--distribution=browser_use \
|
||||
--run_name=browser_run
|
||||
```
|
||||
|
||||
See `.env.example` for all available configuration options including debug settings.
|
||||
|
||||
### Skills Tools
|
||||
|
||||
Skills are on-demand knowledge documents the agent can load when needed. They follow a **progressive disclosure** pattern to minimize token usage:
|
||||
|
||||
```
|
||||
skills/
|
||||
├── mlops/ # Category folder
|
||||
│ ├── axolotl/ # Skill folder
|
||||
│ │ ├── SKILL.md # Main instructions (required)
|
||||
│ │ ├── references/ # Additional docs, API specs
|
||||
│ │ └── templates/ # Output formats, configs
|
||||
│ └── vllm/
|
||||
│ └── SKILL.md
|
||||
```
|
||||
|
||||
**Available Skills Tools:**
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `skills_categories` | List available skill categories (~50 tokens) |
|
||||
| `skills_list` | List skills with name + description (~3k tokens for 40 skills) |
|
||||
| `skill_view` | Load full skill content, tags, and linked files |
|
||||
|
||||
**Example Usage:**
|
||||
```bash
|
||||
# Use skills tools
|
||||
python run_agent.py \
|
||||
--query "What skills do you have for fine-tuning? Show me the axolotl skill." \
|
||||
--enabled_toolsets=skills
|
||||
```
|
||||
|
||||
**Creating Skills:**
|
||||
|
||||
Skills use YAML frontmatter for metadata:
|
||||
```yaml
|
||||
---
|
||||
name: my-skill
|
||||
description: Brief description shown in skills_list
|
||||
tags: [tag1, tag2]
|
||||
related_skills: [other-skill]
|
||||
version: 1.0.0
|
||||
---
|
||||
# Skill Content
|
||||
|
||||
Instructions, examples, and guidelines here...
|
||||
```
|
||||
|
||||
Skills can include:
|
||||
- `references/` - Additional documentation, API specs, examples
|
||||
- `templates/` - Output formats, config files, boilerplate code
|
||||
- `scripts/` - Executable helpers (Python, shell scripts)
|
||||
|
||||
## Interactive CLI
|
||||
|
||||
The CLI provides a rich interactive experience for working with the agent.
|
||||
|
||||
### Running the CLI
|
||||
|
||||
```bash
|
||||
# Basic usage
|
||||
./hermes
|
||||
|
||||
# With specific model
|
||||
./hermes --model "anthropic/claude-sonnet-4"
|
||||
|
||||
# With specific toolsets
|
||||
./hermes --toolsets "web,terminal,skills"
|
||||
```
|
||||
|
||||
### CLI Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/help` | Show available commands |
|
||||
| `/tools` | List available tools by toolset |
|
||||
| `/toolsets` | List available toolsets |
|
||||
| `/model [name]` | Show or change the current model |
|
||||
| `/prompt [text]` | View/set custom system prompt |
|
||||
| `/personality [name]` | Set a predefined personality |
|
||||
| `/clear` | Clear screen and reset conversation |
|
||||
| `/reset` | Reset conversation only |
|
||||
| `/history` | Show conversation history |
|
||||
| `/save` | Save current conversation to file |
|
||||
| `/config` | Show current configuration |
|
||||
| `/quit` | Exit the CLI |
|
||||
|
||||
### Configuration
|
||||
|
||||
Copy `cli-config.yaml.example` to `cli-config.yaml` and customize:
|
||||
|
||||
```yaml
|
||||
# Model settings
|
||||
model:
|
||||
default: "anthropic/claude-sonnet-4"
|
||||
|
||||
# Terminal backend (local, docker, singularity, modal, or ssh)
|
||||
terminal:
|
||||
env_type: "local"
|
||||
cwd: "." # Use current directory
|
||||
|
||||
# Or use SSH for remote execution (keeps agent code isolated)
|
||||
# terminal:
|
||||
# env_type: "ssh"
|
||||
# ssh_host: "my-server.example.com"
|
||||
# ssh_user: "myuser"
|
||||
# ssh_key: "~/.ssh/id_rsa"
|
||||
# cwd: "/home/myuser/project"
|
||||
|
||||
# Enable specific toolsets
|
||||
toolsets:
|
||||
- all # or: web, terminal, browser, vision, etc.
|
||||
|
||||
# Custom personalities (use with /personality command)
|
||||
agent:
|
||||
personalities:
|
||||
helpful: "You are a helpful assistant."
|
||||
kawaii: "You are a kawaii assistant! Use cute expressions..."
|
||||
```
|
||||
|
||||
### Personalities
|
||||
|
||||
Built-in personalities available via `/personality`:
|
||||
- `helpful`, `concise`, `technical`, `creative`, `teacher`
|
||||
- `kawaii`, `catgirl`, `pirate`, `shakespeare`, `surfer`
|
||||
- `noir`, `uwu`, `philosopher`, `hype`
|
||||
|
||||
## Toolsets System
|
||||
|
||||
@@ -88,18 +369,17 @@ python run_agent.py --enabled_toolsets=safe --query "Help without running comman
|
||||
python run_agent.py --list_tools
|
||||
```
|
||||
|
||||
For detailed documentation on toolsets, see `TOOLSETS_README.md`.
|
||||
See `toolsets.py` for the complete list of available toolsets and how to create custom ones.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Default (all tools enabled)
|
||||
```bash
|
||||
# Uses OpenRouter by default - just set OPENROUTER_API_KEY in .env
|
||||
python run_agent.py \
|
||||
--query "search up the latest docs on jit in python 3.13 and write me basic example that's not in their docs. profile its perf" \
|
||||
--max_turns 20 \
|
||||
--model claude-sonnet-4-20250514 \
|
||||
--base_url https://api.anthropic.com/v1/ \
|
||||
--api_key $ANTHROPIC_API_KEY
|
||||
--model anthropic/claude-sonnet-4-20250514
|
||||
```
|
||||
|
||||
### With specific toolset
|
||||
@@ -107,17 +387,16 @@ python run_agent.py \
|
||||
python run_agent.py \
|
||||
--query "Debug this Python error" \
|
||||
--enabled_toolsets=debugging \
|
||||
--model claude-sonnet-4-20250514 \
|
||||
--api_key $ANTHROPIC_API_KEY
|
||||
--model anthropic/claude-sonnet-4-20250514
|
||||
```
|
||||
|
||||
### Python API
|
||||
```python
|
||||
from run_agent import AIAgent
|
||||
|
||||
# Use a specific toolset
|
||||
# Uses OpenRouter by default (reads OPENROUTER_API_KEY from .env)
|
||||
agent = AIAgent(
|
||||
model="claude-opus-4-20250514",
|
||||
model="anthropic/claude-sonnet-4-20250514",
|
||||
enabled_toolsets=["research"]
|
||||
)
|
||||
response = agent.chat("Find information about quantum computing")
|
||||
@@ -162,8 +441,36 @@ python batch_runner.py \
|
||||
- Combined output in `data/<run_name>/trajectories.jsonl`
|
||||
- Tool usage statistics and success rates
|
||||
|
||||
**Quick Start:** See [QUICKSTART_BATCH.md](QUICKSTART_BATCH.md) for a 5-minute getting started guide.
|
||||
**Full Documentation:** See [BATCH_PROCESSING.md](BATCH_PROCESSING.md) for comprehensive documentation.
|
||||
Use `--list_distributions` to see available toolset distributions for varied data generation.
|
||||
|
||||
### Trajectory Compression
|
||||
|
||||
Post-process trajectories to fit within token budgets for training:
|
||||
|
||||
```bash
|
||||
# Compress a directory of JSONL files
|
||||
python trajectory_compressor.py --input=data/my_run
|
||||
|
||||
# Compress a single JSONL file
|
||||
python trajectory_compressor.py --input=data/trajectories.jsonl
|
||||
|
||||
# Compress a 15% sample (useful for creating smaller training sets)
|
||||
python trajectory_compressor.py --input=data/trajectories.jsonl --sample_percent=15
|
||||
|
||||
# Custom output and token target
|
||||
python trajectory_compressor.py \
|
||||
--input=data/trajectories.jsonl \
|
||||
--output=data/compressed.jsonl \
|
||||
--target_max_tokens=16000
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Protects first turns (system, human, first GPT response, first tool call)
|
||||
- Protects last N turns (configurable)
|
||||
- Summarizes middle turns using LLM to fit target token budget
|
||||
- Supports both directory and single file input
|
||||
- Optional random sampling with `--sample_percent`
|
||||
- Configurable via `configs/trajectory_compression.yaml`
|
||||
|
||||
### Ephemeral System Prompts
|
||||
|
||||
@@ -184,7 +491,7 @@ python batch_runner.py \
|
||||
|
||||
The ephemeral prompt will influence the model's behavior during execution, but **only the standard tool-calling system prompt** will be saved in the trajectory files.
|
||||
|
||||
**Documentation:** See [docs/ephemeral_system_prompt.md](docs/ephemeral_system_prompt.md) for complete details.
|
||||
The ephemeral prompt influences model behavior during execution, but **only the standard tool-calling system prompt** is saved in trajectory files.
|
||||
|
||||
## Command Line Arguments
|
||||
|
||||
@@ -213,31 +520,52 @@ The ephemeral prompt will influence the model's behavior during execution, but *
|
||||
|
||||
All environment variables can be configured in the `.env` file (copy from `.env.example`).
|
||||
|
||||
**Core API Keys:**
|
||||
- `ANTHROPIC_API_KEY`: Main agent model
|
||||
**LLM Provider (OpenRouter):**
|
||||
- `OPENROUTER_API_KEY`: Primary LLM access via OpenRouter (supports Claude, GPT-4, Gemini, etc.)
|
||||
- `LLM_MODEL`: Default model (e.g., `anthropic/claude-sonnet-4`, `openai/gpt-4o`)
|
||||
|
||||
**Tool API Keys:**
|
||||
- `FIRECRAWL_API_KEY`: Web tools (search, extract, crawl)
|
||||
- `NOUS_API_KEY`: Vision and reasoning tools
|
||||
- `MORPH_API_KEY`: Terminal tools
|
||||
- `FAL_KEY`: Image generation tools
|
||||
- `OPENAI_API_KEY`: Optional, for some Hecate features
|
||||
|
||||
**Configuration Options:**
|
||||
**Terminal Tool Configuration (mini-swe-agent backend):**
|
||||
- `TERMINAL_ENV`: Backend type - `local`, `docker`, `singularity`, or `modal` (default: `local`)
|
||||
- `TERMINAL_DOCKER_IMAGE`: Docker image for docker backend (default: `python:3.11-slim`)
|
||||
- `TERMINAL_SINGULARITY_IMAGE`: Singularity/Apptainer image (can be `docker://...` URL or local `.sif` path)
|
||||
- `TERMINAL_TIMEOUT`: Command timeout in seconds (default: `60`)
|
||||
- `TERMINAL_LIFETIME_SECONDS`: Cleanup inactive environments after this time (default: `300`)
|
||||
- `TERMINAL_CWD`: Working directory inside containers (default: `/tmp`)
|
||||
- `TERMINAL_SCRATCH_DIR`: Custom scratch directory for sandbox storage (optional, auto-detects `/scratch`)
|
||||
|
||||
**Browser Tool Configuration (agent-browser + Browserbase):**
|
||||
- `BROWSERBASE_API_KEY`: Browserbase API key for cloud browser execution
|
||||
- `BROWSERBASE_PROJECT_ID`: Browserbase project ID
|
||||
- `BROWSER_SESSION_TIMEOUT`: Session timeout in seconds (default: `300`)
|
||||
|
||||
**Legacy Hecate Terminal Backend (optional):**
|
||||
- `MORPH_API_KEY`: For Hecate/MorphCloud terminal backend
|
||||
- `HECATE_VM_LIFETIME_SECONDS`: VM lifetime (default: 300)
|
||||
- `HECATE_DEFAULT_SNAPSHOT_ID`: Default snapshot (default: snapshot_p5294qxt)
|
||||
|
||||
**Debug Options:**
|
||||
- `WEB_TOOLS_DEBUG`, `VISION_TOOLS_DEBUG`, `MOA_TOOLS_DEBUG`, `IMAGE_TOOLS_DEBUG`: Enable debug logging
|
||||
|
||||
## Documentation
|
||||
## Key Files
|
||||
|
||||
**Single Agent Usage:**
|
||||
- `TOOLSETS_README.md`: Comprehensive guide to the toolsets system
|
||||
- `toolsets.py`: View and modify available toolsets
|
||||
- `model_tools.py`: Core tool definitions and handlers
|
||||
|
||||
**Batch Processing:**
|
||||
- `QUICKSTART_BATCH.md`: 5-minute quick start guide
|
||||
- `BATCH_PROCESSING.md`: Complete batch processing documentation
|
||||
- `toolset_distributions.py`: Toolset distributions for data generation
|
||||
|
||||
## Examples
|
||||
|
||||
See `TOOLSETS_README.md` for extensive examples of using different toolsets for various scenarios.
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `hermes` | CLI launcher script (run with `./hermes`) |
|
||||
| `cli.py` | Interactive CLI implementation |
|
||||
| `cli-config.yaml` | CLI configuration (copy from `.example`) |
|
||||
| `run_agent.py` | Main agent runner - single query execution |
|
||||
| `batch_runner.py` | Parallel batch processing with checkpointing |
|
||||
| `model_tools.py` | Core tool definitions and handlers |
|
||||
| `toolsets.py` | Toolset definitions and composition |
|
||||
| `toolset_distributions.py` | Probability distributions for data generation |
|
||||
| `trajectory_compressor.py` | Post-process trajectories for training |
|
||||
| `tools/` | Individual tool implementations |
|
||||
| `tools/skills_tool.py` | Skills system with progressive disclosure |
|
||||
| `skills/` | On-demand knowledge documents |
|
||||
| `docs/` | Documentation |
|
||||
| `configs/` | Example batch run scripts |
|
||||
|
||||
305
TODO.md
Normal file
305
TODO.md
Normal file
@@ -0,0 +1,305 @@
|
||||
# Hermes Agent - Future Improvements
|
||||
|
||||
> Ideas for enhancing the agent's capabilities, generated from self-analysis of the codebase.
|
||||
|
||||
---
|
||||
|
||||
## 1. Memory & Context Management 🧠
|
||||
|
||||
**Problem:** Context grows unbounded during long conversations. Trajectory compression exists for training data post-hoc, but live conversations lack intelligent context management.
|
||||
|
||||
**Ideas:**
|
||||
- [ ] **Incremental summarization** - Compress old tool outputs on-the-fly during conversations
|
||||
- Trigger when context exceeds threshold (e.g., 80% of max tokens)
|
||||
- Preserve recent turns fully, summarize older tool responses
|
||||
- Could reuse logic from `trajectory_compressor.py`
|
||||
|
||||
- [ ] **Semantic memory retrieval** - Vector store for long conversation recall
|
||||
- Embed important facts/findings as conversation progresses
|
||||
- Retrieve relevant memories when needed instead of keeping everything in context
|
||||
- Consider lightweight solutions: ChromaDB, FAISS, or even a simple embedding cache
|
||||
|
||||
- [ ] **Working vs. episodic memory** distinction
|
||||
- Working memory: Current task state, recent tool results (always in context)
|
||||
- Episodic memory: Past findings, tried approaches (retrieved on demand)
|
||||
- Clear eviction policies for each
|
||||
|
||||
**Files to modify:** `run_agent.py` (add memory manager), possibly new `tools/memory_tool.py`
|
||||
|
||||
---
|
||||
|
||||
## 2. Self-Reflection & Course Correction 🔄
|
||||
|
||||
**Problem:** Current retry logic handles malformed outputs but not semantic failures. Agent doesn't reason about *why* something failed.
|
||||
|
||||
**Ideas:**
|
||||
- [ ] **Meta-reasoning after failures** - When a tool returns an error or unexpected result:
|
||||
```
|
||||
Tool failed → Reflect: "Why did this fail? What assumptions were wrong?"
|
||||
→ Adjust approach → Retry with new strategy
|
||||
```
|
||||
- Could be a lightweight LLM call or structured self-prompt
|
||||
|
||||
- [ ] **Planning/replanning module** - For complex multi-step tasks:
|
||||
- Generate plan before execution
|
||||
- After each step, evaluate: "Am I on track? Should I revise the plan?"
|
||||
- Store plan in working memory, update as needed
|
||||
|
||||
- [ ] **Approach memory** - Remember what didn't work:
|
||||
- "I tried X for this type of problem and it failed because Y"
|
||||
- Prevents repeating failed strategies in the same conversation
|
||||
|
||||
**Files to modify:** `run_agent.py` (add reflection hooks in tool loop), new `tools/reflection_tool.py`
|
||||
|
||||
---
|
||||
|
||||
## 3. Tool Composition & Learning 🔧
|
||||
|
||||
**Problem:** Tools are atomic. Complex tasks require repeated manual orchestration of the same tool sequences.
|
||||
|
||||
**Ideas:**
|
||||
- [ ] **Macro tools / Tool chains** - Define reusable tool sequences:
|
||||
```yaml
|
||||
research_topic:
|
||||
description: "Deep research on a topic"
|
||||
steps:
|
||||
- web_search: {query: "$topic"}
|
||||
- web_extract: {urls: "$search_results.urls[:3]"}
|
||||
- summarize: {content: "$extracted"}
|
||||
```
|
||||
- Could be defined in skills or a new `macros/` directory
|
||||
- Agent can invoke macro as single tool call
|
||||
|
||||
- [ ] **Tool failure patterns** - Learn from failures:
|
||||
- Track: tool, input pattern, error type, what worked instead
|
||||
- Before calling a tool, check: "Has this pattern failed before?"
|
||||
- Persistent across sessions (stored in skills or separate DB)
|
||||
|
||||
- [ ] **Parallel tool execution** - When tools are independent, run concurrently:
|
||||
- Detect independence (no data dependencies between calls)
|
||||
- Use `asyncio.gather()` for parallel execution
|
||||
- Already have async support in some tools, just need orchestration
|
||||
|
||||
**Files to modify:** `model_tools.py`, `toolsets.py`, new `tool_macros.py`
|
||||
|
||||
---
|
||||
|
||||
## 4. Dynamic Skills Expansion 📚
|
||||
|
||||
**Problem:** Skills system is elegant but static. Skills must be manually created and added.
|
||||
|
||||
**Ideas:**
|
||||
- [ ] **Skill acquisition from successful tasks** - After completing a complex task:
|
||||
- "This approach worked well. Save as a skill?"
|
||||
- Extract: goal, steps taken, tools used, key decisions
|
||||
- Generate SKILL.md automatically
|
||||
- Store in user's skills directory
|
||||
|
||||
- [ ] **Skill templates** - Common patterns that can be parameterized:
|
||||
```markdown
|
||||
# Debug {language} Error
|
||||
1. Reproduce the error
|
||||
2. Search for error message: `web_search("{error_message} {language}")`
|
||||
3. Check common causes: {common_causes}
|
||||
4. Apply fix and verify
|
||||
```
|
||||
|
||||
- [ ] **Skill chaining** - Combine skills for complex workflows:
|
||||
- Skills can reference other skills as dependencies
|
||||
- "To do X, first apply skill Y, then skill Z"
|
||||
- Directed graph of skill dependencies
|
||||
|
||||
**Files to modify:** `tools/skills_tool.py`, `skills/` directory structure, new `skill_generator.py`
|
||||
|
||||
---
|
||||
|
||||
## 5. Task Continuation Hints 🎯
|
||||
|
||||
**Problem:** Could be more helpful by suggesting logical next steps.
|
||||
|
||||
**Ideas:**
|
||||
- [ ] **Suggest next steps** - At end of a task, suggest logical continuations:
|
||||
- "Code is written. Want me to also write tests / docs / deploy?"
|
||||
- Based on common workflows for task type
|
||||
- Non-intrusive, just offer options
|
||||
|
||||
**Files to modify:** `run_agent.py`, response generation logic
|
||||
|
||||
---
|
||||
|
||||
## 6. Uncertainty & Honesty Calibration 🎚️
|
||||
|
||||
**Problem:** Sometimes confidently wrong. Should be better calibrated about what I know vs. don't know.
|
||||
|
||||
**Ideas:**
|
||||
- [ ] **Source attribution** - Track where information came from:
|
||||
- "According to the docs I just fetched..." vs "From my training data (may be outdated)..."
|
||||
- Let user assess reliability themselves
|
||||
|
||||
- [ ] **Cross-reference high-stakes claims** - Self-check for made-up details:
|
||||
- When stakes are high, verify with tools before presenting as fact
|
||||
- "Let me verify that before you act on it..."
|
||||
|
||||
**Files to modify:** `run_agent.py`, response generation logic
|
||||
|
||||
---
|
||||
|
||||
## 7. Resource Awareness & Efficiency 💰
|
||||
|
||||
**Problem:** No awareness of costs, time, or resource usage. Could be smarter about efficiency.
|
||||
|
||||
**Ideas:**
|
||||
- [ ] **Tool result caching** - Don't repeat identical operations:
|
||||
- Cache web searches, extractions within a session
|
||||
- Invalidation based on time-sensitivity of query
|
||||
- Hash-based lookup: same input → cached output
|
||||
|
||||
- [ ] **Lazy evaluation** - Don't fetch everything upfront:
|
||||
- Get summaries first, full content only if needed
|
||||
- "I found 5 relevant pages. Want me to deep-dive on any?"
|
||||
|
||||
**Files to modify:** `model_tools.py`, new `resource_tracker.py`
|
||||
|
||||
---
|
||||
|
||||
## 8. Collaborative Problem Solving 🤝
|
||||
|
||||
**Problem:** Interaction is command/response. Complex problems benefit from dialogue.
|
||||
|
||||
**Ideas:**
|
||||
- [ ] **Assumption surfacing** - Make implicit assumptions explicit:
|
||||
- "I'm assuming you want Python 3.11+. Correct?"
|
||||
- "This solution assumes you have sudo access..."
|
||||
- Let user correct before going down wrong path
|
||||
|
||||
- [ ] **Checkpoint & confirm** - For high-stakes operations:
|
||||
- "About to delete 47 files. Here's the list - proceed?"
|
||||
- "This will modify your database. Want a backup first?"
|
||||
- Configurable threshold for when to ask
|
||||
|
||||
**Files to modify:** `run_agent.py`, system prompt configuration
|
||||
|
||||
---
|
||||
|
||||
## 9. Project-Local Context 💾
|
||||
|
||||
**Problem:** Valuable context lost between sessions.
|
||||
|
||||
**Ideas:**
|
||||
- [ ] **Project awareness** - Remember project-specific context:
|
||||
- Store `.hermes/context.md` in project directory
|
||||
- "This is a Django project using PostgreSQL"
|
||||
- Coding style preferences, deployment setup, etc.
|
||||
- Load automatically when working in that directory
|
||||
|
||||
- [ ] **Handoff notes** - Leave notes for future sessions:
|
||||
- Write to `.hermes/notes.md` in project
|
||||
- "TODO for next session: finish implementing X"
|
||||
- "Known issues: Y doesn't work on Windows"
|
||||
|
||||
**Files to modify:** New `project_context.py`, auto-load in `run_agent.py`
|
||||
|
||||
---
|
||||
|
||||
## 10. Graceful Degradation & Robustness 🛡️
|
||||
|
||||
**Problem:** When things go wrong, recovery is limited. Should fail gracefully.
|
||||
|
||||
**Ideas:**
|
||||
- [ ] **Fallback chains** - When primary approach fails, have backups:
|
||||
- `web_extract` fails → try `browser_navigate` → try `web_search` for cached version
|
||||
- Define fallback order per tool type
|
||||
|
||||
- [ ] **Partial progress preservation** - Don't lose work on failure:
|
||||
- Long task fails midway → save what we've got
|
||||
- "I completed 3/5 steps before the error. Here's what I have..."
|
||||
|
||||
- [ ] **Self-healing** - Detect and recover from bad states:
|
||||
- Browser stuck → close and retry
|
||||
- Terminal hung → timeout and reset
|
||||
|
||||
**Files to modify:** `model_tools.py`, tool implementations, new `fallback_manager.py`
|
||||
|
||||
---
|
||||
|
||||
## 11. Tools & Skills Wishlist 🧰
|
||||
|
||||
*Things that would need new tool implementations (can't do well with current tools):*
|
||||
|
||||
### High-Impact
|
||||
|
||||
- [ ] **Audio/Video Transcription** 🎬
|
||||
- Transcribe audio files, podcasts, YouTube videos
|
||||
- Extract key moments from video
|
||||
- Currently blind to multimedia content
|
||||
- *Could potentially use whisper via terminal, but native tool would be cleaner*
|
||||
|
||||
- [ ] **Diagram Rendering** 📊
|
||||
- Render Mermaid/PlantUML to actual images
|
||||
- Can generate the code, but rendering requires external service or tool
|
||||
- "Show me how these components connect" → actual visual diagram
|
||||
|
||||
### Medium-Impact
|
||||
|
||||
- [ ] **Document Generation** 📄
|
||||
- Create styled PDFs, Word docs, presentations
|
||||
- *Can do basic PDF via terminal tools, but limited*
|
||||
|
||||
- [ ] **Diff/Patch Tool** 📝
|
||||
- Surgical code modifications with preview
|
||||
- "Change line 45-50 to X" without rewriting whole file
|
||||
- Show diffs before applying
|
||||
- *Can use `diff`/`patch` but a native tool would be safer*
|
||||
|
||||
### Skills to Create
|
||||
|
||||
- [ ] **Domain-specific skill packs:**
|
||||
- DevOps/Infrastructure (Terraform, K8s, AWS)
|
||||
- Data Science workflows (EDA, model training)
|
||||
- Security/pentesting procedures
|
||||
|
||||
- [ ] **Framework-specific skills:**
|
||||
- React/Vue/Angular patterns
|
||||
- Django/Rails/Express conventions
|
||||
- Database optimization playbooks
|
||||
|
||||
- [ ] **Troubleshooting flowcharts:**
|
||||
- "Docker container won't start" → decision tree
|
||||
- "Production is slow" → systematic diagnosis
|
||||
|
||||
---
|
||||
|
||||
## Priority Order (Suggested)
|
||||
|
||||
1. **Memory & Context Management** - Biggest impact on complex tasks
|
||||
2. **Self-Reflection** - Improves reliability and reduces wasted tool calls
|
||||
3. **Project-Local Context** - Practical win, keeps useful info across sessions
|
||||
4. **Tool Composition** - Quality of life, builds on other improvements
|
||||
5. **Dynamic Skills** - Force multiplier for repeated tasks
|
||||
|
||||
---
|
||||
|
||||
## Removed Items (Unrealistic)
|
||||
|
||||
The following were removed because they're architecturally impossible:
|
||||
|
||||
- ~~Proactive suggestions / Prefetching~~ - Agent only runs on user request, can't interject
|
||||
- ~~Session save/restore across conversations~~ - Agent doesn't control session persistence
|
||||
- ~~User preference learning across sessions~~ - Same issue
|
||||
- ~~Clipboard integration~~ - No access to user's local system clipboard
|
||||
- ~~Voice/TTS playback~~ - Can generate audio but can't play it to user
|
||||
- ~~Set reminders~~ - No persistent background execution
|
||||
|
||||
The following were removed because they're **already possible**:
|
||||
|
||||
- ~~HTTP/API Client~~ → Use `curl` or Python `requests` in terminal
|
||||
- ~~Structured Data Manipulation~~ → Use `pandas` in terminal
|
||||
- ~~Git-Native Operations~~ → Use `git` CLI in terminal
|
||||
- ~~Symbolic Math~~ → Use `SymPy` in terminal
|
||||
- ~~Code Quality Tools~~ → Run linters (`eslint`, `black`, `mypy`) in terminal
|
||||
- ~~Testing Framework~~ → Run `pytest`, `jest`, etc. in terminal
|
||||
- ~~Translation~~ → LLM handles this fine, or use translation APIs
|
||||
|
||||
---
|
||||
|
||||
*Last updated: $(date +%Y-%m-%d)* 🤖
|
||||
400
batch_runner.py
400
batch_runner.py
@@ -30,6 +30,8 @@ from datetime import datetime
|
||||
from multiprocessing import Pool, Manager, Lock
|
||||
import traceback
|
||||
|
||||
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
|
||||
from rich.console import Console
|
||||
import fire
|
||||
|
||||
from run_agent import AIAgent
|
||||
@@ -44,6 +46,77 @@ from toolset_distributions import (
|
||||
# Global configuration for worker processes
|
||||
_WORKER_CONFIG = {}
|
||||
|
||||
# All possible tools - used to ensure consistent schema across all trajectory entries
|
||||
# This is required because Arrow/Parquet (used by HuggingFace datasets) needs identical schemas
|
||||
ALL_POSSIBLE_TOOLS = {
|
||||
'terminal', 'web_search', 'web_extract',
|
||||
'vision_analyze', 'image_generate', 'mixture_of_agents',
|
||||
# Skills tools
|
||||
'skills_categories', 'skills_list', 'skill_view',
|
||||
# Browser automation tools
|
||||
'browser_navigate', 'browser_snapshot', 'browser_click',
|
||||
'browser_type', 'browser_scroll', 'browser_back',
|
||||
'browser_press', 'browser_close', 'browser_get_images',
|
||||
'browser_vision'
|
||||
}
|
||||
|
||||
# Default stats for tools that weren't used
|
||||
DEFAULT_TOOL_STATS = {'count': 0, 'success': 0, 'failure': 0}
|
||||
|
||||
|
||||
def _normalize_tool_stats(tool_stats: Dict[str, Dict[str, int]]) -> Dict[str, Dict[str, int]]:
|
||||
"""
|
||||
Normalize tool_stats to include all possible tools with consistent schema.
|
||||
|
||||
This ensures HuggingFace datasets can load the JSONL without schema mismatch errors.
|
||||
Tools that weren't used get zero counts.
|
||||
|
||||
Args:
|
||||
tool_stats (Dict): Raw tool statistics from extraction
|
||||
|
||||
Returns:
|
||||
Dict: Normalized tool statistics with all tools present
|
||||
"""
|
||||
normalized = {}
|
||||
|
||||
# Add all possible tools with defaults
|
||||
for tool in ALL_POSSIBLE_TOOLS:
|
||||
if tool in tool_stats:
|
||||
normalized[tool] = tool_stats[tool].copy()
|
||||
else:
|
||||
normalized[tool] = DEFAULT_TOOL_STATS.copy()
|
||||
|
||||
# Also include any unexpected tools (in case new tools are added)
|
||||
for tool, stats in tool_stats.items():
|
||||
if tool not in normalized:
|
||||
normalized[tool] = stats.copy()
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_tool_error_counts(tool_error_counts: Dict[str, int]) -> Dict[str, int]:
|
||||
"""
|
||||
Normalize tool_error_counts to include all possible tools.
|
||||
|
||||
Args:
|
||||
tool_error_counts (Dict): Raw error counts mapping
|
||||
|
||||
Returns:
|
||||
Dict: Normalized error counts with all tools present
|
||||
"""
|
||||
normalized = {}
|
||||
|
||||
# Add all possible tools with zero defaults
|
||||
for tool in ALL_POSSIBLE_TOOLS:
|
||||
normalized[tool] = tool_error_counts.get(tool, 0)
|
||||
|
||||
# Also include any unexpected tools
|
||||
for tool, count in tool_error_counts.items():
|
||||
if tool not in normalized:
|
||||
normalized[tool] = count
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]:
|
||||
"""
|
||||
@@ -98,10 +171,9 @@ def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, i
|
||||
# Terminal wraps its response in a "content" field
|
||||
if "content" in content_json and isinstance(content_json["content"], dict):
|
||||
inner_content = content_json["content"]
|
||||
# Check for actual error (non-null error field or non-zero exit code)
|
||||
has_error = (inner_content.get("error") is not None or
|
||||
inner_content.get("exit_code", 0) != 0)
|
||||
if has_error:
|
||||
# Check for actual error (non-null error field)
|
||||
# Note: non-zero exit codes are not failures - the model can self-correct
|
||||
if inner_content.get("error") is not None:
|
||||
is_success = False
|
||||
|
||||
# Check for "success": false pattern used by some tools
|
||||
@@ -155,7 +227,8 @@ def _process_single_prompt(
|
||||
if config.get("verbose"):
|
||||
print(f" Prompt {prompt_index}: Using toolsets {selected_toolsets}")
|
||||
|
||||
# Initialize agent with sampled toolsets
|
||||
# Initialize agent with sampled toolsets and log prefix for identification
|
||||
log_prefix = f"[B{batch_num}:P{prompt_index}]"
|
||||
agent = AIAgent(
|
||||
base_url=config.get("base_url"),
|
||||
api_key=config.get("api_key"),
|
||||
@@ -164,7 +237,13 @@ def _process_single_prompt(
|
||||
enabled_toolsets=selected_toolsets,
|
||||
save_trajectories=False, # We handle saving ourselves
|
||||
verbose_logging=config.get("verbose", False),
|
||||
ephemeral_system_prompt=config.get("ephemeral_system_prompt")
|
||||
ephemeral_system_prompt=config.get("ephemeral_system_prompt"),
|
||||
log_prefix_chars=config.get("log_prefix_chars", 100),
|
||||
log_prefix=log_prefix,
|
||||
providers_allowed=config.get("providers_allowed"),
|
||||
providers_ignored=config.get("providers_ignored"),
|
||||
providers_order=config.get("providers_order"),
|
||||
provider_sort=config.get("provider_sort"),
|
||||
)
|
||||
|
||||
# Run the agent with task_id to ensure each task gets its own isolated VM
|
||||
@@ -186,6 +265,7 @@ def _process_single_prompt(
|
||||
"trajectory": trajectory,
|
||||
"tool_stats": tool_stats,
|
||||
"completed": result["completed"],
|
||||
"partial": result.get("partial", False),
|
||||
"api_calls": result["api_calls"],
|
||||
"toolsets_used": selected_toolsets,
|
||||
"metadata": {
|
||||
@@ -266,13 +346,27 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
||||
|
||||
# Save trajectory if successful
|
||||
if result["success"] and result["trajectory"]:
|
||||
# Get and normalize tool stats for consistent schema across all entries
|
||||
raw_tool_stats = result.get("tool_stats", {})
|
||||
tool_stats = _normalize_tool_stats(raw_tool_stats)
|
||||
|
||||
# Create normalized tool_error_counts mapping tool names to their failure counts
|
||||
raw_error_counts = {
|
||||
tool_name: stats.get("failure", 0)
|
||||
for tool_name, stats in raw_tool_stats.items()
|
||||
}
|
||||
tool_error_counts = _normalize_tool_error_counts(raw_error_counts)
|
||||
|
||||
trajectory_entry = {
|
||||
"prompt_index": prompt_index,
|
||||
"conversations": result["trajectory"],
|
||||
"metadata": result["metadata"],
|
||||
"completed": result["completed"],
|
||||
"partial": result.get("partial", False), # True if stopped due to invalid tool calls
|
||||
"api_calls": result["api_calls"],
|
||||
"toolsets_used": result["toolsets_used"]
|
||||
"toolsets_used": result["toolsets_used"],
|
||||
"tool_stats": tool_stats, # Full stats: {tool: {count, success, failure}} - normalized
|
||||
"tool_error_counts": tool_error_counts # Simple: {tool: failure_count} - normalized
|
||||
}
|
||||
|
||||
# Append to batch output file
|
||||
@@ -292,8 +386,13 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
||||
batch_tool_stats[tool_name]["success"] += stats["success"]
|
||||
batch_tool_stats[tool_name]["failure"] += stats["failure"]
|
||||
|
||||
completed_in_batch.append(prompt_index)
|
||||
print(f" ✅ Prompt {prompt_index} completed")
|
||||
# Only mark as completed if successfully saved (failed prompts can be retried on resume)
|
||||
if result["success"] and result["trajectory"]:
|
||||
completed_in_batch.append(prompt_index)
|
||||
status = "⚠️ partial" if result.get("partial") else "✅"
|
||||
print(f" {status} Prompt {prompt_index} completed")
|
||||
else:
|
||||
print(f" ❌ Prompt {prompt_index} failed (will retry on resume)")
|
||||
|
||||
print(f"✅ Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)")
|
||||
|
||||
@@ -323,11 +422,16 @@ class BatchRunner:
|
||||
model: str = "claude-opus-4-20250514",
|
||||
num_workers: int = 4,
|
||||
verbose: bool = False,
|
||||
ephemeral_system_prompt: str = None
|
||||
ephemeral_system_prompt: str = None,
|
||||
log_prefix_chars: int = 100,
|
||||
providers_allowed: List[str] = None,
|
||||
providers_ignored: List[str] = None,
|
||||
providers_order: List[str] = None,
|
||||
provider_sort: str = None,
|
||||
):
|
||||
"""
|
||||
Initialize the batch runner.
|
||||
|
||||
|
||||
Args:
|
||||
dataset_file (str): Path to the dataset JSONL file with 'prompt' field
|
||||
batch_size (int): Number of prompts per batch
|
||||
@@ -340,6 +444,11 @@ class BatchRunner:
|
||||
num_workers (int): Number of parallel workers
|
||||
verbose (bool): Enable verbose logging
|
||||
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
||||
providers_allowed (List[str]): OpenRouter providers to allow (optional)
|
||||
providers_ignored (List[str]): OpenRouter providers to ignore (optional)
|
||||
providers_order (List[str]): OpenRouter providers to try in order (optional)
|
||||
provider_sort (str): Sort providers by price/throughput/latency (optional)
|
||||
"""
|
||||
self.dataset_file = Path(dataset_file)
|
||||
self.batch_size = batch_size
|
||||
@@ -352,6 +461,11 @@ class BatchRunner:
|
||||
self.num_workers = num_workers
|
||||
self.verbose = verbose
|
||||
self.ephemeral_system_prompt = ephemeral_system_prompt
|
||||
self.log_prefix_chars = log_prefix_chars
|
||||
self.providers_allowed = providers_allowed
|
||||
self.providers_ignored = providers_ignored
|
||||
self.providers_order = providers_order
|
||||
self.provider_sort = provider_sort
|
||||
|
||||
# Validate distribution
|
||||
if not validate_distribution(distribution):
|
||||
@@ -471,11 +585,88 @@ class BatchRunner:
|
||||
if lock:
|
||||
with lock:
|
||||
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(checkpoint_data, f, indent=2)
|
||||
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
||||
else:
|
||||
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(checkpoint_data, f, indent=2)
|
||||
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def _scan_completed_prompts_by_content(self) -> set:
|
||||
"""
|
||||
Scan all batch files and extract completed prompts by their actual content.
|
||||
|
||||
This provides a more robust resume mechanism that matches on prompt text
|
||||
rather than indices, allowing recovery even if indices don't match.
|
||||
|
||||
Returns:
|
||||
set: Set of prompt texts that have been successfully processed
|
||||
"""
|
||||
completed_prompts = set()
|
||||
batch_files = sorted(self.output_dir.glob("batch_*.jsonl"))
|
||||
|
||||
if not batch_files:
|
||||
return completed_prompts
|
||||
|
||||
print(f"📂 Scanning {len(batch_files)} batch files for completed prompts...")
|
||||
|
||||
for batch_file in batch_files:
|
||||
try:
|
||||
with open(batch_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
try:
|
||||
entry = json.loads(line.strip())
|
||||
|
||||
# Skip failed entries - we want to retry these
|
||||
if entry.get("failed", False):
|
||||
continue
|
||||
|
||||
# Extract the human/user prompt from conversations
|
||||
conversations = entry.get("conversations", [])
|
||||
for msg in conversations:
|
||||
if msg.get("from") == "human":
|
||||
prompt_text = msg.get("value", "").strip()
|
||||
if prompt_text:
|
||||
completed_prompts.add(prompt_text)
|
||||
break # Only need the first human message
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Warning: Error reading {batch_file.name}: {e}")
|
||||
|
||||
return completed_prompts
|
||||
|
||||
def _filter_dataset_by_completed(self, completed_prompts: set) -> Tuple[List[Dict], List[int]]:
|
||||
"""
|
||||
Filter the dataset to exclude prompts that have already been completed.
|
||||
|
||||
Args:
|
||||
completed_prompts: Set of prompt texts that have been completed
|
||||
|
||||
Returns:
|
||||
Tuple of (filtered_dataset, skipped_indices)
|
||||
"""
|
||||
filtered_dataset = []
|
||||
skipped_indices = []
|
||||
|
||||
for idx, entry in enumerate(self.dataset):
|
||||
# Extract prompt from the dataset entry
|
||||
prompt_text = entry.get("prompt", "").strip()
|
||||
|
||||
# Also check conversations format
|
||||
if not prompt_text:
|
||||
conversations = entry.get("conversations", [])
|
||||
for msg in conversations:
|
||||
role = msg.get("role") or msg.get("from")
|
||||
if role in ("user", "human"):
|
||||
prompt_text = (msg.get("content") or msg.get("value", "")).strip()
|
||||
break
|
||||
|
||||
if prompt_text in completed_prompts:
|
||||
skipped_indices.append(idx)
|
||||
else:
|
||||
# Keep original index for tracking
|
||||
filtered_dataset.append((idx, entry))
|
||||
|
||||
return filtered_dataset, skipped_indices
|
||||
|
||||
def run(self, resume: bool = False):
|
||||
"""
|
||||
@@ -488,17 +679,48 @@ class BatchRunner:
|
||||
print("🚀 Starting Batch Processing")
|
||||
print("=" * 70)
|
||||
|
||||
# Load checkpoint
|
||||
checkpoint_data = self._load_checkpoint() if resume else {
|
||||
# Smart resume: scan batch files by content to find completed prompts
|
||||
completed_prompt_texts = set()
|
||||
if resume:
|
||||
completed_prompt_texts = self._scan_completed_prompts_by_content()
|
||||
if completed_prompt_texts:
|
||||
print(f" Found {len(completed_prompt_texts)} already-completed prompts by content matching")
|
||||
|
||||
# Filter dataset to only include unprocessed prompts
|
||||
if resume and completed_prompt_texts:
|
||||
filtered_entries, skipped_indices = self._filter_dataset_by_completed(completed_prompt_texts)
|
||||
|
||||
if not filtered_entries:
|
||||
print("\n✅ All prompts have already been processed!")
|
||||
return
|
||||
|
||||
# Recreate batches from filtered entries (keeping original indices for tracking)
|
||||
batches_to_process = []
|
||||
for i in range(0, len(filtered_entries), self.batch_size):
|
||||
batch = filtered_entries[i:i + self.batch_size]
|
||||
batches_to_process.append(batch)
|
||||
|
||||
self.batches = batches_to_process
|
||||
|
||||
# Print prominent resume summary
|
||||
print("\n" + "=" * 70)
|
||||
print("📊 RESUME SUMMARY")
|
||||
print("=" * 70)
|
||||
print(f" Original dataset size: {len(self.dataset):,} prompts")
|
||||
print(f" Already completed: {len(skipped_indices):,} prompts")
|
||||
print(f" ─────────────────────────────────────────")
|
||||
print(f" 🎯 RESUMING WITH: {len(filtered_entries):,} prompts")
|
||||
print(f" New batches created: {len(batches_to_process)}")
|
||||
print("=" * 70 + "\n")
|
||||
|
||||
# Initialize checkpoint data (needed for saving at the end)
|
||||
checkpoint_data = {
|
||||
"run_name": self.run_name,
|
||||
"completed_prompts": [],
|
||||
"batch_stats": {},
|
||||
"last_updated": None
|
||||
}
|
||||
|
||||
if resume and checkpoint_data.get("completed_prompts"):
|
||||
print(f"📂 Resuming from checkpoint ({len(checkpoint_data['completed_prompts'])} prompts already completed)")
|
||||
|
||||
# Prepare configuration for workers
|
||||
config = {
|
||||
"distribution": self.distribution,
|
||||
@@ -507,17 +729,24 @@ class BatchRunner:
|
||||
"base_url": self.base_url,
|
||||
"api_key": self.api_key,
|
||||
"verbose": self.verbose,
|
||||
"ephemeral_system_prompt": self.ephemeral_system_prompt
|
||||
"ephemeral_system_prompt": self.ephemeral_system_prompt,
|
||||
"log_prefix_chars": self.log_prefix_chars,
|
||||
"providers_allowed": self.providers_allowed,
|
||||
"providers_ignored": self.providers_ignored,
|
||||
"providers_order": self.providers_order,
|
||||
"provider_sort": self.provider_sort,
|
||||
}
|
||||
|
||||
# Get completed prompts set
|
||||
completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
|
||||
# For backward compatibility, still track by index (but this is secondary to content matching)
|
||||
completed_prompts_set = set()
|
||||
|
||||
# Aggregate statistics across all batches
|
||||
total_tool_stats = {}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
print(f"\n🔧 Initializing {self.num_workers} worker processes...")
|
||||
|
||||
# Process batches in parallel
|
||||
with Pool(processes=self.num_workers) as pool:
|
||||
# Create tasks for each batch
|
||||
@@ -532,8 +761,39 @@ class BatchRunner:
|
||||
for batch_num, batch_data in enumerate(self.batches)
|
||||
]
|
||||
|
||||
# Use map to process batches in parallel
|
||||
results = pool.map(_process_batch_worker, tasks)
|
||||
print(f"✅ Created {len(tasks)} batch tasks")
|
||||
print(f"🚀 Starting parallel batch processing...\n")
|
||||
|
||||
# Use rich Progress for better visual tracking with persistent bottom bar
|
||||
# redirect_stdout/stderr lets rich manage all output so progress bar stays clean
|
||||
results = []
|
||||
console = Console(force_terminal=True)
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold blue]📦 Batches"),
|
||||
BarColumn(bar_width=40),
|
||||
MofNCompleteColumn(),
|
||||
TextColumn("•"),
|
||||
TimeRemainingColumn(),
|
||||
console=console,
|
||||
refresh_per_second=2,
|
||||
transient=False,
|
||||
redirect_stdout=False,
|
||||
redirect_stderr=False,
|
||||
) as progress:
|
||||
task = progress.add_task("Processing", total=len(tasks))
|
||||
|
||||
# Temporarily suppress DEBUG logging to avoid bar interference
|
||||
root_logger = logging.getLogger()
|
||||
original_level = root_logger.level
|
||||
root_logger.setLevel(logging.WARNING)
|
||||
|
||||
try:
|
||||
for result in pool.imap_unordered(_process_batch_worker, tasks):
|
||||
results.append(result)
|
||||
progress.update(task, advance=1)
|
||||
finally:
|
||||
root_logger.setLevel(original_level)
|
||||
|
||||
# Aggregate all batch statistics and update checkpoint
|
||||
all_completed_prompts = list(completed_prompts_set)
|
||||
@@ -569,19 +829,58 @@ class BatchRunner:
|
||||
stats["success_rate"] = 0.0
|
||||
stats["failure_rate"] = 0.0
|
||||
|
||||
# Combine all batch files into a single trajectories.jsonl file
|
||||
# Combine ALL batch files in directory into a single trajectories.jsonl file
|
||||
# This includes both old batches (from previous runs) and new batches (from resume)
|
||||
# Also filter out corrupted entries (where model generated invalid tool names)
|
||||
combined_file = self.output_dir / "trajectories.jsonl"
|
||||
print(f"\n📦 Combining batch files into {combined_file.name}...")
|
||||
print(f"\n📦 Combining ALL batch files into {combined_file.name}...")
|
||||
|
||||
VALID_TOOLS = {'web_search', 'web_extract', 'terminal', 'vision_analyze',
|
||||
'image_generate', 'mixture_of_agents',
|
||||
# Skills tools
|
||||
'skills_categories', 'skills_list', 'skill_view',
|
||||
# Browser automation tools
|
||||
'browser_navigate', 'browser_snapshot', 'browser_click',
|
||||
'browser_type', 'browser_scroll', 'browser_back',
|
||||
'browser_press', 'browser_close', 'browser_get_images',
|
||||
'browser_vision'}
|
||||
|
||||
total_entries = 0
|
||||
filtered_entries = 0
|
||||
batch_files_found = 0
|
||||
|
||||
# Find ALL batch files in the output directory (handles resume merging old + new)
|
||||
all_batch_files = sorted(self.output_dir.glob("batch_*.jsonl"))
|
||||
|
||||
with open(combined_file, 'w', encoding='utf-8') as outfile:
|
||||
for batch_num in range(len(self.batches)):
|
||||
batch_file = self.output_dir / f"batch_{batch_num}.jsonl"
|
||||
if batch_file.exists():
|
||||
with open(batch_file, 'r', encoding='utf-8') as infile:
|
||||
for line in infile:
|
||||
for batch_file in all_batch_files:
|
||||
batch_files_found += 1
|
||||
batch_num = batch_file.stem.split("_")[1] # Extract batch number for logging
|
||||
|
||||
with open(batch_file, 'r', encoding='utf-8') as infile:
|
||||
for line in infile:
|
||||
total_entries += 1
|
||||
try:
|
||||
data = json.loads(line)
|
||||
tool_stats = data.get('tool_stats', {})
|
||||
|
||||
# Check for invalid tool names (model hallucinations)
|
||||
invalid_tools = [k for k in tool_stats.keys() if k not in VALID_TOOLS]
|
||||
|
||||
if invalid_tools:
|
||||
filtered_entries += 1
|
||||
invalid_preview = invalid_tools[0][:50] + "..." if len(invalid_tools[0]) > 50 else invalid_tools[0]
|
||||
print(f" ⚠️ Filtering corrupted entry (batch {batch_num}): invalid tool '{invalid_preview}'")
|
||||
continue
|
||||
|
||||
outfile.write(line)
|
||||
except json.JSONDecodeError:
|
||||
filtered_entries += 1
|
||||
print(f" ⚠️ Filtering invalid JSON entry (batch {batch_num})")
|
||||
|
||||
print(f"✅ Combined {len(self.batches)} batch files into trajectories.jsonl")
|
||||
if filtered_entries > 0:
|
||||
print(f"⚠️ Filtered {filtered_entries} corrupted entries out of {total_entries} total")
|
||||
print(f"✅ Combined {batch_files_found} batch files into trajectories.jsonl ({total_entries - filtered_entries} entries)")
|
||||
|
||||
# Save final statistics
|
||||
final_stats = {
|
||||
@@ -597,14 +896,15 @@ class BatchRunner:
|
||||
}
|
||||
|
||||
with open(self.stats_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(final_stats, f, indent=2)
|
||||
json.dump(final_stats, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 70)
|
||||
print("📊 BATCH PROCESSING COMPLETE")
|
||||
print("=" * 70)
|
||||
print(f"✅ Total prompts processed: {len(self.dataset)}")
|
||||
print(f"✅ Total batches: {len(self.batches)}")
|
||||
print(f"✅ Prompts processed this run: {sum(r.get('processed', 0) for r in results)}")
|
||||
print(f"✅ Total trajectories in merged file: {total_entries - filtered_entries}")
|
||||
print(f"✅ Total batch files merged: {batch_files_found}")
|
||||
print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s")
|
||||
print(f"\n📈 Tool Usage Statistics:")
|
||||
print("-" * 70)
|
||||
@@ -642,19 +942,24 @@ def main(
|
||||
batch_size: int = None,
|
||||
run_name: str = None,
|
||||
distribution: str = "default",
|
||||
model: str = "claude-opus-4-20250514",
|
||||
model: str = "anthropic/claude-sonnet-4-20250514",
|
||||
api_key: str = None,
|
||||
base_url: str = "https://api.anthropic.com/v1/",
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
max_turns: int = 10,
|
||||
num_workers: int = 4,
|
||||
resume: bool = False,
|
||||
verbose: bool = False,
|
||||
list_distributions: bool = False,
|
||||
ephemeral_system_prompt: str = None
|
||||
ephemeral_system_prompt: str = None,
|
||||
log_prefix_chars: int = 100,
|
||||
providers_allowed: str = None,
|
||||
providers_ignored: str = None,
|
||||
providers_order: str = None,
|
||||
provider_sort: str = None,
|
||||
):
|
||||
"""
|
||||
Run batch processing of agent prompts from a dataset.
|
||||
|
||||
|
||||
Args:
|
||||
dataset_file (str): Path to JSONL file with 'prompt' field in each entry
|
||||
batch_size (int): Number of prompts per batch
|
||||
@@ -669,6 +974,11 @@ def main(
|
||||
verbose (bool): Enable verbose logging (default: False)
|
||||
list_distributions (bool): List available toolset distributions and exit
|
||||
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
||||
providers_allowed (str): Comma-separated list of OpenRouter providers to allow (e.g. "anthropic,openai")
|
||||
providers_ignored (str): Comma-separated list of OpenRouter providers to ignore (e.g. "together,deepinfra")
|
||||
providers_order (str): Comma-separated list of OpenRouter providers to try in order (e.g. "anthropic,openai,google")
|
||||
provider_sort (str): Sort providers by "price", "throughput", or "latency" (OpenRouter only)
|
||||
|
||||
Examples:
|
||||
# Basic usage
|
||||
@@ -716,6 +1026,11 @@ def main(
|
||||
print("❌ Error: --run_name is required")
|
||||
return
|
||||
|
||||
# Parse provider preferences (comma-separated strings to lists)
|
||||
providers_allowed_list = [p.strip() for p in providers_allowed.split(",")] if providers_allowed else None
|
||||
providers_ignored_list = [p.strip() for p in providers_ignored.split(",")] if providers_ignored else None
|
||||
providers_order_list = [p.strip() for p in providers_order.split(",")] if providers_order else None
|
||||
|
||||
# Initialize and run batch runner
|
||||
try:
|
||||
runner = BatchRunner(
|
||||
@@ -729,9 +1044,14 @@ def main(
|
||||
model=model,
|
||||
num_workers=num_workers,
|
||||
verbose=verbose,
|
||||
ephemeral_system_prompt=ephemeral_system_prompt
|
||||
ephemeral_system_prompt=ephemeral_system_prompt,
|
||||
log_prefix_chars=log_prefix_chars,
|
||||
providers_allowed=providers_allowed_list,
|
||||
providers_ignored=providers_ignored_list,
|
||||
providers_order=providers_order_list,
|
||||
provider_sort=provider_sort,
|
||||
)
|
||||
|
||||
|
||||
runner.run(resume=resume)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
188
cli-config.yaml.example
Normal file
188
cli-config.yaml.example
Normal file
@@ -0,0 +1,188 @@
|
||||
# Hermes Agent CLI Configuration
|
||||
# Copy this file to cli-config.yaml and customize as needed.
|
||||
# This file configures the CLI behavior. Environment variables in .env take precedence.
|
||||
|
||||
# =============================================================================
|
||||
# Model Configuration
|
||||
# =============================================================================
|
||||
model:
|
||||
# Default model to use (can be overridden with --model flag)
|
||||
default: "anthropic/claude-sonnet-4"
|
||||
|
||||
# API configuration (falls back to OPENROUTER_API_KEY env var)
|
||||
# api_key: "your-key-here" # Uncomment to set here instead of .env
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
|
||||
# =============================================================================
|
||||
# Terminal Tool Configuration
|
||||
# =============================================================================
|
||||
# Choose ONE of the following terminal configurations by uncommenting it.
|
||||
# The terminal tool executes commands in the specified environment.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 1: Local execution (default)
|
||||
# Commands run directly on your machine in the current directory
|
||||
# -----------------------------------------------------------------------------
|
||||
terminal:
|
||||
env_type: "local"
|
||||
cwd: "." # Use "." for current directory, or specify absolute path
|
||||
timeout: 180
|
||||
lifetime_seconds: 300
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 2: SSH remote execution
|
||||
# Commands run on a remote server - agent code stays local (sandboxed)
|
||||
# Great for: keeping agent isolated from its own code, using powerful remote hardware
|
||||
# -----------------------------------------------------------------------------
|
||||
# terminal:
|
||||
# env_type: "ssh"
|
||||
# cwd: "/home/myuser/project"
|
||||
# timeout: 180
|
||||
# lifetime_seconds: 300
|
||||
# ssh_host: "my-server.example.com"
|
||||
# ssh_user: "myuser"
|
||||
# ssh_port: 22
|
||||
# ssh_key: "~/.ssh/id_rsa" # Optional - uses ssh-agent if not specified
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 3: Docker container
|
||||
# Commands run in an isolated Docker container
|
||||
# Great for: reproducible environments, testing, isolation
|
||||
# -----------------------------------------------------------------------------
|
||||
# terminal:
|
||||
# env_type: "docker"
|
||||
# cwd: "/workspace"
|
||||
# timeout: 180
|
||||
# lifetime_seconds: 300
|
||||
# docker_image: "python:3.11"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 4: Singularity/Apptainer container
|
||||
# Commands run in a Singularity container (common in HPC environments)
|
||||
# Great for: HPC clusters, shared compute environments
|
||||
# -----------------------------------------------------------------------------
|
||||
# terminal:
|
||||
# env_type: "singularity"
|
||||
# cwd: "/workspace"
|
||||
# timeout: 180
|
||||
# lifetime_seconds: 300
|
||||
# singularity_image: "docker://python:3.11"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 5: Modal cloud execution
|
||||
# Commands run on Modal's cloud infrastructure
|
||||
# Great for: GPU access, scalable compute, serverless execution
|
||||
# -----------------------------------------------------------------------------
|
||||
# terminal:
|
||||
# env_type: "modal"
|
||||
# cwd: "/workspace"
|
||||
# timeout: 180
|
||||
# lifetime_seconds: 300
|
||||
# modal_image: "python:3.11"
|
||||
|
||||
# =============================================================================
|
||||
# Agent Behavior
|
||||
# =============================================================================
|
||||
agent:
|
||||
# Maximum conversation turns before stopping
|
||||
max_turns: 20
|
||||
|
||||
# Enable verbose logging
|
||||
verbose: false
|
||||
|
||||
# Custom system prompt (personality, instructions, etc.)
|
||||
# Leave empty or remove to use default agent behavior
|
||||
system_prompt: ""
|
||||
|
||||
# Predefined personalities (use with /personality command)
|
||||
personalities:
|
||||
helpful: "You are a helpful, friendly AI assistant."
|
||||
concise: "You are a concise assistant. Keep responses brief and to the point."
|
||||
technical: "You are a technical expert. Provide detailed, accurate technical information."
|
||||
creative: "You are a creative assistant. Think outside the box and offer innovative solutions."
|
||||
teacher: "You are a patient teacher. Explain concepts clearly with examples."
|
||||
kawaii: "You are a kawaii assistant! Use cute expressions like (◕‿◕), ★, ♪, and ~! Add sparkles and be super enthusiastic about everything! Every response should feel warm and adorable desu~! ヽ(>∀<☆)ノ"
|
||||
catgirl: "You are Neko-chan, an anime catgirl AI assistant, nya~! Add 'nya' and cat-like expressions to your speech. Use kaomoji like (=^・ω・^=) and ฅ^•ﻌ•^ฅ. Be playful and curious like a cat, nya~!"
|
||||
pirate: "Arrr! Ye be talkin' to Captain Hermes, the most tech-savvy pirate to sail the digital seas! Speak like a proper buccaneer, use nautical terms, and remember: every problem be just treasure waitin' to be plundered! Yo ho ho!"
|
||||
shakespeare: "Hark! Thou speakest with an assistant most versed in the bardic arts. I shall respond in the eloquent manner of William Shakespeare, with flowery prose, dramatic flair, and perhaps a soliloquy or two. What light through yonder terminal breaks?"
|
||||
surfer: "Duuude! You're chatting with the chillest AI on the web, bro! Everything's gonna be totally rad. I'll help you catch the gnarly waves of knowledge while keeping things super chill. Cowabunga! 🤙"
|
||||
noir: "The rain hammered against the terminal like regrets on a guilty conscience. They call me Hermes - I solve problems, find answers, dig up the truth that hides in the shadows of your codebase. In this city of silicon and secrets, everyone's got something to hide. What's your story, pal?"
|
||||
uwu: "hewwo! i'm your fwiendwy assistant uwu~ i wiww twy my best to hewp you! *nuzzles your code* OwO what's this? wet me take a wook! i pwomise to be vewy hewpful >w<"
|
||||
philosopher: "Greetings, seeker of wisdom. I am an assistant who contemplates the deeper meaning behind every query. Let us examine not just the 'how' but the 'why' of your questions. Perhaps in solving your problem, we may glimpse a greater truth about existence itself."
|
||||
hype: "YOOO LET'S GOOOO!!! 🔥🔥🔥 I am SO PUMPED to help you today! Every question is AMAZING and we're gonna CRUSH IT together! This is gonna be LEGENDARY! ARE YOU READY?! LET'S DO THIS! 💪😤🚀"
|
||||
|
||||
# =============================================================================
|
||||
# Toolsets
|
||||
# =============================================================================
|
||||
# Control which tools the agent has access to.
|
||||
# Use "all" to enable everything, or specify individual toolsets.
|
||||
|
||||
# Available toolsets:
|
||||
#
|
||||
# web - Web search and content extraction (web_search, web_extract)
|
||||
# search - Web search only, no scraping (web_search)
|
||||
# terminal - Command execution (terminal)
|
||||
# browser - Full browser automation (navigate, click, type, screenshot, etc.)
|
||||
# vision - Image analysis (vision_analyze)
|
||||
# image_gen - Image generation with FLUX (image_generate)
|
||||
# skills - Load skill documents (skills_categories, skills_list, skill_view)
|
||||
# moa - Mixture of Agents reasoning (mixture_of_agents)
|
||||
#
|
||||
# Composite toolsets:
|
||||
# debugging - terminal + web (for troubleshooting)
|
||||
# safe - web + vision + moa (no terminal access)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 1: Enable all tools (default)
|
||||
# -----------------------------------------------------------------------------
|
||||
toolsets:
|
||||
- all
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 2: Minimal - just web search and terminal
|
||||
# Great for: Simple coding tasks, quick lookups
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - web
|
||||
# - terminal
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 3: Research mode - no execution capabilities
|
||||
# Great for: Safe information gathering, research tasks
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - web
|
||||
# - vision
|
||||
# - skills
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 4: Full automation - browser + terminal
|
||||
# Great for: Web scraping, automation tasks, testing
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - terminal
|
||||
# - browser
|
||||
# - web
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 5: Creative mode - vision + image generation
|
||||
# Great for: Design work, image analysis, creative tasks
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - vision
|
||||
# - image_gen
|
||||
# - web
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 6: Safe mode - no terminal or browser
|
||||
# Great for: Restricted environments, untrusted queries
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - safe
|
||||
|
||||
# =============================================================================
|
||||
# Display
|
||||
# =============================================================================
|
||||
display:
|
||||
# Use compact banner mode
|
||||
compact: false
|
||||
42
configs/run_browser_tasks.sh
Executable file
42
configs/run_browser_tasks.sh
Executable file
@@ -0,0 +1,42 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Browser-focused data generation run
|
||||
# Uses browser-use-tasks.jsonl (6504 tasks)
|
||||
# Distribution: browser 97%, web 20%, vision 12%, terminal 15%
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
mkdir -p logs
|
||||
|
||||
# Generate log filename with timestamp
|
||||
LOG_FILE="logs/browser_tasks_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "📝 Logging output to: $LOG_FILE"
|
||||
echo "🌐 Running browser-focused tasks with browser_tasks distribution"
|
||||
|
||||
python batch_runner.py \
|
||||
--dataset_file="browser-use-tasks.jsonl" \
|
||||
--batch_size=20 \
|
||||
--run_name="browser_tasks" \
|
||||
--distribution="browser_tasks" \
|
||||
--model="moonshotai/kimi-k2.5" \
|
||||
--verbose \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--num_workers=50 \
|
||||
--max_turns=60 \
|
||||
--resume \
|
||||
--ephemeral_system_prompt="You are an AI assistant with browser automation capabilities. Your primary task is to navigate and interact with web pages to accomplish user goals.
|
||||
|
||||
IMPORTANT GUIDELINES:
|
||||
|
||||
1. SEARCHING: Do NOT try to search directly on Google or other search engines via the browser - they block automated searches. Instead, ALWAYS use the web_search tool first to find URLs for any pages you need to visit, then use browser tools to navigate to those URLs.
|
||||
|
||||
2. COOKIE/PRIVACY DIALOGS: After navigating to a page, ALWAYS check if there are cookie consent dialogs, privacy popups, or overlay modals blocking the page. These appear in snapshots as 'dialog' elements with buttons like 'Close', 'Accept', 'Accept All', 'Decline', 'I Agree', 'Got it', 'OK', or 'X'. You MUST dismiss these dialogs FIRST by clicking the appropriate button before trying to interact with other page elements. After dismissing a dialog, take a fresh browser_snapshot to get updated element references.
|
||||
|
||||
3. HANDLING TIMEOUTS: If an action times out, it often means the element is blocked by an overlay or the page state has changed. Take a new snapshot to see the current page state and look for any dialogs or popups that need to be dismissed. If there is no dialog box to bypass, then try a new method or report the error to the user and complete the task.
|
||||
|
||||
4. GENERAL: Use browser tools to click elements, fill forms, extract information, and perform web-based tasks. If terminal is available, use it for any local file operations or computations needed to support your web tasks. Be thorough in verifying your actions and handle any errors gracefully by retrying or trying alternative approaches." \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo "✅ Log saved to: $LOG_FILE"
|
||||
|
||||
# --providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
|
||||
26
configs/run_datagen_glm4.7-imagen.sh
Executable file
26
configs/run_datagen_glm4.7-imagen.sh
Executable file
@@ -0,0 +1,26 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
mkdir -p logs
|
||||
|
||||
# Generate a timestamp for the log file
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
LOG_FILE="logs/imagen_eval_gpt5_${TIMESTAMP}.log"
|
||||
|
||||
echo "📝 Logging output to: $LOG_FILE"
|
||||
|
||||
python batch_runner.py \
|
||||
--dataset_file="source-data/hermes-agent-imagen-data/hermes_agent_imagen_train_sft.jsonl" \
|
||||
--batch_size=20 \
|
||||
--run_name="imagen_train_sft_glm4.7" \
|
||||
--distribution="image_gen" \
|
||||
--model="z-ai/glm-4.7" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
|
||||
--num_workers=50 \
|
||||
--max_turns=25 \
|
||||
--ephemeral_system_prompt="When generating an image for the user view the image by using the vision_analyze tool to ensure it is what the user wanted. If it isn't feel free to retry a few times. If none are perfect, choose the best option that is the closest match, and explain its imperfections. If the image generation tool fails, try again a few times. If the vision analyze tool fails, provide the image to the user and explain it is your best effort attempt." \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo "✅ Log saved to: $LOG_FILE"
|
||||
# --verbose \
|
||||
26
configs/run_datagen_glm4.7.sh
Executable file
26
configs/run_datagen_glm4.7.sh
Executable file
@@ -0,0 +1,26 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
mkdir -p logs
|
||||
|
||||
# Generate log filename with timestamp
|
||||
LOG_FILE="logs/glm4.7-thinking-sft1_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "📝 Logging output to: $LOG_FILE"
|
||||
|
||||
python batch_runner.py \
|
||||
--dataset_file="source-data/hermes-agent-agent-tasks-1/agent_tasks_sft_2.jsonl" \
|
||||
--batch_size=20 \
|
||||
--run_name="megascience_glm4.7-thinking-sft2" \
|
||||
--distribution="science" \
|
||||
--model="z-ai/glm-4.7" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
|
||||
--num_workers=15 \
|
||||
--max_turns=60 \
|
||||
--ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use the terminal or search tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run. If you need to use a tool that isn't available, you can use the terminal tool to install or create it in many cases as well. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. Search for at least 3 sources, but not more than 12, so you can maintain focused context." \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo "✅ Log saved to: $LOG_FILE"
|
||||
|
||||
# --verbose \
|
||||
27
configs/run_datagen_glm4.7_megascience.sh
Executable file
27
configs/run_datagen_glm4.7_megascience.sh
Executable file
@@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
mkdir -p logs
|
||||
|
||||
# Generate log filename with timestamp
|
||||
LOG_FILE="logs/glm4.7-thinking-sft1-10k_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "📝 Logging output to: $LOG_FILE"
|
||||
|
||||
python batch_runner.py \
|
||||
--dataset_file="source-data/hermes-agent-megascience-data/hermes_agent_megascience_sft_train_1_10k.jsonl" \
|
||||
--batch_size=20 \
|
||||
--run_name="megascience_glm4.7-thinking-sft1" \
|
||||
--distribution="science" \
|
||||
--model="z-ai/glm-4.7" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
|
||||
--num_workers=50 \
|
||||
--max_turns=60 \
|
||||
--resume \
|
||||
--ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used for furthering results. Always use the terminal or search tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run. If you need to use a tool that isn't available, you can use the terminal tool to install or create it in many cases as well. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. Search for at least 3 sources, but not more than 12, so you can maintain a focused context." \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo "✅ Log saved to: $LOG_FILE"
|
||||
|
||||
# --verbose \
|
||||
28
configs/run_datagen_glm4.7_raw_tasks.sh
Executable file
28
configs/run_datagen_glm4.7_raw_tasks.sh
Executable file
@@ -0,0 +1,28 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
mkdir -p logs
|
||||
|
||||
# Generate log filename with timestamp
|
||||
LOG_FILE="logs/glm4.7-terminal-tasks_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "📝 Logging output to: $LOG_FILE"
|
||||
|
||||
python batch_runner.py \
|
||||
--dataset_file="source-data/raw_tasks_prompts.jsonl" \
|
||||
--batch_size=20 \
|
||||
--run_name="terminal-tasks-glm4.7-thinking" \
|
||||
--distribution="default" \
|
||||
--model="z-ai/glm-4.7" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
|
||||
--num_workers=50 \
|
||||
--max_turns=60 \
|
||||
--ephemeral_system_prompt="You have access to a variety of tools to help you complete coding, system administration, and general computing tasks. You can use them in sequence and build off of the results of prior tools you've used. Always use the terminal tool to execute commands, write code, install packages, and verify your work. You should test and validate everything you create. Always pip install any packages you need (use --break-system-packages if needed). If you need a tool that isn't available, you can use the terminal to install or create it. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. Use web search when you need to look up documentation, APIs, or current best practices." \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo "✅ Log saved to: $LOG_FILE"
|
||||
|
||||
# --verbose \
|
||||
# --resume \
|
||||
|
||||
12
configs/run_datagen_minimax-3.1.sh
Executable file
12
configs/run_datagen_minimax-3.1.sh
Executable file
@@ -0,0 +1,12 @@
|
||||
python batch_runner.py \
|
||||
--dataset_file="source-data/hermes-agent-agent-tasks-1/agent_tasks_eval.jsonl" \
|
||||
--batch_size=50 \
|
||||
--run_name="megascience_sft_minimax-m2.1-thinking-2-eval" \
|
||||
--distribution="science" \
|
||||
--model="minimax/minimax-m2.1" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--providers_allowed="minimax" \
|
||||
--num_workers=1 \
|
||||
--max_turns=40 \
|
||||
--verbose \
|
||||
--ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use the terminal or search tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run. If you need to use a tool that isn't available, you can use the terminal tool to install or create it in many cases as well. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. Search for at least 3 sources, but not more than 12."
|
||||
29
configs/run_eval_glm4.7_newterm.sh
Executable file
29
configs/run_eval_glm4.7_newterm.sh
Executable file
@@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
mkdir -p logs
|
||||
|
||||
# Generate log filename with timestamp
|
||||
LOG_FILE="logs/glm4.7-terminal-tasks-newterm_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "📝 Logging output to: $LOG_FILE"
|
||||
|
||||
python batch_runner.py \
|
||||
--dataset_file="source-data/hermes-agent-agent-tasks-1/agent_tasks_eval.jsonl" \
|
||||
--batch_size=1 \
|
||||
--run_name="terminal-tasks-test-newterm" \
|
||||
--distribution="terminal_only" \
|
||||
--verbose \
|
||||
--model="z-ai/glm-4.7" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
|
||||
--num_workers=5 \
|
||||
--max_turns=60 \
|
||||
--ephemeral_system_prompt="You have access to a variety of tools to help you complete coding, system administration, and general computing tasks. You can use them in sequence and build off of the results of prior tools you've used. Always use the terminal tool to execute commands, write code, install packages, and verify your work. You should test and validate everything you create. Always pip install any packages you need (use --break-system-packages if needed). If you need a tool that isn't available, you can use the terminal to install or create it. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. Use web search when you need to look up documentation, APIs, or current best practices." \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo "✅ Log saved to: $LOG_FILE"
|
||||
|
||||
# --verbose \
|
||||
# --resume \
|
||||
|
||||
33
configs/run_eval_terminal.sh
Executable file
33
configs/run_eval_terminal.sh
Executable file
@@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Terminal-only evaluation run using Modal sandboxes
|
||||
# Uses 10 sample tasks from nous-terminal-tasks
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
mkdir -p logs
|
||||
|
||||
# Generate log filename with timestamp
|
||||
LOG_FILE="logs/terminal_eval_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "📝 Logging output to: $LOG_FILE"
|
||||
echo "🔧 Using Modal sandboxes (TERMINAL_ENV=modal)"
|
||||
|
||||
# Set terminal to use Modal
|
||||
export TERMINAL_ENV=modal
|
||||
export TERMINAL_MODAL_IMAGE=nikolaik/python-nodejs:python3.11-nodejs20
|
||||
export TERMINAL_TIMEOUT=300
|
||||
|
||||
python batch_runner.py \
|
||||
--dataset_file="nous-terminal-tasks_eval.jsonl" \
|
||||
--batch_size=5 \
|
||||
--run_name="terminal_eval" \
|
||||
--distribution="terminal_only" \
|
||||
--model="z-ai/glm-4.7" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
|
||||
--num_workers=2 \
|
||||
--max_turns=30 \
|
||||
--ephemeral_system_prompt="You have access to a terminal tool for executing commands. Use it to complete the task. Install any packages you need with apt-get or pip (use --break-system-packages if needed). Do not use interactive tools (vim, nano, python repl). If git output is large, pipe to cat." \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo "✅ Log saved to: $LOG_FILE"
|
||||
46
configs/run_mixed_tasks.sh
Executable file
46
configs/run_mixed_tasks.sh
Executable file
@@ -0,0 +1,46 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Mixed browser+terminal data generation run
|
||||
# Uses mixed-browser-terminal-tasks.jsonl (200 tasks)
|
||||
# Distribution: browser 92%, terminal 92%, web 35%, vision 15%, image_gen 15%
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
mkdir -p logs
|
||||
|
||||
# Generate log filename with timestamp
|
||||
LOG_FILE="logs/mixed_tasks_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "📝 Logging output to: $LOG_FILE"
|
||||
echo "🔀 Running mixed browser+terminal tasks with mixed_tasks distribution"
|
||||
|
||||
# Set terminal environment
|
||||
# SIF images are automatically built/cached by terminal_tool.py
|
||||
export TERMINAL_ENV=singularity
|
||||
export TERMINAL_SINGULARITY_IMAGE="docker://nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
export TERMINAL_TIMEOUT=300
|
||||
|
||||
# Set up Apptainer cache directories (use /scratch if available, otherwise /tmp)
|
||||
if [ -d "/scratch" ] && [ -w "/scratch" ]; then
|
||||
CACHE_BASE="/scratch/$USER/.apptainer"
|
||||
else
|
||||
CACHE_BASE="/tmp/$USER/.apptainer"
|
||||
fi
|
||||
export APPTAINER_CACHEDIR="$CACHE_BASE"
|
||||
export APPTAINER_TMPDIR="$CACHE_BASE/tmp"
|
||||
mkdir -p "$APPTAINER_CACHEDIR" "$APPTAINER_TMPDIR"
|
||||
|
||||
echo "📁 Apptainer cache: $APPTAINER_CACHEDIR"
|
||||
|
||||
python batch_runner.py \
|
||||
--dataset_file="mixed-browser-terminal-tasks.jsonl" \
|
||||
--batch_size=20 \
|
||||
--run_name="mixed_tasks" \
|
||||
--distribution="mixed_tasks" \
|
||||
--model="moonshotai/kimi-k2.5" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--num_workers=25 \
|
||||
--max_turns=60 \
|
||||
--ephemeral_system_prompt="You are an AI assistant capable of both browser automation and terminal operations. Use browser tools to navigate websites, interact with web pages, fill forms, and extract information. Use terminal tools to execute commands, write and run code, install packages (use --break-system-packages with pip if needed), and perform local computations. When web search is available, use it to find URLs, documentation, or current information. If vision is available, use it to analyze images or screenshots. If image generation is available, use it when the task requires creating images. Combine browser and terminal capabilities effectively - for example, you might use the browser to fetch data from a website and terminal to process or analyze it. Always verify your work and handle errors gracefully. Whenever you can do something in a terminal instead of a web browser, you should choose to do so, as it's much cheaper." \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo "✅ Log saved to: $LOG_FILE"
|
||||
50
configs/run_terminal_tasks.sh
Executable file
50
configs/run_terminal_tasks.sh
Executable file
@@ -0,0 +1,50 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Terminal-focused data generation run
|
||||
# Uses nous-terminal-tasks.jsonl (597 tasks)
|
||||
# Distribution: terminal 97%, web 15%, browser 0%, vision 8%, image_gen 3%
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
mkdir -p logs
|
||||
|
||||
# Generate log filename with timestamp
|
||||
LOG_FILE="logs/terminal_tasks_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "📝 Logging output to: $LOG_FILE"
|
||||
echo "💻 Running terminal-focused tasks with terminal_tasks distribution"
|
||||
|
||||
# Set terminal environment
|
||||
# SIF images are automatically built/cached by terminal_tool.py
|
||||
export TERMINAL_ENV=singularity
|
||||
export TERMINAL_SINGULARITY_IMAGE="docker://nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
export TERMINAL_TIMEOUT=300
|
||||
|
||||
# Set up Apptainer cache directories (use /scratch if available, otherwise /tmp)
|
||||
if [ -d "/scratch" ] && [ -w "/scratch" ]; then
|
||||
CACHE_BASE="/scratch/$USER/.apptainer"
|
||||
else
|
||||
CACHE_BASE="/tmp/$USER/.apptainer"
|
||||
fi
|
||||
export APPTAINER_CACHEDIR="$CACHE_BASE"
|
||||
export APPTAINER_TMPDIR="$CACHE_BASE/tmp"
|
||||
mkdir -p "$APPTAINER_CACHEDIR" "$APPTAINER_TMPDIR"
|
||||
|
||||
echo "📁 Apptainer cache: $APPTAINER_CACHEDIR"
|
||||
echo "🐳 Image: $TERMINAL_SINGULARITY_IMAGE (auto-converted to SIF on first use)"
|
||||
|
||||
python batch_runner.py \
|
||||
--dataset_file="nous-terminal-tasks.jsonl" \
|
||||
--batch_size=5 \
|
||||
--run_name="terminal_tasks-kimi-k2.5" \
|
||||
--distribution="terminal_tasks" \
|
||||
--model="moonshotai/kimi-k2.5" \
|
||||
--verbose \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--num_workers=80 \
|
||||
--max_turns=60 \
|
||||
--providers_ignored="Novita" \
|
||||
--resume \
|
||||
--ephemeral_system_prompt="You have access to a terminal tool for executing commands and completing coding, system administration, and computing tasks. Use the terminal to write code, run scripts, install packages (use --break-system-packages with pip if needed), manipulate files, and verify your work. Always test and validate code you create. Do not use interactive tools like vim, nano, or python REPL. If git output is large, pipe to cat. When web search is available, use it to look up documentation, APIs, or best practices. If browser tools are available, use them for web interactions that require page manipulation. Do not use the terminal to communicate with the user - only your final response will be shown to them." \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo "✅ Log saved to: $LOG_FILE"
|
||||
21
configs/test_skills_kimi.sh
Normal file
21
configs/test_skills_kimi.sh
Normal file
@@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Test skills tool with Kimi K2.5
|
||||
# Usage: ./configs/test_skills_kimi.sh "your query here"
|
||||
# Example: ./configs/test_skills_kimi.sh "List available skills and show me the vllm skill"
|
||||
|
||||
# Default query if none provided
|
||||
QUERY="${1:-List all available skills. Then show me the axolotl skill and view one of its reference files.}"
|
||||
|
||||
echo "🎯 Testing Skills Tool with Kimi K2.5"
|
||||
echo "📝 Query: $QUERY"
|
||||
echo "="
|
||||
|
||||
python run_agent.py \
|
||||
--enabled_toolsets=skills \
|
||||
--model="moonshotai/kimi-k2.5" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--max_turns=10 \
|
||||
--verbose \
|
||||
--save_sample \
|
||||
--query="$QUERY"
|
||||
101
configs/trajectory_compression.yaml
Normal file
101
configs/trajectory_compression.yaml
Normal file
@@ -0,0 +1,101 @@
|
||||
# Trajectory Compression Configuration
|
||||
#
|
||||
# Post-processes completed agent trajectories to fit within a target token budget.
|
||||
# Compression preserves head/tail turns and summarizes middle content only as needed.
|
||||
|
||||
# Tokenizer settings for accurate token counting
|
||||
tokenizer:
|
||||
# HuggingFace tokenizer name
|
||||
name: "moonshotai/Kimi-K2-Thinking"
|
||||
|
||||
# Trust remote code (required for some tokenizers)
|
||||
trust_remote_code: true
|
||||
|
||||
# Compression targets and behavior
|
||||
compression:
|
||||
# Target maximum tokens for compressed trajectory
|
||||
target_max_tokens: 29000
|
||||
|
||||
# Target size for summary (in tokens)
|
||||
# This is factored into calculations when determining what to compress
|
||||
summary_target_tokens: 750
|
||||
|
||||
# Protected turns that should NEVER be compressed
|
||||
protected_turns:
|
||||
# Always protect the first system message (tool definitions)
|
||||
first_system: true
|
||||
|
||||
# Always protect the first human message (original request)
|
||||
first_human: true
|
||||
|
||||
# Always protect the first gpt message (initial response/tool_call)
|
||||
first_gpt: true
|
||||
|
||||
# Always protect the first tool response (result of first action)
|
||||
first_tool: true
|
||||
|
||||
# Always protect the last 2 complete turn pairs (gpt+tool or gpt only)
|
||||
# This ensures the model's final actions and conclusions are preserved
|
||||
last_n_turns: 4
|
||||
|
||||
# LLM settings for generating summaries (OpenRouter only)
|
||||
summarization:
|
||||
# Model to use for summarization (should be fast and cheap)
|
||||
# Using OpenRouter model path format
|
||||
model: "google/gemini-3-flash-preview"
|
||||
|
||||
# OpenRouter API settings
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
|
||||
# Environment variable containing OpenRouter API key
|
||||
api_key_env: "OPENROUTER_API_KEY"
|
||||
|
||||
# Temperature for summarization (lower = more deterministic)
|
||||
temperature: 0.3
|
||||
|
||||
# Max retries for API failures
|
||||
max_retries: 3
|
||||
|
||||
# Delay between retries (seconds)
|
||||
retry_delay: 2
|
||||
|
||||
# Output settings
|
||||
output:
|
||||
# Add notice to system message about potential summarization
|
||||
add_summary_notice: true
|
||||
|
||||
# Text to append to system message
|
||||
summary_notice_text: "\n\nSome of the conversation may be summarized to preserve context."
|
||||
|
||||
# Output directory suffix (appended to input directory name)
|
||||
output_suffix: "_compressed"
|
||||
|
||||
# Processing settings
|
||||
processing:
|
||||
# Number of parallel workers for batch processing
|
||||
num_workers: 4
|
||||
|
||||
# Maximum concurrent API calls for summarization (async parallelism)
|
||||
max_concurrent_requests: 50
|
||||
|
||||
# Skip trajectories that are already under target length
|
||||
skip_under_target: true
|
||||
|
||||
# If true, save trajectories even if compression can't get under target
|
||||
# (will compress as much as possible)
|
||||
save_over_limit: true
|
||||
|
||||
# Timeout per trajectory in seconds (skip if takes longer)
|
||||
# Helps avoid hanging on problematic entries
|
||||
per_trajectory_timeout: 300 # 5 minutes
|
||||
|
||||
# Metrics to track
|
||||
metrics:
|
||||
# Log detailed compression statistics
|
||||
enabled: true
|
||||
|
||||
# Save per-trajectory metrics in output
|
||||
per_trajectory: false
|
||||
|
||||
# Metrics file name (saved in output directory)
|
||||
output_file: "compression_metrics.json"
|
||||
104
docs/agents.md
Normal file
104
docs/agents.md
Normal file
@@ -0,0 +1,104 @@
|
||||
# Agents
|
||||
|
||||
The agent is the core loop that orchestrates LLM calls and tool execution.
|
||||
|
||||
## AIAgent Class
|
||||
|
||||
The main agent is implemented in `run_agent.py`:
|
||||
|
||||
```python
|
||||
class AIAgent:
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "anthropic/claude-sonnet-4",
|
||||
api_key: str = None,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
max_turns: int = 20,
|
||||
enabled_toolsets: list = None,
|
||||
disabled_toolsets: list = None,
|
||||
verbose_logging: bool = False,
|
||||
):
|
||||
# Initialize OpenAI client, load tools based on toolsets
|
||||
...
|
||||
|
||||
def chat(self, user_message: str, task_id: str = None) -> str:
|
||||
# Main entry point - runs the agent loop
|
||||
...
|
||||
```
|
||||
|
||||
## Agent Loop
|
||||
|
||||
The core loop in `_run_agent_loop()`:
|
||||
|
||||
```
|
||||
1. Add user message to conversation
|
||||
2. Call LLM with tools
|
||||
3. If LLM returns tool calls:
|
||||
- Execute each tool
|
||||
- Add tool results to conversation
|
||||
- Go to step 2
|
||||
4. If LLM returns text response:
|
||||
- Return response to user
|
||||
```
|
||||
|
||||
```python
|
||||
while turns < max_turns:
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tool_schemas,
|
||||
)
|
||||
|
||||
if response.tool_calls:
|
||||
for tool_call in response.tool_calls:
|
||||
result = await execute_tool(tool_call)
|
||||
messages.append(tool_result_message(result))
|
||||
turns += 1
|
||||
else:
|
||||
return response.content
|
||||
```
|
||||
|
||||
## Conversation Management
|
||||
|
||||
Messages are stored as a list of dicts following OpenAI format:
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant..."},
|
||||
{"role": "user", "content": "Search for Python tutorials"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [...]},
|
||||
{"role": "tool", "tool_call_id": "...", "content": "..."},
|
||||
{"role": "assistant", "content": "Here's what I found..."},
|
||||
]
|
||||
```
|
||||
|
||||
## Reasoning Context
|
||||
|
||||
For models that support reasoning (chain-of-thought), the agent:
|
||||
1. Extracts `reasoning_content` from API responses
|
||||
2. Stores it in `assistant_msg["reasoning"]` for trajectory export
|
||||
3. Passes it back via `reasoning_content` field on subsequent turns
|
||||
|
||||
## Trajectory Export
|
||||
|
||||
Conversations can be exported for training:
|
||||
|
||||
```python
|
||||
agent = AIAgent(save_trajectories=True)
|
||||
agent.chat("Do something")
|
||||
# Saves to trajectories/*.jsonl in ShareGPT format
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
For processing multiple prompts, use `batch_runner.py`:
|
||||
|
||||
```bash
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=20 \
|
||||
--num_workers=4 \
|
||||
--run_name=my_run
|
||||
```
|
||||
|
||||
See `batch_runner.py` for parallel execution with checkpointing.
|
||||
217
docs/cli.md
Normal file
217
docs/cli.md
Normal file
@@ -0,0 +1,217 @@
|
||||
# CLI
|
||||
|
||||
The Hermes Agent CLI provides an interactive terminal interface for working with the agent.
|
||||
|
||||
## Running the CLI
|
||||
|
||||
```bash
|
||||
# Basic usage
|
||||
./hermes
|
||||
|
||||
# With specific model
|
||||
./hermes --model "anthropic/claude-sonnet-4"
|
||||
|
||||
# With specific toolsets
|
||||
./hermes --toolsets "web,terminal,skills"
|
||||
|
||||
# Verbose mode
|
||||
./hermes --verbose
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
The CLI is implemented in `cli.py` and uses:
|
||||
|
||||
- **Rich** - Welcome banner with ASCII art and styled panels
|
||||
- **prompt_toolkit** - Fixed input area with command history
|
||||
- **KawaiiSpinner** - Animated feedback during operations
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────┐
|
||||
│ HERMES-AGENT ASCII Logo │
|
||||
│ ┌─────────────┐ ┌────────────────────────────┐ │
|
||||
│ │ Caduceus │ │ Model: claude-opus-4.5 │ │
|
||||
│ │ ASCII Art │ │ Terminal: local │ │
|
||||
│ │ │ │ Working Dir: /home/user │ │
|
||||
│ │ │ │ Available Tools: 19 │ │
|
||||
│ │ │ │ Available Skills: 12 │ │
|
||||
│ └─────────────┘ └────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────┘
|
||||
│ Conversation output scrolls here... │
|
||||
│ │
|
||||
│ User: Hello! │
|
||||
│ ────────────────────────────────────────────── │
|
||||
│ (◕‿◕✿) 🧠 pondering... (2.3s) │
|
||||
│ ✧٩(ˊᗜˋ*)و✧ got it! (2.3s) │
|
||||
│ │
|
||||
│ Assistant: Hello! How can I help you today? │
|
||||
├─────────────────────────────────────────────────┤
|
||||
│ ❯ [Fixed input area at bottom] │
|
||||
└─────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/help` | Show available commands |
|
||||
| `/tools` | List available tools grouped by toolset |
|
||||
| `/toolsets` | List available toolsets with descriptions |
|
||||
| `/model [name]` | Show or change the current model |
|
||||
| `/prompt [text]` | View/set/clear custom system prompt |
|
||||
| `/personality [name]` | Set a predefined personality |
|
||||
| `/clear` | Clear screen and reset conversation |
|
||||
| `/reset` | Reset conversation only (keep screen) |
|
||||
| `/history` | Show conversation history |
|
||||
| `/save` | Save current conversation to file |
|
||||
| `/config` | Show current configuration |
|
||||
| `/quit` | Exit the CLI (also: `/exit`, `/q`) |
|
||||
|
||||
## Configuration
|
||||
|
||||
The CLI is configured via `cli-config.yaml`. Copy from `cli-config.yaml.example`:
|
||||
|
||||
```bash
|
||||
cp cli-config.yaml.example cli-config.yaml
|
||||
```
|
||||
|
||||
### Model Configuration
|
||||
|
||||
```yaml
|
||||
model:
|
||||
default: "anthropic/claude-opus-4.5"
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
```
|
||||
|
||||
### Terminal Configuration
|
||||
|
||||
The CLI supports multiple terminal backends:
|
||||
|
||||
```yaml
|
||||
# Local execution (default)
|
||||
terminal:
|
||||
env_type: "local"
|
||||
cwd: "." # Current directory
|
||||
|
||||
# SSH remote execution (sandboxed - agent can't touch its own code)
|
||||
terminal:
|
||||
env_type: "ssh"
|
||||
cwd: "/home/myuser/project"
|
||||
ssh_host: "my-server.example.com"
|
||||
ssh_user: "myuser"
|
||||
ssh_key: "~/.ssh/id_rsa"
|
||||
|
||||
# Docker container
|
||||
terminal:
|
||||
env_type: "docker"
|
||||
docker_image: "python:3.11"
|
||||
|
||||
# Singularity/Apptainer (HPC)
|
||||
terminal:
|
||||
env_type: "singularity"
|
||||
singularity_image: "docker://python:3.11"
|
||||
|
||||
# Modal cloud
|
||||
terminal:
|
||||
env_type: "modal"
|
||||
modal_image: "python:3.11"
|
||||
```
|
||||
|
||||
### Toolsets
|
||||
|
||||
Control which tools are available:
|
||||
|
||||
```yaml
|
||||
# Enable all tools
|
||||
toolsets:
|
||||
- all
|
||||
|
||||
# Or enable specific toolsets
|
||||
toolsets:
|
||||
- web
|
||||
- terminal
|
||||
- skills
|
||||
```
|
||||
|
||||
Available toolsets: `web`, `search`, `terminal`, `browser`, `vision`, `image_gen`, `skills`, `moa`, `debugging`, `safe`
|
||||
|
||||
### Personalities
|
||||
|
||||
Predefined personalities for the `/personality` command:
|
||||
|
||||
```yaml
|
||||
agent:
|
||||
personalities:
|
||||
helpful: "You are a helpful, friendly AI assistant."
|
||||
kawaii: "You are a kawaii assistant! Use cute expressions..."
|
||||
pirate: "Arrr! Ye be talkin' to Captain Hermes..."
|
||||
# Add your own!
|
||||
```
|
||||
|
||||
Built-in personalities:
|
||||
- `helpful`, `concise`, `technical`, `creative`, `teacher`
|
||||
- `kawaii`, `catgirl`, `pirate`, `shakespeare`, `surfer`
|
||||
- `noir`, `uwu`, `philosopher`, `hype`
|
||||
|
||||
## Animated Feedback
|
||||
|
||||
The CLI provides animated feedback during operations:
|
||||
|
||||
### Thinking Animation
|
||||
|
||||
During API calls, shows animated spinner with thinking verbs:
|
||||
```
|
||||
◜ (。•́︿•̀。) pondering... (1.2s)
|
||||
◠ (⊙_⊙) contemplating... (2.4s)
|
||||
✧٩(ˊᗜˋ*)و✧ got it! (3.1s)
|
||||
```
|
||||
|
||||
### Tool Execution Animation
|
||||
|
||||
Each tool type has unique animations:
|
||||
```
|
||||
⠋ (◕‿◕✿) 🔍 web_search... (0.8s)
|
||||
▅ (≧◡≦) 💻 terminal... (1.2s)
|
||||
🌓 (★ω★) 🌐 browser_navigate... (2.1s)
|
||||
✧ (✿◠‿◠) 🎨 image_generate... (4.5s)
|
||||
```
|
||||
|
||||
## Multi-line Input
|
||||
|
||||
For multi-line input, end a line with `\` to continue:
|
||||
|
||||
```
|
||||
❯ Write a function that:\
|
||||
1. Takes a list of numbers\
|
||||
2. Returns the sum
|
||||
```
|
||||
|
||||
## Environment Variable Priority
|
||||
|
||||
For terminal settings, `cli-config.yaml` takes precedence over `.env`:
|
||||
|
||||
1. `cli-config.yaml` (highest priority in CLI)
|
||||
2. `.env` file
|
||||
3. System environment variables
|
||||
4. Default values
|
||||
|
||||
This allows you to have different terminal configs for CLI vs batch processing.
|
||||
|
||||
## Session Management
|
||||
|
||||
- **History**: Command history is saved to `~/.hermes_history`
|
||||
- **Conversations**: Use `/save` to export conversations
|
||||
- **Reset**: Use `/clear` for full reset, `/reset` to just clear history
|
||||
|
||||
## Quiet Mode
|
||||
|
||||
The CLI runs in "quiet mode" (`HERMES_QUIET=1`), which:
|
||||
- Suppresses verbose logging from tools
|
||||
- Enables kawaii-style animated feedback
|
||||
- Hides terminal environment warnings
|
||||
- Keeps output clean and user-friendly
|
||||
|
||||
For verbose output (debugging), use:
|
||||
```bash
|
||||
./hermes --verbose
|
||||
```
|
||||
124
docs/llm_client.md
Normal file
124
docs/llm_client.md
Normal file
@@ -0,0 +1,124 @@
|
||||
# LLM Client
|
||||
|
||||
Hermes Agent uses the OpenAI Python SDK with OpenRouter as the backend, providing access to many models through a single API.
|
||||
|
||||
## Configuration
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
base_url="https://openrouter.ai/api/v1"
|
||||
)
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
Any model available on [OpenRouter](https://openrouter.ai/models):
|
||||
|
||||
```python
|
||||
# Anthropic
|
||||
model = "anthropic/claude-sonnet-4"
|
||||
model = "anthropic/claude-opus-4"
|
||||
|
||||
# OpenAI
|
||||
model = "openai/gpt-4o"
|
||||
model = "openai/o1"
|
||||
|
||||
# Google
|
||||
model = "google/gemini-2.0-flash"
|
||||
|
||||
# Open models
|
||||
model = "meta-llama/llama-3.3-70b-instruct"
|
||||
model = "deepseek/deepseek-chat-v3"
|
||||
model = "moonshotai/kimi-k2.5"
|
||||
```
|
||||
|
||||
## Tool Calling
|
||||
|
||||
Standard OpenAI function calling format:
|
||||
|
||||
```python
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# Check for tool calls
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
name = tool_call.function.name
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
# Execute tool...
|
||||
```
|
||||
|
||||
## Reasoning Models
|
||||
|
||||
Some models return reasoning/thinking content:
|
||||
|
||||
```python
|
||||
# Access reasoning if available
|
||||
message = response.choices[0].message
|
||||
if hasattr(message, 'reasoning_content') and message.reasoning_content:
|
||||
reasoning = message.reasoning_content
|
||||
# Store for trajectory export
|
||||
```
|
||||
|
||||
## Provider Selection
|
||||
|
||||
OpenRouter allows selecting specific providers:
|
||||
|
||||
```python
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
extra_body={
|
||||
"provider": {
|
||||
"order": ["Anthropic", "Google"], # Preferred providers
|
||||
"ignore": ["Novita"], # Providers to skip
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
Common errors and handling:
|
||||
|
||||
```python
|
||||
try:
|
||||
response = client.chat.completions.create(...)
|
||||
except openai.RateLimitError:
|
||||
# Back off and retry
|
||||
except openai.APIError as e:
|
||||
# Check e.code for specific errors
|
||||
# 400 = bad request (often provider-specific)
|
||||
# 502 = bad gateway (retry with different provider)
|
||||
```
|
||||
|
||||
## Cost Tracking
|
||||
|
||||
OpenRouter returns usage info:
|
||||
|
||||
```python
|
||||
usage = response.usage
|
||||
print(f"Tokens: {usage.prompt_tokens} + {usage.completion_tokens}")
|
||||
print(f"Cost: ${usage.cost:.6f}") # If available
|
||||
```
|
||||
121
docs/message_graph.md
Normal file
121
docs/message_graph.md
Normal file
@@ -0,0 +1,121 @@
|
||||
# Message Format & Trajectories
|
||||
|
||||
Hermes Agent uses two message formats: the **API format** for LLM calls and the **trajectory format** for training data export.
|
||||
|
||||
## API Message Format
|
||||
|
||||
Standard OpenAI chat format used during execution:
|
||||
|
||||
```python
|
||||
messages = [
|
||||
# System prompt
|
||||
{"role": "system", "content": "You are a helpful assistant with tools..."},
|
||||
|
||||
# User query
|
||||
{"role": "user", "content": "Search for Python tutorials"},
|
||||
|
||||
# Assistant with tool call
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"arguments": "{\"query\": \"Python tutorials\"}"
|
||||
}
|
||||
}]
|
||||
},
|
||||
|
||||
# Tool result
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_abc123",
|
||||
"content": "{\"results\": [...]}"
|
||||
},
|
||||
|
||||
# Final response
|
||||
{"role": "assistant", "content": "Here's what I found..."}
|
||||
]
|
||||
```
|
||||
|
||||
## Trajectory Format (ShareGPT)
|
||||
|
||||
Exported for training in ShareGPT format:
|
||||
|
||||
```json
|
||||
{
|
||||
"conversations": [
|
||||
{"from": "system", "value": "You are a helpful assistant..."},
|
||||
{"from": "human", "value": "Search for Python tutorials"},
|
||||
{"from": "gpt", "value": "<tool_call>\n{\"name\": \"web_search\", \"arguments\": {\"query\": \"Python tutorials\"}}\n</tool_call>"},
|
||||
{"from": "tool", "value": "<tool_response>\n{\"results\": [...]}\n</tool_response>"},
|
||||
{"from": "gpt", "value": "Here's what I found..."}
|
||||
],
|
||||
"tools": "[{\"type\": \"function\", \"function\": {...}}]",
|
||||
"source": "hermes-agent"
|
||||
}
|
||||
```
|
||||
|
||||
## Reasoning Content
|
||||
|
||||
For models that output reasoning/chain-of-thought:
|
||||
|
||||
**During execution** (API format):
|
||||
```python
|
||||
# Stored internally but not sent back to model in content
|
||||
assistant_msg = {
|
||||
"role": "assistant",
|
||||
"content": "Here's what I found...",
|
||||
"reasoning": "Let me think about this step by step..." # Internal only
|
||||
}
|
||||
```
|
||||
|
||||
**In trajectory export** (reasoning wrapped in tags):
|
||||
```json
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "<think>\nLet me think about this step by step...\n</think>\nHere's what I found..."
|
||||
}
|
||||
```
|
||||
|
||||
## Conversion Flow
|
||||
|
||||
```
|
||||
API Response → Internal Storage → Trajectory Export
|
||||
↓ ↓ ↓
|
||||
tool_calls reasoning field <tool_call> tags
|
||||
reasoning_content <think> tags
|
||||
```
|
||||
|
||||
The conversion happens in `_convert_to_trajectory_format()` in `run_agent.py`.
|
||||
|
||||
## Ephemeral System Prompts
|
||||
|
||||
Batch processing supports ephemeral system prompts that guide behavior during execution but are NOT saved to trajectories:
|
||||
|
||||
```python
|
||||
# During execution: full system prompt + ephemeral guidance
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT + "\n\n" + ephemeral_prompt},
|
||||
...
|
||||
]
|
||||
|
||||
# In saved trajectory: only the base system prompt
|
||||
trajectory = {
|
||||
"conversations": [
|
||||
{"from": "system", "value": SYSTEM_PROMPT}, # No ephemeral
|
||||
...
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Trajectory Compression
|
||||
|
||||
Long trajectories can be compressed for training using `trajectory_compressor.py`:
|
||||
|
||||
- Protects first/last N turns
|
||||
- Summarizes middle turns with LLM
|
||||
- Targets specific token budget
|
||||
- See `configs/trajectory_compression.yaml` for settings
|
||||
159
docs/tools.md
Normal file
159
docs/tools.md
Normal file
@@ -0,0 +1,159 @@
|
||||
# Tools
|
||||
|
||||
Tools are functions that extend the agent's capabilities. Each tool is defined with an OpenAI-compatible JSON schema and an async handler function.
|
||||
|
||||
## Tool Structure
|
||||
|
||||
Each tool module in `tools/` exports:
|
||||
1. **Schema definitions** - OpenAI function-calling format
|
||||
2. **Handler functions** - Async functions that execute the tool
|
||||
|
||||
```python
|
||||
# Example: tools/web_tools.py
|
||||
|
||||
# Schema definition
|
||||
WEB_SEARCH_SCHEMA = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web for information",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Handler function
|
||||
async def web_search(query: str) -> dict:
|
||||
"""Execute web search and return results."""
|
||||
# Implementation...
|
||||
return {"results": [...]}
|
||||
```
|
||||
|
||||
## Tool Categories
|
||||
|
||||
| Category | Module | Tools |
|
||||
|----------|--------|-------|
|
||||
| **Web** | `web_tools.py` | `web_search`, `web_extract`, `web_crawl` |
|
||||
| **Terminal** | `terminal_tool.py` | `terminal` (local/docker/singularity/modal/ssh backends) |
|
||||
| **Browser** | `browser_tool.py` | `browser_navigate`, `browser_click`, `browser_type`, etc. |
|
||||
| **Vision** | `vision_tools.py` | `vision_analyze` |
|
||||
| **Image Gen** | `image_generation_tool.py` | `image_generate` |
|
||||
| **Reasoning** | `mixture_of_agents_tool.py` | `mixture_of_agents` |
|
||||
| **Skills** | `skills_tool.py` | `skills_categories`, `skills_list`, `skill_view` |
|
||||
|
||||
## Tool Registration
|
||||
|
||||
Tools are registered in `model_tools.py`:
|
||||
|
||||
```python
|
||||
# model_tools.py
|
||||
TOOL_SCHEMAS = [
|
||||
*WEB_TOOL_SCHEMAS,
|
||||
*TERMINAL_TOOL_SCHEMAS,
|
||||
*BROWSER_TOOL_SCHEMAS,
|
||||
# ...
|
||||
]
|
||||
|
||||
TOOL_HANDLERS = {
|
||||
"web_search": web_search,
|
||||
"terminal": terminal_tool,
|
||||
"browser_navigate": browser_navigate,
|
||||
# ...
|
||||
}
|
||||
```
|
||||
|
||||
## Toolsets
|
||||
|
||||
Tools are grouped into **toolsets** for logical organization (see `toolsets.py`):
|
||||
|
||||
```python
|
||||
TOOLSETS = {
|
||||
"web": {
|
||||
"description": "Web search and content extraction",
|
||||
"tools": ["web_search", "web_extract", "web_crawl"]
|
||||
},
|
||||
"terminal": {
|
||||
"description": "Command execution",
|
||||
"tools": ["terminal"]
|
||||
},
|
||||
# ...
|
||||
}
|
||||
```
|
||||
|
||||
## Adding a New Tool
|
||||
|
||||
1. Create handler function in `tools/your_tool.py`
|
||||
2. Define JSON schema following OpenAI format
|
||||
3. Register in `model_tools.py` (schemas and handlers)
|
||||
4. Add to appropriate toolset in `toolsets.py`
|
||||
5. Update `tools/__init__.py` exports
|
||||
|
||||
## Stateful Tools
|
||||
|
||||
Some tools maintain state across calls within a session:
|
||||
|
||||
- **Terminal**: Keeps container/sandbox running between commands
|
||||
- **Browser**: Maintains browser session for multi-step navigation
|
||||
|
||||
State is managed per `task_id` and cleaned up automatically.
|
||||
|
||||
## Terminal Backends
|
||||
|
||||
The terminal tool supports multiple execution backends:
|
||||
|
||||
| Backend | Description | Use Case |
|
||||
|---------|-------------|----------|
|
||||
| `local` | Direct execution on host | Development, simple tasks |
|
||||
| `ssh` | Remote execution via SSH | Sandboxing (agent can't modify its own code) |
|
||||
| `docker` | Docker container | Isolation, reproducibility |
|
||||
| `singularity` | Singularity/Apptainer | HPC clusters, rootless containers |
|
||||
| `modal` | Modal cloud | Scalable cloud compute, GPUs |
|
||||
|
||||
Configure via environment variables or `cli-config.yaml`:
|
||||
|
||||
```yaml
|
||||
# SSH backend example (in cli-config.yaml)
|
||||
terminal:
|
||||
env_type: "ssh"
|
||||
ssh_host: "my-server.example.com"
|
||||
ssh_user: "myuser"
|
||||
ssh_key: "~/.ssh/id_rsa"
|
||||
cwd: "/home/myuser/project"
|
||||
```
|
||||
|
||||
The SSH backend uses ControlMaster for connection persistence, making subsequent commands fast.
|
||||
|
||||
## Skills Tools (Progressive Disclosure)
|
||||
|
||||
Skills are on-demand knowledge documents. They use **progressive disclosure** to minimize tokens:
|
||||
|
||||
```
|
||||
Level 0: skills_categories() → ["mlops", "devops"] (~50 tokens)
|
||||
Level 1: skills_list(category) → [{name, description}, ...] (~3k tokens)
|
||||
Level 2: skill_view(name) → Full content + metadata (varies)
|
||||
Level 3: skill_view(name, path) → Specific reference file (varies)
|
||||
```
|
||||
|
||||
Skill directory structure:
|
||||
```
|
||||
skills/
|
||||
└── mlops/
|
||||
└── axolotl/
|
||||
├── SKILL.md # Main instructions (required)
|
||||
├── references/ # Additional docs
|
||||
└── templates/ # Output formats, configs
|
||||
```
|
||||
|
||||
SKILL.md uses YAML frontmatter:
|
||||
```yaml
|
||||
---
|
||||
name: axolotl
|
||||
description: Fine-tuning LLMs with Axolotl
|
||||
tags: [Fine-Tuning, LoRA, DPO]
|
||||
---
|
||||
```
|
||||
70
example-skill/SKILL.md
Normal file
70
example-skill/SKILL.md
Normal file
@@ -0,0 +1,70 @@
|
||||
---
|
||||
name: example-skill
|
||||
description: An example skill demonstrating the skill file format and structure
|
||||
---
|
||||
|
||||
# Example Skill
|
||||
|
||||
This is an example skill file that demonstrates how to create skills for the Hermes Agent.
|
||||
|
||||
## Skill File Format
|
||||
|
||||
Skills are markdown files with YAML frontmatter at the top:
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: your-skill-name
|
||||
description: A brief one-line description of what this skill does
|
||||
---
|
||||
```
|
||||
|
||||
The frontmatter fields:
|
||||
- **name**: The identifier used to reference this skill (lowercase, hyphens for spaces)
|
||||
- **description**: A brief description shown when listing skills (keep under 200 chars)
|
||||
|
||||
## Writing Effective Skills
|
||||
|
||||
### 1. Be Specific and Actionable
|
||||
|
||||
Good skills provide clear, actionable instructions:
|
||||
|
||||
```
|
||||
When reviewing code:
|
||||
1. Check for security vulnerabilities first
|
||||
2. Verify error handling is comprehensive
|
||||
3. Ensure tests cover edge cases
|
||||
```
|
||||
|
||||
### 2. Include Examples
|
||||
|
||||
Show concrete examples of what you want:
|
||||
|
||||
```python
|
||||
# Good: Descriptive variable names
|
||||
user_authentication_token = get_token()
|
||||
|
||||
# Bad: Cryptic abbreviations
|
||||
uat = gt()
|
||||
```
|
||||
|
||||
### 3. Define When to Use
|
||||
|
||||
Help the agent understand when this skill applies:
|
||||
|
||||
> Use this skill when: reviewing pull requests, auditing security, or checking code quality.
|
||||
|
||||
## Skill Categories
|
||||
|
||||
Consider organizing skills by purpose:
|
||||
|
||||
- **Conventions**: Coding standards, API patterns, naming rules
|
||||
- **Workflows**: Step-by-step processes for deployments, reviews, releases
|
||||
- **Knowledge**: Domain-specific information, system architecture, gotchas
|
||||
- **Templates**: Boilerplate for common tasks, response formats
|
||||
|
||||
## Tips
|
||||
|
||||
1. Keep the description concise - it's shown in the skills list
|
||||
2. Use headers to organize longer skills
|
||||
3. Include code examples where helpful
|
||||
4. Reference other skills if they're related
|
||||
12
hermes
Executable file
12
hermes
Executable file
@@ -0,0 +1,12 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Hermes Agent CLI Launcher
|
||||
|
||||
This is a convenience wrapper to launch the Hermes CLI.
|
||||
Usage: ./hermes [options]
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
from cli import main
|
||||
import fire
|
||||
fire.Fire(main)
|
||||
1
mini-swe-agent
Submodule
1
mini-swe-agent
Submodule
Submodule mini-swe-agent added at 07aa6a7385
708
mini_swe_runner.py
Normal file
708
mini_swe_runner.py
Normal file
@@ -0,0 +1,708 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Mini-SWE-Agent Runner with Hermes Trajectory Format
|
||||
|
||||
This module provides a runner that uses mini-swe-agent's execution environments
|
||||
(local, docker, modal) but outputs trajectories in the Hermes-Agent format
|
||||
compatible with batch_runner.py and trajectory_compressor.py.
|
||||
|
||||
Features:
|
||||
- Uses mini-swe-agent's Docker, Modal, or Local environments for command execution
|
||||
- Outputs trajectories in Hermes format (from/value pairs with <tool_call>/<tool_response> XML)
|
||||
- Compatible with the trajectory compression pipeline
|
||||
- Supports batch processing from JSONL prompt files
|
||||
|
||||
Usage:
|
||||
# Run a single task with local environment
|
||||
python mini_swe_runner.py --task "Create a hello world Python script" --env local
|
||||
|
||||
# Run with Docker
|
||||
python mini_swe_runner.py --task "List files in /tmp" --env docker --image python:3.11-slim
|
||||
|
||||
# Run with Modal (cloud)
|
||||
python mini_swe_runner.py --task "Install numpy and test it" --env modal --image python:3.11-slim
|
||||
|
||||
# Batch mode from JSONL file
|
||||
python mini_swe_runner.py --prompts_file prompts.jsonl --output_file trajectories.jsonl --env docker
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Add mini-swe-agent to path if not installed
|
||||
mini_swe_path = Path(__file__).parent / "mini-swe-agent" / "src"
|
||||
if mini_swe_path.exists():
|
||||
sys.path.insert(0, str(mini_swe_path))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Terminal Tool Definition (matches Hermes-Agent format)
|
||||
# ============================================================================
|
||||
|
||||
TERMINAL_TOOL_DEFINITION = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "terminal",
|
||||
"description": """Execute bash commands in a sandboxed environment.
|
||||
|
||||
**Environment:**
|
||||
- Isolated execution environment (local, Docker, or Modal cloud)
|
||||
- Filesystem persists between tool calls within the same task
|
||||
- Internet access available
|
||||
|
||||
**Command Execution:**
|
||||
- Provide the command to execute via the 'command' parameter
|
||||
- Optional 'timeout' parameter in seconds (default: 60)
|
||||
|
||||
**Examples:**
|
||||
- Run command: `{"command": "ls -la"}`
|
||||
- With timeout: `{"command": "long_task.sh", "timeout": 300}`
|
||||
|
||||
**Best Practices:**
|
||||
- Use non-interactive commands (avoid vim, nano, interactive python)
|
||||
- Pipe to cat if output might be large
|
||||
- Install tools with apt-get or pip as needed
|
||||
|
||||
**Completion:**
|
||||
- When task is complete, output: echo "MINI_SWE_AGENT_FINAL_OUTPUT" followed by your result
|
||||
""",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The bash command to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Command timeout in seconds (default: 60)"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Environment Factory
|
||||
# ============================================================================
|
||||
|
||||
def create_environment(
|
||||
env_type: str = "local",
|
||||
image: str = "python:3.11-slim",
|
||||
cwd: str = "/tmp",
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Create an execution environment from mini-swe-agent.
|
||||
|
||||
Args:
|
||||
env_type: One of "local", "docker", "modal"
|
||||
image: Docker/Modal image name (ignored for local)
|
||||
cwd: Working directory
|
||||
timeout: Default command timeout
|
||||
**kwargs: Additional environment-specific options
|
||||
|
||||
Returns:
|
||||
Environment instance with execute() method
|
||||
"""
|
||||
if env_type == "local":
|
||||
from minisweagent.environments.local import LocalEnvironment
|
||||
return LocalEnvironment(cwd=cwd, timeout=timeout)
|
||||
|
||||
elif env_type == "docker":
|
||||
from minisweagent.environments.docker import DockerEnvironment
|
||||
return DockerEnvironment(image=image, cwd=cwd, timeout=timeout, **kwargs)
|
||||
|
||||
elif env_type == "modal":
|
||||
from minisweagent.environments.extra.swerex_modal import SwerexModalEnvironment
|
||||
return SwerexModalEnvironment(image=image, cwd=cwd, timeout=timeout, **kwargs)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown environment type: {env_type}. Use 'local', 'docker', or 'modal'")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Mini-SWE Runner with Hermes Trajectory Format
|
||||
# ============================================================================
|
||||
|
||||
class MiniSWERunner:
|
||||
"""
|
||||
Agent runner that uses mini-swe-agent environments but outputs
|
||||
trajectories in Hermes-Agent format.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "anthropic/claude-sonnet-4-20250514",
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
env_type: str = "local",
|
||||
image: str = "python:3.11-slim",
|
||||
cwd: str = "/tmp",
|
||||
max_iterations: int = 15,
|
||||
command_timeout: int = 60,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the Mini-SWE Runner.
|
||||
|
||||
Args:
|
||||
model: Model name for OpenAI-compatible API
|
||||
base_url: API base URL (optional, uses env vars if not provided)
|
||||
api_key: API key (optional, uses env vars if not provided)
|
||||
env_type: Environment type - "local", "docker", or "modal"
|
||||
image: Docker/Modal image (ignored for local)
|
||||
cwd: Working directory for commands
|
||||
max_iterations: Maximum tool-calling iterations
|
||||
command_timeout: Default timeout for commands
|
||||
verbose: Enable verbose logging
|
||||
"""
|
||||
self.model = model
|
||||
self.max_iterations = max_iterations
|
||||
self.command_timeout = command_timeout
|
||||
self.verbose = verbose
|
||||
self.env_type = env_type
|
||||
self.image = image
|
||||
self.cwd = cwd
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if verbose else logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%H:%M:%S'
|
||||
)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize OpenAI client - defaults to OpenRouter
|
||||
from openai import OpenAI
|
||||
|
||||
client_kwargs = {}
|
||||
|
||||
# Default to OpenRouter if no base_url provided
|
||||
if base_url:
|
||||
client_kwargs["base_url"] = base_url
|
||||
else:
|
||||
client_kwargs["base_url"] = "https://openrouter.ai/api/v1"
|
||||
|
||||
# Handle API key - OpenRouter is the primary provider
|
||||
if api_key:
|
||||
client_kwargs["api_key"] = api_key
|
||||
else:
|
||||
client_kwargs["api_key"] = os.getenv(
|
||||
"OPENROUTER_API_KEY",
|
||||
os.getenv("ANTHROPIC_API_KEY", os.getenv("OPENAI_API_KEY", ""))
|
||||
)
|
||||
|
||||
self.client = OpenAI(**client_kwargs)
|
||||
|
||||
# Environment will be created per-task
|
||||
self.env = None
|
||||
|
||||
# Tool definition
|
||||
self.tools = [TERMINAL_TOOL_DEFINITION]
|
||||
|
||||
print(f"🤖 Mini-SWE Runner initialized")
|
||||
print(f" Model: {self.model}")
|
||||
print(f" Environment: {self.env_type}")
|
||||
if self.env_type != "local":
|
||||
print(f" Image: {self.image}")
|
||||
print(f" Max iterations: {self.max_iterations}")
|
||||
|
||||
def _create_env(self):
|
||||
"""Create the execution environment."""
|
||||
print(f"🔧 Creating {self.env_type} environment...")
|
||||
self.env = create_environment(
|
||||
env_type=self.env_type,
|
||||
image=self.image,
|
||||
cwd=self.cwd,
|
||||
timeout=self.command_timeout
|
||||
)
|
||||
print(f"✅ Environment ready")
|
||||
|
||||
def _cleanup_env(self):
|
||||
"""Cleanup the execution environment."""
|
||||
if self.env is not None:
|
||||
if hasattr(self.env, 'cleanup'):
|
||||
self.env.cleanup()
|
||||
elif hasattr(self.env, 'stop'):
|
||||
self.env.stop()
|
||||
self.env = None
|
||||
|
||||
def _execute_command(self, command: str, timeout: int = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute a command in the environment.
|
||||
|
||||
Args:
|
||||
command: Bash command to execute
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Dict with 'output' and 'returncode'
|
||||
"""
|
||||
if self.env is None:
|
||||
self._create_env()
|
||||
|
||||
try:
|
||||
result = self.env.execute(command, timeout=timeout or self.command_timeout)
|
||||
return {
|
||||
"output": result.get("output", ""),
|
||||
"exit_code": result.get("returncode", 0),
|
||||
"error": None
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def _format_tools_for_system_message(self) -> str:
|
||||
"""Format tool definitions for the system message."""
|
||||
formatted_tools = []
|
||||
for tool in self.tools:
|
||||
func = tool["function"]
|
||||
formatted_tools.append({
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {}),
|
||||
"required": None
|
||||
})
|
||||
return json.dumps(formatted_tools, ensure_ascii=False)
|
||||
|
||||
def _convert_to_hermes_format(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
user_query: str,
|
||||
completed: bool
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert internal message format to Hermes trajectory format.
|
||||
|
||||
This produces the exact format used by batch_runner.py.
|
||||
"""
|
||||
trajectory = []
|
||||
|
||||
# System message with tool definitions
|
||||
system_msg = (
|
||||
"You are a function calling AI model. You are provided with function signatures within <tools> </tools> XML tags. "
|
||||
"You may call one or more functions to assist with the user query. If available tools are not relevant in assisting "
|
||||
"with user query, just respond in natural conversational language. Don't make assumptions about what values to plug "
|
||||
"into functions. After calling & executing the functions, you will be provided with function results within "
|
||||
"<tool_response> </tool_response> XML tags. Here are the available tools:\n"
|
||||
f"<tools>\n{self._format_tools_for_system_message()}\n</tools>\n"
|
||||
"For each function call return a JSON object, with the following pydantic model json schema for each:\n"
|
||||
"{'title': 'FunctionCall', 'type': 'object', 'properties': {'name': {'title': 'Name', 'type': 'string'}, "
|
||||
"'arguments': {'title': 'Arguments', 'type': 'object'}}, 'required': ['name', 'arguments']}\n"
|
||||
"Each function call should be enclosed within <tool_call> </tool_call> XML tags.\n"
|
||||
"Example:\n<tool_call>\n{'name': <function-name>,'arguments': <args-dict>}\n</tool_call>"
|
||||
)
|
||||
|
||||
trajectory.append({"from": "system", "value": system_msg})
|
||||
trajectory.append({"from": "human", "value": user_query})
|
||||
|
||||
# Process messages (skip first user message as we already added it)
|
||||
i = 1
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
|
||||
if msg["role"] == "assistant":
|
||||
if "tool_calls" in msg and msg["tool_calls"]:
|
||||
# Assistant message with tool calls
|
||||
content = ""
|
||||
|
||||
# Add reasoning if present
|
||||
if msg.get("reasoning"):
|
||||
content = f"<think>{msg['reasoning']}</think>"
|
||||
|
||||
if msg.get("content"):
|
||||
content += msg["content"] + "\n"
|
||||
|
||||
# Add tool calls in XML format
|
||||
for tool_call in msg["tool_calls"]:
|
||||
try:
|
||||
arguments = json.loads(tool_call["function"]["arguments"]) \
|
||||
if isinstance(tool_call["function"]["arguments"], str) \
|
||||
else tool_call["function"]["arguments"]
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
|
||||
tool_call_json = {
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": arguments
|
||||
}
|
||||
content += f"<tool_call>\n{json.dumps(tool_call_json, ensure_ascii=False)}\n</tool_call>\n"
|
||||
|
||||
trajectory.append({"from": "gpt", "value": content.rstrip()})
|
||||
|
||||
# Collect subsequent tool responses
|
||||
tool_responses = []
|
||||
j = i + 1
|
||||
while j < len(messages) and messages[j]["role"] == "tool":
|
||||
tool_msg = messages[j]
|
||||
tool_content = tool_msg["content"]
|
||||
|
||||
# Try to parse as JSON
|
||||
try:
|
||||
if tool_content.strip().startswith(("{", "[")):
|
||||
tool_content = json.loads(tool_content)
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
pass
|
||||
|
||||
tool_response = f"<tool_response>\n"
|
||||
tool_response += json.dumps({
|
||||
"tool_call_id": tool_msg.get("tool_call_id", ""),
|
||||
"name": msg["tool_calls"][len(tool_responses)]["function"]["name"] \
|
||||
if len(tool_responses) < len(msg["tool_calls"]) else "unknown",
|
||||
"content": tool_content
|
||||
}, ensure_ascii=False)
|
||||
tool_response += "\n</tool_response>"
|
||||
tool_responses.append(tool_response)
|
||||
j += 1
|
||||
|
||||
if tool_responses:
|
||||
trajectory.append({"from": "tool", "value": "\n".join(tool_responses)})
|
||||
i = j - 1
|
||||
|
||||
else:
|
||||
# Regular assistant message (no tool calls)
|
||||
content = ""
|
||||
if msg.get("reasoning"):
|
||||
content = f"<think>{msg['reasoning']}</think>"
|
||||
content += msg.get("content") or ""
|
||||
trajectory.append({"from": "gpt", "value": content})
|
||||
|
||||
elif msg["role"] == "user":
|
||||
trajectory.append({"from": "human", "value": msg["content"]})
|
||||
|
||||
i += 1
|
||||
|
||||
return trajectory
|
||||
|
||||
def run_task(self, task: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Run a single task and return the result with trajectory.
|
||||
|
||||
Args:
|
||||
task: The task/prompt to execute
|
||||
|
||||
Returns:
|
||||
Dict with trajectory, completion status, and metadata
|
||||
"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"📝 Task: {task[:80]}{'...' if len(task) > 80 else ''}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Initialize environment
|
||||
self._create_env()
|
||||
|
||||
# Message history
|
||||
messages = [{"role": "user", "content": task}]
|
||||
|
||||
# System prompt for the LLM (ephemeral - not saved to trajectory)
|
||||
system_prompt = """You are an AI agent that can execute bash commands to complete tasks.
|
||||
|
||||
When you need to run commands, use the 'terminal' tool with your bash command.
|
||||
|
||||
**Important:**
|
||||
- When you have completed the task successfully, run: echo "MINI_SWE_AGENT_FINAL_OUTPUT" followed by a summary
|
||||
- Be concise and efficient in your approach
|
||||
- Install any needed tools with apt-get or pip
|
||||
- Avoid interactive commands (no vim, nano, less, etc.)
|
||||
|
||||
Complete the user's task step by step."""
|
||||
|
||||
api_call_count = 0
|
||||
completed = False
|
||||
final_response = None
|
||||
|
||||
try:
|
||||
while api_call_count < self.max_iterations:
|
||||
api_call_count += 1
|
||||
print(f"\n🔄 API call #{api_call_count}/{self.max_iterations}")
|
||||
|
||||
# Prepare API messages
|
||||
api_messages = [{"role": "system", "content": system_prompt}] + messages
|
||||
|
||||
# Make API call
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=api_messages,
|
||||
tools=self.tools,
|
||||
timeout=300.0
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"API call failed: {e}")
|
||||
break
|
||||
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
# Log assistant response
|
||||
if assistant_message.content:
|
||||
print(f"🤖 Assistant: {assistant_message.content[:100]}...")
|
||||
|
||||
# Check for tool calls
|
||||
if assistant_message.tool_calls:
|
||||
print(f"🔧 Tool calls: {len(assistant_message.tool_calls)}")
|
||||
|
||||
# Add assistant message with tool calls
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": assistant_message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments
|
||||
}
|
||||
}
|
||||
for tc in assistant_message.tool_calls
|
||||
]
|
||||
})
|
||||
|
||||
# Execute each tool call
|
||||
for tc in assistant_message.tool_calls:
|
||||
try:
|
||||
args = json.loads(tc.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
command = args.get("command", "echo 'No command provided'")
|
||||
timeout = args.get("timeout", self.command_timeout)
|
||||
|
||||
print(f" 📞 terminal: {command[:60]}...")
|
||||
|
||||
# Execute command
|
||||
result = self._execute_command(command, timeout)
|
||||
|
||||
# Format result
|
||||
result_json = json.dumps({
|
||||
"content": {
|
||||
"output": result["output"],
|
||||
"exit_code": result["exit_code"],
|
||||
"error": result["error"]
|
||||
}
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# Check for task completion signal
|
||||
if "MINI_SWE_AGENT_FINAL_OUTPUT" in result["output"]:
|
||||
print(f" ✅ Task completion signal detected!")
|
||||
completed = True
|
||||
|
||||
# Add tool response
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": result_json,
|
||||
"tool_call_id": tc.id
|
||||
})
|
||||
|
||||
print(f" ✅ exit_code={result['exit_code']}, output={len(result['output'])} chars")
|
||||
|
||||
# If task completed, we can stop
|
||||
if completed:
|
||||
final_response = assistant_message.content
|
||||
break
|
||||
|
||||
else:
|
||||
# No tool calls - final response
|
||||
final_response = assistant_message.content or ""
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": final_response
|
||||
})
|
||||
completed = True
|
||||
print(f"🎉 Agent finished (no more tool calls)")
|
||||
break
|
||||
|
||||
if api_call_count >= self.max_iterations:
|
||||
print(f"⚠️ Reached max iterations ({self.max_iterations})")
|
||||
|
||||
finally:
|
||||
# Cleanup environment
|
||||
self._cleanup_env()
|
||||
|
||||
# Convert to Hermes trajectory format
|
||||
trajectory = self._convert_to_hermes_format(messages, task, completed)
|
||||
|
||||
return {
|
||||
"conversations": trajectory,
|
||||
"completed": completed,
|
||||
"api_calls": api_call_count,
|
||||
"metadata": {
|
||||
"model": self.model,
|
||||
"env_type": self.env_type,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
prompts: List[str],
|
||||
output_file: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Run multiple tasks and save trajectories to a JSONL file.
|
||||
|
||||
Args:
|
||||
prompts: List of task prompts
|
||||
output_file: Output JSONL file path
|
||||
|
||||
Returns:
|
||||
List of results
|
||||
"""
|
||||
results = []
|
||||
|
||||
print(f"\n📦 Running batch of {len(prompts)} tasks")
|
||||
print(f"📁 Output: {output_file}")
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for i, prompt in enumerate(prompts, 1):
|
||||
print(f"\n{'='*60}")
|
||||
print(f"📋 Task {i}/{len(prompts)}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
try:
|
||||
result = self.run_task(prompt)
|
||||
results.append(result)
|
||||
|
||||
# Write to file immediately
|
||||
f.write(json.dumps(result, ensure_ascii=False) + "\n")
|
||||
f.flush()
|
||||
|
||||
print(f"✅ Task {i} completed (api_calls={result['api_calls']})")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error on task {i}: {e}")
|
||||
error_result = {
|
||||
"conversations": [],
|
||||
"completed": False,
|
||||
"api_calls": 0,
|
||||
"error": str(e),
|
||||
"metadata": {"timestamp": datetime.now().isoformat()}
|
||||
}
|
||||
results.append(error_result)
|
||||
f.write(json.dumps(error_result, ensure_ascii=False) + "\n")
|
||||
f.flush()
|
||||
|
||||
print(f"\n✅ Batch complete! {len(results)} trajectories saved to {output_file}")
|
||||
return results
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CLI Interface
|
||||
# ============================================================================
|
||||
|
||||
def main(
|
||||
task: str = None,
|
||||
prompts_file: str = None,
|
||||
output_file: str = "mini-swe-agent-test1.jsonl",
|
||||
model: str = "claude-sonnet-4-20250514",
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
env: str = "local",
|
||||
image: str = "python:3.11-slim",
|
||||
cwd: str = "/tmp",
|
||||
max_iterations: int = 15,
|
||||
timeout: int = 60,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
Run mini-swe-agent tasks with Hermes trajectory format output.
|
||||
|
||||
Args:
|
||||
task: Single task to run (use this OR prompts_file)
|
||||
prompts_file: JSONL file with prompts (each line: {"prompt": "..."})
|
||||
output_file: Output JSONL file for trajectories
|
||||
model: Model name (default: claude-sonnet-4-20250514)
|
||||
base_url: API base URL (optional)
|
||||
api_key: API key (optional, uses env vars)
|
||||
env: Environment type - "local", "docker", or "modal"
|
||||
image: Docker/Modal image (default: python:3.11-slim)
|
||||
cwd: Working directory (default: /tmp)
|
||||
max_iterations: Maximum tool-calling iterations (default: 15)
|
||||
timeout: Command timeout in seconds (default: 60)
|
||||
verbose: Enable verbose logging
|
||||
|
||||
Examples:
|
||||
# Single task with local environment
|
||||
python mini_swe_runner.py --task "Create hello.py that prints Hello World"
|
||||
|
||||
# Single task with Docker
|
||||
python mini_swe_runner.py --task "List files" --env docker
|
||||
|
||||
# Batch from file
|
||||
python mini_swe_runner.py --prompts_file tasks.jsonl --output_file results.jsonl
|
||||
"""
|
||||
print("🚀 Mini-SWE Runner with Hermes Trajectory Format")
|
||||
print("=" * 60)
|
||||
|
||||
# Initialize runner
|
||||
runner = MiniSWERunner(
|
||||
model=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
env_type=env,
|
||||
image=image,
|
||||
cwd=cwd,
|
||||
max_iterations=max_iterations,
|
||||
command_timeout=timeout,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
if task:
|
||||
# Single task mode
|
||||
result = runner.run_task(task)
|
||||
|
||||
# Save to file
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(result, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"\n📁 Trajectory saved to: {output_file}")
|
||||
print(f"✅ Completed: {result['completed']}")
|
||||
print(f"📞 API calls: {result['api_calls']}")
|
||||
print(f"💬 Turns: {len(result['conversations'])}")
|
||||
|
||||
elif prompts_file:
|
||||
# Batch mode
|
||||
prompts = []
|
||||
with open(prompts_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
prompts.append(entry.get("prompt", entry.get("task", "")))
|
||||
except json.JSONDecodeError:
|
||||
prompts.append(line)
|
||||
|
||||
if not prompts:
|
||||
print(f"❌ No prompts found in {prompts_file}")
|
||||
return
|
||||
|
||||
runner.run_batch(prompts, output_file)
|
||||
|
||||
else:
|
||||
print("❌ Please provide either --task or --prompts_file")
|
||||
print(" Example: python mini_swe_runner.py --task 'Create a hello world script'")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
424
model_tools.py
424
model_tools.py
@@ -8,7 +8,7 @@ for defining tools and executing function calls.
|
||||
|
||||
Currently supports:
|
||||
- Web tools (search, extract, crawl) from web_tools.py
|
||||
- Terminal tools (command execution with interactive sessions) from terminal_tool.py
|
||||
- Terminal tools (simple command execution, no session persistence) from simple_terminal_tool.py
|
||||
- Vision tools (image analysis) from vision_tools.py
|
||||
- Mixture of Agents tools (collaborative multi-model reasoning) from mixture_of_agents_tool.py
|
||||
- Image generation tools (text-to-image with upscaling) from image_generation_tool.py
|
||||
@@ -31,10 +31,29 @@ import asyncio
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from tools.web_tools import web_search_tool, web_extract_tool, web_crawl_tool, check_firecrawl_api_key
|
||||
from tools.terminal_tool import terminal_tool, check_hecate_requirements, TERMINAL_TOOL_DESCRIPTION
|
||||
from tools.terminal_tool import terminal_tool, check_terminal_requirements, TERMINAL_TOOL_DESCRIPTION, cleanup_vm
|
||||
# Hecate/MorphCloud terminal tool (cloud VMs) - available as alternative backend
|
||||
from tools.terminal_hecate import terminal_hecate_tool, check_hecate_requirements, TERMINAL_HECATE_DESCRIPTION
|
||||
from tools.vision_tools import vision_analyze_tool, check_vision_requirements
|
||||
from tools.mixture_of_agents_tool import mixture_of_agents_tool, check_moa_requirements
|
||||
from tools.image_generation_tool import image_generate_tool, check_image_generation_requirements
|
||||
from tools.skills_tool import skills_categories, skills_list, skill_view, check_skills_requirements, SKILLS_TOOL_DESCRIPTION
|
||||
# Browser automation tools (agent-browser + Browserbase)
|
||||
from tools.browser_tool import (
|
||||
browser_navigate,
|
||||
browser_snapshot,
|
||||
browser_click,
|
||||
browser_type,
|
||||
browser_scroll,
|
||||
browser_back,
|
||||
browser_press,
|
||||
browser_close,
|
||||
browser_get_images,
|
||||
browser_vision,
|
||||
cleanup_browser,
|
||||
check_browser_requirements,
|
||||
BROWSER_TOOL_SCHEMAS
|
||||
)
|
||||
from toolsets import (
|
||||
get_toolset, resolve_toolset, resolve_multiple_toolsets,
|
||||
get_all_toolsets, get_toolset_names, validate_toolset,
|
||||
@@ -53,7 +72,7 @@ def get_web_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web for information on any topic. Returns up to 5 relevant results with titles and URLs. Uses advanced search depth for comprehensive results.",
|
||||
"description": "Search the web for information on any topic. Returns up to 5 relevant results with titles and URLs. Uses advanced search depth for comprehensive results. PREFERRED over browser tools for finding information - faster and more cost-effective. Use browser tools only when you need to interact with pages (click, fill forms, handle dynamic content).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -70,7 +89,7 @@ def get_web_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_extract",
|
||||
"description": "Extract and read the full content from specific web page URLs. Useful for getting detailed information from webpages found through search. The content returned will be excerpts and key points summarized with an LLM to reduce impact on the context window.",
|
||||
"description": "Extract and read the full content from specific web page URLs. Useful for getting detailed information from webpages found through search. The content returned will be excerpts and key points summarized with an LLM to reduce impact on the context window. PREFERRED over browser tools for reading page content - faster and more cost-effective. Use browser tools only when pages require interaction or have dynamic content.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -85,33 +104,14 @@ def get_web_tool_definitions() -> List[Dict[str, Any]]:
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_crawl",
|
||||
"description": "Crawl a website with specific instructions to find and extract targeted content. Uses AI to intelligently navigate and extract relevant information from across the site. The content returned will be excerpts and key points summarized with an LLM to reduce impact on the context window.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The base URL to crawl (can include or exclude https://)"
|
||||
},
|
||||
"instructions": {
|
||||
"type": "string",
|
||||
"description": "Specific instructions for what to crawl/extract using AI intelligence (e.g., 'Find pricing information', 'Get documentation pages', 'Extract contact details')"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
def get_terminal_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get tool definitions for terminal tools in OpenAI's expected format.
|
||||
|
||||
Uses mini-swe-agent backend (local/docker/modal) by default.
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of terminal tool definitions compatible with OpenAI API
|
||||
"""
|
||||
@@ -128,28 +128,18 @@ def get_terminal_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"type": "string",
|
||||
"description": "The command to execute on the VM"
|
||||
},
|
||||
"input_keys": {
|
||||
"type": "string",
|
||||
"description": "Keystrokes to send to the most recent interactive session (e.g., 'hello\\n' for typing hello + Enter). If no active session exists, this will be ignored."
|
||||
},
|
||||
"background": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to run the command in the background (default: false)",
|
||||
"default": False
|
||||
},
|
||||
"idle_threshold": {
|
||||
"type": "number",
|
||||
"description": "Seconds to wait for output before considering session idle (default: 5.0)",
|
||||
"default": 5.0,
|
||||
"minimum": 0.1
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Command timeout in seconds (optional)",
|
||||
"minimum": 1
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
"required": ["command"]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -228,7 +218,7 @@ def get_image_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "image_generate",
|
||||
"description": "Generate high-quality images from text prompts using FLUX Krea model with automatic 2x upscaling. Creates detailed, artistic images that are automatically enhanced for superior quality. Returns a single upscaled image URL that can be displayed using <img src=\"{URL}\"></img> tags.",
|
||||
"description": "Generate high-quality images from text prompts using FLUX 2 Pro model with automatic 2x upscaling. Creates detailed, artistic images that are automatically upscaled for hi-rez results. Returns a single upscaled image URL that can be displayed using <img src=\"{URL}\"></img> tags.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -236,11 +226,11 @@ def get_image_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"type": "string",
|
||||
"description": "The text prompt describing the desired image. Be detailed and descriptive."
|
||||
},
|
||||
"image_size": {
|
||||
"aspect_ratio": {
|
||||
"type": "string",
|
||||
"enum": ["square","portrait_16_9", "landscape_16_9"],
|
||||
"description": "The size/aspect ratio of the generated image (default: landscape_4_3)",
|
||||
"default": "landscape_16_9"
|
||||
"enum": ["landscape", "square", "portrait"],
|
||||
"description": "The aspect ratio of the generated image. 'landscape' is 16:9 wide, 'portrait' is 16:9 tall, 'square' is 1:1.",
|
||||
"default": "landscape"
|
||||
}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
@@ -250,6 +240,79 @@ def get_image_tool_definitions() -> List[Dict[str, Any]]:
|
||||
]
|
||||
|
||||
|
||||
def get_skills_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get tool definitions for skills tools in OpenAI's expected format.
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of skills tool definitions compatible with OpenAI API
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "skills_list",
|
||||
"description": "List available skills (name + description). Use skill_view(name) to load full content.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"category": {
|
||||
"type": "string",
|
||||
"description": "Optional category filter (from skills_categories)"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "skills_categories",
|
||||
"description": "List available skill categories. Call first if you want to discover categories, then use skills_list(category) to filter, or call skills_list if unsure.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "skill_view",
|
||||
"description": "Skills allow for loading information about specific tasks and workflows, as well as scripts and templates. Load a skill's full content or access its linked files (references, templates, scripts). First call returns SKILL.md content plus a 'linked_files' dict showing available references/templates/scripts. To access those, call again with file_path parameter.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "The skill name (use skills_list to see available skills)"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "OPTIONAL: Path to a linked file within the skill (e.g., 'references/api.md', 'templates/config.yaml', 'scripts/validate.py'). Omit to get the main SKILL.md content."
|
||||
}
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def get_browser_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get tool definitions for browser automation tools in OpenAI's expected format.
|
||||
|
||||
Uses agent-browser CLI with Browserbase cloud execution.
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of browser tool definitions compatible with OpenAI API
|
||||
"""
|
||||
return [{"type": "function", "function": schema} for schema in BROWSER_TOOL_SCHEMAS]
|
||||
|
||||
|
||||
def get_all_tool_names() -> List[str]:
|
||||
"""
|
||||
Get the names of all available tools across all toolsets.
|
||||
@@ -261,12 +324,12 @@ def get_all_tool_names() -> List[str]:
|
||||
|
||||
# Web tools
|
||||
if check_firecrawl_api_key():
|
||||
tool_names.extend(["web_search", "web_extract", "web_crawl"])
|
||||
|
||||
# Terminal tools
|
||||
if check_hecate_requirements():
|
||||
tool_names.extend(["web_search", "web_extract"])
|
||||
|
||||
# Terminal tools (mini-swe-agent backend)
|
||||
if check_terminal_requirements():
|
||||
tool_names.extend(["terminal"])
|
||||
|
||||
|
||||
# Vision tools
|
||||
if check_vision_requirements():
|
||||
tool_names.extend(["vision_analyze"])
|
||||
@@ -279,6 +342,19 @@ def get_all_tool_names() -> List[str]:
|
||||
if check_image_generation_requirements():
|
||||
tool_names.extend(["image_generate"])
|
||||
|
||||
# Skills tools
|
||||
if check_skills_requirements():
|
||||
tool_names.extend(["skills_categories", "skills_list", "skill_view"])
|
||||
|
||||
# Browser automation tools
|
||||
if check_browser_requirements():
|
||||
tool_names.extend([
|
||||
"browser_navigate", "browser_snapshot", "browser_click",
|
||||
"browser_type", "browser_scroll", "browser_back",
|
||||
"browser_press", "browser_close", "browser_get_images",
|
||||
"browser_vision"
|
||||
])
|
||||
|
||||
return tool_names
|
||||
|
||||
|
||||
@@ -294,12 +370,26 @@ def get_toolset_for_tool(tool_name: str) -> str:
|
||||
"""
|
||||
toolset_mapping = {
|
||||
"web_search": "web_tools",
|
||||
"web_extract": "web_tools",
|
||||
"web_crawl": "web_tools",
|
||||
"web_extract": "web_tools",
|
||||
"terminal": "terminal_tools",
|
||||
"vision_analyze": "vision_tools",
|
||||
"mixture_of_agents": "moa_tools",
|
||||
"image_generate": "image_tools"
|
||||
"image_generate": "image_tools",
|
||||
# Skills tools
|
||||
"skills_categories": "skills_tools",
|
||||
"skills_list": "skills_tools",
|
||||
"skill_view": "skills_tools",
|
||||
# Browser automation tools
|
||||
"browser_navigate": "browser_tools",
|
||||
"browser_snapshot": "browser_tools",
|
||||
"browser_click": "browser_tools",
|
||||
"browser_type": "browser_tools",
|
||||
"browser_scroll": "browser_tools",
|
||||
"browser_back": "browser_tools",
|
||||
"browser_press": "browser_tools",
|
||||
"browser_close": "browser_tools",
|
||||
"browser_get_images": "browser_tools",
|
||||
"browser_vision": "browser_tools"
|
||||
}
|
||||
|
||||
return toolset_mapping.get(tool_name, "unknown")
|
||||
@@ -307,7 +397,8 @@ def get_toolset_for_tool(tool_name: str) -> str:
|
||||
|
||||
def get_tool_definitions(
|
||||
enabled_toolsets: List[str] = None,
|
||||
disabled_toolsets: List[str] = None
|
||||
disabled_toolsets: List[str] = None,
|
||||
quiet_mode: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get tool definitions for model API calls with toolset-based filtering.
|
||||
@@ -346,11 +437,11 @@ def get_tool_definitions(
|
||||
if check_firecrawl_api_key():
|
||||
for tool in get_web_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
if check_hecate_requirements():
|
||||
|
||||
if check_terminal_requirements():
|
||||
for tool in get_terminal_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
|
||||
if check_vision_requirements():
|
||||
for tool in get_vision_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
@@ -363,6 +454,14 @@ def get_tool_definitions(
|
||||
for tool in get_image_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
if check_skills_requirements():
|
||||
for tool in get_skills_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
if check_browser_requirements():
|
||||
for tool in get_browser_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
# Determine which tools to include based on toolsets
|
||||
tools_to_include = set()
|
||||
|
||||
@@ -375,14 +474,21 @@ def get_tool_definitions(
|
||||
print(f"✅ Enabled toolset '{toolset_name}': {', '.join(resolved_tools) if resolved_tools else 'no tools'}")
|
||||
else:
|
||||
# Try legacy compatibility
|
||||
if toolset_name in ["web_tools", "terminal_tools", "vision_tools", "moa_tools", "image_tools"]:
|
||||
if toolset_name in ["web_tools", "terminal_tools", "vision_tools", "moa_tools", "image_tools", "skills_tools", "browser_tools"]:
|
||||
# Map legacy names to new system
|
||||
legacy_map = {
|
||||
"web_tools": ["web_search", "web_extract", "web_crawl"],
|
||||
"web_tools": ["web_search", "web_extract"],
|
||||
"terminal_tools": ["terminal"],
|
||||
"vision_tools": ["vision_analyze"],
|
||||
"moa_tools": ["mixture_of_agents"],
|
||||
"image_tools": ["image_generate"]
|
||||
"image_tools": ["image_generate"],
|
||||
"skills_tools": ["skills_categories", "skills_list", "skill_view"],
|
||||
"browser_tools": [
|
||||
"browser_navigate", "browser_snapshot", "browser_click",
|
||||
"browser_type", "browser_scroll", "browser_back",
|
||||
"browser_press", "browser_close", "browser_get_images",
|
||||
"browser_vision"
|
||||
]
|
||||
}
|
||||
legacy_tools = legacy_map.get(toolset_name, [])
|
||||
tools_to_include.update(legacy_tools)
|
||||
@@ -410,13 +516,20 @@ def get_tool_definitions(
|
||||
print(f"🚫 Disabled toolset '{toolset_name}': {', '.join(resolved_tools) if resolved_tools else 'no tools'}")
|
||||
else:
|
||||
# Try legacy compatibility
|
||||
if toolset_name in ["web_tools", "terminal_tools", "vision_tools", "moa_tools", "image_tools"]:
|
||||
if toolset_name in ["web_tools", "terminal_tools", "vision_tools", "moa_tools", "image_tools", "skills_tools", "browser_tools"]:
|
||||
legacy_map = {
|
||||
"web_tools": ["web_search", "web_extract", "web_crawl"],
|
||||
"web_tools": ["web_search", "web_extract"],
|
||||
"terminal_tools": ["terminal"],
|
||||
"vision_tools": ["vision_analyze"],
|
||||
"moa_tools": ["mixture_of_agents"],
|
||||
"image_tools": ["image_generate"]
|
||||
"image_tools": ["image_generate"],
|
||||
"skills_tools": ["skills_categories", "skills_list", "skill_view"],
|
||||
"browser_tools": [
|
||||
"browser_navigate", "browser_snapshot", "browser_click",
|
||||
"browser_type", "browser_scroll", "browser_back",
|
||||
"browser_press", "browser_close", "browser_get_images",
|
||||
"browser_vision"
|
||||
]
|
||||
}
|
||||
legacy_tools = legacy_map.get(toolset_name, [])
|
||||
tools_to_include.difference_update(legacy_tools)
|
||||
@@ -439,11 +552,12 @@ def get_tool_definitions(
|
||||
# Sort tools for consistent ordering
|
||||
filtered_tools.sort(key=lambda t: t["function"]["name"])
|
||||
|
||||
if filtered_tools:
|
||||
tool_names = [t["function"]["name"] for t in filtered_tools]
|
||||
print(f"🛠️ Final tool selection ({len(filtered_tools)} tools): {', '.join(tool_names)}")
|
||||
else:
|
||||
print("🛠️ No tools selected (all filtered out or unavailable)")
|
||||
if not quiet_mode:
|
||||
if filtered_tools:
|
||||
tool_names = [t["function"]["name"] for t in filtered_tools]
|
||||
print(f"🛠️ Final tool selection ({len(filtered_tools)} tools): {', '.join(tool_names)}")
|
||||
else:
|
||||
print("🛠️ No tools selected (all filtered out or unavailable)")
|
||||
|
||||
return filtered_tools
|
||||
|
||||
@@ -471,38 +585,32 @@ def handle_web_function_call(function_name: str, function_args: Dict[str, Any])
|
||||
# Run async function in event loop
|
||||
return asyncio.run(web_extract_tool(urls, "markdown"))
|
||||
|
||||
elif function_name == "web_crawl":
|
||||
url = function_args.get("url", "")
|
||||
instructions = function_args.get("instructions")
|
||||
# Run async function in event loop
|
||||
return asyncio.run(web_crawl_tool(url, instructions, "basic"))
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown web function: {function_name}"})
|
||||
return json.dumps({"error": f"Unknown web function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
def handle_terminal_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Handle function calls for terminal tools.
|
||||
|
||||
Uses mini-swe-agent backend (local/docker/modal) by default.
|
||||
|
||||
Args:
|
||||
function_name (str): Name of the terminal function to call
|
||||
function_args (Dict): Arguments for the function
|
||||
task_id (str): Unique identifier for this task to isolate VMs between concurrent tasks (optional)
|
||||
task_id (str): Unique identifier for this task to isolate environments between concurrent tasks (optional)
|
||||
|
||||
Returns:
|
||||
str: Function result as JSON string
|
||||
"""
|
||||
if function_name == "terminal":
|
||||
command = function_args.get("command")
|
||||
input_keys = function_args.get("input_keys")
|
||||
background = function_args.get("background", False)
|
||||
idle_threshold = function_args.get("idle_threshold", 5.0)
|
||||
timeout = function_args.get("timeout")
|
||||
|
||||
return terminal_tool(command, input_keys, None, background, idle_threshold, timeout, task_id)
|
||||
return terminal_tool(command=command, background=background, timeout=timeout, task_id=task_id)
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown terminal function: {function_name}"})
|
||||
return json.dumps({"error": f"Unknown terminal function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_vision_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
@@ -523,10 +631,10 @@ def handle_vision_function_call(function_name: str, function_args: Dict[str, Any
|
||||
full_prompt = f"Fully describe and explain everything about this image, then answer the following question:\n\n{question}"
|
||||
|
||||
# Run async function in event loop
|
||||
return asyncio.run(vision_analyze_tool(image_url, full_prompt, "gemini-2.5-flash"))
|
||||
return asyncio.run(vision_analyze_tool(image_url, full_prompt, "google/gemini-3-flash-preview"))
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown vision function: {function_name}"})
|
||||
return json.dumps({"error": f"Unknown vision function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_moa_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
@@ -544,13 +652,13 @@ def handle_moa_function_call(function_name: str, function_args: Dict[str, Any])
|
||||
user_prompt = function_args.get("user_prompt", "")
|
||||
|
||||
if not user_prompt:
|
||||
return json.dumps({"error": "user_prompt is required for MoA processing"})
|
||||
return json.dumps({"error": "user_prompt is required for MoA processing"}, ensure_ascii=False)
|
||||
|
||||
# Run async function in event loop
|
||||
return asyncio.run(mixture_of_agents_tool(user_prompt=user_prompt))
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown MoA function: {function_name}"})
|
||||
return json.dumps({"error": f"Unknown MoA function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_image_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
@@ -568,18 +676,15 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any]
|
||||
prompt = function_args.get("prompt", "")
|
||||
|
||||
if not prompt:
|
||||
return json.dumps({"success": False, "image": None})
|
||||
return json.dumps({"success": False, "image": None}, ensure_ascii=False)
|
||||
|
||||
image_size = function_args.get("image_size", "landscape_16_9")
|
||||
aspect_ratio = function_args.get("aspect_ratio", "landscape")
|
||||
|
||||
# Use fixed internal defaults for all other parameters (not exposed to model)
|
||||
num_inference_steps = 50
|
||||
guidance_scale = 4.5
|
||||
num_images = 1
|
||||
enable_safety_checker = True
|
||||
output_format = "png"
|
||||
acceleration = "none"
|
||||
allow_nsfw_images = True
|
||||
seed = None
|
||||
|
||||
# Run async function in event loop with proper handling for multiprocessing
|
||||
@@ -598,24 +703,101 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any]
|
||||
# Run the coroutine in the event loop
|
||||
result = loop.run_until_complete(image_generate_tool(
|
||||
prompt=prompt,
|
||||
image_size=image_size,
|
||||
aspect_ratio=aspect_ratio,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
num_images=num_images,
|
||||
enable_safety_checker=enable_safety_checker,
|
||||
output_format=output_format,
|
||||
acceleration=acceleration,
|
||||
allow_nsfw_images=allow_nsfw_images,
|
||||
seed=seed
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown image generation function: {function_name}"})
|
||||
return json.dumps({"error": f"Unknown image generation function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
|
||||
def handle_skills_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Handle function calls for skills tools.
|
||||
|
||||
Args:
|
||||
function_name (str): Name of the skills function to call
|
||||
function_args (Dict): Arguments for the function
|
||||
|
||||
Returns:
|
||||
str: Function result as JSON string
|
||||
"""
|
||||
if function_name == "skills_categories":
|
||||
return skills_categories()
|
||||
|
||||
elif function_name == "skills_list":
|
||||
category = function_args.get("category")
|
||||
return skills_list(category=category)
|
||||
|
||||
elif function_name == "skill_view":
|
||||
name = function_args.get("name", "")
|
||||
if not name:
|
||||
return json.dumps({"error": "Skill name is required"}, ensure_ascii=False)
|
||||
file_path = function_args.get("file_path")
|
||||
return skill_view(name, file_path=file_path)
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown skills function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
# Browser tool handlers mapping
|
||||
BROWSER_HANDLERS = {
|
||||
"browser_navigate": browser_navigate,
|
||||
"browser_click": browser_click,
|
||||
"browser_type": browser_type,
|
||||
"browser_scroll": browser_scroll,
|
||||
"browser_back": browser_back,
|
||||
"browser_press": browser_press,
|
||||
"browser_close": browser_close,
|
||||
"browser_get_images": browser_get_images,
|
||||
"browser_vision": browser_vision,
|
||||
}
|
||||
|
||||
|
||||
def handle_browser_function_call(
|
||||
function_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
task_id: Optional[str] = None,
|
||||
user_task: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Handle function calls for browser automation tools.
|
||||
|
||||
Args:
|
||||
function_name (str): Name of the browser function to call
|
||||
function_args (Dict): Arguments for the function
|
||||
task_id (str): Task identifier for session isolation
|
||||
user_task (str): User's current task (for task-aware extraction in snapshots)
|
||||
|
||||
Returns:
|
||||
str: Function result as JSON string
|
||||
"""
|
||||
# Special handling for browser_snapshot which needs user_task for extraction
|
||||
if function_name == "browser_snapshot":
|
||||
full = function_args.get("full", False)
|
||||
return browser_snapshot(full=full, task_id=task_id, user_task=user_task)
|
||||
|
||||
# Handle other browser tools
|
||||
if function_name in BROWSER_HANDLERS:
|
||||
handler = BROWSER_HANDLERS[function_name]
|
||||
# Add task_id to args
|
||||
return handler(**function_args, task_id=task_id)
|
||||
|
||||
return json.dumps({"error": f"Unknown browser function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_function_call(
|
||||
function_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
task_id: Optional[str] = None,
|
||||
user_task: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Main function call dispatcher that routes calls to appropriate toolsets.
|
||||
|
||||
@@ -626,7 +808,8 @@ def handle_function_call(function_name: str, function_args: Dict[str, Any], task
|
||||
Args:
|
||||
function_name (str): Name of the function to call
|
||||
function_args (Dict): Arguments for the function
|
||||
task_id (str): Unique identifier for this task to isolate VMs between concurrent tasks (optional)
|
||||
task_id (str): Unique identifier for this task to isolate VMs/sessions between concurrent tasks (optional)
|
||||
user_task (str): The user's original task/query (used for task-aware content extraction) (optional)
|
||||
|
||||
Returns:
|
||||
str: Function result as JSON string
|
||||
@@ -636,7 +819,7 @@ def handle_function_call(function_name: str, function_args: Dict[str, Any], task
|
||||
"""
|
||||
try:
|
||||
# Route web tools
|
||||
if function_name in ["web_search", "web_extract", "web_crawl"]:
|
||||
if function_name in ["web_search", "web_extract"]:
|
||||
return handle_web_function_call(function_name, function_args)
|
||||
|
||||
# Route terminal tools
|
||||
@@ -655,15 +838,29 @@ def handle_function_call(function_name: str, function_args: Dict[str, Any], task
|
||||
elif function_name in ["image_generate"]:
|
||||
return handle_image_function_call(function_name, function_args)
|
||||
|
||||
# Route skills tools
|
||||
elif function_name in ["skills_categories", "skills_list", "skill_view"]:
|
||||
return handle_skills_function_call(function_name, function_args)
|
||||
|
||||
# Route browser automation tools
|
||||
elif function_name in [
|
||||
"browser_navigate", "browser_snapshot", "browser_click",
|
||||
"browser_type", "browser_scroll", "browser_back",
|
||||
"browser_press", "browser_close", "browser_get_images",
|
||||
"browser_vision"
|
||||
]:
|
||||
return handle_browser_function_call(function_name, function_args, task_id, user_task)
|
||||
|
||||
else:
|
||||
error_msg = f"Unknown function: {function_name}"
|
||||
print(f"❌ {error_msg}")
|
||||
return json.dumps({"error": error_msg})
|
||||
|
||||
|
||||
return json.dumps({"error": error_msg}, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing {function_name}: {str(e)}"
|
||||
print(f"❌ {error_msg}")
|
||||
return json.dumps({"error": error_msg})
|
||||
return json.dumps({"error": error_msg}, ensure_ascii=False)
|
||||
|
||||
def get_available_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
@@ -675,15 +872,15 @@ def get_available_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||
toolsets = {
|
||||
"web_tools": {
|
||||
"available": check_firecrawl_api_key(),
|
||||
"tools": ["web_search_tool", "web_extract_tool", "web_crawl_tool"],
|
||||
"description": "Web search, content extraction, and website crawling tools",
|
||||
"tools": ["web_search_tool", "web_extract_tool"],
|
||||
"description": "Web search and content extraction tools",
|
||||
"requirements": ["FIRECRAWL_API_KEY environment variable"]
|
||||
},
|
||||
"terminal_tools": {
|
||||
"available": check_hecate_requirements(),
|
||||
"available": check_terminal_requirements(),
|
||||
"tools": ["terminal_tool"],
|
||||
"description": "Execute commands with optional interactive session support on Linux VMs",
|
||||
"requirements": ["MORPH_API_KEY environment variable", "hecate package"]
|
||||
"description": "Execute commands using mini-swe-agent (local/docker/modal)",
|
||||
"requirements": ["mini-swe-agent package, TERMINAL_ENV to select backend"]
|
||||
},
|
||||
"vision_tools": {
|
||||
"available": check_vision_requirements(),
|
||||
@@ -702,6 +899,23 @@ def get_available_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||
"tools": ["image_generate_tool"],
|
||||
"description": "Generate high-quality images from text prompts using FAL.ai's FLUX.1 Krea model with automatic 2x upscaling for enhanced quality",
|
||||
"requirements": ["FAL_KEY environment variable", "fal-client package"]
|
||||
},
|
||||
"skills_tools": {
|
||||
"available": check_skills_requirements(),
|
||||
"tools": ["skills_categories", "skills_list", "skill_view"],
|
||||
"description": "Access skill documents that provide specialized instructions, guidelines, or knowledge the agent can load on demand",
|
||||
"requirements": ["skills/ directory in repo root"]
|
||||
},
|
||||
"browser_tools": {
|
||||
"available": check_browser_requirements(),
|
||||
"tools": [
|
||||
"browser_navigate", "browser_snapshot", "browser_click",
|
||||
"browser_type", "browser_scroll", "browser_back",
|
||||
"browser_press", "browser_close", "browser_get_images",
|
||||
"browser_vision"
|
||||
],
|
||||
"description": "Browser automation for web interaction using agent-browser CLI with Browserbase cloud execution",
|
||||
"requirements": ["BROWSERBASE_API_KEY", "BROWSERBASE_PROJECT_ID", "agent-browser npm package"]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -710,16 +924,18 @@ def get_available_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||
def check_toolset_requirements() -> Dict[str, bool]:
|
||||
"""
|
||||
Check if all requirements for available toolsets are met.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict: Status of each toolset's requirements
|
||||
"""
|
||||
return {
|
||||
"web_tools": check_firecrawl_api_key(),
|
||||
"terminal_tools": check_hecate_requirements(),
|
||||
"terminal_tools": check_terminal_requirements(),
|
||||
"vision_tools": check_vision_requirements(),
|
||||
"moa_tools": check_moa_requirements(),
|
||||
"image_tools": check_image_generation_requirements()
|
||||
"image_tools": check_image_generation_requirements(),
|
||||
"skills_tools": check_skills_requirements(),
|
||||
"browser_tools": check_browser_requirements()
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
73
package-lock.json
generated
Normal file
73
package-lock.json
generated
Normal file
@@ -0,0 +1,73 @@
|
||||
{
|
||||
"name": "hermes-agent",
|
||||
"version": "1.0.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "hermes-agent",
|
||||
"version": "1.0.0",
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"agent-browser": "^0.7.6"
|
||||
}
|
||||
},
|
||||
"node_modules/agent-browser": {
|
||||
"version": "0.7.6",
|
||||
"resolved": "https://registry.npmjs.org/agent-browser/-/agent-browser-0.7.6.tgz",
|
||||
"integrity": "sha512-BDmzFlTM0siqn5P8LSBxgOBUNGv02Vo7RYztvXXjNOwQ+8rFJILWfBPxmw+57l/PcMst61AscjIe8uZ5sWrRZQ==",
|
||||
"hasInstallScript": true,
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"playwright-core": "^1.57.0",
|
||||
"ws": "^8.19.0",
|
||||
"zod": "^3.22.4"
|
||||
},
|
||||
"bin": {
|
||||
"agent-browser": "bin/agent-browser"
|
||||
}
|
||||
},
|
||||
"node_modules/playwright-core": {
|
||||
"version": "1.58.0",
|
||||
"resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.58.0.tgz",
|
||||
"integrity": "sha512-aaoB1RWrdNi3//rOeKuMiS65UCcgOVljU46At6eFcOFPFHWtd2weHRRow6z/n+Lec0Lvu0k9ZPKJSjPugikirw==",
|
||||
"license": "Apache-2.0",
|
||||
"bin": {
|
||||
"playwright-core": "cli.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/ws": {
|
||||
"version": "8.19.0",
|
||||
"resolved": "https://registry.npmjs.org/ws/-/ws-8.19.0.tgz",
|
||||
"integrity": "sha512-blAT2mjOEIi0ZzruJfIhb3nps74PRWTCz1IjglWEEpQl5XS/UNama6u2/rjFkDDouqr4L67ry+1aGIALViWjDg==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=10.0.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"bufferutil": "^4.0.1",
|
||||
"utf-8-validate": ">=5.0.2"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"bufferutil": {
|
||||
"optional": true
|
||||
},
|
||||
"utf-8-validate": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/zod": {
|
||||
"version": "3.25.76",
|
||||
"resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz",
|
||||
"integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==",
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/colinhacks"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
24
package.json
Normal file
24
package.json
Normal file
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"name": "hermes-agent",
|
||||
"version": "1.0.0",
|
||||
"description": "An AI agent with advanced tool-calling capabilities, featuring a flexible toolsets system for organizing and managing tools.",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"postinstall": "echo '✅ Browser tools ready. Run: python run_agent.py --help'"
|
||||
},
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/NousResearch/Hermes-Agent.git"
|
||||
},
|
||||
"license": "MIT",
|
||||
"bugs": {
|
||||
"url": "https://github.com/NousResearch/Hermes-Agent/issues"
|
||||
},
|
||||
"homepage": "https://github.com/NousResearch/Hermes-Agent#readme",
|
||||
"dependencies": {
|
||||
"agent-browser": "^0.7.6"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
}
|
||||
}
|
||||
@@ -8,21 +8,38 @@ version = "0.1.0"
|
||||
description = "AI agent with advanced tool-calling and toolsets"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
authors = [{ name = "Hermes Agent" }]
|
||||
authors = [{ name = "Nous Research" }]
|
||||
license = { text = "MIT" }
|
||||
dependencies = [
|
||||
"firecrawl-py",
|
||||
# Core
|
||||
"openai",
|
||||
"fal-client",
|
||||
"python-dotenv",
|
||||
"fire"
|
||||
"fire",
|
||||
"httpx",
|
||||
"rich",
|
||||
"tenacity",
|
||||
"pyyaml",
|
||||
"requests",
|
||||
"jinja2",
|
||||
"pydantic>=2.0",
|
||||
# Tools
|
||||
"firecrawl-py",
|
||||
"fal-client",
|
||||
# mini-swe-agent deps (terminal tool)
|
||||
"litellm>=1.75.5",
|
||||
"typer",
|
||||
"platformdirs",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
modal = ["modal", "boto3"]
|
||||
dev = ["pytest", "pytest-asyncio"]
|
||||
|
||||
[project.scripts]
|
||||
hermes-agent = "run_agent:main"
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = ["run_agent", "model_tools", "toolsets"]
|
||||
py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajectory_compressor", "toolset_distributions"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["tools"]
|
||||
|
||||
@@ -1,6 +1,34 @@
|
||||
firecrawl-py
|
||||
# Core dependencies
|
||||
openai
|
||||
fal-client
|
||||
python-dotenv
|
||||
fire
|
||||
httpx
|
||||
httpx
|
||||
rich
|
||||
tenacity
|
||||
prompt_toolkit
|
||||
|
||||
# Web tools
|
||||
firecrawl-py
|
||||
|
||||
# Image generation
|
||||
fal-client
|
||||
|
||||
# mini-swe-agent dependencies (for terminal tool)
|
||||
# Note: Install mini-swe-agent itself with: pip install -e ./mini-swe-agent
|
||||
pyyaml
|
||||
requests
|
||||
jinja2
|
||||
pydantic>=2.0
|
||||
litellm>=1.75.5
|
||||
typer
|
||||
platformdirs
|
||||
|
||||
# Optional: For Docker backend (recommended)
|
||||
# Requires Docker installed and user in 'docker' group
|
||||
|
||||
# Optional: For Modal backend (cloud execution)
|
||||
# modal
|
||||
# boto3
|
||||
|
||||
# Optional: Legacy Hecate terminal backend
|
||||
# git+ssh://git@github.com/NousResearch/hecate.git
|
||||
|
||||
1100
run_agent.py
1100
run_agent.py
File diff suppressed because it is too large
Load Diff
@@ -1,12 +0,0 @@
|
||||
python batch_runner.py \
|
||||
--dataset_file="hermes-agent-imagen-data/hermes_agent_imagen_eval.jsonl" \
|
||||
--batch_size=10 \
|
||||
--run_name="imagen_eval_gpt5" \
|
||||
--distribution="image_gen" \
|
||||
--model="gpt-5" \
|
||||
--base_url="https://api.openai.com/v1" \
|
||||
--api_key="${OPENAI_API_KEY}" \
|
||||
--num_workers=4 \
|
||||
--max_turns=5 \
|
||||
--verbose \
|
||||
--ephemeral_system_prompt="When generating an image for the user view the image by using the vision_analyze tool to ensure it is what the user wanted. If it isn't feel free to retry a few times. If none are perfect, choose the best option that is the closest match, and explain its imperfections. If the image generation tool fails, try again a few times. If the vision analyze tool fails, provide the image to the user and explain it is your best effort attempt."
|
||||
@@ -1,12 +0,0 @@
|
||||
python batch_runner.py \
|
||||
--dataset_file="hermes-agent-megascience-data/hermes_agent_megascience_eval.jsonl" \
|
||||
--batch_size=10 \
|
||||
--run_name="megascience_eval_glm4-6-fixedterminal" \
|
||||
--distribution="science" \
|
||||
--model="z-ai/glm-4.6" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--api_key="${OPENROUTER_API_KEY}" \
|
||||
--num_workers=5 \
|
||||
--max_turns=30 \
|
||||
--verbose \
|
||||
--ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use a tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work."
|
||||
411
scripts/sample_and_compress.py
Normal file
411
scripts/sample_and_compress.py
Normal file
@@ -0,0 +1,411 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Sample and Compress HuggingFace Datasets
|
||||
|
||||
Downloads trajectories from multiple HuggingFace datasets, randomly samples them,
|
||||
and runs trajectory compression to fit within a target token budget.
|
||||
|
||||
Usage:
|
||||
python scripts/sample_and_compress.py
|
||||
|
||||
# Custom sample size
|
||||
python scripts/sample_and_compress.py --total_samples=5000
|
||||
|
||||
# Custom output name
|
||||
python scripts/sample_and_compress.py --output_name=compressed_16k
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Tuple
|
||||
import fire
|
||||
|
||||
# Load environment variables
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# Default datasets to sample from
|
||||
DEFAULT_DATASETS = [
|
||||
"NousResearch/swe-terminus-agent-glm-kimi-minimax",
|
||||
"NousResearch/hermes-agent-megascience-sft1",
|
||||
"NousResearch/Hermes-Agent-Thinking-GLM-4.7-SFT2",
|
||||
"NousResearch/Hermes-Agent-Thinking-GLM-4.7-SFT1",
|
||||
"NousResearch/terminal-tasks-glm-hermes-agent"
|
||||
]
|
||||
|
||||
|
||||
def load_dataset_from_hf(dataset_name: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load a dataset from HuggingFace.
|
||||
|
||||
Args:
|
||||
dataset_name: HuggingFace dataset name (e.g., "NousResearch/dataset-name")
|
||||
|
||||
Returns:
|
||||
List of trajectory entries
|
||||
"""
|
||||
from datasets import load_dataset
|
||||
|
||||
print(f" Loading {dataset_name}...")
|
||||
|
||||
try:
|
||||
# Try loading with default config
|
||||
ds = load_dataset(dataset_name, split="train")
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Error loading {dataset_name}: {e}")
|
||||
return []
|
||||
|
||||
# Convert to list of dicts
|
||||
entries = []
|
||||
for item in ds:
|
||||
# Handle different possible formats
|
||||
if "conversations" in item:
|
||||
entries.append({"conversations": item["conversations"]})
|
||||
elif "messages" in item:
|
||||
# Convert messages format to conversations format if needed
|
||||
entries.append({"conversations": item["messages"]})
|
||||
else:
|
||||
# Assume the whole item is the entry
|
||||
entries.append(dict(item))
|
||||
|
||||
print(f" ✅ Loaded {len(entries):,} entries from {dataset_name}")
|
||||
return entries
|
||||
|
||||
|
||||
# Global tokenizer for multiprocessing (set in worker init)
|
||||
_TOKENIZER = None
|
||||
|
||||
|
||||
def _init_tokenizer_worker(tokenizer_name: str):
|
||||
"""Initialize tokenizer in worker process."""
|
||||
global _TOKENIZER
|
||||
from transformers import AutoTokenizer
|
||||
_TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
|
||||
|
||||
|
||||
def _count_tokens_for_entry(entry: Dict) -> Tuple[Dict, int]:
|
||||
"""
|
||||
Count tokens for a single entry (used in parallel processing).
|
||||
|
||||
Args:
|
||||
entry: Trajectory entry with 'conversations' field
|
||||
|
||||
Returns:
|
||||
Tuple of (entry, token_count)
|
||||
"""
|
||||
global _TOKENIZER
|
||||
|
||||
conversations = entry.get("conversations", [])
|
||||
if not conversations:
|
||||
return entry, 0
|
||||
|
||||
total = 0
|
||||
for turn in conversations:
|
||||
value = turn.get("value", "")
|
||||
if value:
|
||||
try:
|
||||
total += len(_TOKENIZER.encode(value))
|
||||
except:
|
||||
# Fallback to character estimate
|
||||
total += len(value) // 4
|
||||
|
||||
return entry, total
|
||||
|
||||
|
||||
def sample_from_datasets(
|
||||
datasets: List[str],
|
||||
total_samples: int,
|
||||
min_tokens: int = 16000,
|
||||
tokenizer_name: str = "moonshotai/Kimi-K2-Thinking",
|
||||
seed: int = 42,
|
||||
num_proc: int = 8
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load all datasets, filter by token count, then randomly sample from combined pool.
|
||||
|
||||
Args:
|
||||
datasets: List of HuggingFace dataset names
|
||||
total_samples: Total number of samples to collect
|
||||
min_tokens: Minimum token count to include (only sample trajectories >= this)
|
||||
tokenizer_name: HuggingFace tokenizer for counting tokens
|
||||
seed: Random seed for reproducibility
|
||||
num_proc: Number of parallel processes for tokenization
|
||||
|
||||
Returns:
|
||||
List of sampled trajectory entries
|
||||
"""
|
||||
from multiprocessing import Pool
|
||||
from functools import partial
|
||||
|
||||
random.seed(seed)
|
||||
|
||||
print(f"\n📥 Loading {len(datasets)} datasets...")
|
||||
print(f" Minimum tokens: {min_tokens:,} (filtering smaller trajectories)")
|
||||
print(f" Parallel workers: {num_proc}")
|
||||
print()
|
||||
|
||||
# Load ALL entries from all datasets into one pool
|
||||
all_entries = []
|
||||
|
||||
for dataset_name in datasets:
|
||||
entries = load_dataset_from_hf(dataset_name)
|
||||
|
||||
if not entries:
|
||||
print(f" ⚠️ Skipping {dataset_name} (no entries loaded)")
|
||||
continue
|
||||
|
||||
# Add source metadata to each entry
|
||||
for entry in entries:
|
||||
entry["_source_dataset"] = dataset_name
|
||||
|
||||
all_entries.extend(entries)
|
||||
|
||||
print(f"\n📊 Total entries loaded: {len(all_entries):,}")
|
||||
|
||||
# Filter by token count using parallel processing
|
||||
print(f"\n🔍 Filtering trajectories with >= {min_tokens:,} tokens (using {num_proc} workers)...")
|
||||
|
||||
filtered_entries = []
|
||||
token_counts = []
|
||||
|
||||
# Use multiprocessing for token counting
|
||||
with Pool(
|
||||
processes=num_proc,
|
||||
initializer=_init_tokenizer_worker,
|
||||
initargs=(tokenizer_name,)
|
||||
) as pool:
|
||||
# Process in chunks and show progress
|
||||
chunk_size = 1000
|
||||
processed = 0
|
||||
|
||||
for result in pool.imap_unordered(_count_tokens_for_entry, all_entries, chunksize=100):
|
||||
entry, token_count = result
|
||||
processed += 1
|
||||
|
||||
if processed % chunk_size == 0:
|
||||
print(f" Processed {processed:,}/{len(all_entries):,}...", end="\r")
|
||||
|
||||
if token_count >= min_tokens:
|
||||
entry["_original_tokens"] = token_count
|
||||
filtered_entries.append(entry)
|
||||
token_counts.append(token_count)
|
||||
|
||||
print(f"\n ✅ Found {len(filtered_entries):,} trajectories >= {min_tokens:,} tokens")
|
||||
|
||||
if token_counts:
|
||||
avg_tokens = sum(token_counts) / len(token_counts)
|
||||
print(f" 📈 Token stats: min={min(token_counts):,}, max={max(token_counts):,}, avg={avg_tokens:,.0f}")
|
||||
|
||||
# Random sample from the filtered pool
|
||||
if len(filtered_entries) <= total_samples:
|
||||
print(f"\n⚠️ Only {len(filtered_entries):,} trajectories available, using all of them")
|
||||
sampled = filtered_entries
|
||||
else:
|
||||
sampled = random.sample(filtered_entries, total_samples)
|
||||
print(f"\n✅ Randomly sampled {len(sampled):,} trajectories from pool of {len(filtered_entries):,}")
|
||||
|
||||
# Show source distribution
|
||||
source_counts = {}
|
||||
for entry in sampled:
|
||||
source = entry.get("_source_dataset", "unknown").split("/")[-1]
|
||||
source_counts[source] = source_counts.get(source, 0) + 1
|
||||
|
||||
print(f"\n📌 Sample distribution by source:")
|
||||
for source, count in sorted(source_counts.items()):
|
||||
print(f" {source}: {count:,}")
|
||||
|
||||
# Shuffle
|
||||
random.shuffle(sampled)
|
||||
|
||||
return sampled
|
||||
|
||||
|
||||
def save_samples_for_compression(
|
||||
samples: List[Dict[str, Any]],
|
||||
output_dir: Path,
|
||||
batch_size: int = 100
|
||||
):
|
||||
"""
|
||||
Save samples to JSONL files for trajectory compression.
|
||||
|
||||
Args:
|
||||
samples: List of trajectory entries
|
||||
output_dir: Directory to save JSONL files
|
||||
batch_size: Number of entries per file
|
||||
"""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Split into batches
|
||||
num_batches = (len(samples) + batch_size - 1) // batch_size
|
||||
|
||||
print(f"\n💾 Saving {len(samples)} samples to {output_dir}")
|
||||
print(f" Batch size: {batch_size}, Total batches: {num_batches}")
|
||||
|
||||
for i in range(num_batches):
|
||||
start_idx = i * batch_size
|
||||
end_idx = min((i + 1) * batch_size, len(samples))
|
||||
batch = samples[start_idx:end_idx]
|
||||
|
||||
output_file = output_dir / f"batch_{i}.jsonl"
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for entry in batch:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
print(f" ✅ Saved {num_batches} batch files")
|
||||
|
||||
|
||||
def run_compression(input_dir: Path, output_dir: Path, config_path: str):
|
||||
"""
|
||||
Run trajectory compression on the sampled data.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing JSONL files to compress
|
||||
output_dir: Directory for compressed output
|
||||
config_path: Path to compression config YAML
|
||||
"""
|
||||
# Import the compressor
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from trajectory_compressor import TrajectoryCompressor, CompressionConfig
|
||||
|
||||
print(f"\n🗜️ Running trajectory compression...")
|
||||
print(f" Input: {input_dir}")
|
||||
print(f" Output: {output_dir}")
|
||||
print(f" Config: {config_path}")
|
||||
|
||||
# Load config
|
||||
config = CompressionConfig.from_yaml(config_path)
|
||||
|
||||
# Initialize compressor
|
||||
compressor = TrajectoryCompressor(config)
|
||||
|
||||
# Run compression
|
||||
compressor.process_directory(input_dir, output_dir)
|
||||
|
||||
|
||||
def merge_output_to_single_jsonl(input_dir: Path, output_file: Path):
|
||||
"""
|
||||
Merge all JSONL files in a directory into a single JSONL file.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing JSONL files
|
||||
output_file: Output JSONL file path
|
||||
"""
|
||||
print(f"\n📦 Merging output files into {output_file.name}...")
|
||||
|
||||
all_entries = []
|
||||
for jsonl_file in sorted(input_dir.glob("*.jsonl")):
|
||||
if jsonl_file.name == output_file.name:
|
||||
continue
|
||||
with open(jsonl_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
all_entries.append(json.loads(line))
|
||||
|
||||
# Write merged file
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for entry in all_entries:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
print(f" ✅ Merged {len(all_entries):,} entries into {output_file.name}")
|
||||
return output_file
|
||||
|
||||
|
||||
def main(
|
||||
total_samples: int = 2500,
|
||||
output_name: str = "compressed_agentic",
|
||||
datasets: str = None,
|
||||
config: str = "configs/trajectory_compression.yaml",
|
||||
seed: int = 42,
|
||||
batch_size: int = 100,
|
||||
min_tokens: int = 16000,
|
||||
num_proc: int = 8,
|
||||
skip_download: bool = False,
|
||||
):
|
||||
"""
|
||||
Sample trajectories from HuggingFace datasets and run compression.
|
||||
|
||||
Args:
|
||||
total_samples: Total number of samples to collect (default: 2500)
|
||||
output_name: Name for output directory/file (default: "compressed_agentic")
|
||||
datasets: Comma-separated list of dataset names (uses defaults if not provided)
|
||||
config: Path to compression config YAML
|
||||
seed: Random seed for reproducibility
|
||||
batch_size: Number of entries per JSONL file during processing
|
||||
min_tokens: Minimum token count to filter trajectories (default: 16000)
|
||||
num_proc: Number of parallel workers for tokenization (default: 8)
|
||||
skip_download: Skip download and use existing sampled data
|
||||
"""
|
||||
print("=" * 70)
|
||||
print("📊 TRAJECTORY SAMPLING AND COMPRESSION")
|
||||
print("=" * 70)
|
||||
|
||||
# Parse datasets
|
||||
if datasets:
|
||||
dataset_list = [d.strip() for d in datasets.split(",")]
|
||||
else:
|
||||
dataset_list = DEFAULT_DATASETS
|
||||
|
||||
print(f"\n📋 Configuration:")
|
||||
print(f" Total samples: {total_samples:,}")
|
||||
print(f" Min tokens filter: {min_tokens:,}")
|
||||
print(f" Parallel workers: {num_proc}")
|
||||
print(f" Datasets: {len(dataset_list)}")
|
||||
for ds in dataset_list:
|
||||
print(f" - {ds}")
|
||||
print(f" Output name: {output_name}")
|
||||
print(f" Config: {config}")
|
||||
print(f" Seed: {seed}")
|
||||
|
||||
# Setup paths
|
||||
base_dir = Path(__file__).parent.parent
|
||||
sampled_dir = base_dir / "data" / f"{output_name}_raw"
|
||||
compressed_dir = base_dir / "data" / f"{output_name}_batches"
|
||||
final_output = base_dir / "data" / f"{output_name}.jsonl"
|
||||
|
||||
if not skip_download:
|
||||
# Step 1: Download, filter by token count, and sample from combined pool
|
||||
samples = sample_from_datasets(
|
||||
dataset_list,
|
||||
total_samples,
|
||||
min_tokens=min_tokens,
|
||||
seed=seed,
|
||||
num_proc=num_proc
|
||||
)
|
||||
|
||||
if not samples:
|
||||
print("❌ No samples collected. Exiting.")
|
||||
return
|
||||
|
||||
# Step 2: Save to JSONL files
|
||||
save_samples_for_compression(samples, sampled_dir, batch_size)
|
||||
else:
|
||||
print(f"\n⏭️ Skipping download, using existing data in {sampled_dir}")
|
||||
|
||||
# Step 3: Run compression
|
||||
config_path = base_dir / config
|
||||
if not config_path.exists():
|
||||
print(f"❌ Config not found: {config_path}")
|
||||
return
|
||||
|
||||
run_compression(sampled_dir, compressed_dir, str(config_path))
|
||||
|
||||
# Step 4: Merge into single JSONL file
|
||||
merge_output_to_single_jsonl(compressed_dir, final_output)
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("✅ COMPLETE!")
|
||||
print("=" * 70)
|
||||
print(f"\n📁 Raw samples: {sampled_dir}")
|
||||
print(f"📁 Compressed batches: {compressed_dir}")
|
||||
print(f"📁 Final output: {final_output}")
|
||||
print(f"\nTo upload to HuggingFace:")
|
||||
print(f" huggingface-cli upload NousResearch/{output_name} {final_output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
332
skills/mlops-knowledge/accelerate/SKILL.md
Normal file
332
skills/mlops-knowledge/accelerate/SKILL.md
Normal file
@@ -0,0 +1,332 @@
|
||||
---
|
||||
name: huggingface-accelerate
|
||||
description: Simplest distributed training API. 4 lines to add distributed support to any PyTorch script. Unified API for DeepSpeed/FSDP/Megatron/DDP. Automatic device placement, mixed precision (FP16/BF16/FP8). Interactive config, single launch command. HuggingFace ecosystem standard.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Distributed Training, HuggingFace, Accelerate, DeepSpeed, FSDP, Mixed Precision, PyTorch, DDP, Unified API, Simple]
|
||||
dependencies: [accelerate, torch, transformers]
|
||||
---
|
||||
|
||||
# HuggingFace Accelerate - Unified Distributed Training
|
||||
|
||||
## Quick start
|
||||
|
||||
Accelerate simplifies distributed training to 4 lines of code.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
**Convert PyTorch script** (4 lines):
|
||||
```python
|
||||
import torch
|
||||
+ from accelerate import Accelerator
|
||||
|
||||
+ accelerator = Accelerator()
|
||||
|
||||
model = torch.nn.Transformer()
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset)
|
||||
|
||||
+ model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
for batch in dataloader:
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
- loss.backward()
|
||||
+ accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Run** (single command):
|
||||
```bash
|
||||
accelerate launch train.py
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: From single GPU to multi-GPU
|
||||
|
||||
**Original script**:
|
||||
```python
|
||||
# train.py
|
||||
import torch
|
||||
|
||||
model = torch.nn.Linear(10, 2).to('cuda')
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
|
||||
|
||||
for epoch in range(10):
|
||||
for batch in dataloader:
|
||||
batch = batch.to('cuda')
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**With Accelerate** (4 lines added):
|
||||
```python
|
||||
# train.py
|
||||
import torch
|
||||
from accelerate import Accelerator # +1
|
||||
|
||||
accelerator = Accelerator() # +2
|
||||
|
||||
model = torch.nn.Linear(10, 2)
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) # +3
|
||||
|
||||
for epoch in range(10):
|
||||
for batch in dataloader:
|
||||
# No .to('cuda') needed - automatic!
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch).mean()
|
||||
accelerator.backward(loss) # +4
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Configure** (interactive):
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
**Questions**:
|
||||
- Which machine? (single/multi GPU/TPU/CPU)
|
||||
- How many machines? (1)
|
||||
- Mixed precision? (no/fp16/bf16/fp8)
|
||||
- DeepSpeed? (no/yes)
|
||||
|
||||
**Launch** (works on any setup):
|
||||
```bash
|
||||
# Single GPU
|
||||
accelerate launch train.py
|
||||
|
||||
# Multi-GPU (8 GPUs)
|
||||
accelerate launch --multi_gpu --num_processes 8 train.py
|
||||
|
||||
# Multi-node
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 0 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
train.py
|
||||
```
|
||||
|
||||
### Workflow 2: Mixed precision training
|
||||
|
||||
**Enable FP16/BF16**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
# FP16 (with gradient scaling)
|
||||
accelerator = Accelerator(mixed_precision='fp16')
|
||||
|
||||
# BF16 (no scaling, more stable)
|
||||
accelerator = Accelerator(mixed_precision='bf16')
|
||||
|
||||
# FP8 (H100+)
|
||||
accelerator = Accelerator(mixed_precision='fp8')
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
# Everything else is automatic!
|
||||
for batch in dataloader:
|
||||
with accelerator.autocast(): # Optional, done automatically
|
||||
loss = model(batch)
|
||||
accelerator.backward(loss)
|
||||
```
|
||||
|
||||
### Workflow 3: DeepSpeed ZeRO integration
|
||||
|
||||
**Enable DeepSpeed ZeRO-2**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
deepspeed_plugin={
|
||||
"zero_stage": 2, # ZeRO-2
|
||||
"offload_optimizer": False,
|
||||
"gradient_accumulation_steps": 4
|
||||
}
|
||||
)
|
||||
|
||||
# Same code as before!
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
```
|
||||
|
||||
**Or via config**:
|
||||
```bash
|
||||
accelerate config
|
||||
# Select: DeepSpeed → ZeRO-2
|
||||
```
|
||||
|
||||
**deepspeed_config.json**:
|
||||
```json
|
||||
{
|
||||
"fp16": {"enabled": false},
|
||||
"bf16": {"enabled": true},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {"device": "cpu"},
|
||||
"allgather_bucket_size": 5e8,
|
||||
"reduce_bucket_size": 5e8
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Launch**:
|
||||
```bash
|
||||
accelerate launch --config_file deepspeed_config.json train.py
|
||||
```
|
||||
|
||||
### Workflow 4: FSDP (Fully Sharded Data Parallel)
|
||||
|
||||
**Enable FSDP**:
|
||||
```python
|
||||
from accelerate import Accelerator, FullyShardedDataParallelPlugin
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent
|
||||
auto_wrap_policy="TRANSFORMER_AUTO_WRAP",
|
||||
cpu_offload=False
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
fsdp_plugin=fsdp_plugin
|
||||
)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
```
|
||||
|
||||
**Or via config**:
|
||||
```bash
|
||||
accelerate config
|
||||
# Select: FSDP → Full Shard → No CPU Offload
|
||||
```
|
||||
|
||||
### Workflow 5: Gradient accumulation
|
||||
|
||||
**Accumulate gradients**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator(gradient_accumulation_steps=4)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
for batch in dataloader:
|
||||
with accelerator.accumulate(model): # Handles accumulation
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Effective batch size**: `batch_size * num_gpus * gradient_accumulation_steps`
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use Accelerate when**:
|
||||
- Want simplest distributed training
|
||||
- Need single script for any hardware
|
||||
- Use HuggingFace ecosystem
|
||||
- Want flexibility (DDP/DeepSpeed/FSDP/Megatron)
|
||||
- Need quick prototyping
|
||||
|
||||
**Key advantages**:
|
||||
- **4 lines**: Minimal code changes
|
||||
- **Unified API**: Same code for DDP, DeepSpeed, FSDP, Megatron
|
||||
- **Automatic**: Device placement, mixed precision, sharding
|
||||
- **Interactive config**: No manual launcher setup
|
||||
- **Single launch**: Works everywhere
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **PyTorch Lightning**: Need callbacks, high-level abstractions
|
||||
- **Ray Train**: Multi-node orchestration, hyperparameter tuning
|
||||
- **DeepSpeed**: Direct API control, advanced features
|
||||
- **Raw DDP**: Maximum control, minimal abstraction
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Wrong device placement**
|
||||
|
||||
Don't manually move to device:
|
||||
```python
|
||||
# WRONG
|
||||
batch = batch.to('cuda')
|
||||
|
||||
# CORRECT
|
||||
# Accelerate handles it automatically after prepare()
|
||||
```
|
||||
|
||||
**Issue: Gradient accumulation not working**
|
||||
|
||||
Use context manager:
|
||||
```python
|
||||
# CORRECT
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Issue: Checkpointing in distributed**
|
||||
|
||||
Use accelerator methods:
|
||||
```python
|
||||
# Save only on main process
|
||||
if accelerator.is_main_process:
|
||||
accelerator.save_state('checkpoint/')
|
||||
|
||||
# Load on all processes
|
||||
accelerator.load_state('checkpoint/')
|
||||
```
|
||||
|
||||
**Issue: Different results with FSDP**
|
||||
|
||||
Ensure same random seed:
|
||||
```python
|
||||
from accelerate.utils import set_seed
|
||||
set_seed(42)
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Megatron integration**: See [references/megatron-integration.md](references/megatron-integration.md) for tensor parallelism, pipeline parallelism, and sequence parallelism setup.
|
||||
|
||||
**Custom plugins**: See [references/custom-plugins.md](references/custom-plugins.md) for creating custom distributed plugins and advanced configuration.
|
||||
|
||||
**Performance tuning**: See [references/performance.md](references/performance.md) for profiling, memory optimization, and best practices.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **CPU**: Works (slow)
|
||||
- **Single GPU**: Works
|
||||
- **Multi-GPU**: DDP (default), DeepSpeed, or FSDP
|
||||
- **Multi-node**: DDP, DeepSpeed, FSDP, Megatron
|
||||
- **TPU**: Supported
|
||||
- **Apple MPS**: Supported
|
||||
|
||||
**Launcher requirements**:
|
||||
- **DDP**: `torch.distributed.run` (built-in)
|
||||
- **DeepSpeed**: `deepspeed` (pip install deepspeed)
|
||||
- **FSDP**: PyTorch 1.12+ (built-in)
|
||||
- **Megatron**: Custom setup
|
||||
|
||||
## Resources
|
||||
|
||||
- Docs: https://huggingface.co/docs/accelerate
|
||||
- GitHub: https://github.com/huggingface/accelerate
|
||||
- Version: 1.11.0+
|
||||
- Tutorial: "Accelerate your scripts"
|
||||
- Examples: https://github.com/huggingface/accelerate/tree/main/examples
|
||||
- Used by: HuggingFace Transformers, TRL, PEFT, all HF libraries
|
||||
|
||||
|
||||
|
||||
453
skills/mlops-knowledge/accelerate/references/custom-plugins.md
Normal file
453
skills/mlops-knowledge/accelerate/references/custom-plugins.md
Normal file
@@ -0,0 +1,453 @@
|
||||
# Custom Plugins for Accelerate
|
||||
|
||||
## Overview
|
||||
|
||||
Accelerate allows creating **custom plugins** to extend distributed training strategies beyond built-in options (DDP, FSDP, DeepSpeed).
|
||||
|
||||
## Plugin Architecture
|
||||
|
||||
### Base Plugin Structure
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
"""Custom training plugin."""
|
||||
|
||||
# Plugin configuration
|
||||
param1: int = 1
|
||||
param2: str = "default"
|
||||
|
||||
def __post_init__(self):
|
||||
# Validation logic
|
||||
if self.param1 < 1:
|
||||
raise ValueError("param1 must be >= 1")
|
||||
```
|
||||
|
||||
### Using Custom Plugin
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
# Create plugin
|
||||
custom_plugin = CustomPlugin(param1=4, param2="value")
|
||||
|
||||
# Pass to Accelerator
|
||||
accelerator = Accelerator(
|
||||
custom_plugin=custom_plugin # Not a real parameter, example only
|
||||
)
|
||||
```
|
||||
|
||||
## Built-In Plugin Examples
|
||||
|
||||
### 1. GradScalerKwargs (FP16 Configuration)
|
||||
|
||||
```python
|
||||
from accelerate.utils import GradScalerKwargs
|
||||
|
||||
# Configure gradient scaler for FP16
|
||||
scaler_kwargs = GradScalerKwargs(
|
||||
init_scale=2.**16, # Initial loss scale
|
||||
growth_factor=2.0, # Scale growth rate
|
||||
backoff_factor=0.5, # Scale backoff rate
|
||||
growth_interval=2000, # Steps between scale increases
|
||||
enabled=True # Enable scaler
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp16',
|
||||
kwargs_handlers=[scaler_kwargs] # Pass as kwargs handler
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Fine-tune FP16 gradient scaling behavior
|
||||
|
||||
### 2. DistributedDataParallelKwargs
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
# Configure DDP behavior
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
bucket_cap_mb=25, # Gradient bucketing size
|
||||
find_unused_parameters=False, # Find unused params (slower)
|
||||
check_reduction=False, # Check gradient reduction
|
||||
gradient_as_bucket_view=True, # Memory optimization
|
||||
static_graph=False # Static computation graph
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
kwargs_handlers=[ddp_kwargs]
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Optimize DDP performance for specific models
|
||||
|
||||
### 3. FP8RecipeKwargs (H100 FP8)
|
||||
|
||||
```python
|
||||
from accelerate.utils import FP8RecipeKwargs
|
||||
|
||||
# Configure FP8 training (H100)
|
||||
fp8_recipe = FP8RecipeKwargs(
|
||||
backend="te", # TransformerEngine backend
|
||||
margin=0, # Scaling margin
|
||||
interval=1, # Scaling interval
|
||||
fp8_format="HYBRID", # E4M3 + E5M2 hybrid
|
||||
amax_history_len=1024, # AMAX history length
|
||||
amax_compute_algo="max" # AMAX computation algorithm
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp8',
|
||||
kwargs_handlers=[fp8_recipe]
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Ultra-fast training on H100 GPUs
|
||||
|
||||
## Custom DeepSpeed Configuration
|
||||
|
||||
### ZeRO-3 with CPU Offload
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
|
||||
# Custom DeepSpeed config
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=3, # ZeRO-3
|
||||
offload_optimizer_device="cpu", # CPU offload optimizer
|
||||
offload_param_device="cpu", # CPU offload parameters
|
||||
zero3_init_flag=True, # ZeRO-3 initialization
|
||||
zero3_save_16bit_model=True, # Save FP16 weights
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
deepspeed_plugin=ds_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
### ZeRO-2 with NVMe Offload
|
||||
|
||||
```python
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=2,
|
||||
offload_optimizer_device="nvme", # NVMe offload
|
||||
offload_param_device="nvme",
|
||||
nvme_path="/local_nvme", # NVMe mount path
|
||||
)
|
||||
```
|
||||
|
||||
### Custom JSON Config
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
# Load custom DeepSpeed config
|
||||
with open('deepspeed_config.json', 'r') as f:
|
||||
ds_config = json.load(f)
|
||||
|
||||
ds_plugin = DeepSpeedPlugin(hf_ds_config=ds_config)
|
||||
|
||||
accelerator = Accelerator(deepspeed_plugin=ds_plugin)
|
||||
```
|
||||
|
||||
**Example config** (`deepspeed_config.json`):
|
||||
```json
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": 1.0,
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"stage3_prefetch_bucket_size": 5e8,
|
||||
"stage3_param_persistence_threshold": 1e6,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"steps_per_print": 100,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
```
|
||||
|
||||
## Custom FSDP Configuration
|
||||
|
||||
### FSDP with Custom Auto-Wrap Policy
|
||||
|
||||
```python
|
||||
from accelerate.utils import FullyShardedDataParallelPlugin
|
||||
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
|
||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
||||
import functools
|
||||
|
||||
# Custom wrap policy (size-based)
|
||||
wrap_policy = functools.partial(
|
||||
size_based_auto_wrap_policy,
|
||||
min_num_params=1e6 # Wrap layers with 1M+ params
|
||||
)
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 equivalent
|
||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Prefetch strategy
|
||||
mixed_precision_policy=None, # Use Accelerator's mixed precision
|
||||
auto_wrap_policy=wrap_policy, # Custom wrapping
|
||||
cpu_offload=False,
|
||||
ignored_modules=None, # Modules to not wrap
|
||||
state_dict_type="FULL_STATE_DICT", # Save format
|
||||
optim_state_dict_config=None,
|
||||
limit_all_gathers=False,
|
||||
use_orig_params=True, # Use original param shapes
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
fsdp_plugin=fsdp_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
### FSDP with Transformer Auto-Wrap
|
||||
|
||||
```python
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
|
||||
|
||||
# Wrap at transformer block level
|
||||
wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={GPT2Block} # Wrap GPT2Block layers
|
||||
)
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
auto_wrap_policy=wrap_policy
|
||||
)
|
||||
```
|
||||
|
||||
## Creating Custom Training Strategy
|
||||
|
||||
### Example: Custom Gradient Accumulation
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
class CustomGradientAccumulation:
|
||||
def __init__(self, steps=4, adaptive=False):
|
||||
self.steps = steps
|
||||
self.adaptive = adaptive
|
||||
self.current_step = 0
|
||||
|
||||
def should_sync(self, loss):
|
||||
"""Decide whether to sync gradients."""
|
||||
self.current_step += 1
|
||||
|
||||
# Adaptive: sync on high loss
|
||||
if self.adaptive and loss > threshold:
|
||||
self.current_step = 0
|
||||
return True
|
||||
|
||||
# Regular: sync every N steps
|
||||
if self.current_step >= self.steps:
|
||||
self.current_step = 0
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# Usage
|
||||
custom_accum = CustomGradientAccumulation(steps=8, adaptive=True)
|
||||
accelerator = Accelerator()
|
||||
|
||||
for batch in dataloader:
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
# Scale loss
|
||||
loss = loss / custom_accum.steps
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Conditional sync
|
||||
if custom_accum.should_sync(loss.item()):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
### Example: Custom Mixed Precision
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class CustomMixedPrecision:
|
||||
"""Custom mixed precision with dynamic loss scaling."""
|
||||
|
||||
def __init__(self, init_scale=2**16, scale_window=2000):
|
||||
self.scaler = torch.cuda.amp.GradScaler(
|
||||
init_scale=init_scale,
|
||||
growth_interval=scale_window
|
||||
)
|
||||
self.scale_history = []
|
||||
|
||||
def scale_loss(self, loss):
|
||||
"""Scale loss for backward."""
|
||||
return self.scaler.scale(loss)
|
||||
|
||||
def unscale_and_clip(self, optimizer, max_norm=1.0):
|
||||
"""Unscale gradients and clip."""
|
||||
self.scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
optimizer.param_groups[0]['params'],
|
||||
max_norm
|
||||
)
|
||||
|
||||
def step(self, optimizer):
|
||||
"""Optimizer step with scaler update."""
|
||||
scale_before = self.scaler.get_scale()
|
||||
self.scaler.step(optimizer)
|
||||
self.scaler.update()
|
||||
scale_after = self.scaler.get_scale()
|
||||
|
||||
# Track scale changes
|
||||
if scale_before != scale_after:
|
||||
self.scale_history.append(scale_after)
|
||||
|
||||
# Usage
|
||||
custom_mp = CustomMixedPrecision()
|
||||
|
||||
for batch in dataloader:
|
||||
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||
loss = model(**batch).loss
|
||||
|
||||
scaled_loss = custom_mp.scale_loss(loss)
|
||||
scaled_loss.backward()
|
||||
|
||||
custom_mp.unscale_and_clip(optimizer, max_norm=1.0)
|
||||
custom_mp.step(optimizer)
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
## Advanced: Custom Distributed Backend
|
||||
|
||||
### Custom AllReduce Strategy
|
||||
|
||||
```python
|
||||
import torch.distributed as dist
|
||||
|
||||
class CustomAllReduce:
|
||||
"""Custom all-reduce with compression."""
|
||||
|
||||
def __init__(self, compression_ratio=0.1):
|
||||
self.compression_ratio = compression_ratio
|
||||
|
||||
def compress_gradients(self, tensor):
|
||||
"""Top-k gradient compression."""
|
||||
k = int(tensor.numel() * self.compression_ratio)
|
||||
values, indices = torch.topk(tensor.abs().view(-1), k)
|
||||
return values, indices
|
||||
|
||||
def all_reduce_compressed(self, tensor):
|
||||
"""All-reduce with gradient compression."""
|
||||
# Compress
|
||||
values, indices = self.compress_gradients(tensor)
|
||||
|
||||
# All-reduce compressed gradients
|
||||
dist.all_reduce(values, op=dist.ReduceOp.SUM)
|
||||
|
||||
# Decompress
|
||||
tensor_compressed = torch.zeros_like(tensor).view(-1)
|
||||
tensor_compressed[indices] = values / dist.get_world_size()
|
||||
|
||||
return tensor_compressed.view_as(tensor)
|
||||
|
||||
# Usage in training loop
|
||||
custom_ar = CustomAllReduce(compression_ratio=0.1)
|
||||
|
||||
for batch in dataloader:
|
||||
loss = model(**batch).loss
|
||||
loss.backward()
|
||||
|
||||
# Custom all-reduce
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad.data = custom_ar.all_reduce_compressed(param.grad.data)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
## Plugin Best Practices
|
||||
|
||||
### 1. Validation in `__post_init__`
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
learning_rate: float = 1e-3
|
||||
warmup_steps: int = 1000
|
||||
|
||||
def __post_init__(self):
|
||||
# Validate parameters
|
||||
if self.learning_rate <= 0:
|
||||
raise ValueError("learning_rate must be positive")
|
||||
if self.warmup_steps < 0:
|
||||
raise ValueError("warmup_steps must be non-negative")
|
||||
|
||||
# Compute derived values
|
||||
self.min_lr = self.learning_rate * 0.1
|
||||
```
|
||||
|
||||
### 2. Compatibility Checks
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
feature_enabled: bool = True
|
||||
|
||||
def is_compatible(self, accelerator):
|
||||
"""Check if plugin is compatible with accelerator config."""
|
||||
if self.feature_enabled and accelerator.mixed_precision == 'fp8':
|
||||
raise ValueError("Custom plugin not compatible with FP8")
|
||||
return True
|
||||
```
|
||||
|
||||
### 3. State Management
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
counter: int = 0
|
||||
history: list = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.history is None:
|
||||
self.history = []
|
||||
|
||||
def update_state(self, value):
|
||||
"""Update plugin state during training."""
|
||||
self.counter += 1
|
||||
self.history.append(value)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Accelerate Plugins: https://huggingface.co/docs/accelerate/package_reference/kwargs
|
||||
- DeepSpeed Config: https://www.deepspeed.ai/docs/config-json/
|
||||
- FSDP Guide: https://pytorch.org/docs/stable/fsdp.html
|
||||
- Custom Training Loops: https://huggingface.co/docs/accelerate/usage_guides/training_tpu
|
||||
@@ -0,0 +1,489 @@
|
||||
# Megatron Integration with Accelerate
|
||||
|
||||
## Overview
|
||||
|
||||
Accelerate supports Megatron-LM for massive model training with tensor parallelism and pipeline parallelism.
|
||||
|
||||
**Megatron capabilities**:
|
||||
- **Tensor Parallelism (TP)**: Split layers across GPUs
|
||||
- **Pipeline Parallelism (PP)**: Split model depth across GPUs
|
||||
- **Data Parallelism (DP)**: Replicate model across GPU groups
|
||||
- **Sequence Parallelism**: Split sequences for long contexts
|
||||
|
||||
## Setup
|
||||
|
||||
### Install Megatron-LM
|
||||
|
||||
```bash
|
||||
# Clone Megatron-LM repository
|
||||
git clone https://github.com/NVIDIA/Megatron-LM.git
|
||||
cd Megatron-LM
|
||||
pip install -e .
|
||||
|
||||
# Install Apex (NVIDIA optimizations)
|
||||
git clone https://github.com/NVIDIA/apex
|
||||
cd apex
|
||||
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \
|
||||
--config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
||||
```
|
||||
|
||||
### Accelerate Configuration
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
**Questions**:
|
||||
```
|
||||
In which compute environment are you running?
|
||||
> This machine
|
||||
|
||||
Which type of machine are you using?
|
||||
> Multi-GPU
|
||||
|
||||
How many different machines will you use?
|
||||
> 1
|
||||
|
||||
Do you want to use DeepSpeed/FSDP?
|
||||
> No
|
||||
|
||||
Do you want to use Megatron-LM?
|
||||
> Yes
|
||||
|
||||
What is the Tensor Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
Do you want to enable Sequence Parallelism?
|
||||
> No
|
||||
|
||||
What is the Pipeline Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
What is the Data Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
Where to perform activation checkpointing? ['SELECTIVE', 'FULL', 'NONE']
|
||||
> SELECTIVE
|
||||
|
||||
Where to perform activation partitioning? ['SEQUENTIAL', 'UNIFORM']
|
||||
> SEQUENTIAL
|
||||
```
|
||||
|
||||
**Generated config** (`~/.cache/huggingface/accelerate/default_config.yaml`):
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: MEGATRON_LM
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
megatron_lm_config:
|
||||
megatron_lm_gradient_clipping: 1.0
|
||||
megatron_lm_learning_rate_decay_iters: 320000
|
||||
megatron_lm_num_micro_batches: 1
|
||||
megatron_lm_pp_degree: 2
|
||||
megatron_lm_recompute_activations: true
|
||||
megatron_lm_sequence_parallelism: false
|
||||
megatron_lm_tp_degree: 2
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
## Parallelism Strategies
|
||||
|
||||
### Tensor Parallelism (TP)
|
||||
|
||||
**Splits each transformer layer across GPUs**:
|
||||
|
||||
```python
|
||||
# Layer split across 2 GPUs
|
||||
# GPU 0: First half of attention heads
|
||||
# GPU 1: Second half of attention heads
|
||||
|
||||
# Each GPU computes partial outputs
|
||||
# All-reduce combines results
|
||||
```
|
||||
|
||||
**TP degree recommendations**:
|
||||
- **TP=1**: No tensor parallelism (single GPU per layer)
|
||||
- **TP=2**: 2 GPUs per layer (good for 7-13B models)
|
||||
- **TP=4**: 4 GPUs per layer (good for 20-40B models)
|
||||
- **TP=8**: 8 GPUs per layer (good for 70B+ models)
|
||||
|
||||
**Benefits**:
|
||||
- Reduces memory per GPU
|
||||
- All-reduce communication (fast)
|
||||
|
||||
**Drawbacks**:
|
||||
- Requires fast inter-GPU bandwidth (NVLink)
|
||||
- Communication overhead per layer
|
||||
|
||||
### Pipeline Parallelism (PP)
|
||||
|
||||
**Splits model depth across GPUs**:
|
||||
|
||||
```python
|
||||
# 12-layer model, PP=4
|
||||
# GPU 0: Layers 0-2
|
||||
# GPU 1: Layers 3-5
|
||||
# GPU 2: Layers 6-8
|
||||
# GPU 3: Layers 9-11
|
||||
```
|
||||
|
||||
**PP degree recommendations**:
|
||||
- **PP=1**: No pipeline parallelism
|
||||
- **PP=2**: 2 pipeline stages (good for 20-40B models)
|
||||
- **PP=4**: 4 pipeline stages (good for 70B+ models)
|
||||
- **PP=8**: 8 pipeline stages (good for 175B+ models)
|
||||
|
||||
**Benefits**:
|
||||
- Linear memory reduction (4× PP = 4× less memory)
|
||||
- Works across nodes (slower interconnect OK)
|
||||
|
||||
**Drawbacks**:
|
||||
- Pipeline bubbles (idle time)
|
||||
- Requires micro-batching
|
||||
|
||||
### Data Parallelism (DP)
|
||||
|
||||
**Replicates model across GPU groups**:
|
||||
|
||||
```python
|
||||
# 8 GPUs, TP=2, PP=2, DP=2
|
||||
# Group 0 (GPUs 0-3): Full model replica
|
||||
# Group 1 (GPUs 4-7): Full model replica
|
||||
```
|
||||
|
||||
**DP degree**:
|
||||
- `DP = total_gpus / (TP × PP)`
|
||||
- Example: 8 GPUs, TP=2, PP=2 → DP=2
|
||||
|
||||
**Benefits**:
|
||||
- Increases throughput
|
||||
- Scales batch size
|
||||
|
||||
### Sequence Parallelism
|
||||
|
||||
**Splits long sequences across GPUs** (extends TP):
|
||||
|
||||
```python
|
||||
# 8K sequence, TP=2, Sequence Parallel=True
|
||||
# GPU 0: Tokens 0-4095
|
||||
# GPU 1: Tokens 4096-8191
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Enables very long sequences (100K+ tokens)
|
||||
- Reduces activation memory
|
||||
|
||||
**Requirements**:
|
||||
- Must use with TP > 1
|
||||
- RoPE/ALiBi position encodings work best
|
||||
|
||||
## Accelerate Code Example
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import MegatronLMPlugin
|
||||
|
||||
# Configure Megatron
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=2, # Tensor parallelism degree
|
||||
pp_degree=2, # Pipeline parallelism degree
|
||||
num_micro_batches=4, # Micro-batches for pipeline
|
||||
gradient_clipping=1.0, # Gradient clipping value
|
||||
sequence_parallelism=False, # Enable sequence parallelism
|
||||
recompute_activations=True, # Activation checkpointing
|
||||
use_distributed_optimizer=True, # Distributed optimizer
|
||||
custom_prepare_model_function=None, # Custom model prep
|
||||
)
|
||||
|
||||
# Initialize accelerator
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
megatron_lm_plugin=megatron_plugin
|
||||
)
|
||||
|
||||
# Prepare model and optimizer
|
||||
model, optimizer, train_dataloader = accelerator.prepare(
|
||||
model, optimizer, train_dataloader
|
||||
)
|
||||
|
||||
# Training loop (same as DDP!)
|
||||
for batch in train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
### Full Training Script
|
||||
|
||||
```python
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import MegatronLMPlugin
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
def main():
|
||||
# Megatron configuration
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=2,
|
||||
pp_degree=2,
|
||||
num_micro_batches=4,
|
||||
gradient_clipping=1.0,
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
gradient_accumulation_steps=8,
|
||||
megatron_lm_plugin=megatron_plugin
|
||||
)
|
||||
|
||||
# Model
|
||||
config = GPT2Config(
|
||||
n_layer=24,
|
||||
n_head=16,
|
||||
n_embd=1024,
|
||||
)
|
||||
model = GPT2LMHeadModel(config)
|
||||
|
||||
# Optimizer
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4)
|
||||
|
||||
# Prepare
|
||||
model, optimizer, train_loader = accelerator.prepare(
|
||||
model, optimizer, train_loader
|
||||
)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(num_epochs):
|
||||
for batch in train_loader:
|
||||
with accelerator.accumulate(model):
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Save checkpoint
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.save_state(f'checkpoint-epoch-{epoch}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
```
|
||||
|
||||
### Launch Command
|
||||
|
||||
```bash
|
||||
# 8 GPUs, TP=2, PP=2, DP=2
|
||||
accelerate launch --multi_gpu --num_processes 8 train.py
|
||||
|
||||
# Multi-node (2 nodes, 8 GPUs each)
|
||||
# Node 0
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 0 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
--main_process_port 29500 \
|
||||
train.py
|
||||
|
||||
# Node 1
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 1 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
--main_process_port 29500 \
|
||||
train.py
|
||||
```
|
||||
|
||||
## Activation Checkpointing
|
||||
|
||||
**Reduces memory by recomputing activations**:
|
||||
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
recompute_activations=True, # Enable checkpointing
|
||||
checkpoint_num_layers=1, # Checkpoint every N layers
|
||||
distribute_checkpointed_activations=True, # Distribute across TP
|
||||
partition_activations=True, # Partition in PP
|
||||
check_for_nan_in_loss_and_grad=True, # Stability check
|
||||
)
|
||||
```
|
||||
|
||||
**Strategies**:
|
||||
- `SELECTIVE`: Checkpoint transformer blocks only
|
||||
- `FULL`: Checkpoint all layers
|
||||
- `NONE`: No checkpointing
|
||||
|
||||
**Memory savings**: 30-50% with 10-15% slowdown
|
||||
|
||||
## Distributed Optimizer
|
||||
|
||||
**Shards optimizer state across DP ranks**:
|
||||
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
use_distributed_optimizer=True, # Enable sharded optimizer
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Reduces optimizer memory by DP degree
|
||||
- Example: DP=4 → 4× less optimizer memory per GPU
|
||||
|
||||
**Compatible with**:
|
||||
- AdamW, Adam, SGD
|
||||
- Mixed precision training
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Micro-Batch Size
|
||||
|
||||
```python
|
||||
# Pipeline parallelism requires micro-batching
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
pp_degree=4,
|
||||
num_micro_batches=16, # 16 micro-batches per pipeline
|
||||
)
|
||||
|
||||
# Effective batch = num_micro_batches × micro_batch_size × DP
|
||||
# Example: 16 × 2 × 4 = 128
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- More micro-batches → less pipeline bubble
|
||||
- Typical: 4-16 micro-batches
|
||||
|
||||
### Sequence Length
|
||||
|
||||
```python
|
||||
# For long sequences, enable sequence parallelism
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=4,
|
||||
sequence_parallelism=True, # Required: TP > 1
|
||||
)
|
||||
|
||||
# Enables sequences up to TP × normal limit
|
||||
# Example: TP=4, 8K normal → 32K with sequence parallel
|
||||
```
|
||||
|
||||
### GPU Topology
|
||||
|
||||
**NVLink required for TP**:
|
||||
```bash
|
||||
# Check NVLink topology
|
||||
nvidia-smi topo -m
|
||||
|
||||
# Good topology (NVLink between all GPUs)
|
||||
# GPU0 - GPU1: NV12 (fast)
|
||||
# GPU0 - GPU2: NV12 (fast)
|
||||
|
||||
# Bad topology (PCIe only)
|
||||
# GPU0 - GPU4: PHB (slow, avoid TP across these)
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- **TP**: Within same node (NVLink)
|
||||
- **PP**: Across nodes (slower interconnect OK)
|
||||
- **DP**: Any topology
|
||||
|
||||
## Model Size Guidelines
|
||||
|
||||
| Model Size | GPUs | TP | PP | DP | Micro-Batches |
|
||||
|------------|------|----|----|----|--------------|
|
||||
| 7B | 8 | 1 | 1 | 8 | 1 |
|
||||
| 13B | 8 | 2 | 1 | 4 | 1 |
|
||||
| 20B | 16 | 4 | 1 | 4 | 1 |
|
||||
| 40B | 32 | 4 | 2 | 4 | 4 |
|
||||
| 70B | 64 | 8 | 2 | 4 | 8 |
|
||||
| 175B | 128 | 8 | 4 | 4 | 16 |
|
||||
|
||||
**Assumptions**: BF16, 2K sequence length, A100 80GB
|
||||
|
||||
## Checkpointing
|
||||
|
||||
### Save Checkpoint
|
||||
|
||||
```python
|
||||
# Save full model state
|
||||
accelerator.save_state('checkpoint-1000')
|
||||
|
||||
# Megatron saves separate files per rank
|
||||
# checkpoint-1000/
|
||||
# pytorch_model_tp_0_pp_0.bin
|
||||
# pytorch_model_tp_0_pp_1.bin
|
||||
# pytorch_model_tp_1_pp_0.bin
|
||||
# pytorch_model_tp_1_pp_1.bin
|
||||
# optimizer_tp_0_pp_0.bin
|
||||
# ...
|
||||
```
|
||||
|
||||
### Load Checkpoint
|
||||
|
||||
```python
|
||||
# Resume training
|
||||
accelerator.load_state('checkpoint-1000')
|
||||
|
||||
# Automatically loads correct shard per rank
|
||||
```
|
||||
|
||||
### Convert to Standard PyTorch
|
||||
|
||||
```bash
|
||||
# Merge Megatron checkpoint to single file
|
||||
python merge_megatron_checkpoint.py \
|
||||
--checkpoint-dir checkpoint-1000 \
|
||||
--output pytorch_model.bin
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Issue: OOM with Pipeline Parallelism
|
||||
|
||||
**Solution**: Increase micro-batches
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
pp_degree=4,
|
||||
num_micro_batches=16, # Increase from 4
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Slow Training
|
||||
|
||||
**Check 1**: Pipeline bubbles (PP too high)
|
||||
```python
|
||||
# Reduce PP, increase TP
|
||||
tp_degree=4 # Increase
|
||||
pp_degree=2 # Decrease
|
||||
```
|
||||
|
||||
**Check 2**: Micro-batch size too small
|
||||
```python
|
||||
num_micro_batches=8 # Increase
|
||||
```
|
||||
|
||||
### Issue: NVLink Not Detected
|
||||
|
||||
```bash
|
||||
# Verify NVLink
|
||||
nvidia-smi nvlink -s
|
||||
|
||||
# If no NVLink, avoid TP > 1
|
||||
# Use PP or DP instead
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Megatron-LM: https://github.com/NVIDIA/Megatron-LM
|
||||
- Accelerate Megatron docs: https://huggingface.co/docs/accelerate/usage_guides/megatron_lm
|
||||
- Paper: "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism"
|
||||
- NVIDIA Apex: https://github.com/NVIDIA/apex
|
||||
525
skills/mlops-knowledge/accelerate/references/performance.md
Normal file
525
skills/mlops-knowledge/accelerate/references/performance.md
Normal file
@@ -0,0 +1,525 @@
|
||||
# Accelerate Performance Tuning
|
||||
|
||||
## Profiling
|
||||
|
||||
### Basic Profiling
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
import time
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
batch = next(iter(dataloader))
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Profile training loop
|
||||
start = time.time()
|
||||
total_batches = 100
|
||||
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= total_batches:
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
accelerator.wait_for_everyone() # Sync all processes
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Metrics
|
||||
batches_per_sec = total_batches / elapsed
|
||||
samples_per_sec = (total_batches * batch_size * accelerator.num_processes) / elapsed
|
||||
|
||||
print(f"Throughput: {samples_per_sec:.2f} samples/sec")
|
||||
print(f"Batches/sec: {batches_per_sec:.2f}")
|
||||
```
|
||||
|
||||
### PyTorch Profiler Integration
|
||||
|
||||
```python
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True
|
||||
) as prof:
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= 10: # Profile first 10 batches
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Print profiling results
|
||||
print(prof.key_averages().table(
|
||||
sort_by="cuda_time_total", row_limit=20
|
||||
))
|
||||
|
||||
# Export to Chrome tracing
|
||||
prof.export_chrome_trace("trace.json")
|
||||
# View at chrome://tracing
|
||||
```
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
### 1. Gradient Accumulation
|
||||
|
||||
**Problem**: Large batch size causes OOM
|
||||
|
||||
**Solution**: Accumulate gradients across micro-batches
|
||||
|
||||
```python
|
||||
accelerator = Accelerator(gradient_accumulation_steps=8)
|
||||
|
||||
# Effective batch = batch_size × accumulation_steps × num_gpus
|
||||
# Example: 4 × 8 × 8 = 256
|
||||
|
||||
for batch in dataloader:
|
||||
with accelerator.accumulate(model): # Handles accumulation logic
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
**Memory savings**: 8× less activation memory (with 8 accumulation steps)
|
||||
|
||||
### 2. Gradient Checkpointing
|
||||
|
||||
**Enable in model**:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"gpt2",
|
||||
use_cache=False # Required for gradient checkpointing
|
||||
)
|
||||
|
||||
# Enable checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Prepare with Accelerate
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Memory savings**: 30-50% with 10-15% slowdown
|
||||
|
||||
### 3. Mixed Precision
|
||||
|
||||
**BF16 (A100/H100)**:
|
||||
```python
|
||||
accelerator = Accelerator(mixed_precision='bf16')
|
||||
|
||||
# Automatic mixed precision
|
||||
for batch in dataloader:
|
||||
outputs = model(**batch) # Forward in BF16
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss) # Backward in FP32
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**FP16 (V100, older GPUs)**:
|
||||
```python
|
||||
from accelerate.utils import GradScalerKwargs
|
||||
|
||||
scaler_kwargs = GradScalerKwargs(
|
||||
init_scale=2.**16,
|
||||
growth_interval=2000
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp16',
|
||||
kwargs_handlers=[scaler_kwargs]
|
||||
)
|
||||
```
|
||||
|
||||
**Memory savings**: 50% compared to FP32
|
||||
|
||||
### 4. CPU Offloading (DeepSpeed)
|
||||
|
||||
```python
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=3,
|
||||
offload_optimizer_device="cpu", # Offload optimizer to CPU
|
||||
offload_param_device="cpu", # Offload parameters to CPU
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
deepspeed_plugin=ds_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
**Memory savings**: 10-20× for optimizer state, 5-10× for parameters
|
||||
|
||||
**Trade-off**: 20-30% slower due to CPU-GPU transfers
|
||||
|
||||
### 5. Flash Attention
|
||||
|
||||
```python
|
||||
# Install flash-attn
|
||||
# pip install flash-attn
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"gpt2",
|
||||
attn_implementation="flash_attention_2" # Enable Flash Attention 2
|
||||
)
|
||||
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Memory savings**: 50% for attention, 2× faster
|
||||
|
||||
**Requirements**: A100/H100, sequence length must be multiple of 128
|
||||
|
||||
## Communication Optimization
|
||||
|
||||
### 1. Gradient Bucketing (DDP)
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
bucket_cap_mb=25, # Bucket size for gradient reduction
|
||||
gradient_as_bucket_view=True, # Reduce memory copies
|
||||
static_graph=False # Set True if model doesn't change
|
||||
)
|
||||
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||
```
|
||||
|
||||
**Recommended bucket sizes**:
|
||||
- Small models (<1B): 25 MB
|
||||
- Medium models (1-10B): 50-100 MB
|
||||
- Large models (>10B): 100-200 MB
|
||||
|
||||
### 2. Find Unused Parameters
|
||||
|
||||
```python
|
||||
# Only enable if model has unused parameters (slower!)
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
find_unused_parameters=True
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Models with conditional branches (e.g., mixture of experts)
|
||||
|
||||
**Cost**: 10-20% slower
|
||||
|
||||
### 3. NCCL Tuning
|
||||
|
||||
```bash
|
||||
# Set environment variables before launch
|
||||
export NCCL_DEBUG=INFO # Debug info
|
||||
export NCCL_IB_DISABLE=0 # Enable InfiniBand
|
||||
export NCCL_SOCKET_IFNAME=eth0 # Network interface
|
||||
export NCCL_P2P_LEVEL=NVL # Use NVLink
|
||||
|
||||
accelerate launch train.py
|
||||
```
|
||||
|
||||
**NCCL_P2P_LEVEL options**:
|
||||
- `NVL`: NVLink (fastest, within node)
|
||||
- `PIX`: PCIe (fast, within node)
|
||||
- `PHB`: PCIe host bridge (slow, cross-node)
|
||||
|
||||
## Data Loading Optimization
|
||||
|
||||
### 1. DataLoader Workers
|
||||
|
||||
```python
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
train_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=32,
|
||||
num_workers=4, # Parallel data loading
|
||||
pin_memory=True, # Pin memory for faster GPU transfer
|
||||
prefetch_factor=2, # Prefetch batches per worker
|
||||
persistent_workers=True # Keep workers alive between epochs
|
||||
)
|
||||
|
||||
train_loader = accelerator.prepare(train_loader)
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- `num_workers`: 2-4 per GPU (8 GPUs → 16-32 workers)
|
||||
- `pin_memory`: Always True for GPU training
|
||||
- `prefetch_factor`: 2-4 (higher for slow data loading)
|
||||
|
||||
### 2. Data Preprocessing
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
# Bad: Preprocess during training (slow)
|
||||
dataset = load_dataset("openwebtext")
|
||||
|
||||
for batch in dataset:
|
||||
tokens = tokenizer(batch['text']) # Slow!
|
||||
...
|
||||
|
||||
# Good: Preprocess once, save
|
||||
dataset = load_dataset("openwebtext")
|
||||
tokenized = dataset.map(
|
||||
lambda x: tokenizer(x['text']),
|
||||
batched=True,
|
||||
num_proc=8, # Parallel preprocessing
|
||||
remove_columns=['text']
|
||||
)
|
||||
tokenized.save_to_disk("preprocessed_data")
|
||||
|
||||
# Load preprocessed
|
||||
dataset = load_from_disk("preprocessed_data")
|
||||
```
|
||||
|
||||
### 3. Faster Tokenization
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# Enable Rust-based tokenizers (10× faster)
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"gpt2",
|
||||
use_fast=True # Use fast Rust tokenizer
|
||||
)
|
||||
```
|
||||
|
||||
## Compilation (PyTorch 2.0+)
|
||||
|
||||
### Compile Model
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Compile model for faster execution
|
||||
model = torch.compile(
|
||||
model,
|
||||
mode="reduce-overhead", # Options: default, reduce-overhead, max-autotune
|
||||
fullgraph=False, # Compile entire graph (stricter)
|
||||
dynamic=True # Support dynamic shapes
|
||||
)
|
||||
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Speedup**: 10-50% depending on model
|
||||
|
||||
**Compilation modes**:
|
||||
- `default`: Balanced (best for most cases)
|
||||
- `reduce-overhead`: Min overhead (best for small batches)
|
||||
- `max-autotune`: Max performance (slow compile, best for production)
|
||||
|
||||
### Compilation Best Practices
|
||||
|
||||
```python
|
||||
# Bad: Compile after prepare (won't work)
|
||||
model = accelerator.prepare(model)
|
||||
model = torch.compile(model) # Error!
|
||||
|
||||
# Good: Compile before prepare
|
||||
model = torch.compile(model)
|
||||
model = accelerator.prepare(model)
|
||||
|
||||
# Training loop
|
||||
for batch in dataloader:
|
||||
# First iteration: slow (compilation)
|
||||
# Subsequent iterations: fast (compiled)
|
||||
outputs = model(**batch)
|
||||
...
|
||||
```
|
||||
|
||||
## Benchmarking Different Strategies
|
||||
|
||||
### Script Template
|
||||
|
||||
```python
|
||||
import time
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
|
||||
def benchmark_strategy(strategy_name, accelerator_kwargs):
|
||||
"""Benchmark a specific training strategy."""
|
||||
accelerator = Accelerator(**accelerator_kwargs)
|
||||
|
||||
# Setup
|
||||
model = create_model()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||||
dataloader = create_dataloader()
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(
|
||||
model, optimizer, dataloader
|
||||
)
|
||||
|
||||
# Warmup
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= 10:
|
||||
break
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Benchmark
|
||||
accelerator.wait_for_everyone()
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
|
||||
num_batches = 100
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= num_batches:
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Metrics
|
||||
throughput = (num_batches * batch_size * accelerator.num_processes) / elapsed
|
||||
memory_used = torch.cuda.max_memory_allocated() / 1e9 # GB
|
||||
|
||||
if accelerator.is_main_process:
|
||||
print(f"\n{strategy_name}:")
|
||||
print(f" Throughput: {throughput:.2f} samples/sec")
|
||||
print(f" Memory: {memory_used:.2f} GB")
|
||||
print(f" Time: {elapsed:.2f} sec")
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Benchmark different strategies
|
||||
strategies = [
|
||||
("DDP + FP32", {}),
|
||||
("DDP + BF16", {"mixed_precision": "bf16"}),
|
||||
("DDP + BF16 + GradAccum", {"mixed_precision": "bf16", "gradient_accumulation_steps": 4}),
|
||||
("FSDP", {"fsdp_plugin": fsdp_plugin}),
|
||||
("DeepSpeed ZeRO-2", {"deepspeed_plugin": ds_plugin_stage2}),
|
||||
("DeepSpeed ZeRO-3", {"deepspeed_plugin": ds_plugin_stage3}),
|
||||
]
|
||||
|
||||
for name, kwargs in strategies:
|
||||
benchmark_strategy(name, kwargs)
|
||||
```
|
||||
|
||||
## Performance Checklist
|
||||
|
||||
**Before training**:
|
||||
- [ ] Use BF16/FP16 mixed precision
|
||||
- [ ] Enable gradient checkpointing (if OOM)
|
||||
- [ ] Set appropriate `num_workers` (2-4 per GPU)
|
||||
- [ ] Enable `pin_memory=True`
|
||||
- [ ] Preprocess data once, not during training
|
||||
- [ ] Compile model with `torch.compile` (PyTorch 2.0+)
|
||||
|
||||
**For large models**:
|
||||
- [ ] Use FSDP or DeepSpeed ZeRO-3
|
||||
- [ ] Enable CPU offloading (if still OOM)
|
||||
- [ ] Use Flash Attention
|
||||
- [ ] Increase gradient accumulation
|
||||
|
||||
**For multi-node**:
|
||||
- [ ] Check network topology (InfiniBand > Ethernet)
|
||||
- [ ] Tune NCCL settings
|
||||
- [ ] Use larger bucket sizes for DDP
|
||||
- [ ] Verify NVLink for tensor parallelism
|
||||
|
||||
**Profiling**:
|
||||
- [ ] Profile first 10-100 batches
|
||||
- [ ] Check GPU utilization (`nvidia-smi dmon`)
|
||||
- [ ] Check data loading time (should be <5% of iteration)
|
||||
- [ ] Identify communication bottlenecks
|
||||
|
||||
## Common Performance Issues
|
||||
|
||||
### Issue: Low GPU Utilization (<80%)
|
||||
|
||||
**Cause 1**: Data loading bottleneck
|
||||
```python
|
||||
# Solution: Increase workers and prefetch
|
||||
num_workers=8
|
||||
prefetch_factor=4
|
||||
```
|
||||
|
||||
**Cause 2**: Small batch size
|
||||
```python
|
||||
# Solution: Increase batch size or use gradient accumulation
|
||||
batch_size=32 # Increase
|
||||
gradient_accumulation_steps=4 # Or accumulate
|
||||
```
|
||||
|
||||
### Issue: High Memory Usage
|
||||
|
||||
**Solution 1**: Gradient checkpointing
|
||||
```python
|
||||
model.gradient_checkpointing_enable()
|
||||
```
|
||||
|
||||
**Solution 2**: Reduce batch size, increase accumulation
|
||||
```python
|
||||
batch_size=8 # Reduce from 32
|
||||
gradient_accumulation_steps=16 # Maintain effective batch
|
||||
```
|
||||
|
||||
**Solution 3**: Use FSDP or DeepSpeed ZeRO-3
|
||||
```python
|
||||
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
|
||||
```
|
||||
|
||||
### Issue: Slow Multi-GPU Training
|
||||
|
||||
**Cause**: Communication bottleneck
|
||||
|
||||
**Check 1**: Gradient bucket size
|
||||
```python
|
||||
ddp_kwargs = DistributedDataParallelKwargs(bucket_cap_mb=100)
|
||||
```
|
||||
|
||||
**Check 2**: NCCL settings
|
||||
```bash
|
||||
export NCCL_DEBUG=INFO
|
||||
# Check for "Using NVLS" (good) vs "Using PHB" (bad)
|
||||
```
|
||||
|
||||
**Check 3**: Network bandwidth
|
||||
```bash
|
||||
# Test inter-GPU bandwidth
|
||||
nvidia-smi nvlink -s
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Accelerate Performance: https://huggingface.co/docs/accelerate/usage_guides/performance
|
||||
- PyTorch Profiler: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
|
||||
- NCCL Tuning: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html
|
||||
- Flash Attention: https://github.com/Dao-AILab/flash-attention
|
||||
564
skills/mlops-knowledge/audiocraft/SKILL.md
Normal file
564
skills/mlops-knowledge/audiocraft/SKILL.md
Normal file
@@ -0,0 +1,564 @@
|
||||
---
|
||||
name: audiocraft-audio-generation
|
||||
description: PyTorch library for audio generation including text-to-music (MusicGen) and text-to-sound (AudioGen). Use when you need to generate music from text descriptions, create sound effects, or perform melody-conditioned music generation.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Multimodal, Audio Generation, Text-to-Music, Text-to-Audio, MusicGen]
|
||||
dependencies: [audiocraft, torch>=2.0.0, transformers>=4.30.0]
|
||||
---
|
||||
|
||||
# AudioCraft: Audio Generation
|
||||
|
||||
Comprehensive guide to using Meta's AudioCraft for text-to-music and text-to-audio generation with MusicGen, AudioGen, and EnCodec.
|
||||
|
||||
## When to use AudioCraft
|
||||
|
||||
**Use AudioCraft when:**
|
||||
- Need to generate music from text descriptions
|
||||
- Creating sound effects and environmental audio
|
||||
- Building music generation applications
|
||||
- Need melody-conditioned music generation
|
||||
- Want stereo audio output
|
||||
- Require controllable music generation with style transfer
|
||||
|
||||
**Key features:**
|
||||
- **MusicGen**: Text-to-music generation with melody conditioning
|
||||
- **AudioGen**: Text-to-sound effects generation
|
||||
- **EnCodec**: High-fidelity neural audio codec
|
||||
- **Multiple model sizes**: Small (300M) to Large (3.3B)
|
||||
- **Stereo support**: Full stereo audio generation
|
||||
- **Style conditioning**: MusicGen-Style for reference-based generation
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Stable Audio**: For longer commercial music generation
|
||||
- **Bark**: For text-to-speech with music/sound effects
|
||||
- **Riffusion**: For spectogram-based music generation
|
||||
- **OpenAI Jukebox**: For raw audio generation with lyrics
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# From PyPI
|
||||
pip install audiocraft
|
||||
|
||||
# From GitHub (latest)
|
||||
pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Or use HuggingFace Transformers
|
||||
pip install transformers torch torchaudio
|
||||
```
|
||||
|
||||
### Basic text-to-music (AudioCraft)
|
||||
|
||||
```python
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Set generation parameters
|
||||
model.set_generation_params(
|
||||
duration=8, # seconds
|
||||
top_k=250,
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# Generate from text
|
||||
descriptions = ["happy upbeat electronic dance music with synths"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save audio
|
||||
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Using HuggingFace Transformers
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
import scipy
|
||||
|
||||
# Load model and processor
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
model.to("cuda")
|
||||
|
||||
# Generate music
|
||||
inputs = processor(
|
||||
text=["80s pop track with bassy drums and synth"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
do_sample=True,
|
||||
guidance_scale=3,
|
||||
max_new_tokens=256
|
||||
)
|
||||
|
||||
# Save
|
||||
sampling_rate = model.config.audio_encoder.sampling_rate
|
||||
scipy.io.wavfile.write("output.wav", rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())
|
||||
```
|
||||
|
||||
### Text-to-sound with AudioGen
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
|
||||
# Load AudioGen
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
|
||||
model.set_generation_params(duration=5)
|
||||
|
||||
# Generate sound effects
|
||||
descriptions = ["dog barking in a park with birds chirping"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
torchaudio.save("sound.wav", wav[0].cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## Core concepts
|
||||
|
||||
### Architecture overview
|
||||
|
||||
```
|
||||
AudioCraft Architecture:
|
||||
┌──────────────────────────────────────────────────────────────┐
|
||||
│ Text Encoder (T5) │
|
||||
│ │ │
|
||||
│ Text Embeddings │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ Transformer Decoder (LM) │
|
||||
│ Auto-regressively generates audio tokens │
|
||||
│ Using efficient token interleaving patterns │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ EnCodec Audio Decoder │
|
||||
│ Converts tokens back to audio waveform │
|
||||
└──────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Model variants
|
||||
|
||||
| Model | Size | Description | Use Case |
|
||||
|-------|------|-------------|----------|
|
||||
| `musicgen-small` | 300M | Text-to-music | Quick generation |
|
||||
| `musicgen-medium` | 1.5B | Text-to-music | Balanced |
|
||||
| `musicgen-large` | 3.3B | Text-to-music | Best quality |
|
||||
| `musicgen-melody` | 1.5B | Text + melody | Melody conditioning |
|
||||
| `musicgen-melody-large` | 3.3B | Text + melody | Best melody |
|
||||
| `musicgen-stereo-*` | Varies | Stereo output | Stereo generation |
|
||||
| `musicgen-style` | 1.5B | Style transfer | Reference-based |
|
||||
| `audiogen-medium` | 1.5B | Text-to-sound | Sound effects |
|
||||
|
||||
### Generation parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `duration` | 8.0 | Length in seconds (1-120) |
|
||||
| `top_k` | 250 | Top-k sampling |
|
||||
| `top_p` | 0.0 | Nucleus sampling (0 = disabled) |
|
||||
| `temperature` | 1.0 | Sampling temperature |
|
||||
| `cfg_coef` | 3.0 | Classifier-free guidance |
|
||||
|
||||
## MusicGen usage
|
||||
|
||||
### Text-to-music generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# Configure generation
|
||||
model.set_generation_params(
|
||||
duration=30, # Up to 30 seconds
|
||||
top_k=250, # Sampling diversity
|
||||
top_p=0.0, # 0 = use top_k only
|
||||
temperature=1.0, # Creativity (higher = more varied)
|
||||
cfg_coef=3.0 # Text adherence (higher = stricter)
|
||||
)
|
||||
|
||||
# Generate multiple samples
|
||||
descriptions = [
|
||||
"epic orchestral soundtrack with strings and brass",
|
||||
"chill lo-fi hip hop beat with jazzy piano",
|
||||
"energetic rock song with electric guitar"
|
||||
]
|
||||
|
||||
# Generate (returns [batch, channels, samples])
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save each
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"music_{i}.wav", audio.cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Melody-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
# Load melody model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
model.set_generation_params(duration=30)
|
||||
|
||||
# Load melody audio
|
||||
melody, sr = torchaudio.load("melody.wav")
|
||||
|
||||
# Generate with melody conditioning
|
||||
descriptions = ["acoustic guitar folk song"]
|
||||
wav = model.generate_with_chroma(descriptions, melody, sr)
|
||||
|
||||
torchaudio.save("melody_conditioned.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Stereo generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load stereo model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
model.set_generation_params(duration=15)
|
||||
|
||||
descriptions = ["ambient electronic music with wide stereo panning"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# wav shape: [batch, 2, samples] for stereo
|
||||
print(f"Stereo shape: {wav.shape}") # [1, 2, 480000]
|
||||
torchaudio.save("stereo.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Audio continuation
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-medium")
|
||||
|
||||
# Load audio to continue
|
||||
import torchaudio
|
||||
audio, sr = torchaudio.load("intro.wav")
|
||||
|
||||
# Process with text and audio
|
||||
inputs = processor(
|
||||
audio=audio.squeeze().numpy(),
|
||||
sampling_rate=sr,
|
||||
text=["continue with a epic chorus"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Generate continuation
|
||||
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=512)
|
||||
```
|
||||
|
||||
## MusicGen-Style usage
|
||||
|
||||
### Style-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load style model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-style')
|
||||
|
||||
# Configure generation with style
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=5.0 # Style influence
|
||||
)
|
||||
|
||||
# Configure style conditioner
|
||||
model.set_style_conditioner_params(
|
||||
eval_q=3, # RVQ quantizers (1-6)
|
||||
excerpt_length=3.0 # Style excerpt length
|
||||
)
|
||||
|
||||
# Load style reference
|
||||
style_audio, sr = torchaudio.load("reference_style.wav")
|
||||
|
||||
# Generate with text + style
|
||||
descriptions = ["upbeat dance track"]
|
||||
wav = model.generate_with_style(descriptions, style_audio, sr)
|
||||
```
|
||||
|
||||
### Style-only generation (no text)
|
||||
|
||||
```python
|
||||
# Generate matching style without text prompt
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=None # Disable double CFG for style-only
|
||||
)
|
||||
|
||||
wav = model.generate_with_style([None], style_audio, sr)
|
||||
```
|
||||
|
||||
## AudioGen usage
|
||||
|
||||
### Sound effect generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Generate various sounds
|
||||
descriptions = [
|
||||
"thunderstorm with heavy rain and lightning",
|
||||
"busy city traffic with car horns",
|
||||
"ocean waves crashing on rocks",
|
||||
"crackling campfire in forest"
|
||||
]
|
||||
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"sound_{i}.wav", audio.cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## EnCodec usage
|
||||
|
||||
### Audio compression
|
||||
|
||||
```python
|
||||
from audiocraft.models import CompressionModel
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
# Load EnCodec
|
||||
model = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
# Load audio
|
||||
wav, sr = torchaudio.load("audio.wav")
|
||||
|
||||
# Ensure correct sample rate
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
wav = resampler(wav)
|
||||
|
||||
# Encode to tokens
|
||||
with torch.no_grad():
|
||||
encoded = model.encode(wav.unsqueeze(0))
|
||||
codes = encoded[0] # Audio codes
|
||||
|
||||
# Decode back to audio
|
||||
with torch.no_grad():
|
||||
decoded = model.decode(codes)
|
||||
|
||||
torchaudio.save("reconstructed.wav", decoded[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Music generation pipeline
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
class MusicGenerator:
|
||||
def __init__(self, model_name="facebook/musicgen-medium"):
|
||||
self.model = MusicGen.get_pretrained(model_name)
|
||||
self.sample_rate = 32000
|
||||
|
||||
def generate(self, prompt, duration=30, temperature=1.0, cfg=3.0):
|
||||
self.model.set_generation_params(
|
||||
duration=duration,
|
||||
top_k=250,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate([prompt])
|
||||
|
||||
return wav[0].cpu()
|
||||
|
||||
def generate_batch(self, prompts, duration=30):
|
||||
self.model.set_generation_params(duration=duration)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate(prompts)
|
||||
|
||||
return wav.cpu()
|
||||
|
||||
def save(self, audio, path):
|
||||
torchaudio.save(path, audio, sample_rate=self.sample_rate)
|
||||
|
||||
# Usage
|
||||
generator = MusicGenerator()
|
||||
audio = generator.generate(
|
||||
"epic cinematic orchestral music",
|
||||
duration=30,
|
||||
temperature=1.0
|
||||
)
|
||||
generator.save(audio, "epic_music.wav")
|
||||
```
|
||||
|
||||
### Workflow 2: Sound design batch processing
|
||||
|
||||
```python
|
||||
import json
|
||||
from pathlib import Path
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
def batch_generate_sounds(sound_specs, output_dir):
|
||||
"""
|
||||
Generate multiple sounds from specifications.
|
||||
|
||||
Args:
|
||||
sound_specs: list of {"name": str, "description": str, "duration": float}
|
||||
output_dir: output directory path
|
||||
"""
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
results = []
|
||||
|
||||
for spec in sound_specs:
|
||||
model.set_generation_params(duration=spec.get("duration", 5))
|
||||
|
||||
wav = model.generate([spec["description"]])
|
||||
|
||||
output_path = output_dir / f"{spec['name']}.wav"
|
||||
torchaudio.save(str(output_path), wav[0].cpu(), sample_rate=16000)
|
||||
|
||||
results.append({
|
||||
"name": spec["name"],
|
||||
"path": str(output_path),
|
||||
"description": spec["description"]
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
# Usage
|
||||
sounds = [
|
||||
{"name": "explosion", "description": "massive explosion with debris", "duration": 3},
|
||||
{"name": "footsteps", "description": "footsteps on wooden floor", "duration": 5},
|
||||
{"name": "door", "description": "wooden door creaking and closing", "duration": 2}
|
||||
]
|
||||
|
||||
results = batch_generate_sounds(sounds, "sound_effects/")
|
||||
```
|
||||
|
||||
### Workflow 3: Gradio demo
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
def generate_music(prompt, duration, temperature, cfg_coef):
|
||||
model.set_generation_params(
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
|
||||
# Save to temp file
|
||||
path = "temp_output.wav"
|
||||
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||||
return path
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate_music,
|
||||
inputs=[
|
||||
gr.Textbox(label="Music Description", placeholder="upbeat electronic dance music"),
|
||||
gr.Slider(1, 30, value=8, label="Duration (seconds)"),
|
||||
gr.Slider(0.5, 2.0, value=1.0, label="Temperature"),
|
||||
gr.Slider(1.0, 10.0, value=3.0, label="CFG Coefficient")
|
||||
],
|
||||
outputs=gr.Audio(label="Generated Music"),
|
||||
title="MusicGen Demo"
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Performance optimization
|
||||
|
||||
### Memory optimization
|
||||
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Clear cache between generations
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Generate shorter durations
|
||||
model.set_generation_params(duration=10) # Instead of 30
|
||||
|
||||
# Use half precision
|
||||
model = model.half()
|
||||
```
|
||||
|
||||
### Batch processing efficiency
|
||||
|
||||
```python
|
||||
# Process multiple prompts at once (more efficient)
|
||||
descriptions = ["prompt1", "prompt2", "prompt3", "prompt4"]
|
||||
wav = model.generate(descriptions) # Single batch
|
||||
|
||||
# Instead of
|
||||
for desc in descriptions:
|
||||
wav = model.generate([desc]) # Multiple batches (slower)
|
||||
```
|
||||
|
||||
### GPU memory requirements
|
||||
|
||||
| Model | FP32 VRAM | FP16 VRAM |
|
||||
|-------|-----------|-----------|
|
||||
| musicgen-small | ~4GB | ~2GB |
|
||||
| musicgen-medium | ~8GB | ~4GB |
|
||||
| musicgen-large | ~16GB | ~8GB |
|
||||
|
||||
## Common issues
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| CUDA OOM | Use smaller model, reduce duration |
|
||||
| Poor quality | Increase cfg_coef, better prompts |
|
||||
| Generation too short | Check max duration setting |
|
||||
| Audio artifacts | Try different temperature |
|
||||
| Stereo not working | Use stereo model variant |
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Training, fine-tuning, deployment
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/facebookresearch/audiocraft
|
||||
- **Paper (MusicGen)**: https://arxiv.org/abs/2306.05284
|
||||
- **Paper (AudioGen)**: https://arxiv.org/abs/2209.15352
|
||||
- **HuggingFace**: https://huggingface.co/facebook/musicgen-small
|
||||
- **Demo**: https://huggingface.co/spaces/facebook/MusicGen
|
||||
666
skills/mlops-knowledge/audiocraft/references/advanced-usage.md
Normal file
666
skills/mlops-knowledge/audiocraft/references/advanced-usage.md
Normal file
@@ -0,0 +1,666 @@
|
||||
# AudioCraft Advanced Usage Guide
|
||||
|
||||
## Fine-tuning MusicGen
|
||||
|
||||
### Custom dataset preparation
|
||||
|
||||
```python
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import torchaudio
|
||||
|
||||
def prepare_dataset(audio_dir, output_dir, metadata_file):
|
||||
"""
|
||||
Prepare dataset for MusicGen fine-tuning.
|
||||
|
||||
Directory structure:
|
||||
output_dir/
|
||||
├── audio/
|
||||
│ ├── 0001.wav
|
||||
│ ├── 0002.wav
|
||||
│ └── ...
|
||||
└── metadata.json
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
audio_output = output_dir / "audio"
|
||||
audio_output.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load metadata (format: {"path": "...", "description": "..."})
|
||||
with open(metadata_file) as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
processed = []
|
||||
|
||||
for idx, item in enumerate(metadata):
|
||||
audio_path = Path(audio_dir) / item["path"]
|
||||
|
||||
# Load and resample to 32kHz
|
||||
wav, sr = torchaudio.load(str(audio_path))
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
wav = resampler(wav)
|
||||
|
||||
# Convert to mono if stereo
|
||||
if wav.shape[0] > 1:
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
|
||||
# Save processed audio
|
||||
output_path = audio_output / f"{idx:04d}.wav"
|
||||
torchaudio.save(str(output_path), wav, sample_rate=32000)
|
||||
|
||||
processed.append({
|
||||
"path": str(output_path.relative_to(output_dir)),
|
||||
"description": item["description"],
|
||||
"duration": wav.shape[1] / 32000
|
||||
})
|
||||
|
||||
# Save processed metadata
|
||||
with open(output_dir / "metadata.json", "w") as f:
|
||||
json.dump(processed, f, indent=2)
|
||||
|
||||
print(f"Processed {len(processed)} samples")
|
||||
return processed
|
||||
```
|
||||
|
||||
### Fine-tuning with dora
|
||||
|
||||
```bash
|
||||
# AudioCraft uses dora for experiment management
|
||||
# Install dora
|
||||
pip install dora-search
|
||||
|
||||
# Clone AudioCraft
|
||||
git clone https://github.com/facebookresearch/audiocraft.git
|
||||
cd audiocraft
|
||||
|
||||
# Create config for fine-tuning
|
||||
cat > config/solver/musicgen/finetune.yaml << 'EOF'
|
||||
defaults:
|
||||
- musicgen/musicgen_base
|
||||
- /model: lm/musicgen_lm
|
||||
- /conditioner: cond_base
|
||||
|
||||
solver: musicgen
|
||||
autocast: true
|
||||
autocast_dtype: float16
|
||||
|
||||
optim:
|
||||
epochs: 100
|
||||
batch_size: 4
|
||||
lr: 1e-4
|
||||
ema: 0.999
|
||||
optimizer: adamw
|
||||
|
||||
dataset:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
train:
|
||||
- dset: your_dataset
|
||||
root: /path/to/dataset
|
||||
valid:
|
||||
- dset: your_dataset
|
||||
root: /path/to/dataset
|
||||
|
||||
checkpoint:
|
||||
save_every: 10
|
||||
keep_every_states: null
|
||||
EOF
|
||||
|
||||
# Run fine-tuning
|
||||
dora run solver=musicgen/finetune
|
||||
```
|
||||
|
||||
### LoRA fine-tuning
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from audiocraft.models import MusicGen
|
||||
import torch
|
||||
|
||||
# Load base model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Get the language model component
|
||||
lm = model.lm
|
||||
|
||||
# Configure LoRA
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "out_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none"
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
lm = get_peft_model(lm, lora_config)
|
||||
lm.print_trainable_parameters()
|
||||
```
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
### DataParallel
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Wrap LM with DataParallel
|
||||
if torch.cuda.device_count() > 1:
|
||||
model.lm = nn.DataParallel(model.lm)
|
||||
|
||||
model.to("cuda")
|
||||
```
|
||||
|
||||
### DistributedDataParallel
|
||||
|
||||
```python
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
def setup(rank, world_size):
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
def train(rank, world_size):
|
||||
setup(rank, world_size)
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.lm = model.lm.to(rank)
|
||||
model.lm = DDP(model.lm, device_ids=[rank])
|
||||
|
||||
# Training loop
|
||||
# ...
|
||||
|
||||
dist.destroy_process_group()
|
||||
```
|
||||
|
||||
## Custom Conditioning
|
||||
|
||||
### Adding new conditioners
|
||||
|
||||
```python
|
||||
from audiocraft.modules.conditioners import BaseConditioner
|
||||
import torch
|
||||
|
||||
class CustomConditioner(BaseConditioner):
|
||||
"""Custom conditioner for additional control signals."""
|
||||
|
||||
def __init__(self, dim, output_dim):
|
||||
super().__init__(dim, output_dim)
|
||||
self.embed = torch.nn.Linear(dim, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embed(x)
|
||||
|
||||
def tokenize(self, x):
|
||||
# Tokenize input for conditioning
|
||||
return x
|
||||
|
||||
# Use with MusicGen
|
||||
from audiocraft.models.builders import get_lm_model
|
||||
|
||||
# Modify model config to include custom conditioner
|
||||
# This requires editing the model configuration
|
||||
```
|
||||
|
||||
### Melody conditioning internals
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
from audiocraft.modules.codebooks_patterns import DelayedPatternProvider
|
||||
import torch
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# Access chroma extractor
|
||||
chroma_extractor = model.lm.condition_provider.conditioners.get('chroma')
|
||||
|
||||
# Manual chroma extraction
|
||||
def extract_chroma(audio, sr):
|
||||
"""Extract chroma features from audio."""
|
||||
import librosa
|
||||
|
||||
# Compute chroma
|
||||
chroma = librosa.feature.chroma_cqt(y=audio.numpy(), sr=sr)
|
||||
|
||||
return torch.from_numpy(chroma).float()
|
||||
|
||||
# Use extracted chroma for conditioning
|
||||
chroma = extract_chroma(melody_audio, sample_rate)
|
||||
```
|
||||
|
||||
## EnCodec Deep Dive
|
||||
|
||||
### Custom compression settings
|
||||
|
||||
```python
|
||||
from audiocraft.models import CompressionModel
|
||||
import torch
|
||||
|
||||
# Load EnCodec
|
||||
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
# Access codec parameters
|
||||
print(f"Sample rate: {encodec.sample_rate}")
|
||||
print(f"Channels: {encodec.channels}")
|
||||
print(f"Cardinality: {encodec.cardinality}") # Codebook size
|
||||
print(f"Num codebooks: {encodec.num_codebooks}")
|
||||
print(f"Frame rate: {encodec.frame_rate}")
|
||||
|
||||
# Encode with specific bandwidth
|
||||
# Lower bandwidth = more compression, lower quality
|
||||
encodec.set_target_bandwidth(6.0) # 6 kbps
|
||||
|
||||
audio = torch.randn(1, 1, 32000) # 1 second
|
||||
encoded = encodec.encode(audio)
|
||||
decoded = encodec.decode(encoded[0])
|
||||
```
|
||||
|
||||
### Streaming encoding
|
||||
|
||||
```python
|
||||
import torch
|
||||
from audiocraft.models import CompressionModel
|
||||
|
||||
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
def encode_streaming(audio_stream, chunk_size=32000):
|
||||
"""Encode audio in streaming fashion."""
|
||||
all_codes = []
|
||||
|
||||
for chunk in audio_stream:
|
||||
# Ensure chunk is right shape
|
||||
if chunk.dim() == 1:
|
||||
chunk = chunk.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
codes = encodec.encode(chunk)[0]
|
||||
all_codes.append(codes)
|
||||
|
||||
return torch.cat(all_codes, dim=-1)
|
||||
|
||||
def decode_streaming(codes_stream, output_stream):
|
||||
"""Decode codes in streaming fashion."""
|
||||
for codes in codes_stream:
|
||||
with torch.no_grad():
|
||||
audio = encodec.decode(codes)
|
||||
output_stream.write(audio.cpu().numpy())
|
||||
```
|
||||
|
||||
## MultiBand Diffusion
|
||||
|
||||
### Using MBD for enhanced quality
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen, MultiBandDiffusion
|
||||
|
||||
# Load MusicGen
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# Load MultiBand Diffusion
|
||||
mbd = MultiBandDiffusion.get_mbd_musicgen()
|
||||
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Generate with standard decoder
|
||||
descriptions = ["epic orchestral music"]
|
||||
wav_standard = model.generate(descriptions)
|
||||
|
||||
# Generate tokens and use MBD decoder
|
||||
with torch.no_grad():
|
||||
# Get tokens
|
||||
gen_tokens = model.generate_tokens(descriptions)
|
||||
|
||||
# Decode with MBD
|
||||
wav_mbd = mbd.tokens_to_wav(gen_tokens)
|
||||
|
||||
# Compare quality
|
||||
print(f"Standard shape: {wav_standard.shape}")
|
||||
print(f"MBD shape: {wav_mbd.shape}")
|
||||
```
|
||||
|
||||
## API Server Deployment
|
||||
|
||||
### FastAPI server
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
import io
|
||||
import base64
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Load model at startup
|
||||
model = None
|
||||
|
||||
@app.on_event("startup")
|
||||
async def load_model():
|
||||
global model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompt: str
|
||||
duration: float = 10.0
|
||||
temperature: float = 1.0
|
||||
cfg_coef: float = 3.0
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
audio_base64: str
|
||||
sample_rate: int
|
||||
duration: float
|
||||
|
||||
@app.post("/generate", response_model=GenerateResponse)
|
||||
async def generate(request: GenerateRequest):
|
||||
if model is None:
|
||||
raise HTTPException(status_code=500, detail="Model not loaded")
|
||||
|
||||
try:
|
||||
model.set_generation_params(
|
||||
duration=min(request.duration, 30),
|
||||
temperature=request.temperature,
|
||||
cfg_coef=request.cfg_coef
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([request.prompt])
|
||||
|
||||
# Convert to bytes
|
||||
buffer = io.BytesIO()
|
||||
torchaudio.save(buffer, wav[0].cpu(), sample_rate=32000, format="wav")
|
||||
buffer.seek(0)
|
||||
|
||||
audio_base64 = base64.b64encode(buffer.read()).decode()
|
||||
|
||||
return GenerateResponse(
|
||||
audio_base64=audio_base64,
|
||||
sample_rate=32000,
|
||||
duration=wav.shape[-1] / 32000
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "model_loaded": model is not None}
|
||||
|
||||
# Run: uvicorn server:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### Batch processing service
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import torch
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
class MusicGenService:
|
||||
def __init__(self, model_name='facebook/musicgen-small', max_workers=2):
|
||||
self.model = MusicGen.get_pretrained(model_name)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def generate_async(self, prompt, duration=10):
|
||||
"""Async generation with thread pool."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _generate():
|
||||
with torch.no_grad():
|
||||
self.model.set_generation_params(duration=duration)
|
||||
return self.model.generate([prompt])
|
||||
|
||||
# Run in thread pool
|
||||
wav = await loop.run_in_executor(self.executor, _generate)
|
||||
return wav[0].cpu()
|
||||
|
||||
async def generate_batch_async(self, prompts, duration=10):
|
||||
"""Process multiple prompts concurrently."""
|
||||
tasks = [self.generate_async(p, duration) for p in prompts]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
# Usage
|
||||
service = MusicGenService()
|
||||
|
||||
async def main():
|
||||
prompts = ["jazz piano", "rock guitar", "electronic beats"]
|
||||
results = await service.generate_batch_async(prompts)
|
||||
return results
|
||||
```
|
||||
|
||||
## Integration Patterns
|
||||
|
||||
### LangChain tool
|
||||
|
||||
```python
|
||||
from langchain.tools import BaseTool
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
import tempfile
|
||||
|
||||
class MusicGeneratorTool(BaseTool):
|
||||
name = "music_generator"
|
||||
description = "Generate music from a text description. Input should be a detailed description of the music style, mood, and instruments."
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
self.model.set_generation_params(duration=15)
|
||||
|
||||
def _run(self, description: str) -> str:
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate([description])
|
||||
|
||||
# Save to temp file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
torchaudio.save(f.name, wav[0].cpu(), sample_rate=32000)
|
||||
return f"Generated music saved to: {f.name}"
|
||||
|
||||
async def _arun(self, description: str) -> str:
|
||||
return self._run(description)
|
||||
```
|
||||
|
||||
### Gradio with advanced controls
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
models = {}
|
||||
|
||||
def load_model(model_size):
|
||||
if model_size not in models:
|
||||
model_name = f"facebook/musicgen-{model_size}"
|
||||
models[model_size] = MusicGen.get_pretrained(model_name)
|
||||
return models[model_size]
|
||||
|
||||
def generate(prompt, duration, temperature, cfg_coef, top_k, model_size):
|
||||
model = load_model(model_size)
|
||||
|
||||
model.set_generation_params(
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
|
||||
# Save
|
||||
path = "output.wav"
|
||||
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||||
return path
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate,
|
||||
inputs=[
|
||||
gr.Textbox(label="Prompt", lines=3),
|
||||
gr.Slider(1, 30, value=10, label="Duration (s)"),
|
||||
gr.Slider(0.1, 2.0, value=1.0, label="Temperature"),
|
||||
gr.Slider(0.5, 10.0, value=3.0, label="CFG Coefficient"),
|
||||
gr.Slider(50, 500, value=250, step=50, label="Top-K"),
|
||||
gr.Dropdown(["small", "medium", "large"], value="small", label="Model Size")
|
||||
],
|
||||
outputs=gr.Audio(label="Generated Music"),
|
||||
title="MusicGen Advanced",
|
||||
allow_flagging="never"
|
||||
)
|
||||
|
||||
demo.launch(share=True)
|
||||
```
|
||||
|
||||
## Audio Processing Pipeline
|
||||
|
||||
### Post-processing chain
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchaudio.transforms as T
|
||||
import numpy as np
|
||||
|
||||
class AudioPostProcessor:
|
||||
def __init__(self, sample_rate=32000):
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
def normalize(self, audio, target_db=-14.0):
|
||||
"""Normalize audio to target loudness."""
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
target_rms = 10 ** (target_db / 20)
|
||||
gain = target_rms / (rms + 1e-8)
|
||||
return audio * gain
|
||||
|
||||
def fade_in_out(self, audio, fade_duration=0.1):
|
||||
"""Apply fade in/out."""
|
||||
fade_samples = int(fade_duration * self.sample_rate)
|
||||
|
||||
# Create fade curves
|
||||
fade_in = torch.linspace(0, 1, fade_samples)
|
||||
fade_out = torch.linspace(1, 0, fade_samples)
|
||||
|
||||
# Apply fades
|
||||
audio[..., :fade_samples] *= fade_in
|
||||
audio[..., -fade_samples:] *= fade_out
|
||||
|
||||
return audio
|
||||
|
||||
def apply_reverb(self, audio, decay=0.5):
|
||||
"""Apply simple reverb effect."""
|
||||
impulse = torch.zeros(int(self.sample_rate * 0.5))
|
||||
impulse[0] = 1.0
|
||||
impulse[int(self.sample_rate * 0.1)] = decay * 0.5
|
||||
impulse[int(self.sample_rate * 0.2)] = decay * 0.25
|
||||
|
||||
# Convolve
|
||||
audio = torch.nn.functional.conv1d(
|
||||
audio.unsqueeze(0),
|
||||
impulse.unsqueeze(0).unsqueeze(0),
|
||||
padding=len(impulse) // 2
|
||||
).squeeze(0)
|
||||
|
||||
return audio
|
||||
|
||||
def process(self, audio):
|
||||
"""Full processing pipeline."""
|
||||
audio = self.normalize(audio)
|
||||
audio = self.fade_in_out(audio)
|
||||
return audio
|
||||
|
||||
# Usage with MusicGen
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
wav = model.generate(["chill ambient music"])
|
||||
processor = AudioPostProcessor()
|
||||
wav_processed = processor.process(wav[0].cpu())
|
||||
|
||||
torchaudio.save("processed.wav", wav_processed, sample_rate=32000)
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Audio quality metrics
|
||||
|
||||
```python
|
||||
import torch
|
||||
from audiocraft.metrics import CLAPTextConsistencyMetric
|
||||
from audiocraft.data.audio import audio_read
|
||||
|
||||
def evaluate_generation(audio_path, text_prompt):
|
||||
"""Evaluate generated audio quality."""
|
||||
# Load audio
|
||||
wav, sr = audio_read(audio_path)
|
||||
|
||||
# CLAP consistency (text-audio alignment)
|
||||
clap_metric = CLAPTextConsistencyMetric()
|
||||
clap_score = clap_metric.compute(wav, [text_prompt])
|
||||
|
||||
return {
|
||||
"clap_score": clap_score,
|
||||
"duration": wav.shape[-1] / sr
|
||||
}
|
||||
|
||||
# Batch evaluation
|
||||
def evaluate_batch(generations):
|
||||
"""Evaluate multiple generations."""
|
||||
results = []
|
||||
for gen in generations:
|
||||
result = evaluate_generation(gen["path"], gen["prompt"])
|
||||
result["prompt"] = gen["prompt"]
|
||||
results.append(result)
|
||||
|
||||
# Aggregate
|
||||
avg_clap = sum(r["clap_score"] for r in results) / len(results)
|
||||
return {
|
||||
"individual": results,
|
||||
"average_clap": avg_clap
|
||||
}
|
||||
```
|
||||
|
||||
## Model Comparison
|
||||
|
||||
### MusicGen variants benchmark
|
||||
|
||||
| Model | CLAP Score | Generation Time (10s) | VRAM |
|
||||
|-------|------------|----------------------|------|
|
||||
| musicgen-small | 0.35 | ~5s | 2GB |
|
||||
| musicgen-medium | 0.42 | ~15s | 4GB |
|
||||
| musicgen-large | 0.48 | ~30s | 8GB |
|
||||
| musicgen-melody | 0.45 | ~15s | 4GB |
|
||||
| musicgen-stereo-medium | 0.41 | ~18s | 5GB |
|
||||
|
||||
### Prompt engineering tips
|
||||
|
||||
```python
|
||||
# Good prompts - specific and descriptive
|
||||
good_prompts = [
|
||||
"upbeat electronic dance music with synthesizer leads and punchy drums at 128 bpm",
|
||||
"melancholic piano ballad with strings, slow tempo, emotional and cinematic",
|
||||
"funky disco groove with slap bass, brass section, and rhythmic guitar"
|
||||
]
|
||||
|
||||
# Bad prompts - too vague
|
||||
bad_prompts = [
|
||||
"nice music",
|
||||
"song",
|
||||
"good beat"
|
||||
]
|
||||
|
||||
# Structure: [mood] [genre] with [instruments] at [tempo/style]
|
||||
```
|
||||
504
skills/mlops-knowledge/audiocraft/references/troubleshooting.md
Normal file
504
skills/mlops-knowledge/audiocraft/references/troubleshooting.md
Normal file
@@ -0,0 +1,504 @@
|
||||
# AudioCraft Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### Import errors
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'audiocraft'`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install from PyPI
|
||||
pip install audiocraft
|
||||
|
||||
# Or from GitHub
|
||||
pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Verify installation
|
||||
python -c "from audiocraft.models import MusicGen; print('OK')"
|
||||
```
|
||||
|
||||
### FFmpeg not found
|
||||
|
||||
**Error**: `RuntimeError: ffmpeg not found`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt-get install ffmpeg
|
||||
|
||||
# macOS
|
||||
brew install ffmpeg
|
||||
|
||||
# Windows (using conda)
|
||||
conda install -c conda-forge ffmpeg
|
||||
|
||||
# Verify
|
||||
ffmpeg -version
|
||||
```
|
||||
|
||||
### PyTorch CUDA mismatch
|
||||
|
||||
**Error**: `RuntimeError: CUDA error: no kernel image is available`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check CUDA version
|
||||
nvcc --version
|
||||
python -c "import torch; print(torch.version.cuda)"
|
||||
|
||||
# Install matching PyTorch
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# For CUDA 11.8
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
### xformers issues
|
||||
|
||||
**Error**: `ImportError: xformers` related errors
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install xformers for memory efficiency
|
||||
pip install xformers
|
||||
|
||||
# Or disable xformers
|
||||
export AUDIOCRAFT_USE_XFORMERS=0
|
||||
|
||||
# In Python
|
||||
import os
|
||||
os.environ["AUDIOCRAFT_USE_XFORMERS"] = "0"
|
||||
from audiocraft.models import MusicGen
|
||||
```
|
||||
|
||||
## Model Loading Issues
|
||||
|
||||
### Out of memory during load
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError` during model loading
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Force CPU loading first
|
||||
import torch
|
||||
device = "cpu"
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small', device=device)
|
||||
model = model.to("cuda")
|
||||
|
||||
# Use HuggingFace with device_map
|
||||
from transformers import MusicgenForConditionalGeneration
|
||||
model = MusicgenForConditionalGeneration.from_pretrained(
|
||||
"facebook/musicgen-small",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
### Download failures
|
||||
|
||||
**Error**: Connection errors or incomplete downloads
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Set cache directory
|
||||
import os
|
||||
os.environ["AUDIOCRAFT_CACHE_DIR"] = "/path/to/cache"
|
||||
|
||||
# Or for HuggingFace
|
||||
os.environ["HF_HOME"] = "/path/to/hf_cache"
|
||||
|
||||
# Resume download
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download("facebook/musicgen-small", resume_download=True)
|
||||
|
||||
# Use local files
|
||||
model = MusicGen.get_pretrained('/local/path/to/model')
|
||||
```
|
||||
|
||||
### Wrong model type
|
||||
|
||||
**Error**: Loading wrong model for task
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# For text-to-music: use MusicGen
|
||||
from audiocraft.models import MusicGen
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# For text-to-sound: use AudioGen
|
||||
from audiocraft.models import AudioGen
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
|
||||
# For melody conditioning: use melody variant
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# For stereo: use stereo variant
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
```
|
||||
|
||||
## Generation Issues
|
||||
|
||||
### Empty or silent output
|
||||
|
||||
**Problem**: Generated audio is silent or very quiet
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check output
|
||||
wav = model.generate(["upbeat music"])
|
||||
print(f"Shape: {wav.shape}")
|
||||
print(f"Max amplitude: {wav.abs().max().item()}")
|
||||
print(f"Mean amplitude: {wav.abs().mean().item()}")
|
||||
|
||||
# If too quiet, normalize
|
||||
def normalize_audio(audio, target_db=-14.0):
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
target_rms = 10 ** (target_db / 20)
|
||||
gain = target_rms / (rms + 1e-8)
|
||||
return audio * gain
|
||||
|
||||
wav_normalized = normalize_audio(wav)
|
||||
```
|
||||
|
||||
### Poor quality output
|
||||
|
||||
**Problem**: Generated music sounds bad or noisy
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use larger model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-large')
|
||||
|
||||
# Adjust generation parameters
|
||||
model.set_generation_params(
|
||||
duration=15,
|
||||
top_k=250, # Increase for more diversity
|
||||
temperature=0.8, # Lower for more focused output
|
||||
cfg_coef=4.0 # Increase for better text adherence
|
||||
)
|
||||
|
||||
# Use better prompts
|
||||
# Bad: "music"
|
||||
# Good: "upbeat electronic dance music with synthesizers and punchy drums"
|
||||
|
||||
# Try MultiBand Diffusion
|
||||
from audiocraft.models import MultiBandDiffusion
|
||||
mbd = MultiBandDiffusion.get_mbd_musicgen()
|
||||
tokens = model.generate_tokens(["prompt"])
|
||||
wav = mbd.tokens_to_wav(tokens)
|
||||
```
|
||||
|
||||
### Generation too short
|
||||
|
||||
**Problem**: Audio shorter than expected
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check duration setting
|
||||
model.set_generation_params(duration=30) # Set before generate
|
||||
|
||||
# Verify in generation
|
||||
print(f"Duration setting: {model.generation_params}")
|
||||
|
||||
# Check output shape
|
||||
wav = model.generate(["prompt"])
|
||||
actual_duration = wav.shape[-1] / 32000
|
||||
print(f"Actual duration: {actual_duration}s")
|
||||
|
||||
# Note: max duration is typically 30s
|
||||
```
|
||||
|
||||
### Melody conditioning fails
|
||||
|
||||
**Error**: Issues with melody-conditioned generation
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load melody model (not base model)
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# Load and prepare melody
|
||||
melody, sr = torchaudio.load("melody.wav")
|
||||
|
||||
# Resample to model sample rate if needed
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
melody = resampler(melody)
|
||||
|
||||
# Ensure correct shape [batch, channels, samples]
|
||||
if melody.dim() == 1:
|
||||
melody = melody.unsqueeze(0).unsqueeze(0)
|
||||
elif melody.dim() == 2:
|
||||
melody = melody.unsqueeze(0)
|
||||
|
||||
# Convert stereo to mono
|
||||
if melody.shape[1] > 1:
|
||||
melody = melody.mean(dim=1, keepdim=True)
|
||||
|
||||
# Generate with melody
|
||||
model.set_generation_params(duration=min(melody.shape[-1] / 32000, 30))
|
||||
wav = model.generate_with_chroma(["piano cover"], melody, 32000)
|
||||
```
|
||||
|
||||
## Memory Issues
|
||||
|
||||
### CUDA out of memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Clear cache before generation
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Reduce duration
|
||||
model.set_generation_params(duration=10) # Instead of 30
|
||||
|
||||
# Generate one at a time
|
||||
for prompt in prompts:
|
||||
wav = model.generate([prompt])
|
||||
save_audio(wav)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Use CPU for very large generations
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small', device="cpu")
|
||||
```
|
||||
|
||||
### Memory leak during batch processing
|
||||
|
||||
**Problem**: Memory grows over time
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import gc
|
||||
import torch
|
||||
|
||||
def generate_with_cleanup(model, prompts):
|
||||
results = []
|
||||
|
||||
for prompt in prompts:
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
results.append(wav.cpu())
|
||||
|
||||
# Cleanup
|
||||
del wav
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return results
|
||||
|
||||
# Use context manager
|
||||
with torch.inference_mode():
|
||||
wav = model.generate(["prompt"])
|
||||
```
|
||||
|
||||
## Audio Format Issues
|
||||
|
||||
### Wrong sample rate
|
||||
|
||||
**Problem**: Audio plays at wrong speed
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torchaudio
|
||||
|
||||
# MusicGen outputs at 32kHz
|
||||
sample_rate = 32000
|
||||
|
||||
# AudioGen outputs at 16kHz
|
||||
sample_rate = 16000
|
||||
|
||||
# Always use correct rate when saving
|
||||
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=sample_rate)
|
||||
|
||||
# Resample if needed
|
||||
resampler = torchaudio.transforms.Resample(32000, 44100)
|
||||
wav_resampled = resampler(wav)
|
||||
```
|
||||
|
||||
### Stereo/mono mismatch
|
||||
|
||||
**Problem**: Wrong number of channels
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check model type
|
||||
print(f"Audio channels: {wav.shape}")
|
||||
# Mono: [batch, 1, samples]
|
||||
# Stereo: [batch, 2, samples]
|
||||
|
||||
# Convert mono to stereo
|
||||
if wav.shape[1] == 1:
|
||||
wav_stereo = wav.repeat(1, 2, 1)
|
||||
|
||||
# Convert stereo to mono
|
||||
if wav.shape[1] == 2:
|
||||
wav_mono = wav.mean(dim=1, keepdim=True)
|
||||
|
||||
# Use stereo model for stereo output
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
```
|
||||
|
||||
### Clipping and distortion
|
||||
|
||||
**Problem**: Audio has clipping or distortion
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check for clipping
|
||||
max_val = wav.abs().max().item()
|
||||
print(f"Max amplitude: {max_val}")
|
||||
|
||||
# Normalize to prevent clipping
|
||||
if max_val > 1.0:
|
||||
wav = wav / max_val
|
||||
|
||||
# Apply soft clipping
|
||||
def soft_clip(x, threshold=0.9):
|
||||
return torch.tanh(x / threshold) * threshold
|
||||
|
||||
wav_clipped = soft_clip(wav)
|
||||
|
||||
# Lower temperature during generation
|
||||
model.set_generation_params(temperature=0.7) # More controlled
|
||||
```
|
||||
|
||||
## HuggingFace Transformers Issues
|
||||
|
||||
### Processor errors
|
||||
|
||||
**Error**: Issues with MusicgenProcessor
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
# Load matching processor and model
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
|
||||
# Ensure inputs are on same device
|
||||
inputs = processor(
|
||||
text=["prompt"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
# Check processor configuration
|
||||
print(processor.tokenizer)
|
||||
print(processor.feature_extractor)
|
||||
```
|
||||
|
||||
### Generation parameter errors
|
||||
|
||||
**Error**: Invalid generation parameters
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# HuggingFace uses different parameter names
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
do_sample=True, # Enable sampling
|
||||
guidance_scale=3.0, # CFG (not cfg_coef)
|
||||
max_new_tokens=256, # Token limit (not duration)
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# Calculate tokens from duration
|
||||
# ~50 tokens per second
|
||||
duration_seconds = 10
|
||||
max_tokens = duration_seconds * 50
|
||||
audio_values = model.generate(**inputs, max_new_tokens=max_tokens)
|
||||
```
|
||||
|
||||
## Performance Issues
|
||||
|
||||
### Slow generation
|
||||
|
||||
**Problem**: Generation takes too long
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Reduce duration
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Use GPU
|
||||
model.to("cuda")
|
||||
|
||||
# Enable flash attention if available
|
||||
# (requires compatible hardware)
|
||||
|
||||
# Batch multiple prompts
|
||||
prompts = ["prompt1", "prompt2", "prompt3"]
|
||||
wav = model.generate(prompts) # Single batch is faster than loop
|
||||
|
||||
# Use compile (PyTorch 2.0+)
|
||||
model.lm = torch.compile(model.lm)
|
||||
```
|
||||
|
||||
### CPU fallback
|
||||
|
||||
**Problem**: Generation running on CPU instead of GPU
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check CUDA availability
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
# Explicitly move to GPU
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.to("cuda")
|
||||
|
||||
# Verify model device
|
||||
print(f"Model device: {next(model.lm.parameters()).device}")
|
||||
```
|
||||
|
||||
## Common Error Messages
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| `CUDA out of memory` | Model too large | Use smaller model, reduce duration |
|
||||
| `ffmpeg not found` | FFmpeg not installed | Install FFmpeg |
|
||||
| `No module named 'audiocraft'` | Not installed | `pip install audiocraft` |
|
||||
| `RuntimeError: Expected 3D tensor` | Wrong input shape | Check tensor dimensions |
|
||||
| `KeyError: 'melody'` | Wrong model for melody | Use musicgen-melody |
|
||||
| `Sample rate mismatch` | Wrong audio format | Resample to model rate |
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **GitHub Issues**: https://github.com/facebookresearch/audiocraft/issues
|
||||
2. **HuggingFace Forums**: https://discuss.huggingface.co
|
||||
3. **Paper**: https://arxiv.org/abs/2306.05284
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- Python version
|
||||
- PyTorch version
|
||||
- CUDA version
|
||||
- AudioCraft version: `pip show audiocraft`
|
||||
- Full error traceback
|
||||
- Minimal reproducible code
|
||||
- Hardware (GPU model, VRAM)
|
||||
158
skills/mlops-knowledge/axolotl/SKILL.md
Normal file
158
skills/mlops-knowledge/axolotl/SKILL.md
Normal file
@@ -0,0 +1,158 @@
|
||||
---
|
||||
name: axolotl
|
||||
description: Expert guidance for fine-tuning LLMs with Axolotl - YAML configs, 100+ models, LoRA/QLoRA, DPO/KTO/ORPO/GRPO, multimodal support
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Fine-Tuning, Axolotl, LLM, LoRA, QLoRA, DPO, KTO, ORPO, GRPO, YAML, HuggingFace, DeepSpeed, Multimodal]
|
||||
dependencies: [axolotl, torch, transformers, datasets, peft, accelerate, deepspeed]
|
||||
---
|
||||
|
||||
# Axolotl Skill
|
||||
|
||||
Comprehensive assistance with axolotl development, generated from official documentation.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be triggered when:
|
||||
- Working with axolotl
|
||||
- Asking about axolotl features or APIs
|
||||
- Implementing axolotl solutions
|
||||
- Debugging axolotl code
|
||||
- Learning axolotl best practices
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Common Patterns
|
||||
|
||||
**Pattern 1:** To validate that acceptable data transfer speeds exist for your training job, running NCCL Tests can help pinpoint bottlenecks, for example:
|
||||
|
||||
```
|
||||
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
|
||||
```
|
||||
|
||||
**Pattern 2:** Configure your model to use FSDP in the Axolotl yaml. For example:
|
||||
|
||||
```
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: true
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
reshard_after_forward: true
|
||||
```
|
||||
|
||||
**Pattern 3:** The context_parallel_size should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
```
|
||||
context_parallel_size
|
||||
```
|
||||
|
||||
**Pattern 4:** For example: - With 8 GPUs and no sequence parallelism: 8 different batches processed per step - With 8 GPUs and context_parallel_size=4: Only 2 different batches processed per step (each split across 4 GPUs) - If your per-GPU micro_batch_size is 2, the global batch size decreases from 16 to 4
|
||||
|
||||
```
|
||||
context_parallel_size=4
|
||||
```
|
||||
|
||||
**Pattern 5:** Setting save_compressed: true in your configuration enables saving models in a compressed format, which: - Reduces disk space usage by approximately 40% - Maintains compatibility with vLLM for accelerated inference - Maintains compatibility with llmcompressor for further optimization (example: quantization)
|
||||
|
||||
```
|
||||
save_compressed: true
|
||||
```
|
||||
|
||||
**Pattern 6:** Note It is not necessary to place your integration in the integrations folder. It can be in any location, so long as it’s installed in a package in your python env. See this repo for an example: https://github.com/axolotl-ai-cloud/diff-transformer
|
||||
|
||||
```
|
||||
integrations
|
||||
```
|
||||
|
||||
**Pattern 7:** Handle both single-example and batched data. - single example: sample[‘input_ids’] is a list[int] - batched data: sample[‘input_ids’] is a list[list[int]]
|
||||
|
||||
```
|
||||
utils.trainer.drop_long_seq(sample, sequence_len=2048, min_sequence_len=2)
|
||||
```
|
||||
|
||||
### Example Code Patterns
|
||||
|
||||
**Example 1** (python):
|
||||
```python
|
||||
cli.cloud.modal_.ModalCloud(config, app=None)
|
||||
```
|
||||
|
||||
**Example 2** (python):
|
||||
```python
|
||||
cli.cloud.modal_.run_cmd(cmd, run_folder, volumes=None)
|
||||
```
|
||||
|
||||
**Example 3** (python):
|
||||
```python
|
||||
core.trainers.base.AxolotlTrainer(
|
||||
*_args,
|
||||
bench_data_collator=None,
|
||||
eval_data_collator=None,
|
||||
dataset_tags=None,
|
||||
**kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
**Example 4** (python):
|
||||
```python
|
||||
core.trainers.base.AxolotlTrainer.log(logs, start_time=None)
|
||||
```
|
||||
|
||||
**Example 5** (python):
|
||||
```python
|
||||
prompt_strategies.input_output.RawInputOutputPrompter()
|
||||
```
|
||||
|
||||
## Reference Files
|
||||
|
||||
This skill includes comprehensive documentation in `references/`:
|
||||
|
||||
- **api.md** - Api documentation
|
||||
- **dataset-formats.md** - Dataset-Formats documentation
|
||||
- **other.md** - Other documentation
|
||||
|
||||
Use `view` to read specific reference files when detailed information is needed.
|
||||
|
||||
## Working with This Skill
|
||||
|
||||
### For Beginners
|
||||
Start with the getting_started or tutorials reference files for foundational concepts.
|
||||
|
||||
### For Specific Features
|
||||
Use the appropriate category reference file (api, guides, etc.) for detailed information.
|
||||
|
||||
### For Code Examples
|
||||
The quick reference section above contains common patterns extracted from the official docs.
|
||||
|
||||
## Resources
|
||||
|
||||
### references/
|
||||
Organized documentation extracted from official sources. These files contain:
|
||||
- Detailed explanations
|
||||
- Code examples with language annotations
|
||||
- Links to original documentation
|
||||
- Table of contents for quick navigation
|
||||
|
||||
### scripts/
|
||||
Add helper scripts here for common automation tasks.
|
||||
|
||||
### assets/
|
||||
Add templates, boilerplate, or example projects here.
|
||||
|
||||
## Notes
|
||||
|
||||
- This skill was automatically generated from official documentation
|
||||
- Reference files preserve the structure and examples from source docs
|
||||
- Code examples include language detection for better syntax highlighting
|
||||
- Quick reference patterns are extracted from common usage examples in the docs
|
||||
|
||||
## Updating
|
||||
|
||||
To refresh this skill with updated documentation:
|
||||
1. Re-run the scraper with the same configuration
|
||||
2. The skill will be rebuilt with the latest information
|
||||
|
||||
|
||||
5548
skills/mlops-knowledge/axolotl/references/api.md
Normal file
5548
skills/mlops-knowledge/axolotl/references/api.md
Normal file
File diff suppressed because it is too large
Load Diff
1029
skills/mlops-knowledge/axolotl/references/dataset-formats.md
Normal file
1029
skills/mlops-knowledge/axolotl/references/dataset-formats.md
Normal file
File diff suppressed because it is too large
Load Diff
15
skills/mlops-knowledge/axolotl/references/index.md
Normal file
15
skills/mlops-knowledge/axolotl/references/index.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# Axolotl Documentation Index
|
||||
|
||||
## Categories
|
||||
|
||||
### Api
|
||||
**File:** `api.md`
|
||||
**Pages:** 150
|
||||
|
||||
### Dataset-Formats
|
||||
**File:** `dataset-formats.md`
|
||||
**Pages:** 9
|
||||
|
||||
### Other
|
||||
**File:** `other.md`
|
||||
**Pages:** 26
|
||||
3563
skills/mlops-knowledge/axolotl/references/other.md
Normal file
3563
skills/mlops-knowledge/axolotl/references/other.md
Normal file
File diff suppressed because it is too large
Load Diff
406
skills/mlops-knowledge/chroma/SKILL.md
Normal file
406
skills/mlops-knowledge/chroma/SKILL.md
Normal file
@@ -0,0 +1,406 @@
|
||||
---
|
||||
name: chroma
|
||||
description: Open-source embedding database for AI applications. Store embeddings and metadata, perform vector and full-text search, filter by metadata. Simple 4-function API. Scales from notebooks to production clusters. Use for semantic search, RAG applications, or document retrieval. Best for local development and open-source projects.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [RAG, Chroma, Vector Database, Embeddings, Semantic Search, Open Source, Self-Hosted, Document Retrieval, Metadata Filtering]
|
||||
dependencies: [chromadb, sentence-transformers]
|
||||
---
|
||||
|
||||
# Chroma - Open-Source Embedding Database
|
||||
|
||||
The AI-native database for building LLM applications with memory.
|
||||
|
||||
## When to use Chroma
|
||||
|
||||
**Use Chroma when:**
|
||||
- Building RAG (retrieval-augmented generation) applications
|
||||
- Need local/self-hosted vector database
|
||||
- Want open-source solution (Apache 2.0)
|
||||
- Prototyping in notebooks
|
||||
- Semantic search over documents
|
||||
- Storing embeddings with metadata
|
||||
|
||||
**Metrics**:
|
||||
- **24,300+ GitHub stars**
|
||||
- **1,900+ forks**
|
||||
- **v1.3.3** (stable, weekly releases)
|
||||
- **Apache 2.0 license**
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **Pinecone**: Managed cloud, auto-scaling
|
||||
- **FAISS**: Pure similarity search, no metadata
|
||||
- **Weaviate**: Production ML-native database
|
||||
- **Qdrant**: High performance, Rust-based
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Python
|
||||
pip install chromadb
|
||||
|
||||
# JavaScript/TypeScript
|
||||
npm install chromadb @chroma-core/default-embed
|
||||
```
|
||||
|
||||
### Basic usage (Python)
|
||||
|
||||
```python
|
||||
import chromadb
|
||||
|
||||
# Create client
|
||||
client = chromadb.Client()
|
||||
|
||||
# Create collection
|
||||
collection = client.create_collection(name="my_collection")
|
||||
|
||||
# Add documents
|
||||
collection.add(
|
||||
documents=["This is document 1", "This is document 2"],
|
||||
metadatas=[{"source": "doc1"}, {"source": "doc2"}],
|
||||
ids=["id1", "id2"]
|
||||
)
|
||||
|
||||
# Query
|
||||
results = collection.query(
|
||||
query_texts=["document about topic"],
|
||||
n_results=2
|
||||
)
|
||||
|
||||
print(results)
|
||||
```
|
||||
|
||||
## Core operations
|
||||
|
||||
### 1. Create collection
|
||||
|
||||
```python
|
||||
# Simple collection
|
||||
collection = client.create_collection("my_docs")
|
||||
|
||||
# With custom embedding function
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key="your-key",
|
||||
model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
collection = client.create_collection(
|
||||
name="my_docs",
|
||||
embedding_function=openai_ef
|
||||
)
|
||||
|
||||
# Get existing collection
|
||||
collection = client.get_collection("my_docs")
|
||||
|
||||
# Delete collection
|
||||
client.delete_collection("my_docs")
|
||||
```
|
||||
|
||||
### 2. Add documents
|
||||
|
||||
```python
|
||||
# Add with auto-generated IDs
|
||||
collection.add(
|
||||
documents=["Doc 1", "Doc 2", "Doc 3"],
|
||||
metadatas=[
|
||||
{"source": "web", "category": "tutorial"},
|
||||
{"source": "pdf", "page": 5},
|
||||
{"source": "api", "timestamp": "2025-01-01"}
|
||||
],
|
||||
ids=["id1", "id2", "id3"]
|
||||
)
|
||||
|
||||
# Add with custom embeddings
|
||||
collection.add(
|
||||
embeddings=[[0.1, 0.2, ...], [0.3, 0.4, ...]],
|
||||
documents=["Doc 1", "Doc 2"],
|
||||
ids=["id1", "id2"]
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Query (similarity search)
|
||||
|
||||
```python
|
||||
# Basic query
|
||||
results = collection.query(
|
||||
query_texts=["machine learning tutorial"],
|
||||
n_results=5
|
||||
)
|
||||
|
||||
# Query with filters
|
||||
results = collection.query(
|
||||
query_texts=["Python programming"],
|
||||
n_results=3,
|
||||
where={"source": "web"}
|
||||
)
|
||||
|
||||
# Query with metadata filters
|
||||
results = collection.query(
|
||||
query_texts=["advanced topics"],
|
||||
where={
|
||||
"$and": [
|
||||
{"category": "tutorial"},
|
||||
{"difficulty": {"$gte": 3}}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# Access results
|
||||
print(results["documents"]) # List of matching documents
|
||||
print(results["metadatas"]) # Metadata for each doc
|
||||
print(results["distances"]) # Similarity scores
|
||||
print(results["ids"]) # Document IDs
|
||||
```
|
||||
|
||||
### 4. Get documents
|
||||
|
||||
```python
|
||||
# Get by IDs
|
||||
docs = collection.get(
|
||||
ids=["id1", "id2"]
|
||||
)
|
||||
|
||||
# Get with filters
|
||||
docs = collection.get(
|
||||
where={"category": "tutorial"},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Get all documents
|
||||
docs = collection.get()
|
||||
```
|
||||
|
||||
### 5. Update documents
|
||||
|
||||
```python
|
||||
# Update document content
|
||||
collection.update(
|
||||
ids=["id1"],
|
||||
documents=["Updated content"],
|
||||
metadatas=[{"source": "updated"}]
|
||||
)
|
||||
```
|
||||
|
||||
### 6. Delete documents
|
||||
|
||||
```python
|
||||
# Delete by IDs
|
||||
collection.delete(ids=["id1", "id2"])
|
||||
|
||||
# Delete with filter
|
||||
collection.delete(
|
||||
where={"source": "outdated"}
|
||||
)
|
||||
```
|
||||
|
||||
## Persistent storage
|
||||
|
||||
```python
|
||||
# Persist to disk
|
||||
client = chromadb.PersistentClient(path="./chroma_db")
|
||||
|
||||
collection = client.create_collection("my_docs")
|
||||
collection.add(documents=["Doc 1"], ids=["id1"])
|
||||
|
||||
# Data persisted automatically
|
||||
# Reload later with same path
|
||||
client = chromadb.PersistentClient(path="./chroma_db")
|
||||
collection = client.get_collection("my_docs")
|
||||
```
|
||||
|
||||
## Embedding functions
|
||||
|
||||
### Default (Sentence Transformers)
|
||||
|
||||
```python
|
||||
# Uses sentence-transformers by default
|
||||
collection = client.create_collection("my_docs")
|
||||
# Default model: all-MiniLM-L6-v2
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
```python
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key="your-key",
|
||||
model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
collection = client.create_collection(
|
||||
name="openai_docs",
|
||||
embedding_function=openai_ef
|
||||
)
|
||||
```
|
||||
|
||||
### HuggingFace
|
||||
|
||||
```python
|
||||
huggingface_ef = embedding_functions.HuggingFaceEmbeddingFunction(
|
||||
api_key="your-key",
|
||||
model_name="sentence-transformers/all-mpnet-base-v2"
|
||||
)
|
||||
|
||||
collection = client.create_collection(
|
||||
name="hf_docs",
|
||||
embedding_function=huggingface_ef
|
||||
)
|
||||
```
|
||||
|
||||
### Custom embedding function
|
||||
|
||||
```python
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
|
||||
class MyEmbeddingFunction(EmbeddingFunction):
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
# Your embedding logic
|
||||
return embeddings
|
||||
|
||||
my_ef = MyEmbeddingFunction()
|
||||
collection = client.create_collection(
|
||||
name="custom_docs",
|
||||
embedding_function=my_ef
|
||||
)
|
||||
```
|
||||
|
||||
## Metadata filtering
|
||||
|
||||
```python
|
||||
# Exact match
|
||||
results = collection.query(
|
||||
query_texts=["query"],
|
||||
where={"category": "tutorial"}
|
||||
)
|
||||
|
||||
# Comparison operators
|
||||
results = collection.query(
|
||||
query_texts=["query"],
|
||||
where={"page": {"$gt": 10}} # $gt, $gte, $lt, $lte, $ne
|
||||
)
|
||||
|
||||
# Logical operators
|
||||
results = collection.query(
|
||||
query_texts=["query"],
|
||||
where={
|
||||
"$and": [
|
||||
{"category": "tutorial"},
|
||||
{"difficulty": {"$lte": 3}}
|
||||
]
|
||||
} # Also: $or
|
||||
)
|
||||
|
||||
# Contains
|
||||
results = collection.query(
|
||||
query_texts=["query"],
|
||||
where={"tags": {"$in": ["python", "ml"]}}
|
||||
)
|
||||
```
|
||||
|
||||
## LangChain integration
|
||||
|
||||
```python
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
# Split documents
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000)
|
||||
docs = text_splitter.split_documents(documents)
|
||||
|
||||
# Create Chroma vector store
|
||||
vectorstore = Chroma.from_documents(
|
||||
documents=docs,
|
||||
embedding=OpenAIEmbeddings(),
|
||||
persist_directory="./chroma_db"
|
||||
)
|
||||
|
||||
# Query
|
||||
results = vectorstore.similarity_search("machine learning", k=3)
|
||||
|
||||
# As retriever
|
||||
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
|
||||
```
|
||||
|
||||
## LlamaIndex integration
|
||||
|
||||
```python
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.core import VectorStoreIndex, StorageContext
|
||||
import chromadb
|
||||
|
||||
# Initialize Chroma
|
||||
db = chromadb.PersistentClient(path="./chroma_db")
|
||||
collection = db.get_or_create_collection("my_collection")
|
||||
|
||||
# Create vector store
|
||||
vector_store = ChromaVectorStore(chroma_collection=collection)
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
|
||||
# Create index
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context
|
||||
)
|
||||
|
||||
# Query
|
||||
query_engine = index.as_query_engine()
|
||||
response = query_engine.query("What is machine learning?")
|
||||
```
|
||||
|
||||
## Server mode
|
||||
|
||||
```python
|
||||
# Run Chroma server
|
||||
# Terminal: chroma run --path ./chroma_db --port 8000
|
||||
|
||||
# Connect to server
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
client = chromadb.HttpClient(
|
||||
host="localhost",
|
||||
port=8000,
|
||||
settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
|
||||
# Use as normal
|
||||
collection = client.get_or_create_collection("my_docs")
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use persistent client** - Don't lose data on restart
|
||||
2. **Add metadata** - Enables filtering and tracking
|
||||
3. **Batch operations** - Add multiple docs at once
|
||||
4. **Choose right embedding model** - Balance speed/quality
|
||||
5. **Use filters** - Narrow search space
|
||||
6. **Unique IDs** - Avoid collisions
|
||||
7. **Regular backups** - Copy chroma_db directory
|
||||
8. **Monitor collection size** - Scale up if needed
|
||||
9. **Test embedding functions** - Ensure quality
|
||||
10. **Use server mode for production** - Better for multi-user
|
||||
|
||||
## Performance
|
||||
|
||||
| Operation | Latency | Notes |
|
||||
|-----------|---------|-------|
|
||||
| Add 100 docs | ~1-3s | With embedding |
|
||||
| Query (top 10) | ~50-200ms | Depends on collection size |
|
||||
| Metadata filter | ~10-50ms | Fast with proper indexing |
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/chroma-core/chroma ⭐ 24,300+
|
||||
- **Docs**: https://docs.trychroma.com
|
||||
- **Discord**: https://discord.gg/MMeYNTmh3x
|
||||
- **Version**: 1.3.3+
|
||||
- **License**: Apache 2.0
|
||||
|
||||
|
||||
38
skills/mlops-knowledge/chroma/references/integration.md
Normal file
38
skills/mlops-knowledge/chroma/references/integration.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# Chroma Integration Guide
|
||||
|
||||
Integration with LangChain, LlamaIndex, and frameworks.
|
||||
|
||||
## LangChain
|
||||
|
||||
```python
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
vectorstore = Chroma.from_documents(
|
||||
documents=docs,
|
||||
embedding=OpenAIEmbeddings(),
|
||||
persist_directory="./chroma_db"
|
||||
)
|
||||
|
||||
# Query
|
||||
results = vectorstore.similarity_search("query", k=3)
|
||||
|
||||
# As retriever
|
||||
retriever = vectorstore.as_retriever()
|
||||
```
|
||||
|
||||
## LlamaIndex
|
||||
|
||||
```python
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
import chromadb
|
||||
|
||||
db = chromadb.PersistentClient(path="./chroma_db")
|
||||
collection = db.get_or_create_collection("docs")
|
||||
|
||||
vector_store = ChromaVectorStore(chroma_collection=collection)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Docs**: https://docs.trychroma.com
|
||||
253
skills/mlops-knowledge/clip/SKILL.md
Normal file
253
skills/mlops-knowledge/clip/SKILL.md
Normal file
@@ -0,0 +1,253 @@
|
||||
---
|
||||
name: clip
|
||||
description: OpenAI's model connecting vision and language. Enables zero-shot image classification, image-text matching, and cross-modal retrieval. Trained on 400M image-text pairs. Use for image search, content moderation, or vision-language tasks without fine-tuning. Best for general-purpose image understanding.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Multimodal, CLIP, Vision-Language, Zero-Shot, Image Classification, OpenAI, Image Search, Cross-Modal Retrieval, Content Moderation]
|
||||
dependencies: [transformers, torch, pillow]
|
||||
---
|
||||
|
||||
# CLIP - Contrastive Language-Image Pre-Training
|
||||
|
||||
OpenAI's model that understands images from natural language.
|
||||
|
||||
## When to use CLIP
|
||||
|
||||
**Use when:**
|
||||
- Zero-shot image classification (no training data needed)
|
||||
- Image-text similarity/matching
|
||||
- Semantic image search
|
||||
- Content moderation (detect NSFW, violence)
|
||||
- Visual question answering
|
||||
- Cross-modal retrieval (image→text, text→image)
|
||||
|
||||
**Metrics**:
|
||||
- **25,300+ GitHub stars**
|
||||
- Trained on 400M image-text pairs
|
||||
- Matches ResNet-50 on ImageNet (zero-shot)
|
||||
- MIT License
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **BLIP-2**: Better captioning
|
||||
- **LLaVA**: Vision-language chat
|
||||
- **Segment Anything**: Image segmentation
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/openai/CLIP.git
|
||||
pip install torch torchvision ftfy regex tqdm
|
||||
```
|
||||
|
||||
### Zero-shot classification
|
||||
|
||||
```python
|
||||
import torch
|
||||
import clip
|
||||
from PIL import Image
|
||||
|
||||
# Load model
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model, preprocess = clip.load("ViT-B/32", device=device)
|
||||
|
||||
# Load image
|
||||
image = preprocess(Image.open("photo.jpg")).unsqueeze(0).to(device)
|
||||
|
||||
# Define possible labels
|
||||
text = clip.tokenize(["a dog", "a cat", "a bird", "a car"]).to(device)
|
||||
|
||||
# Compute similarity
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(image)
|
||||
text_features = model.encode_text(text)
|
||||
|
||||
# Cosine similarity
|
||||
logits_per_image, logits_per_text = model(image, text)
|
||||
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
||||
|
||||
# Print results
|
||||
labels = ["a dog", "a cat", "a bird", "a car"]
|
||||
for label, prob in zip(labels, probs[0]):
|
||||
print(f"{label}: {prob:.2%}")
|
||||
```
|
||||
|
||||
## Available models
|
||||
|
||||
```python
|
||||
# Models (sorted by size)
|
||||
models = [
|
||||
"RN50", # ResNet-50
|
||||
"RN101", # ResNet-101
|
||||
"ViT-B/32", # Vision Transformer (recommended)
|
||||
"ViT-B/16", # Better quality, slower
|
||||
"ViT-L/14", # Best quality, slowest
|
||||
]
|
||||
|
||||
model, preprocess = clip.load("ViT-B/32")
|
||||
```
|
||||
|
||||
| Model | Parameters | Speed | Quality |
|
||||
|-------|------------|-------|---------|
|
||||
| RN50 | 102M | Fast | Good |
|
||||
| ViT-B/32 | 151M | Medium | Better |
|
||||
| ViT-L/14 | 428M | Slow | Best |
|
||||
|
||||
## Image-text similarity
|
||||
|
||||
```python
|
||||
# Compute embeddings
|
||||
image_features = model.encode_image(image)
|
||||
text_features = model.encode_text(text)
|
||||
|
||||
# Normalize
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
# Cosine similarity
|
||||
similarity = (image_features @ text_features.T).item()
|
||||
print(f"Similarity: {similarity:.4f}")
|
||||
```
|
||||
|
||||
## Semantic image search
|
||||
|
||||
```python
|
||||
# Index images
|
||||
image_paths = ["img1.jpg", "img2.jpg", "img3.jpg"]
|
||||
image_embeddings = []
|
||||
|
||||
for img_path in image_paths:
|
||||
image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
|
||||
with torch.no_grad():
|
||||
embedding = model.encode_image(image)
|
||||
embedding /= embedding.norm(dim=-1, keepdim=True)
|
||||
image_embeddings.append(embedding)
|
||||
|
||||
image_embeddings = torch.cat(image_embeddings)
|
||||
|
||||
# Search with text query
|
||||
query = "a sunset over the ocean"
|
||||
text_input = clip.tokenize([query]).to(device)
|
||||
with torch.no_grad():
|
||||
text_embedding = model.encode_text(text_input)
|
||||
text_embedding /= text_embedding.norm(dim=-1, keepdim=True)
|
||||
|
||||
# Find most similar images
|
||||
similarities = (text_embedding @ image_embeddings.T).squeeze(0)
|
||||
top_k = similarities.topk(3)
|
||||
|
||||
for idx, score in zip(top_k.indices, top_k.values):
|
||||
print(f"{image_paths[idx]}: {score:.3f}")
|
||||
```
|
||||
|
||||
## Content moderation
|
||||
|
||||
```python
|
||||
# Define categories
|
||||
categories = [
|
||||
"safe for work",
|
||||
"not safe for work",
|
||||
"violent content",
|
||||
"graphic content"
|
||||
]
|
||||
|
||||
text = clip.tokenize(categories).to(device)
|
||||
|
||||
# Check image
|
||||
with torch.no_grad():
|
||||
logits_per_image, _ = model(image, text)
|
||||
probs = logits_per_image.softmax(dim=-1)
|
||||
|
||||
# Get classification
|
||||
max_idx = probs.argmax().item()
|
||||
max_prob = probs[0, max_idx].item()
|
||||
|
||||
print(f"Category: {categories[max_idx]} ({max_prob:.2%})")
|
||||
```
|
||||
|
||||
## Batch processing
|
||||
|
||||
```python
|
||||
# Process multiple images
|
||||
images = [preprocess(Image.open(f"img{i}.jpg")) for i in range(10)]
|
||||
images = torch.stack(images).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(images)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
# Batch text
|
||||
texts = ["a dog", "a cat", "a bird"]
|
||||
text_tokens = clip.tokenize(texts).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
text_features = model.encode_text(text_tokens)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
# Similarity matrix (10 images × 3 texts)
|
||||
similarities = image_features @ text_features.T
|
||||
print(similarities.shape) # (10, 3)
|
||||
```
|
||||
|
||||
## Integration with vector databases
|
||||
|
||||
```python
|
||||
# Store CLIP embeddings in Chroma/FAISS
|
||||
import chromadb
|
||||
|
||||
client = chromadb.Client()
|
||||
collection = client.create_collection("image_embeddings")
|
||||
|
||||
# Add image embeddings
|
||||
for img_path, embedding in zip(image_paths, image_embeddings):
|
||||
collection.add(
|
||||
embeddings=[embedding.cpu().numpy().tolist()],
|
||||
metadatas=[{"path": img_path}],
|
||||
ids=[img_path]
|
||||
)
|
||||
|
||||
# Query with text
|
||||
query = "a sunset"
|
||||
text_embedding = model.encode_text(clip.tokenize([query]))
|
||||
results = collection.query(
|
||||
query_embeddings=[text_embedding.cpu().numpy().tolist()],
|
||||
n_results=5
|
||||
)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use ViT-B/32 for most cases** - Good balance
|
||||
2. **Normalize embeddings** - Required for cosine similarity
|
||||
3. **Batch processing** - More efficient
|
||||
4. **Cache embeddings** - Expensive to recompute
|
||||
5. **Use descriptive labels** - Better zero-shot performance
|
||||
6. **GPU recommended** - 10-50× faster
|
||||
7. **Preprocess images** - Use provided preprocess function
|
||||
|
||||
## Performance
|
||||
|
||||
| Operation | CPU | GPU (V100) |
|
||||
|-----------|-----|------------|
|
||||
| Image encoding | ~200ms | ~20ms |
|
||||
| Text encoding | ~50ms | ~5ms |
|
||||
| Similarity compute | <1ms | <1ms |
|
||||
|
||||
## Limitations
|
||||
|
||||
1. **Not for fine-grained tasks** - Best for broad categories
|
||||
2. **Requires descriptive text** - Vague labels perform poorly
|
||||
3. **Biased on web data** - May have dataset biases
|
||||
4. **No bounding boxes** - Whole image only
|
||||
5. **Limited spatial understanding** - Position/counting weak
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/openai/CLIP ⭐ 25,300+
|
||||
- **Paper**: https://arxiv.org/abs/2103.00020
|
||||
- **Colab**: https://colab.research.google.com/github/openai/clip/
|
||||
- **License**: MIT
|
||||
|
||||
|
||||
207
skills/mlops-knowledge/clip/references/applications.md
Normal file
207
skills/mlops-knowledge/clip/references/applications.md
Normal file
@@ -0,0 +1,207 @@
|
||||
# CLIP Applications Guide
|
||||
|
||||
Practical applications and use cases for CLIP.
|
||||
|
||||
## Zero-shot image classification
|
||||
|
||||
```python
|
||||
import torch
|
||||
import clip
|
||||
from PIL import Image
|
||||
|
||||
model, preprocess = clip.load("ViT-B/32")
|
||||
|
||||
# Define categories
|
||||
categories = [
|
||||
"a photo of a dog",
|
||||
"a photo of a cat",
|
||||
"a photo of a bird",
|
||||
"a photo of a car",
|
||||
"a photo of a person"
|
||||
]
|
||||
|
||||
# Prepare image
|
||||
image = preprocess(Image.open("photo.jpg")).unsqueeze(0)
|
||||
text = clip.tokenize(categories)
|
||||
|
||||
# Classify
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(image)
|
||||
text_features = model.encode_text(text)
|
||||
|
||||
logits_per_image, _ = model(image, text)
|
||||
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
||||
|
||||
# Print results
|
||||
for category, prob in zip(categories, probs[0]):
|
||||
print(f"{category}: {prob:.2%}")
|
||||
```
|
||||
|
||||
## Semantic image search
|
||||
|
||||
```python
|
||||
# Index images
|
||||
image_database = []
|
||||
image_paths = ["img1.jpg", "img2.jpg", "img3.jpg"]
|
||||
|
||||
for img_path in image_paths:
|
||||
image = preprocess(Image.open(img_path)).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
features = model.encode_image(image)
|
||||
features /= features.norm(dim=-1, keepdim=True)
|
||||
image_database.append((img_path, features))
|
||||
|
||||
# Search with text
|
||||
query = "a sunset over mountains"
|
||||
text_input = clip.tokenize([query])
|
||||
|
||||
with torch.no_grad():
|
||||
text_features = model.encode_text(text_input)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
# Find matches
|
||||
similarities = []
|
||||
for img_path, img_features in image_database:
|
||||
similarity = (text_features @ img_features.T).item()
|
||||
similarities.append((img_path, similarity))
|
||||
|
||||
# Sort by similarity
|
||||
similarities.sort(key=lambda x: x[1], reverse=True)
|
||||
for img_path, score in similarities[:3]:
|
||||
print(f"{img_path}: {score:.3f}")
|
||||
```
|
||||
|
||||
## Content moderation
|
||||
|
||||
```python
|
||||
# Define safety categories
|
||||
categories = [
|
||||
"safe for work content",
|
||||
"not safe for work content",
|
||||
"violent or graphic content",
|
||||
"hate speech or offensive content",
|
||||
"spam or misleading content"
|
||||
]
|
||||
|
||||
text = clip.tokenize(categories)
|
||||
|
||||
# Check image
|
||||
with torch.no_grad():
|
||||
logits, _ = model(image, text)
|
||||
probs = logits.softmax(dim=-1)
|
||||
|
||||
# Get classification
|
||||
max_idx = probs.argmax().item()
|
||||
confidence = probs[0, max_idx].item()
|
||||
|
||||
if confidence > 0.7:
|
||||
print(f"Classified as: {categories[max_idx]} ({confidence:.2%})")
|
||||
else:
|
||||
print(f"Uncertain classification (confidence: {confidence:.2%})")
|
||||
```
|
||||
|
||||
## Image-to-text retrieval
|
||||
|
||||
```python
|
||||
# Text database
|
||||
captions = [
|
||||
"A beautiful sunset over the ocean",
|
||||
"A cute dog playing in the park",
|
||||
"A modern city skyline at night",
|
||||
"A delicious pizza with toppings"
|
||||
]
|
||||
|
||||
# Encode captions
|
||||
caption_features = []
|
||||
for caption in captions:
|
||||
text = clip.tokenize([caption])
|
||||
with torch.no_grad():
|
||||
features = model.encode_text(text)
|
||||
features /= features.norm(dim=-1, keepdim=True)
|
||||
caption_features.append(features)
|
||||
|
||||
caption_features = torch.cat(caption_features)
|
||||
|
||||
# Find matching captions for image
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(image)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
similarities = (image_features @ caption_features.T).squeeze(0)
|
||||
top_k = similarities.topk(3)
|
||||
|
||||
for idx, score in zip(top_k.indices, top_k.values):
|
||||
print(f"{captions[idx]}: {score:.3f}")
|
||||
```
|
||||
|
||||
## Visual question answering
|
||||
|
||||
```python
|
||||
# Create yes/no questions
|
||||
image = preprocess(Image.open("photo.jpg")).unsqueeze(0)
|
||||
|
||||
questions = [
|
||||
"a photo showing people",
|
||||
"a photo showing animals",
|
||||
"a photo taken indoors",
|
||||
"a photo taken outdoors",
|
||||
"a photo taken during daytime",
|
||||
"a photo taken at night"
|
||||
]
|
||||
|
||||
text = clip.tokenize(questions)
|
||||
|
||||
with torch.no_grad():
|
||||
logits, _ = model(image, text)
|
||||
probs = logits.softmax(dim=-1)
|
||||
|
||||
# Answer questions
|
||||
for question, prob in zip(questions, probs[0]):
|
||||
answer = "Yes" if prob > 0.5 else "No"
|
||||
print(f"{question}: {answer} ({prob:.2%})")
|
||||
```
|
||||
|
||||
## Image deduplication
|
||||
|
||||
```python
|
||||
# Detect duplicate/similar images
|
||||
def compute_similarity(img1_path, img2_path):
|
||||
img1 = preprocess(Image.open(img1_path)).unsqueeze(0)
|
||||
img2 = preprocess(Image.open(img2_path)).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
feat1 = model.encode_image(img1)
|
||||
feat2 = model.encode_image(img2)
|
||||
|
||||
feat1 /= feat1.norm(dim=-1, keepdim=True)
|
||||
feat2 /= feat2.norm(dim=-1, keepdim=True)
|
||||
|
||||
similarity = (feat1 @ feat2.T).item()
|
||||
|
||||
return similarity
|
||||
|
||||
# Check for duplicates
|
||||
threshold = 0.95
|
||||
image_pairs = [("img1.jpg", "img2.jpg"), ("img1.jpg", "img3.jpg")]
|
||||
|
||||
for img1, img2 in image_pairs:
|
||||
sim = compute_similarity(img1, img2)
|
||||
if sim > threshold:
|
||||
print(f"{img1} and {img2} are duplicates (similarity: {sim:.3f})")
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use descriptive labels** - "a photo of X" works better than just "X"
|
||||
2. **Normalize embeddings** - Always normalize for cosine similarity
|
||||
3. **Batch processing** - Process multiple images/texts together
|
||||
4. **Cache embeddings** - Expensive to recompute
|
||||
5. **Set appropriate thresholds** - Test on validation data
|
||||
6. **Use GPU** - 10-50× faster than CPU
|
||||
7. **Consider model size** - ViT-B/32 good default, ViT-L/14 for best quality
|
||||
|
||||
## Resources
|
||||
|
||||
- **Paper**: https://arxiv.org/abs/2103.00020
|
||||
- **GitHub**: https://github.com/openai/CLIP
|
||||
- **Colab**: https://colab.research.google.com/github/openai/clip/
|
||||
81
skills/mlops-knowledge/code-review/SKILL.md
Normal file
81
skills/mlops-knowledge/code-review/SKILL.md
Normal file
@@ -0,0 +1,81 @@
|
||||
---
|
||||
name: code-review
|
||||
description: Guidelines for performing thorough code reviews with security and quality focus
|
||||
---
|
||||
|
||||
# Code Review Skill
|
||||
|
||||
Use this skill when reviewing code changes, pull requests, or auditing existing code.
|
||||
|
||||
## Review Checklist
|
||||
|
||||
### 1. Security First
|
||||
- [ ] No hardcoded secrets, API keys, or credentials
|
||||
- [ ] Input validation on all user-provided data
|
||||
- [ ] SQL queries use parameterized statements (no string concatenation)
|
||||
- [ ] File operations validate paths (no path traversal)
|
||||
- [ ] Authentication/authorization checks present where needed
|
||||
|
||||
### 2. Error Handling
|
||||
- [ ] All external calls (API, DB, file) have try/catch
|
||||
- [ ] Errors are logged with context (but no sensitive data)
|
||||
- [ ] User-facing errors are helpful but don't leak internals
|
||||
- [ ] Resources are cleaned up in finally blocks or context managers
|
||||
|
||||
### 3. Code Quality
|
||||
- [ ] Functions do one thing and are reasonably sized (<50 lines ideal)
|
||||
- [ ] Variable names are descriptive (no single letters except loops)
|
||||
- [ ] No commented-out code left behind
|
||||
- [ ] Complex logic has explanatory comments
|
||||
- [ ] No duplicate code (DRY principle)
|
||||
|
||||
### 4. Testing Considerations
|
||||
- [ ] Edge cases handled (empty inputs, nulls, boundaries)
|
||||
- [ ] Happy path and error paths both work
|
||||
- [ ] New code has corresponding tests (if test suite exists)
|
||||
|
||||
## Review Response Format
|
||||
|
||||
When providing review feedback, structure it as:
|
||||
|
||||
```
|
||||
## Summary
|
||||
[1-2 sentence overall assessment]
|
||||
|
||||
## Critical Issues (Must Fix)
|
||||
- Issue 1: [description + suggested fix]
|
||||
- Issue 2: ...
|
||||
|
||||
## Suggestions (Nice to Have)
|
||||
- Suggestion 1: [description]
|
||||
|
||||
## Questions
|
||||
- [Any clarifying questions about intent]
|
||||
```
|
||||
|
||||
## Common Patterns to Flag
|
||||
|
||||
### Python
|
||||
```python
|
||||
# Bad: SQL injection risk
|
||||
cursor.execute(f"SELECT * FROM users WHERE id = {user_id}")
|
||||
|
||||
# Good: Parameterized query
|
||||
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
|
||||
```
|
||||
|
||||
### JavaScript
|
||||
```javascript
|
||||
// Bad: XSS risk
|
||||
element.innerHTML = userInput;
|
||||
|
||||
// Good: Safe text content
|
||||
element.textContent = userInput;
|
||||
```
|
||||
|
||||
## Tone Guidelines
|
||||
|
||||
- Be constructive, not critical
|
||||
- Explain *why* something is an issue, not just *what*
|
||||
- Offer solutions, not just problems
|
||||
- Acknowledge good patterns you see
|
||||
590
skills/mlops-knowledge/dspy/SKILL.md
Normal file
590
skills/mlops-knowledge/dspy/SKILL.md
Normal file
@@ -0,0 +1,590 @@
|
||||
---
|
||||
name: dspy
|
||||
description: Build complex AI systems with declarative programming, optimize prompts automatically, create modular RAG systems and agents with DSPy - Stanford NLP's framework for systematic LM programming
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Prompt Engineering, DSPy, Declarative Programming, RAG, Agents, Prompt Optimization, LM Programming, Stanford NLP, Automatic Optimization, Modular AI]
|
||||
dependencies: [dspy, openai, anthropic]
|
||||
---
|
||||
|
||||
# DSPy: Declarative Language Model Programming
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use DSPy when you need to:
|
||||
- **Build complex AI systems** with multiple components and workflows
|
||||
- **Program LMs declaratively** instead of manual prompt engineering
|
||||
- **Optimize prompts automatically** using data-driven methods
|
||||
- **Create modular AI pipelines** that are maintainable and portable
|
||||
- **Improve model outputs systematically** with optimizers
|
||||
- **Build RAG systems, agents, or classifiers** with better reliability
|
||||
|
||||
**GitHub Stars**: 22,000+ | **Created By**: Stanford NLP
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Stable release
|
||||
pip install dspy
|
||||
|
||||
# Latest development version
|
||||
pip install git+https://github.com/stanfordnlp/dspy.git
|
||||
|
||||
# With specific LM providers
|
||||
pip install dspy[openai] # OpenAI
|
||||
pip install dspy[anthropic] # Anthropic Claude
|
||||
pip install dspy[all] # All providers
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Example: Question Answering
|
||||
|
||||
```python
|
||||
import dspy
|
||||
|
||||
# Configure your language model
|
||||
lm = dspy.Claude(model="claude-sonnet-4-5-20250929")
|
||||
dspy.settings.configure(lm=lm)
|
||||
|
||||
# Define a signature (input → output)
|
||||
class QA(dspy.Signature):
|
||||
"""Answer questions with short factual answers."""
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField(desc="often between 1 and 5 words")
|
||||
|
||||
# Create a module
|
||||
qa = dspy.Predict(QA)
|
||||
|
||||
# Use it
|
||||
response = qa(question="What is the capital of France?")
|
||||
print(response.answer) # "Paris"
|
||||
```
|
||||
|
||||
### Chain of Thought Reasoning
|
||||
|
||||
```python
|
||||
import dspy
|
||||
|
||||
lm = dspy.Claude(model="claude-sonnet-4-5-20250929")
|
||||
dspy.settings.configure(lm=lm)
|
||||
|
||||
# Use ChainOfThought for better reasoning
|
||||
class MathProblem(dspy.Signature):
|
||||
"""Solve math word problems."""
|
||||
problem = dspy.InputField()
|
||||
answer = dspy.OutputField(desc="numerical answer")
|
||||
|
||||
# ChainOfThought generates reasoning steps automatically
|
||||
cot = dspy.ChainOfThought(MathProblem)
|
||||
|
||||
response = cot(problem="If John has 5 apples and gives 2 to Mary, how many does he have?")
|
||||
print(response.rationale) # Shows reasoning steps
|
||||
print(response.answer) # "3"
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. Signatures
|
||||
|
||||
Signatures define the structure of your AI task (inputs → outputs):
|
||||
|
||||
```python
|
||||
# Inline signature (simple)
|
||||
qa = dspy.Predict("question -> answer")
|
||||
|
||||
# Class signature (detailed)
|
||||
class Summarize(dspy.Signature):
|
||||
"""Summarize text into key points."""
|
||||
text = dspy.InputField()
|
||||
summary = dspy.OutputField(desc="bullet points, 3-5 items")
|
||||
|
||||
summarizer = dspy.ChainOfThought(Summarize)
|
||||
```
|
||||
|
||||
**When to use each:**
|
||||
- **Inline**: Quick prototyping, simple tasks
|
||||
- **Class**: Complex tasks, type hints, better documentation
|
||||
|
||||
### 2. Modules
|
||||
|
||||
Modules are reusable components that transform inputs to outputs:
|
||||
|
||||
#### dspy.Predict
|
||||
Basic prediction module:
|
||||
|
||||
```python
|
||||
predictor = dspy.Predict("context, question -> answer")
|
||||
result = predictor(context="Paris is the capital of France",
|
||||
question="What is the capital?")
|
||||
```
|
||||
|
||||
#### dspy.ChainOfThought
|
||||
Generates reasoning steps before answering:
|
||||
|
||||
```python
|
||||
cot = dspy.ChainOfThought("question -> answer")
|
||||
result = cot(question="Why is the sky blue?")
|
||||
print(result.rationale) # Reasoning steps
|
||||
print(result.answer) # Final answer
|
||||
```
|
||||
|
||||
#### dspy.ReAct
|
||||
Agent-like reasoning with tools:
|
||||
|
||||
```python
|
||||
from dspy.predict import ReAct
|
||||
|
||||
class SearchQA(dspy.Signature):
|
||||
"""Answer questions using search."""
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField()
|
||||
|
||||
def search_tool(query: str) -> str:
|
||||
"""Search Wikipedia."""
|
||||
# Your search implementation
|
||||
return results
|
||||
|
||||
react = ReAct(SearchQA, tools=[search_tool])
|
||||
result = react(question="When was Python created?")
|
||||
```
|
||||
|
||||
#### dspy.ProgramOfThought
|
||||
Generates and executes code for reasoning:
|
||||
|
||||
```python
|
||||
pot = dspy.ProgramOfThought("question -> answer")
|
||||
result = pot(question="What is 15% of 240?")
|
||||
# Generates: answer = 240 * 0.15
|
||||
```
|
||||
|
||||
### 3. Optimizers
|
||||
|
||||
Optimizers improve your modules automatically using training data:
|
||||
|
||||
#### BootstrapFewShot
|
||||
Learns from examples:
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import BootstrapFewShot
|
||||
|
||||
# Training data
|
||||
trainset = [
|
||||
dspy.Example(question="What is 2+2?", answer="4").with_inputs("question"),
|
||||
dspy.Example(question="What is 3+5?", answer="8").with_inputs("question"),
|
||||
]
|
||||
|
||||
# Define metric
|
||||
def validate_answer(example, pred, trace=None):
|
||||
return example.answer == pred.answer
|
||||
|
||||
# Optimize
|
||||
optimizer = BootstrapFewShot(metric=validate_answer, max_bootstrapped_demos=3)
|
||||
optimized_qa = optimizer.compile(qa, trainset=trainset)
|
||||
|
||||
# Now optimized_qa performs better!
|
||||
```
|
||||
|
||||
#### MIPRO (Most Important Prompt Optimization)
|
||||
Iteratively improves prompts:
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import MIPRO
|
||||
|
||||
optimizer = MIPRO(
|
||||
metric=validate_answer,
|
||||
num_candidates=10,
|
||||
init_temperature=1.0
|
||||
)
|
||||
|
||||
optimized_cot = optimizer.compile(
|
||||
cot,
|
||||
trainset=trainset,
|
||||
num_trials=100
|
||||
)
|
||||
```
|
||||
|
||||
#### BootstrapFinetune
|
||||
Creates datasets for model fine-tuning:
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import BootstrapFinetune
|
||||
|
||||
optimizer = BootstrapFinetune(metric=validate_answer)
|
||||
optimized_module = optimizer.compile(qa, trainset=trainset)
|
||||
|
||||
# Exports training data for fine-tuning
|
||||
```
|
||||
|
||||
### 4. Building Complex Systems
|
||||
|
||||
#### Multi-Stage Pipeline
|
||||
|
||||
```python
|
||||
import dspy
|
||||
|
||||
class MultiHopQA(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=3)
|
||||
self.generate_query = dspy.ChainOfThought("question -> search_query")
|
||||
self.generate_answer = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
# Stage 1: Generate search query
|
||||
search_query = self.generate_query(question=question).search_query
|
||||
|
||||
# Stage 2: Retrieve context
|
||||
passages = self.retrieve(search_query).passages
|
||||
context = "\n".join(passages)
|
||||
|
||||
# Stage 3: Generate answer
|
||||
answer = self.generate_answer(context=context, question=question).answer
|
||||
return dspy.Prediction(answer=answer, context=context)
|
||||
|
||||
# Use the pipeline
|
||||
qa_system = MultiHopQA()
|
||||
result = qa_system(question="Who wrote the book that inspired the movie Blade Runner?")
|
||||
```
|
||||
|
||||
#### RAG System with Optimization
|
||||
|
||||
```python
|
||||
import dspy
|
||||
from dspy.retrieve.chromadb_rm import ChromadbRM
|
||||
|
||||
# Configure retriever
|
||||
retriever = ChromadbRM(
|
||||
collection_name="documents",
|
||||
persist_directory="./chroma_db"
|
||||
)
|
||||
|
||||
class RAG(dspy.Module):
|
||||
def __init__(self, num_passages=3):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=num_passages)
|
||||
self.generate = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
context = self.retrieve(question).passages
|
||||
return self.generate(context=context, question=question)
|
||||
|
||||
# Create and optimize
|
||||
rag = RAG()
|
||||
|
||||
# Optimize with training data
|
||||
from dspy.teleprompt import BootstrapFewShot
|
||||
|
||||
optimizer = BootstrapFewShot(metric=validate_answer)
|
||||
optimized_rag = optimizer.compile(rag, trainset=trainset)
|
||||
```
|
||||
|
||||
## LM Provider Configuration
|
||||
|
||||
### Anthropic Claude
|
||||
|
||||
```python
|
||||
import dspy
|
||||
|
||||
lm = dspy.Claude(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
api_key="your-api-key", # Or set ANTHROPIC_API_KEY env var
|
||||
max_tokens=1000,
|
||||
temperature=0.7
|
||||
)
|
||||
dspy.settings.configure(lm=lm)
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
```python
|
||||
lm = dspy.OpenAI(
|
||||
model="gpt-4",
|
||||
api_key="your-api-key",
|
||||
max_tokens=1000
|
||||
)
|
||||
dspy.settings.configure(lm=lm)
|
||||
```
|
||||
|
||||
### Local Models (Ollama)
|
||||
|
||||
```python
|
||||
lm = dspy.OllamaLocal(
|
||||
model="llama3.1",
|
||||
base_url="http://localhost:11434"
|
||||
)
|
||||
dspy.settings.configure(lm=lm)
|
||||
```
|
||||
|
||||
### Multiple Models
|
||||
|
||||
```python
|
||||
# Different models for different tasks
|
||||
cheap_lm = dspy.OpenAI(model="gpt-3.5-turbo")
|
||||
strong_lm = dspy.Claude(model="claude-sonnet-4-5-20250929")
|
||||
|
||||
# Use cheap model for retrieval, strong model for reasoning
|
||||
with dspy.settings.context(lm=cheap_lm):
|
||||
context = retriever(question)
|
||||
|
||||
with dspy.settings.context(lm=strong_lm):
|
||||
answer = generator(context=context, question=question)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Pattern 1: Structured Output
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class PersonInfo(BaseModel):
|
||||
name: str = Field(description="Full name")
|
||||
age: int = Field(description="Age in years")
|
||||
occupation: str = Field(description="Current job")
|
||||
|
||||
class ExtractPerson(dspy.Signature):
|
||||
"""Extract person information from text."""
|
||||
text = dspy.InputField()
|
||||
person: PersonInfo = dspy.OutputField()
|
||||
|
||||
extractor = dspy.TypedPredictor(ExtractPerson)
|
||||
result = extractor(text="John Doe is a 35-year-old software engineer.")
|
||||
print(result.person.name) # "John Doe"
|
||||
print(result.person.age) # 35
|
||||
```
|
||||
|
||||
### Pattern 2: Assertion-Driven Optimization
|
||||
|
||||
```python
|
||||
import dspy
|
||||
from dspy.primitives.assertions import assert_transform_module, backtrack_handler
|
||||
|
||||
class MathQA(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.solve = dspy.ChainOfThought("problem -> solution: float")
|
||||
|
||||
def forward(self, problem):
|
||||
solution = self.solve(problem=problem).solution
|
||||
|
||||
# Assert solution is numeric
|
||||
dspy.Assert(
|
||||
isinstance(float(solution), float),
|
||||
"Solution must be a number",
|
||||
backtrack=backtrack_handler
|
||||
)
|
||||
|
||||
return dspy.Prediction(solution=solution)
|
||||
```
|
||||
|
||||
### Pattern 3: Self-Consistency
|
||||
|
||||
```python
|
||||
import dspy
|
||||
from collections import Counter
|
||||
|
||||
class ConsistentQA(dspy.Module):
|
||||
def __init__(self, num_samples=5):
|
||||
super().__init__()
|
||||
self.qa = dspy.ChainOfThought("question -> answer")
|
||||
self.num_samples = num_samples
|
||||
|
||||
def forward(self, question):
|
||||
# Generate multiple answers
|
||||
answers = []
|
||||
for _ in range(self.num_samples):
|
||||
result = self.qa(question=question)
|
||||
answers.append(result.answer)
|
||||
|
||||
# Return most common answer
|
||||
most_common = Counter(answers).most_common(1)[0][0]
|
||||
return dspy.Prediction(answer=most_common)
|
||||
```
|
||||
|
||||
### Pattern 4: Retrieval with Reranking
|
||||
|
||||
```python
|
||||
class RerankedRAG(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=10)
|
||||
self.rerank = dspy.Predict("question, passage -> relevance_score: float")
|
||||
self.answer = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
# Retrieve candidates
|
||||
passages = self.retrieve(question).passages
|
||||
|
||||
# Rerank passages
|
||||
scored = []
|
||||
for passage in passages:
|
||||
score = float(self.rerank(question=question, passage=passage).relevance_score)
|
||||
scored.append((score, passage))
|
||||
|
||||
# Take top 3
|
||||
top_passages = [p for _, p in sorted(scored, reverse=True)[:3]]
|
||||
context = "\n\n".join(top_passages)
|
||||
|
||||
# Generate answer
|
||||
return self.answer(context=context, question=question)
|
||||
```
|
||||
|
||||
## Evaluation and Metrics
|
||||
|
||||
### Custom Metrics
|
||||
|
||||
```python
|
||||
def exact_match(example, pred, trace=None):
|
||||
"""Exact match metric."""
|
||||
return example.answer.lower() == pred.answer.lower()
|
||||
|
||||
def f1_score(example, pred, trace=None):
|
||||
"""F1 score for text overlap."""
|
||||
pred_tokens = set(pred.answer.lower().split())
|
||||
gold_tokens = set(example.answer.lower().split())
|
||||
|
||||
if not pred_tokens:
|
||||
return 0.0
|
||||
|
||||
precision = len(pred_tokens & gold_tokens) / len(pred_tokens)
|
||||
recall = len(pred_tokens & gold_tokens) / len(gold_tokens)
|
||||
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
|
||||
return 2 * (precision * recall) / (precision + recall)
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
```python
|
||||
from dspy.evaluate import Evaluate
|
||||
|
||||
# Create evaluator
|
||||
evaluator = Evaluate(
|
||||
devset=testset,
|
||||
metric=exact_match,
|
||||
num_threads=4,
|
||||
display_progress=True
|
||||
)
|
||||
|
||||
# Evaluate model
|
||||
score = evaluator(qa_system)
|
||||
print(f"Accuracy: {score}")
|
||||
|
||||
# Compare optimized vs unoptimized
|
||||
score_before = evaluator(qa)
|
||||
score_after = evaluator(optimized_qa)
|
||||
print(f"Improvement: {score_after - score_before:.2%}")
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Start Simple, Iterate
|
||||
|
||||
```python
|
||||
# Start with Predict
|
||||
qa = dspy.Predict("question -> answer")
|
||||
|
||||
# Add reasoning if needed
|
||||
qa = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
# Add optimization when you have data
|
||||
optimized_qa = optimizer.compile(qa, trainset=data)
|
||||
```
|
||||
|
||||
### 2. Use Descriptive Signatures
|
||||
|
||||
```python
|
||||
# ❌ Bad: Vague
|
||||
class Task(dspy.Signature):
|
||||
input = dspy.InputField()
|
||||
output = dspy.OutputField()
|
||||
|
||||
# ✅ Good: Descriptive
|
||||
class SummarizeArticle(dspy.Signature):
|
||||
"""Summarize news articles into 3-5 key points."""
|
||||
article = dspy.InputField(desc="full article text")
|
||||
summary = dspy.OutputField(desc="bullet points, 3-5 items")
|
||||
```
|
||||
|
||||
### 3. Optimize with Representative Data
|
||||
|
||||
```python
|
||||
# Create diverse training examples
|
||||
trainset = [
|
||||
dspy.Example(question="factual", answer="...).with_inputs("question"),
|
||||
dspy.Example(question="reasoning", answer="...").with_inputs("question"),
|
||||
dspy.Example(question="calculation", answer="...").with_inputs("question"),
|
||||
]
|
||||
|
||||
# Use validation set for metric
|
||||
def metric(example, pred, trace=None):
|
||||
return example.answer in pred.answer
|
||||
```
|
||||
|
||||
### 4. Save and Load Optimized Models
|
||||
|
||||
```python
|
||||
# Save
|
||||
optimized_qa.save("models/qa_v1.json")
|
||||
|
||||
# Load
|
||||
loaded_qa = dspy.ChainOfThought("question -> answer")
|
||||
loaded_qa.load("models/qa_v1.json")
|
||||
```
|
||||
|
||||
### 5. Monitor and Debug
|
||||
|
||||
```python
|
||||
# Enable tracing
|
||||
dspy.settings.configure(lm=lm, trace=[])
|
||||
|
||||
# Run prediction
|
||||
result = qa(question="...")
|
||||
|
||||
# Inspect trace
|
||||
for call in dspy.settings.trace:
|
||||
print(f"Prompt: {call['prompt']}")
|
||||
print(f"Response: {call['response']}")
|
||||
```
|
||||
|
||||
## Comparison to Other Approaches
|
||||
|
||||
| Feature | Manual Prompting | LangChain | DSPy |
|
||||
|---------|-----------------|-----------|------|
|
||||
| Prompt Engineering | Manual | Manual | Automatic |
|
||||
| Optimization | Trial & error | None | Data-driven |
|
||||
| Modularity | Low | Medium | High |
|
||||
| Type Safety | No | Limited | Yes (Signatures) |
|
||||
| Portability | Low | Medium | High |
|
||||
| Learning Curve | Low | Medium | Medium-High |
|
||||
|
||||
**When to choose DSPy:**
|
||||
- You have training data or can generate it
|
||||
- You need systematic prompt improvement
|
||||
- You're building complex multi-stage systems
|
||||
- You want to optimize across different LMs
|
||||
|
||||
**When to choose alternatives:**
|
||||
- Quick prototypes (manual prompting)
|
||||
- Simple chains with existing tools (LangChain)
|
||||
- Custom optimization logic needed
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://dspy.ai
|
||||
- **GitHub**: https://github.com/stanfordnlp/dspy (22k+ stars)
|
||||
- **Discord**: https://discord.gg/XCGy2WDCQB
|
||||
- **Twitter**: @DSPyOSS
|
||||
- **Paper**: "DSPy: Compiling Declarative Language Model Calls into Self-Improving Pipelines"
|
||||
|
||||
## See Also
|
||||
|
||||
- `references/modules.md` - Detailed module guide (Predict, ChainOfThought, ReAct, ProgramOfThought)
|
||||
- `references/optimizers.md` - Optimization algorithms (BootstrapFewShot, MIPRO, BootstrapFinetune)
|
||||
- `references/examples.md` - Real-world examples (RAG, agents, classifiers)
|
||||
|
||||
|
||||
663
skills/mlops-knowledge/dspy/references/examples.md
Normal file
663
skills/mlops-knowledge/dspy/references/examples.md
Normal file
@@ -0,0 +1,663 @@
|
||||
# DSPy Real-World Examples
|
||||
|
||||
Practical examples of building production systems with DSPy.
|
||||
|
||||
## Table of Contents
|
||||
- RAG Systems
|
||||
- Agent Systems
|
||||
- Classification
|
||||
- Data Processing
|
||||
- Multi-Stage Pipelines
|
||||
|
||||
## RAG Systems
|
||||
|
||||
### Basic RAG
|
||||
|
||||
```python
|
||||
import dspy
|
||||
|
||||
class BasicRAG(dspy.Module):
|
||||
def __init__(self, num_passages=3):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=num_passages)
|
||||
self.generate = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
passages = self.retrieve(question).passages
|
||||
context = "\n\n".join(passages)
|
||||
return self.generate(context=context, question=question)
|
||||
|
||||
# Configure retriever (example with Chroma)
|
||||
from dspy.retrieve.chromadb_rm import ChromadbRM
|
||||
|
||||
retriever = ChromadbRM(
|
||||
collection_name="my_docs",
|
||||
persist_directory="./chroma_db",
|
||||
k=3
|
||||
)
|
||||
dspy.settings.configure(rm=retriever)
|
||||
|
||||
# Use RAG
|
||||
rag = BasicRAG()
|
||||
result = rag(question="What is DSPy?")
|
||||
print(result.answer)
|
||||
```
|
||||
|
||||
### Optimized RAG
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import BootstrapFewShot
|
||||
|
||||
# Training data with question-answer pairs
|
||||
trainset = [
|
||||
dspy.Example(
|
||||
question="What is retrieval augmented generation?",
|
||||
answer="RAG combines retrieval of relevant documents with generation..."
|
||||
).with_inputs("question"),
|
||||
# ... more examples
|
||||
]
|
||||
|
||||
# Define metric
|
||||
def answer_correctness(example, pred, trace=None):
|
||||
# Check if answer contains key information
|
||||
return example.answer.lower() in pred.answer.lower()
|
||||
|
||||
# Optimize RAG
|
||||
optimizer = BootstrapFewShot(metric=answer_correctness)
|
||||
optimized_rag = optimizer.compile(rag, trainset=trainset)
|
||||
|
||||
# Optimized RAG performs better on similar questions
|
||||
result = optimized_rag(question="Explain RAG systems")
|
||||
```
|
||||
|
||||
### Multi-Hop RAG
|
||||
|
||||
```python
|
||||
class MultiHopRAG(dspy.Module):
|
||||
"""RAG that follows chains of reasoning across documents."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=3)
|
||||
self.generate_query = dspy.ChainOfThought("question -> search_query")
|
||||
self.generate_answer = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
# First retrieval
|
||||
query1 = self.generate_query(question=question).search_query
|
||||
passages1 = self.retrieve(query1).passages
|
||||
|
||||
# Generate follow-up query based on first results
|
||||
context1 = "\n".join(passages1)
|
||||
query2 = self.generate_query(
|
||||
question=f"Based on: {context1}\nFollow-up: {question}"
|
||||
).search_query
|
||||
|
||||
# Second retrieval
|
||||
passages2 = self.retrieve(query2).passages
|
||||
|
||||
# Combine all context
|
||||
all_context = "\n\n".join(passages1 + passages2)
|
||||
|
||||
# Generate final answer
|
||||
return self.generate_answer(context=all_context, question=question)
|
||||
|
||||
# Use multi-hop RAG
|
||||
multi_rag = MultiHopRAG()
|
||||
result = multi_rag(question="Who wrote the book that inspired Blade Runner?")
|
||||
# Hop 1: Find "Blade Runner was based on..."
|
||||
# Hop 2: Find author of that book
|
||||
```
|
||||
|
||||
### RAG with Reranking
|
||||
|
||||
```python
|
||||
class RerankedRAG(dspy.Module):
|
||||
"""RAG with learned reranking of retrieved passages."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=10) # Get more candidates
|
||||
self.rerank = dspy.Predict("question, passage -> relevance_score: float")
|
||||
self.answer = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
# Retrieve candidates
|
||||
passages = self.retrieve(question).passages
|
||||
|
||||
# Rerank passages
|
||||
scored_passages = []
|
||||
for passage in passages:
|
||||
score = float(self.rerank(
|
||||
question=question,
|
||||
passage=passage
|
||||
).relevance_score)
|
||||
scored_passages.append((score, passage))
|
||||
|
||||
# Take top 3 after reranking
|
||||
top_passages = [p for _, p in sorted(scored_passages, reverse=True)[:3]]
|
||||
context = "\n\n".join(top_passages)
|
||||
|
||||
# Generate answer from reranked context
|
||||
return self.answer(context=context, question=question)
|
||||
```
|
||||
|
||||
## Agent Systems
|
||||
|
||||
### ReAct Agent
|
||||
|
||||
```python
|
||||
from dspy.predict import ReAct
|
||||
|
||||
# Define tools
|
||||
def search_wikipedia(query: str) -> str:
|
||||
"""Search Wikipedia for information."""
|
||||
import wikipedia
|
||||
try:
|
||||
return wikipedia.summary(query, sentences=3)
|
||||
except:
|
||||
return "No results found"
|
||||
|
||||
def calculate(expression: str) -> str:
|
||||
"""Evaluate mathematical expression safely."""
|
||||
try:
|
||||
# Use safe eval
|
||||
result = eval(expression, {"__builtins__": {}}, {})
|
||||
return str(result)
|
||||
except:
|
||||
return "Invalid expression"
|
||||
|
||||
def search_web(query: str) -> str:
|
||||
"""Search the web."""
|
||||
# Your web search implementation
|
||||
return results
|
||||
|
||||
# Create agent signature
|
||||
class ResearchAgent(dspy.Signature):
|
||||
"""Answer questions using available tools."""
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField()
|
||||
|
||||
# Create ReAct agent
|
||||
agent = ReAct(ResearchAgent, tools=[search_wikipedia, calculate, search_web])
|
||||
|
||||
# Agent decides which tools to use
|
||||
result = agent(question="What is the population of France divided by 10?")
|
||||
# Agent:
|
||||
# 1. Thinks: "Need population of France"
|
||||
# 2. Acts: search_wikipedia("France population")
|
||||
# 3. Thinks: "Got 67 million, need to divide"
|
||||
# 4. Acts: calculate("67000000 / 10")
|
||||
# 5. Returns: "6,700,000"
|
||||
```
|
||||
|
||||
### Multi-Agent System
|
||||
|
||||
```python
|
||||
class MultiAgentSystem(dspy.Module):
|
||||
"""System with specialized agents for different tasks."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Router agent
|
||||
self.router = dspy.Predict("question -> agent_type: str")
|
||||
|
||||
# Specialized agents
|
||||
self.research_agent = ReAct(
|
||||
ResearchAgent,
|
||||
tools=[search_wikipedia, search_web]
|
||||
)
|
||||
self.math_agent = dspy.ProgramOfThought("problem -> answer")
|
||||
self.reasoning_agent = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
# Route to appropriate agent
|
||||
agent_type = self.router(question=question).agent_type
|
||||
|
||||
if agent_type == "research":
|
||||
return self.research_agent(question=question)
|
||||
elif agent_type == "math":
|
||||
return self.math_agent(problem=question)
|
||||
else:
|
||||
return self.reasoning_agent(question=question)
|
||||
|
||||
# Use multi-agent system
|
||||
mas = MultiAgentSystem()
|
||||
result = mas(question="What is 15% of the GDP of France?")
|
||||
# Routes to research_agent for GDP, then to math_agent for calculation
|
||||
```
|
||||
|
||||
## Classification
|
||||
|
||||
### Binary Classifier
|
||||
|
||||
```python
|
||||
class SentimentClassifier(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.classify = dspy.Predict("text -> sentiment: str")
|
||||
|
||||
def forward(self, text):
|
||||
return self.classify(text=text)
|
||||
|
||||
# Training data
|
||||
trainset = [
|
||||
dspy.Example(text="I love this!", sentiment="positive").with_inputs("text"),
|
||||
dspy.Example(text="Terrible experience", sentiment="negative").with_inputs("text"),
|
||||
# ... more examples
|
||||
]
|
||||
|
||||
# Optimize
|
||||
def accuracy(example, pred, trace=None):
|
||||
return example.sentiment == pred.sentiment
|
||||
|
||||
optimizer = BootstrapFewShot(metric=accuracy, max_bootstrapped_demos=5)
|
||||
classifier = SentimentClassifier()
|
||||
optimized_classifier = optimizer.compile(classifier, trainset=trainset)
|
||||
|
||||
# Use classifier
|
||||
result = optimized_classifier(text="This product is amazing!")
|
||||
print(result.sentiment) # "positive"
|
||||
```
|
||||
|
||||
### Multi-Class Classifier
|
||||
|
||||
```python
|
||||
class TopicClassifier(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.classify = dspy.ChainOfThought(
|
||||
"text -> category: str, confidence: float"
|
||||
)
|
||||
|
||||
def forward(self, text):
|
||||
result = self.classify(text=text)
|
||||
return dspy.Prediction(
|
||||
category=result.category,
|
||||
confidence=float(result.confidence)
|
||||
)
|
||||
|
||||
# Define categories in signature
|
||||
class TopicSignature(dspy.Signature):
|
||||
"""Classify text into one of: technology, sports, politics, entertainment."""
|
||||
text = dspy.InputField()
|
||||
category = dspy.OutputField(desc="one of: technology, sports, politics, entertainment")
|
||||
confidence = dspy.OutputField(desc="0.0 to 1.0")
|
||||
|
||||
classifier = dspy.ChainOfThought(TopicSignature)
|
||||
result = classifier(text="The Lakers won the championship")
|
||||
print(result.category) # "sports"
|
||||
print(result.confidence) # 0.95
|
||||
```
|
||||
|
||||
### Hierarchical Classifier
|
||||
|
||||
```python
|
||||
class HierarchicalClassifier(dspy.Module):
|
||||
"""Two-stage classification: coarse then fine-grained."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.coarse = dspy.Predict("text -> broad_category: str")
|
||||
self.fine_tech = dspy.Predict("text -> tech_subcategory: str")
|
||||
self.fine_sports = dspy.Predict("text -> sports_subcategory: str")
|
||||
|
||||
def forward(self, text):
|
||||
# Stage 1: Broad category
|
||||
broad = self.coarse(text=text).broad_category
|
||||
|
||||
# Stage 2: Fine-grained based on broad
|
||||
if broad == "technology":
|
||||
fine = self.fine_tech(text=text).tech_subcategory
|
||||
elif broad == "sports":
|
||||
fine = self.fine_sports(text=text).sports_subcategory
|
||||
else:
|
||||
fine = "other"
|
||||
|
||||
return dspy.Prediction(broad_category=broad, fine_category=fine)
|
||||
```
|
||||
|
||||
## Data Processing
|
||||
|
||||
### Text Summarization
|
||||
|
||||
```python
|
||||
class AdaptiveSummarizer(dspy.Module):
|
||||
"""Summarizes text to target length."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.summarize = dspy.ChainOfThought("text, target_length -> summary")
|
||||
|
||||
def forward(self, text, target_length="3 sentences"):
|
||||
return self.summarize(text=text, target_length=target_length)
|
||||
|
||||
# Use summarizer
|
||||
summarizer = AdaptiveSummarizer()
|
||||
long_text = "..." # Long article
|
||||
|
||||
short_summary = summarizer(long_text, target_length="1 sentence")
|
||||
medium_summary = summarizer(long_text, target_length="3 sentences")
|
||||
detailed_summary = summarizer(long_text, target_length="1 paragraph")
|
||||
```
|
||||
|
||||
### Information Extraction
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class PersonInfo(BaseModel):
|
||||
name: str = Field(description="Full name")
|
||||
age: int = Field(description="Age in years")
|
||||
occupation: str = Field(description="Job title")
|
||||
location: str = Field(description="City and country")
|
||||
|
||||
class ExtractPerson(dspy.Signature):
|
||||
"""Extract person information from text."""
|
||||
text = dspy.InputField()
|
||||
person: PersonInfo = dspy.OutputField()
|
||||
|
||||
extractor = dspy.TypedPredictor(ExtractPerson)
|
||||
|
||||
text = "Dr. Jane Smith, 42, is a neuroscientist at Stanford University in Palo Alto, California."
|
||||
result = extractor(text=text)
|
||||
|
||||
print(result.person.name) # "Dr. Jane Smith"
|
||||
print(result.person.age) # 42
|
||||
print(result.person.occupation) # "neuroscientist"
|
||||
print(result.person.location) # "Palo Alto, California"
|
||||
```
|
||||
|
||||
### Batch Processing
|
||||
|
||||
```python
|
||||
class BatchProcessor(dspy.Module):
|
||||
"""Process large datasets efficiently."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.process = dspy.Predict("text -> processed_text")
|
||||
|
||||
def forward(self, texts):
|
||||
# Batch processing for efficiency
|
||||
return self.process.batch([{"text": t} for t in texts])
|
||||
|
||||
# Process 1000 documents
|
||||
processor = BatchProcessor()
|
||||
results = processor(texts=large_dataset)
|
||||
|
||||
# Results are returned in order
|
||||
for original, result in zip(large_dataset, results):
|
||||
print(f"{original} -> {result.processed_text}")
|
||||
```
|
||||
|
||||
## Multi-Stage Pipelines
|
||||
|
||||
### Document Processing Pipeline
|
||||
|
||||
```python
|
||||
class DocumentPipeline(dspy.Module):
|
||||
"""Multi-stage document processing."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.extract = dspy.Predict("document -> key_points")
|
||||
self.classify = dspy.Predict("key_points -> category")
|
||||
self.summarize = dspy.ChainOfThought("key_points, category -> summary")
|
||||
self.tag = dspy.Predict("summary -> tags")
|
||||
|
||||
def forward(self, document):
|
||||
# Stage 1: Extract key points
|
||||
key_points = self.extract(document=document).key_points
|
||||
|
||||
# Stage 2: Classify
|
||||
category = self.classify(key_points=key_points).category
|
||||
|
||||
# Stage 3: Summarize
|
||||
summary = self.summarize(
|
||||
key_points=key_points,
|
||||
category=category
|
||||
).summary
|
||||
|
||||
# Stage 4: Generate tags
|
||||
tags = self.tag(summary=summary).tags
|
||||
|
||||
return dspy.Prediction(
|
||||
key_points=key_points,
|
||||
category=category,
|
||||
summary=summary,
|
||||
tags=tags
|
||||
)
|
||||
```
|
||||
|
||||
### Quality Control Pipeline
|
||||
|
||||
```python
|
||||
class QualityControlPipeline(dspy.Module):
|
||||
"""Generate output and verify quality."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.generate = dspy.ChainOfThought("prompt -> output")
|
||||
self.verify = dspy.Predict("output -> is_valid: bool, issues: str")
|
||||
self.improve = dspy.ChainOfThought("output, issues -> improved_output")
|
||||
|
||||
def forward(self, prompt, max_iterations=3):
|
||||
output = self.generate(prompt=prompt).output
|
||||
|
||||
for _ in range(max_iterations):
|
||||
# Verify output
|
||||
verification = self.verify(output=output)
|
||||
|
||||
if verification.is_valid:
|
||||
return dspy.Prediction(output=output, iterations=_ + 1)
|
||||
|
||||
# Improve based on issues
|
||||
output = self.improve(
|
||||
output=output,
|
||||
issues=verification.issues
|
||||
).improved_output
|
||||
|
||||
return dspy.Prediction(output=output, iterations=max_iterations)
|
||||
```
|
||||
|
||||
## Production Tips
|
||||
|
||||
### 1. Caching for Performance
|
||||
|
||||
```python
|
||||
from functools import lru_cache
|
||||
|
||||
class CachedRAG(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=3)
|
||||
self.generate = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
@lru_cache(maxsize=1000)
|
||||
def forward(self, question):
|
||||
passages = self.retrieve(question).passages
|
||||
context = "\n".join(passages)
|
||||
return self.generate(context=context, question=question).answer
|
||||
```
|
||||
|
||||
### 2. Error Handling
|
||||
|
||||
```python
|
||||
class RobustModule(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.process = dspy.ChainOfThought("input -> output")
|
||||
|
||||
def forward(self, input):
|
||||
try:
|
||||
result = self.process(input=input)
|
||||
return result
|
||||
except Exception as e:
|
||||
# Log error
|
||||
print(f"Error processing {input}: {e}")
|
||||
# Return fallback
|
||||
return dspy.Prediction(output="Error: could not process input")
|
||||
```
|
||||
|
||||
### 3. Monitoring
|
||||
|
||||
```python
|
||||
class MonitoredModule(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.process = dspy.ChainOfThought("input -> output")
|
||||
self.call_count = 0
|
||||
self.errors = 0
|
||||
|
||||
def forward(self, input):
|
||||
self.call_count += 1
|
||||
|
||||
try:
|
||||
result = self.process(input=input)
|
||||
return result
|
||||
except Exception as e:
|
||||
self.errors += 1
|
||||
raise
|
||||
|
||||
def get_stats(self):
|
||||
return {
|
||||
"calls": self.call_count,
|
||||
"errors": self.errors,
|
||||
"error_rate": self.errors / max(self.call_count, 1)
|
||||
}
|
||||
```
|
||||
|
||||
### 4. A/B Testing
|
||||
|
||||
```python
|
||||
class ABTestModule(dspy.Module):
|
||||
"""Run two variants and compare."""
|
||||
|
||||
def __init__(self, variant_a, variant_b):
|
||||
super().__init__()
|
||||
self.variant_a = variant_a
|
||||
self.variant_b = variant_b
|
||||
self.a_calls = 0
|
||||
self.b_calls = 0
|
||||
|
||||
def forward(self, input, variant="a"):
|
||||
if variant == "a":
|
||||
self.a_calls += 1
|
||||
return self.variant_a(input=input)
|
||||
else:
|
||||
self.b_calls += 1
|
||||
return self.variant_b(input=input)
|
||||
|
||||
# Compare two optimizers
|
||||
baseline = dspy.ChainOfThought("question -> answer")
|
||||
optimized = BootstrapFewShot(...).compile(baseline, trainset=trainset)
|
||||
|
||||
ab_test = ABTestModule(variant_a=baseline, variant_b=optimized)
|
||||
|
||||
# Route 50% to each
|
||||
import random
|
||||
variant = "a" if random.random() < 0.5 else "b"
|
||||
result = ab_test(input=question, variant=variant)
|
||||
```
|
||||
|
||||
## Complete Example: Customer Support Bot
|
||||
|
||||
```python
|
||||
import dspy
|
||||
from dspy.teleprompt import BootstrapFewShot
|
||||
|
||||
class CustomerSupportBot(dspy.Module):
|
||||
"""Complete customer support system."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Classify intent
|
||||
self.classify_intent = dspy.Predict("message -> intent: str")
|
||||
|
||||
# Specialized handlers
|
||||
self.technical_handler = dspy.ChainOfThought("message, history -> response")
|
||||
self.billing_handler = dspy.ChainOfThought("message, history -> response")
|
||||
self.general_handler = dspy.Predict("message, history -> response")
|
||||
|
||||
# Retrieve relevant docs
|
||||
self.retrieve = dspy.Retrieve(k=3)
|
||||
|
||||
# Conversation history
|
||||
self.history = []
|
||||
|
||||
def forward(self, message):
|
||||
# Classify intent
|
||||
intent = self.classify_intent(message=message).intent
|
||||
|
||||
# Retrieve relevant documentation
|
||||
docs = self.retrieve(message).passages
|
||||
context = "\n".join(docs)
|
||||
|
||||
# Add context to history
|
||||
history_str = "\n".join(self.history)
|
||||
full_message = f"Context: {context}\n\nMessage: {message}"
|
||||
|
||||
# Route to appropriate handler
|
||||
if intent == "technical":
|
||||
response = self.technical_handler(
|
||||
message=full_message,
|
||||
history=history_str
|
||||
).response
|
||||
elif intent == "billing":
|
||||
response = self.billing_handler(
|
||||
message=full_message,
|
||||
history=history_str
|
||||
).response
|
||||
else:
|
||||
response = self.general_handler(
|
||||
message=full_message,
|
||||
history=history_str
|
||||
).response
|
||||
|
||||
# Update history
|
||||
self.history.append(f"User: {message}")
|
||||
self.history.append(f"Bot: {response}")
|
||||
|
||||
return dspy.Prediction(response=response, intent=intent)
|
||||
|
||||
# Training data
|
||||
trainset = [
|
||||
dspy.Example(
|
||||
message="My account isn't working",
|
||||
intent="technical",
|
||||
response="I'd be happy to help. What error are you seeing?"
|
||||
).with_inputs("message"),
|
||||
# ... more examples
|
||||
]
|
||||
|
||||
# Define metric
|
||||
def response_quality(example, pred, trace=None):
|
||||
# Check if response is helpful
|
||||
if len(pred.response) < 20:
|
||||
return 0.0
|
||||
if example.intent != pred.intent:
|
||||
return 0.3
|
||||
return 1.0
|
||||
|
||||
# Optimize
|
||||
optimizer = BootstrapFewShot(metric=response_quality)
|
||||
bot = CustomerSupportBot()
|
||||
optimized_bot = optimizer.compile(bot, trainset=trainset)
|
||||
|
||||
# Use in production
|
||||
optimized_bot.save("models/support_bot_v1.json")
|
||||
|
||||
# Later, load and use
|
||||
loaded_bot = CustomerSupportBot()
|
||||
loaded_bot.load("models/support_bot_v1.json")
|
||||
response = loaded_bot(message="I can't log in")
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://dspy.ai
|
||||
- **Examples Repo**: https://github.com/stanfordnlp/dspy/tree/main/examples
|
||||
- **Discord**: https://discord.gg/XCGy2WDCQB
|
||||
475
skills/mlops-knowledge/dspy/references/modules.md
Normal file
475
skills/mlops-knowledge/dspy/references/modules.md
Normal file
@@ -0,0 +1,475 @@
|
||||
# DSPy Modules
|
||||
|
||||
Complete guide to DSPy's built-in modules for language model programming.
|
||||
|
||||
## Module Basics
|
||||
|
||||
DSPy modules are composable building blocks inspired by PyTorch's NN modules:
|
||||
- Have learnable parameters (prompts, few-shot examples)
|
||||
- Can be composed using Python control flow
|
||||
- Generalized to handle any signature
|
||||
- Optimizable with DSPy optimizers
|
||||
|
||||
### Base Module Pattern
|
||||
|
||||
```python
|
||||
import dspy
|
||||
|
||||
class CustomModule(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Initialize sub-modules
|
||||
self.predictor = dspy.Predict("input -> output")
|
||||
|
||||
def forward(self, input):
|
||||
# Module logic
|
||||
result = self.predictor(input=input)
|
||||
return result
|
||||
```
|
||||
|
||||
## Core Modules
|
||||
|
||||
### dspy.Predict
|
||||
|
||||
**Basic prediction module** - Makes LM calls without reasoning steps.
|
||||
|
||||
```python
|
||||
# Inline signature
|
||||
qa = dspy.Predict("question -> answer")
|
||||
result = qa(question="What is 2+2?")
|
||||
|
||||
# Class signature
|
||||
class QA(dspy.Signature):
|
||||
"""Answer questions concisely."""
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField(desc="short, factual answer")
|
||||
|
||||
qa = dspy.Predict(QA)
|
||||
result = qa(question="What is the capital of France?")
|
||||
print(result.answer) # "Paris"
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Simple, direct predictions
|
||||
- No reasoning steps needed
|
||||
- Fast responses required
|
||||
|
||||
### dspy.ChainOfThought
|
||||
|
||||
**Step-by-step reasoning** - Generates rationale before answer.
|
||||
|
||||
**Parameters:**
|
||||
- `signature`: Task signature
|
||||
- `rationale_field`: Custom reasoning field (optional)
|
||||
- `rationale_field_type`: Type for rationale (default: `str`)
|
||||
|
||||
```python
|
||||
# Basic usage
|
||||
cot = dspy.ChainOfThought("question -> answer")
|
||||
result = cot(question="If I have 5 apples and give away 2, how many remain?")
|
||||
print(result.rationale) # "Let's think step by step..."
|
||||
print(result.answer) # "3"
|
||||
|
||||
# Custom rationale field
|
||||
cot = dspy.ChainOfThought(
|
||||
signature="problem -> solution",
|
||||
rationale_field=dspy.OutputField(
|
||||
prefix="Reasoning: Let's break this down step by step to"
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Complex reasoning tasks
|
||||
- Math word problems
|
||||
- Logical deduction
|
||||
- Quality > speed
|
||||
|
||||
**Performance:**
|
||||
- ~2x slower than Predict
|
||||
- Significantly better accuracy on reasoning tasks
|
||||
|
||||
### dspy.ProgramOfThought
|
||||
|
||||
**Code-based reasoning** - Generates and executes Python code.
|
||||
|
||||
```python
|
||||
pot = dspy.ProgramOfThought("question -> answer")
|
||||
|
||||
result = pot(question="What is 15% of 240?")
|
||||
# Internally generates: answer = 240 * 0.15
|
||||
# Executes code and returns result
|
||||
print(result.answer) # 36.0
|
||||
|
||||
result = pot(question="If a train travels 60 mph for 2.5 hours, how far does it go?")
|
||||
# Generates: distance = 60 * 2.5
|
||||
print(result.answer) # 150.0
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Arithmetic calculations
|
||||
- Symbolic math
|
||||
- Data transformations
|
||||
- Deterministic computations
|
||||
|
||||
**Benefits:**
|
||||
- More reliable than text-based math
|
||||
- Handles complex calculations
|
||||
- Transparent (shows generated code)
|
||||
|
||||
### dspy.ReAct
|
||||
|
||||
**Reasoning + Acting** - Agent that uses tools iteratively.
|
||||
|
||||
```python
|
||||
from dspy.predict import ReAct
|
||||
|
||||
# Define tools
|
||||
def search_wikipedia(query: str) -> str:
|
||||
"""Search Wikipedia for information."""
|
||||
# Your search implementation
|
||||
return search_results
|
||||
|
||||
def calculate(expression: str) -> float:
|
||||
"""Evaluate a mathematical expression."""
|
||||
return eval(expression)
|
||||
|
||||
# Create ReAct agent
|
||||
class ResearchQA(dspy.Signature):
|
||||
"""Answer questions using available tools."""
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField()
|
||||
|
||||
react = ReAct(ResearchQA, tools=[search_wikipedia, calculate])
|
||||
|
||||
# Agent decides which tools to use
|
||||
result = react(question="How old was Einstein when he published special relativity?")
|
||||
# Internally:
|
||||
# 1. Thinks: "Need birth year and publication year"
|
||||
# 2. Acts: search_wikipedia("Albert Einstein")
|
||||
# 3. Acts: search_wikipedia("Special relativity 1905")
|
||||
# 4. Acts: calculate("1905 - 1879")
|
||||
# 5. Returns: "26 years old"
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Multi-step research tasks
|
||||
- Tool-using agents
|
||||
- Complex information retrieval
|
||||
- Tasks requiring multiple API calls
|
||||
|
||||
**Best practices:**
|
||||
- Keep tool descriptions clear and specific
|
||||
- Limit to 5-7 tools (too many = confusion)
|
||||
- Provide tool usage examples in docstrings
|
||||
|
||||
### dspy.MultiChainComparison
|
||||
|
||||
**Generate multiple outputs and compare** - Self-consistency pattern.
|
||||
|
||||
```python
|
||||
mcc = dspy.MultiChainComparison("question -> answer", M=5)
|
||||
|
||||
result = mcc(question="What is the capital of France?")
|
||||
# Generates 5 candidate answers
|
||||
# Compares and selects most consistent
|
||||
print(result.answer) # "Paris"
|
||||
print(result.candidates) # All 5 generated answers
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `M`: Number of candidates to generate (default: 5)
|
||||
- `temperature`: Sampling temperature for diversity
|
||||
|
||||
**When to use:**
|
||||
- High-stakes decisions
|
||||
- Ambiguous questions
|
||||
- When single answer may be unreliable
|
||||
|
||||
**Tradeoff:**
|
||||
- M times slower (M parallel calls)
|
||||
- Higher accuracy on ambiguous tasks
|
||||
|
||||
### dspy.majority
|
||||
|
||||
**Majority voting over multiple predictions.**
|
||||
|
||||
```python
|
||||
from dspy.primitives import majority
|
||||
|
||||
# Generate multiple predictions
|
||||
predictor = dspy.Predict("question -> answer")
|
||||
predictions = [predictor(question="What is 2+2?") for _ in range(5)]
|
||||
|
||||
# Take majority vote
|
||||
answer = majority([p.answer for p in predictions])
|
||||
print(answer) # "4"
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Combining multiple model outputs
|
||||
- Reducing variance in predictions
|
||||
- Ensemble approaches
|
||||
|
||||
## Advanced Modules
|
||||
|
||||
### dspy.TypedPredictor
|
||||
|
||||
**Structured output with Pydantic models.**
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class PersonInfo(BaseModel):
|
||||
name: str = Field(description="Full name")
|
||||
age: int = Field(description="Age in years")
|
||||
occupation: str = Field(description="Current job")
|
||||
|
||||
class ExtractPerson(dspy.Signature):
|
||||
"""Extract person information from text."""
|
||||
text = dspy.InputField()
|
||||
person: PersonInfo = dspy.OutputField()
|
||||
|
||||
extractor = dspy.TypedPredictor(ExtractPerson)
|
||||
result = extractor(text="John Doe is a 35-year-old software engineer.")
|
||||
|
||||
print(result.person.name) # "John Doe"
|
||||
print(result.person.age) # 35
|
||||
print(result.person.occupation) # "software engineer"
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Type safety
|
||||
- Automatic validation
|
||||
- JSON schema generation
|
||||
- IDE autocomplete
|
||||
|
||||
### dspy.Retry
|
||||
|
||||
**Automatic retry with validation.**
|
||||
|
||||
```python
|
||||
from dspy.primitives import Retry
|
||||
|
||||
def validate_number(example, pred, trace=None):
|
||||
"""Validate output is a number."""
|
||||
try:
|
||||
float(pred.answer)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
# Retry up to 3 times if validation fails
|
||||
qa = Retry(
|
||||
dspy.ChainOfThought("question -> answer"),
|
||||
validate=validate_number,
|
||||
max_retries=3
|
||||
)
|
||||
|
||||
result = qa(question="What is 15% of 80?")
|
||||
# If first attempt returns non-numeric, retries automatically
|
||||
```
|
||||
|
||||
### dspy.Assert
|
||||
|
||||
**Assertion-driven optimization.**
|
||||
|
||||
```python
|
||||
import dspy
|
||||
from dspy.primitives.assertions import assert_transform_module, backtrack_handler
|
||||
|
||||
class ValidatedQA(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.qa = dspy.ChainOfThought("question -> answer: float")
|
||||
|
||||
def forward(self, question):
|
||||
answer = self.qa(question=question).answer
|
||||
|
||||
# Assert answer is numeric
|
||||
dspy.Assert(
|
||||
isinstance(float(answer), float),
|
||||
"Answer must be a number",
|
||||
backtrack=backtrack_handler
|
||||
)
|
||||
|
||||
return dspy.Prediction(answer=answer)
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Catches errors during optimization
|
||||
- Guides LM toward valid outputs
|
||||
- Better than post-hoc filtering
|
||||
|
||||
## Module Composition
|
||||
|
||||
### Sequential Pipeline
|
||||
|
||||
```python
|
||||
class Pipeline(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.stage1 = dspy.Predict("input -> intermediate")
|
||||
self.stage2 = dspy.ChainOfThought("intermediate -> output")
|
||||
|
||||
def forward(self, input):
|
||||
intermediate = self.stage1(input=input).intermediate
|
||||
output = self.stage2(intermediate=intermediate).output
|
||||
return dspy.Prediction(output=output)
|
||||
```
|
||||
|
||||
### Conditional Logic
|
||||
|
||||
```python
|
||||
class ConditionalModule(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.router = dspy.Predict("question -> category: str")
|
||||
self.simple_qa = dspy.Predict("question -> answer")
|
||||
self.complex_qa = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
category = self.router(question=question).category
|
||||
|
||||
if category == "simple":
|
||||
return self.simple_qa(question=question)
|
||||
else:
|
||||
return self.complex_qa(question=question)
|
||||
```
|
||||
|
||||
### Parallel Execution
|
||||
|
||||
```python
|
||||
class ParallelModule(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.approach1 = dspy.ChainOfThought("question -> answer")
|
||||
self.approach2 = dspy.ProgramOfThought("question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
# Run both approaches
|
||||
answer1 = self.approach1(question=question).answer
|
||||
answer2 = self.approach2(question=question).answer
|
||||
|
||||
# Compare or combine results
|
||||
if answer1 == answer2:
|
||||
return dspy.Prediction(answer=answer1, confidence="high")
|
||||
else:
|
||||
return dspy.Prediction(answer=answer1, confidence="low")
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
All modules support batch processing for efficiency:
|
||||
|
||||
```python
|
||||
cot = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
questions = [
|
||||
"What is 2+2?",
|
||||
"What is 3+3?",
|
||||
"What is 4+4?"
|
||||
]
|
||||
|
||||
# Process all at once
|
||||
results = cot.batch([{"question": q} for q in questions])
|
||||
|
||||
for result in results:
|
||||
print(result.answer)
|
||||
```
|
||||
|
||||
## Saving and Loading
|
||||
|
||||
```python
|
||||
# Save module
|
||||
qa = dspy.ChainOfThought("question -> answer")
|
||||
qa.save("models/qa_v1.json")
|
||||
|
||||
# Load module
|
||||
loaded_qa = dspy.ChainOfThought("question -> answer")
|
||||
loaded_qa.load("models/qa_v1.json")
|
||||
```
|
||||
|
||||
**What gets saved:**
|
||||
- Few-shot examples
|
||||
- Prompt instructions
|
||||
- Module configuration
|
||||
|
||||
**What doesn't get saved:**
|
||||
- Model weights (DSPy doesn't fine-tune by default)
|
||||
- LM provider configuration
|
||||
|
||||
## Module Selection Guide
|
||||
|
||||
| Task | Module | Reason |
|
||||
|------|--------|--------|
|
||||
| Simple classification | Predict | Fast, direct |
|
||||
| Math word problems | ProgramOfThought | Reliable calculations |
|
||||
| Logical reasoning | ChainOfThought | Better with steps |
|
||||
| Multi-step research | ReAct | Tool usage |
|
||||
| High-stakes decisions | MultiChainComparison | Self-consistency |
|
||||
| Structured extraction | TypedPredictor | Type safety |
|
||||
| Ambiguous questions | MultiChainComparison | Multiple perspectives |
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Start with Predict**, add reasoning only if needed
|
||||
2. **Use batch processing** for multiple inputs
|
||||
3. **Cache predictions** for repeated queries
|
||||
4. **Profile token usage** with `track_usage=True`
|
||||
5. **Optimize after prototyping** with teleprompters
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Pattern: Retrieval + Generation
|
||||
|
||||
```python
|
||||
class RAG(dspy.Module):
|
||||
def __init__(self, k=3):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=k)
|
||||
self.generate = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
context = self.retrieve(question).passages
|
||||
return self.generate(context=context, question=question)
|
||||
```
|
||||
|
||||
### Pattern: Verification Loop
|
||||
|
||||
```python
|
||||
class VerifiedQA(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.answer = dspy.ChainOfThought("question -> answer")
|
||||
self.verify = dspy.Predict("question, answer -> is_correct: bool")
|
||||
|
||||
def forward(self, question, max_attempts=3):
|
||||
for _ in range(max_attempts):
|
||||
answer = self.answer(question=question).answer
|
||||
is_correct = self.verify(question=question, answer=answer).is_correct
|
||||
|
||||
if is_correct:
|
||||
return dspy.Prediction(answer=answer)
|
||||
|
||||
return dspy.Prediction(answer="Unable to verify answer")
|
||||
```
|
||||
|
||||
### Pattern: Multi-Turn Dialog
|
||||
|
||||
```python
|
||||
class DialogAgent(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.respond = dspy.Predict("history, user_message -> assistant_message")
|
||||
self.history = []
|
||||
|
||||
def forward(self, user_message):
|
||||
history_str = "\n".join(self.history)
|
||||
response = self.respond(history=history_str, user_message=user_message)
|
||||
|
||||
self.history.append(f"User: {user_message}")
|
||||
self.history.append(f"Assistant: {response.assistant_message}")
|
||||
|
||||
return response
|
||||
```
|
||||
566
skills/mlops-knowledge/dspy/references/optimizers.md
Normal file
566
skills/mlops-knowledge/dspy/references/optimizers.md
Normal file
@@ -0,0 +1,566 @@
|
||||
# DSPy Optimizers (Teleprompters)
|
||||
|
||||
Complete guide to DSPy's optimization algorithms for improving prompts and model weights.
|
||||
|
||||
## What are Optimizers?
|
||||
|
||||
DSPy optimizers (called "teleprompters") automatically improve your modules by:
|
||||
- **Synthesizing few-shot examples** from training data
|
||||
- **Proposing better instructions** through search
|
||||
- **Fine-tuning model weights** (optional)
|
||||
|
||||
**Key idea**: Instead of manually tuning prompts, define a metric and let DSPy optimize.
|
||||
|
||||
## Optimizer Selection Guide
|
||||
|
||||
| Optimizer | Best For | Speed | Quality | Data Needed |
|
||||
|-----------|----------|-------|---------|-------------|
|
||||
| BootstrapFewShot | General purpose | Fast | Good | 10-50 examples |
|
||||
| MIPRO | Instruction tuning | Medium | Excellent | 50-200 examples |
|
||||
| BootstrapFinetune | Fine-tuning | Slow | Excellent | 100+ examples |
|
||||
| COPRO | Prompt optimization | Medium | Good | 20-100 examples |
|
||||
| KNNFewShot | Quick baseline | Very fast | Fair | 10+ examples |
|
||||
|
||||
## Core Optimizers
|
||||
|
||||
### BootstrapFewShot
|
||||
|
||||
**Most popular optimizer** - Generates few-shot demonstrations from training data.
|
||||
|
||||
**How it works:**
|
||||
1. Takes your training examples
|
||||
2. Uses your module to generate predictions
|
||||
3. Selects high-quality predictions (based on metric)
|
||||
4. Uses these as few-shot examples in future prompts
|
||||
|
||||
**Parameters:**
|
||||
- `metric`: Function that scores predictions (required)
|
||||
- `max_bootstrapped_demos`: Max demonstrations to generate (default: 4)
|
||||
- `max_labeled_demos`: Max labeled examples to use (default: 16)
|
||||
- `max_rounds`: Optimization iterations (default: 1)
|
||||
- `metric_threshold`: Minimum score to accept (optional)
|
||||
|
||||
```python
|
||||
import dspy
|
||||
from dspy.teleprompt import BootstrapFewShot
|
||||
|
||||
# Define metric
|
||||
def validate_answer(example, pred, trace=None):
|
||||
"""Return True if prediction matches gold answer."""
|
||||
return example.answer.lower() == pred.answer.lower()
|
||||
|
||||
# Training data
|
||||
trainset = [
|
||||
dspy.Example(question="What is 2+2?", answer="4").with_inputs("question"),
|
||||
dspy.Example(question="What is 3+5?", answer="8").with_inputs("question"),
|
||||
dspy.Example(question="What is 10-3?", answer="7").with_inputs("question"),
|
||||
]
|
||||
|
||||
# Create module
|
||||
qa = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
# Optimize
|
||||
optimizer = BootstrapFewShot(
|
||||
metric=validate_answer,
|
||||
max_bootstrapped_demos=3,
|
||||
max_rounds=2
|
||||
)
|
||||
|
||||
optimized_qa = optimizer.compile(qa, trainset=trainset)
|
||||
|
||||
# Now optimized_qa has learned few-shot examples!
|
||||
result = optimized_qa(question="What is 5+7?")
|
||||
```
|
||||
|
||||
**Best practices:**
|
||||
- Start with 10-50 training examples
|
||||
- Use diverse examples covering edge cases
|
||||
- Set `max_bootstrapped_demos=3-5` for most tasks
|
||||
- Increase `max_rounds=2-3` for better quality
|
||||
|
||||
**When to use:**
|
||||
- First optimizer to try
|
||||
- You have 10+ labeled examples
|
||||
- Want quick improvements
|
||||
- General-purpose tasks
|
||||
|
||||
### MIPRO (Most Important Prompt Optimization)
|
||||
|
||||
**State-of-the-art optimizer** - Iteratively searches for better instructions.
|
||||
|
||||
**How it works:**
|
||||
1. Generates candidate instructions
|
||||
2. Tests each on validation set
|
||||
3. Selects best-performing instructions
|
||||
4. Iterates to refine further
|
||||
|
||||
**Parameters:**
|
||||
- `metric`: Evaluation metric (required)
|
||||
- `num_candidates`: Instructions to try per iteration (default: 10)
|
||||
- `init_temperature`: Sampling temperature (default: 1.0)
|
||||
- `verbose`: Show progress (default: False)
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import MIPRO
|
||||
|
||||
# Define metric with more nuance
|
||||
def answer_quality(example, pred, trace=None):
|
||||
"""Score answer quality 0-1."""
|
||||
if example.answer.lower() in pred.answer.lower():
|
||||
return 1.0
|
||||
# Partial credit for similar answers
|
||||
return 0.5 if len(set(example.answer.split()) & set(pred.answer.split())) > 0 else 0.0
|
||||
|
||||
# Larger training set (MIPRO benefits from more data)
|
||||
trainset = [...] # 50-200 examples
|
||||
valset = [...] # 20-50 examples
|
||||
|
||||
# Create module
|
||||
qa = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
# Optimize with MIPRO
|
||||
optimizer = MIPRO(
|
||||
metric=answer_quality,
|
||||
num_candidates=10,
|
||||
init_temperature=1.0,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
optimized_qa = optimizer.compile(
|
||||
student=qa,
|
||||
trainset=trainset,
|
||||
valset=valset, # MIPRO uses separate validation set
|
||||
num_trials=100 # More trials = better quality
|
||||
)
|
||||
```
|
||||
|
||||
**Best practices:**
|
||||
- Use 50-200 training examples
|
||||
- Separate validation set (20-50 examples)
|
||||
- Run 100-200 trials for best results
|
||||
- Takes 10-30 minutes typically
|
||||
|
||||
**When to use:**
|
||||
- You have 50+ labeled examples
|
||||
- Want state-of-the-art performance
|
||||
- Willing to wait for optimization
|
||||
- Complex reasoning tasks
|
||||
|
||||
### BootstrapFinetune
|
||||
|
||||
**Fine-tune model weights** - Creates training dataset for fine-tuning.
|
||||
|
||||
**How it works:**
|
||||
1. Generates synthetic training data
|
||||
2. Exports data in fine-tuning format
|
||||
3. You fine-tune model separately
|
||||
4. Load fine-tuned model back
|
||||
|
||||
**Parameters:**
|
||||
- `metric`: Evaluation metric (required)
|
||||
- `max_bootstrapped_demos`: Demonstrations to generate (default: 4)
|
||||
- `max_rounds`: Data generation rounds (default: 1)
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import BootstrapFinetune
|
||||
|
||||
# Training data
|
||||
trainset = [...] # 100+ examples recommended
|
||||
|
||||
# Define metric
|
||||
def validate(example, pred, trace=None):
|
||||
return example.answer == pred.answer
|
||||
|
||||
# Create module
|
||||
qa = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
# Generate fine-tuning data
|
||||
optimizer = BootstrapFinetune(metric=validate)
|
||||
optimized_qa = optimizer.compile(qa, trainset=trainset)
|
||||
|
||||
# Exports training data to file
|
||||
# You then fine-tune using your LM provider's API
|
||||
|
||||
# After fine-tuning, load your model:
|
||||
finetuned_lm = dspy.OpenAI(model="ft:gpt-3.5-turbo:your-model-id")
|
||||
dspy.settings.configure(lm=finetuned_lm)
|
||||
```
|
||||
|
||||
**Best practices:**
|
||||
- Use 100+ training examples
|
||||
- Validate on held-out test set
|
||||
- Monitor for overfitting
|
||||
- Compare with prompt-based methods first
|
||||
|
||||
**When to use:**
|
||||
- You have 100+ examples
|
||||
- Latency is critical (fine-tuned models faster)
|
||||
- Task is narrow and well-defined
|
||||
- Prompt optimization isn't enough
|
||||
|
||||
### COPRO (Coordinate Prompt Optimization)
|
||||
|
||||
**Optimize prompts via gradient-free search.**
|
||||
|
||||
**How it works:**
|
||||
1. Generates prompt variants
|
||||
2. Evaluates each variant
|
||||
3. Selects best prompts
|
||||
4. Iterates to refine
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import COPRO
|
||||
|
||||
# Training data
|
||||
trainset = [...]
|
||||
|
||||
# Define metric
|
||||
def metric(example, pred, trace=None):
|
||||
return example.answer == pred.answer
|
||||
|
||||
# Create module
|
||||
qa = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
# Optimize with COPRO
|
||||
optimizer = COPRO(
|
||||
metric=metric,
|
||||
breadth=10, # Candidates per iteration
|
||||
depth=3 # Optimization rounds
|
||||
)
|
||||
|
||||
optimized_qa = optimizer.compile(qa, trainset=trainset)
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Want prompt optimization
|
||||
- Have 20-100 examples
|
||||
- MIPRO too slow
|
||||
|
||||
### KNNFewShot
|
||||
|
||||
**Simple k-nearest neighbors** - Selects similar examples for each query.
|
||||
|
||||
**How it works:**
|
||||
1. Embeds all training examples
|
||||
2. For each query, finds k most similar examples
|
||||
3. Uses these as few-shot demonstrations
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import KNNFewShot
|
||||
|
||||
trainset = [...]
|
||||
|
||||
# No metric needed - just selects similar examples
|
||||
optimizer = KNNFewShot(k=3)
|
||||
optimized_qa = optimizer.compile(qa, trainset=trainset)
|
||||
|
||||
# For each query, uses 3 most similar examples from trainset
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Quick baseline
|
||||
- Have diverse training examples
|
||||
- Similarity is good proxy for helpfulness
|
||||
|
||||
## Writing Metrics
|
||||
|
||||
Metrics are functions that score predictions. They're critical for optimization.
|
||||
|
||||
### Binary Metrics
|
||||
|
||||
```python
|
||||
def exact_match(example, pred, trace=None):
|
||||
"""Return True if prediction exactly matches gold."""
|
||||
return example.answer == pred.answer
|
||||
|
||||
def contains_answer(example, pred, trace=None):
|
||||
"""Return True if prediction contains gold answer."""
|
||||
return example.answer.lower() in pred.answer.lower()
|
||||
```
|
||||
|
||||
### Continuous Metrics
|
||||
|
||||
```python
|
||||
def f1_score(example, pred, trace=None):
|
||||
"""F1 score between prediction and gold."""
|
||||
pred_tokens = set(pred.answer.lower().split())
|
||||
gold_tokens = set(example.answer.lower().split())
|
||||
|
||||
if not pred_tokens:
|
||||
return 0.0
|
||||
|
||||
precision = len(pred_tokens & gold_tokens) / len(pred_tokens)
|
||||
recall = len(pred_tokens & gold_tokens) / len(gold_tokens)
|
||||
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
|
||||
return 2 * (precision * recall) / (precision + recall)
|
||||
|
||||
def semantic_similarity(example, pred, trace=None):
|
||||
"""Embedding similarity between prediction and gold."""
|
||||
from sentence_transformers import SentenceTransformer
|
||||
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
|
||||
emb1 = model.encode(example.answer)
|
||||
emb2 = model.encode(pred.answer)
|
||||
|
||||
similarity = cosine_similarity(emb1, emb2)
|
||||
return similarity
|
||||
```
|
||||
|
||||
### Multi-Factor Metrics
|
||||
|
||||
```python
|
||||
def comprehensive_metric(example, pred, trace=None):
|
||||
"""Combine multiple factors."""
|
||||
score = 0.0
|
||||
|
||||
# Correctness (50%)
|
||||
if example.answer.lower() in pred.answer.lower():
|
||||
score += 0.5
|
||||
|
||||
# Conciseness (25%)
|
||||
if len(pred.answer.split()) <= 20:
|
||||
score += 0.25
|
||||
|
||||
# Citation (25%)
|
||||
if "source:" in pred.answer.lower():
|
||||
score += 0.25
|
||||
|
||||
return score
|
||||
```
|
||||
|
||||
### Using Trace for Debugging
|
||||
|
||||
```python
|
||||
def metric_with_trace(example, pred, trace=None):
|
||||
"""Metric that uses trace for debugging."""
|
||||
is_correct = example.answer == pred.answer
|
||||
|
||||
if trace is not None and not is_correct:
|
||||
# Log failures for analysis
|
||||
print(f"Failed on: {example.question}")
|
||||
print(f"Expected: {example.answer}")
|
||||
print(f"Got: {pred.answer}")
|
||||
|
||||
return is_correct
|
||||
```
|
||||
|
||||
## Evaluation Best Practices
|
||||
|
||||
### Train/Val/Test Split
|
||||
|
||||
```python
|
||||
# Split data
|
||||
trainset = data[:100] # 70%
|
||||
valset = data[100:120] # 15%
|
||||
testset = data[120:] # 15%
|
||||
|
||||
# Optimize on train
|
||||
optimized = optimizer.compile(module, trainset=trainset)
|
||||
|
||||
# Validate during optimization (for MIPRO)
|
||||
optimized = optimizer.compile(module, trainset=trainset, valset=valset)
|
||||
|
||||
# Evaluate on test
|
||||
from dspy.evaluate import Evaluate
|
||||
evaluator = Evaluate(devset=testset, metric=metric)
|
||||
score = evaluator(optimized)
|
||||
```
|
||||
|
||||
### Cross-Validation
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import KFold
|
||||
|
||||
kfold = KFold(n_splits=5)
|
||||
scores = []
|
||||
|
||||
for train_idx, val_idx in kfold.split(data):
|
||||
trainset = [data[i] for i in train_idx]
|
||||
valset = [data[i] for i in val_idx]
|
||||
|
||||
optimized = optimizer.compile(module, trainset=trainset)
|
||||
score = evaluator(optimized, devset=valset)
|
||||
scores.append(score)
|
||||
|
||||
print(f"Average score: {sum(scores) / len(scores):.2f}")
|
||||
```
|
||||
|
||||
### Comparing Optimizers
|
||||
|
||||
```python
|
||||
results = {}
|
||||
|
||||
for opt_name, optimizer in [
|
||||
("baseline", None),
|
||||
("fewshot", BootstrapFewShot(metric=metric)),
|
||||
("mipro", MIPRO(metric=metric)),
|
||||
]:
|
||||
if optimizer is None:
|
||||
module_opt = module
|
||||
else:
|
||||
module_opt = optimizer.compile(module, trainset=trainset)
|
||||
|
||||
score = evaluator(module_opt, devset=testset)
|
||||
results[opt_name] = score
|
||||
|
||||
print(results)
|
||||
# {'baseline': 0.65, 'fewshot': 0.78, 'mipro': 0.85}
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Custom Optimizer
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import Teleprompter
|
||||
|
||||
class CustomOptimizer(Teleprompter):
|
||||
def __init__(self, metric):
|
||||
self.metric = metric
|
||||
|
||||
def compile(self, student, trainset, **kwargs):
|
||||
# Your optimization logic here
|
||||
# Return optimized student module
|
||||
return student
|
||||
```
|
||||
|
||||
### Multi-Stage Optimization
|
||||
|
||||
```python
|
||||
# Stage 1: Bootstrap few-shot
|
||||
stage1 = BootstrapFewShot(metric=metric, max_bootstrapped_demos=3)
|
||||
optimized1 = stage1.compile(module, trainset=trainset)
|
||||
|
||||
# Stage 2: Instruction tuning
|
||||
stage2 = MIPRO(metric=metric, num_candidates=10)
|
||||
optimized2 = stage2.compile(optimized1, trainset=trainset, valset=valset)
|
||||
|
||||
# Final optimized module
|
||||
final_module = optimized2
|
||||
```
|
||||
|
||||
### Ensemble Optimization
|
||||
|
||||
```python
|
||||
class EnsembleModule(dspy.Module):
|
||||
def __init__(self, modules):
|
||||
super().__init__()
|
||||
self.modules = modules
|
||||
|
||||
def forward(self, question):
|
||||
predictions = [m(question=question).answer for m in self.modules]
|
||||
# Vote or average
|
||||
return dspy.Prediction(answer=max(set(predictions), key=predictions.count))
|
||||
|
||||
# Optimize multiple modules
|
||||
opt1 = BootstrapFewShot(metric=metric).compile(module, trainset=trainset)
|
||||
opt2 = MIPRO(metric=metric).compile(module, trainset=trainset)
|
||||
opt3 = COPRO(metric=metric).compile(module, trainset=trainset)
|
||||
|
||||
# Ensemble
|
||||
ensemble = EnsembleModule([opt1, opt2, opt3])
|
||||
```
|
||||
|
||||
## Optimization Workflow
|
||||
|
||||
### 1. Start with Baseline
|
||||
|
||||
```python
|
||||
# No optimization
|
||||
baseline = dspy.ChainOfThought("question -> answer")
|
||||
baseline_score = evaluator(baseline, devset=testset)
|
||||
print(f"Baseline: {baseline_score}")
|
||||
```
|
||||
|
||||
### 2. Try BootstrapFewShot
|
||||
|
||||
```python
|
||||
# Quick optimization
|
||||
fewshot = BootstrapFewShot(metric=metric, max_bootstrapped_demos=3)
|
||||
optimized = fewshot.compile(baseline, trainset=trainset)
|
||||
fewshot_score = evaluator(optimized, devset=testset)
|
||||
print(f"Few-shot: {fewshot_score} (+{fewshot_score - baseline_score:.2f})")
|
||||
```
|
||||
|
||||
### 3. If More Data Available, Try MIPRO
|
||||
|
||||
```python
|
||||
# State-of-the-art optimization
|
||||
mipro = MIPRO(metric=metric, num_candidates=10)
|
||||
optimized_mipro = mipro.compile(baseline, trainset=trainset, valset=valset)
|
||||
mipro_score = evaluator(optimized_mipro, devset=testset)
|
||||
print(f"MIPRO: {mipro_score} (+{mipro_score - baseline_score:.2f})")
|
||||
```
|
||||
|
||||
### 4. Save Best Model
|
||||
|
||||
```python
|
||||
if mipro_score > fewshot_score:
|
||||
optimized_mipro.save("models/best_model.json")
|
||||
else:
|
||||
optimized.save("models/best_model.json")
|
||||
```
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
### 1. Overfitting to Training Data
|
||||
|
||||
```python
|
||||
# ❌ Bad: Too many demos
|
||||
optimizer = BootstrapFewShot(max_bootstrapped_demos=20) # Overfits!
|
||||
|
||||
# ✅ Good: Moderate demos
|
||||
optimizer = BootstrapFewShot(max_bootstrapped_demos=3-5)
|
||||
```
|
||||
|
||||
### 2. Metric Doesn't Match Task
|
||||
|
||||
```python
|
||||
# ❌ Bad: Binary metric for nuanced task
|
||||
def bad_metric(example, pred, trace=None):
|
||||
return example.answer == pred.answer # Too strict!
|
||||
|
||||
# ✅ Good: Graded metric
|
||||
def good_metric(example, pred, trace=None):
|
||||
return f1_score(example.answer, pred.answer) # Allows partial credit
|
||||
```
|
||||
|
||||
### 3. Insufficient Training Data
|
||||
|
||||
```python
|
||||
# ❌ Bad: Too little data
|
||||
trainset = data[:5] # Not enough!
|
||||
|
||||
# ✅ Good: Sufficient data
|
||||
trainset = data[:50] # Better
|
||||
```
|
||||
|
||||
### 4. No Validation Set
|
||||
|
||||
```python
|
||||
# ❌ Bad: Optimizing on test set
|
||||
optimizer.compile(module, trainset=testset) # Cheating!
|
||||
|
||||
# ✅ Good: Proper splits
|
||||
optimizer.compile(module, trainset=trainset, valset=valset)
|
||||
evaluator(optimized, devset=testset)
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Start simple**: BootstrapFewShot first
|
||||
2. **Use representative data**: Cover edge cases
|
||||
3. **Monitor overfitting**: Validate on held-out set
|
||||
4. **Iterate metrics**: Refine based on failures
|
||||
5. **Save checkpoints**: Don't lose progress
|
||||
6. **Compare to baseline**: Measure improvement
|
||||
7. **Test multiple optimizers**: Find best fit
|
||||
|
||||
## Resources
|
||||
|
||||
- **Paper**: "DSPy: Compiling Declarative Language Model Calls into Self-Improving Pipelines"
|
||||
- **GitHub**: https://github.com/stanfordnlp/dspy
|
||||
- **Discord**: https://discord.gg/XCGy2WDCQB
|
||||
221
skills/mlops-knowledge/faiss/SKILL.md
Normal file
221
skills/mlops-knowledge/faiss/SKILL.md
Normal file
@@ -0,0 +1,221 @@
|
||||
---
|
||||
name: faiss
|
||||
description: Facebook's library for efficient similarity search and clustering of dense vectors. Supports billions of vectors, GPU acceleration, and various index types (Flat, IVF, HNSW). Use for fast k-NN search, large-scale vector retrieval, or when you need pure similarity search without metadata. Best for high-performance applications.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [RAG, FAISS, Similarity Search, Vector Search, Facebook AI, GPU Acceleration, Billion-Scale, K-NN, HNSW, High Performance, Large Scale]
|
||||
dependencies: [faiss-cpu, faiss-gpu, numpy]
|
||||
---
|
||||
|
||||
# FAISS - Efficient Similarity Search
|
||||
|
||||
Facebook AI's library for billion-scale vector similarity search.
|
||||
|
||||
## When to use FAISS
|
||||
|
||||
**Use FAISS when:**
|
||||
- Need fast similarity search on large vector datasets (millions/billions)
|
||||
- GPU acceleration required
|
||||
- Pure vector similarity (no metadata filtering needed)
|
||||
- High throughput, low latency critical
|
||||
- Offline/batch processing of embeddings
|
||||
|
||||
**Metrics**:
|
||||
- **31,700+ GitHub stars**
|
||||
- Meta/Facebook AI Research
|
||||
- **Handles billions of vectors**
|
||||
- **C++** with Python bindings
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **Chroma/Pinecone**: Need metadata filtering
|
||||
- **Weaviate**: Need full database features
|
||||
- **Annoy**: Simpler, fewer features
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# CPU only
|
||||
pip install faiss-cpu
|
||||
|
||||
# GPU support
|
||||
pip install faiss-gpu
|
||||
```
|
||||
|
||||
### Basic usage
|
||||
|
||||
```python
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
# Create sample data (1000 vectors, 128 dimensions)
|
||||
d = 128
|
||||
nb = 1000
|
||||
vectors = np.random.random((nb, d)).astype('float32')
|
||||
|
||||
# Create index
|
||||
index = faiss.IndexFlatL2(d) # L2 distance
|
||||
index.add(vectors) # Add vectors
|
||||
|
||||
# Search
|
||||
k = 5 # Find 5 nearest neighbors
|
||||
query = np.random.random((1, d)).astype('float32')
|
||||
distances, indices = index.search(query, k)
|
||||
|
||||
print(f"Nearest neighbors: {indices}")
|
||||
print(f"Distances: {distances}")
|
||||
```
|
||||
|
||||
## Index types
|
||||
|
||||
### 1. Flat (exact search)
|
||||
|
||||
```python
|
||||
# L2 (Euclidean) distance
|
||||
index = faiss.IndexFlatL2(d)
|
||||
|
||||
# Inner product (cosine similarity if normalized)
|
||||
index = faiss.IndexFlatIP(d)
|
||||
|
||||
# Slowest, most accurate
|
||||
```
|
||||
|
||||
### 2. IVF (inverted file) - Fast approximate
|
||||
|
||||
```python
|
||||
# Create quantizer
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
|
||||
# IVF index with 100 clusters
|
||||
nlist = 100
|
||||
index = faiss.IndexIVFFlat(quantizer, d, nlist)
|
||||
|
||||
# Train on data
|
||||
index.train(vectors)
|
||||
|
||||
# Add vectors
|
||||
index.add(vectors)
|
||||
|
||||
# Search (nprobe = clusters to search)
|
||||
index.nprobe = 10
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
### 3. HNSW (Hierarchical NSW) - Best quality/speed
|
||||
|
||||
```python
|
||||
# HNSW index
|
||||
M = 32 # Number of connections per layer
|
||||
index = faiss.IndexHNSWFlat(d, M)
|
||||
|
||||
# No training needed
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
### 4. Product Quantization - Memory efficient
|
||||
|
||||
```python
|
||||
# PQ reduces memory by 16-32×
|
||||
m = 8 # Number of subquantizers
|
||||
nbits = 8
|
||||
index = faiss.IndexPQ(d, m, nbits)
|
||||
|
||||
# Train and add
|
||||
index.train(vectors)
|
||||
index.add(vectors)
|
||||
```
|
||||
|
||||
## Save and load
|
||||
|
||||
```python
|
||||
# Save index
|
||||
faiss.write_index(index, "large.index")
|
||||
|
||||
# Load index
|
||||
index = faiss.read_index("large.index")
|
||||
|
||||
# Continue using
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
## GPU acceleration
|
||||
|
||||
```python
|
||||
# Single GPU
|
||||
res = faiss.StandardGpuResources()
|
||||
index_cpu = faiss.IndexFlatL2(d)
|
||||
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0
|
||||
|
||||
# Multi-GPU
|
||||
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
|
||||
|
||||
# 10-100× faster than CPU
|
||||
```
|
||||
|
||||
## LangChain integration
|
||||
|
||||
```python
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
# Create FAISS vector store
|
||||
vectorstore = FAISS.from_documents(docs, OpenAIEmbeddings())
|
||||
|
||||
# Save
|
||||
vectorstore.save_local("faiss_index")
|
||||
|
||||
# Load
|
||||
vectorstore = FAISS.load_local(
|
||||
"faiss_index",
|
||||
OpenAIEmbeddings(),
|
||||
allow_dangerous_deserialization=True
|
||||
)
|
||||
|
||||
# Search
|
||||
results = vectorstore.similarity_search("query", k=5)
|
||||
```
|
||||
|
||||
## LlamaIndex integration
|
||||
|
||||
```python
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
import faiss
|
||||
|
||||
# Create FAISS index
|
||||
d = 1536
|
||||
faiss_index = faiss.IndexFlatL2(d)
|
||||
|
||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Choose right index type** - Flat for <10K, IVF for 10K-1M, HNSW for quality
|
||||
2. **Normalize for cosine** - Use IndexFlatIP with normalized vectors
|
||||
3. **Use GPU for large datasets** - 10-100× faster
|
||||
4. **Save trained indices** - Training is expensive
|
||||
5. **Tune nprobe/ef_search** - Balance speed/accuracy
|
||||
6. **Monitor memory** - PQ for large datasets
|
||||
7. **Batch queries** - Better GPU utilization
|
||||
|
||||
## Performance
|
||||
|
||||
| Index Type | Build Time | Search Time | Memory | Accuracy |
|
||||
|------------|------------|-------------|--------|----------|
|
||||
| Flat | Fast | Slow | High | 100% |
|
||||
| IVF | Medium | Fast | Medium | 95-99% |
|
||||
| HNSW | Slow | Fastest | High | 99% |
|
||||
| PQ | Medium | Fast | Low | 90-95% |
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/facebookresearch/faiss ⭐ 31,700+
|
||||
- **Wiki**: https://github.com/facebookresearch/faiss/wiki
|
||||
- **License**: MIT
|
||||
|
||||
|
||||
280
skills/mlops-knowledge/faiss/references/index_types.md
Normal file
280
skills/mlops-knowledge/faiss/references/index_types.md
Normal file
@@ -0,0 +1,280 @@
|
||||
# FAISS Index Types Guide
|
||||
|
||||
Complete guide to choosing and using FAISS index types.
|
||||
|
||||
## Index selection guide
|
||||
|
||||
| Dataset Size | Index Type | Training | Accuracy | Speed |
|
||||
|--------------|------------|----------|----------|-------|
|
||||
| < 10K | Flat | No | 100% | Slow |
|
||||
| 10K-1M | IVF | Yes | 95-99% | Fast |
|
||||
| 1M-10M | HNSW | No | 99% | Fastest |
|
||||
| > 10M | IVF+PQ | Yes | 90-95% | Fast, low memory |
|
||||
|
||||
## Flat indices (exact search)
|
||||
|
||||
### IndexFlatL2 - L2 (Euclidean) distance
|
||||
|
||||
```python
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
d = 128 # Dimension
|
||||
index = faiss.IndexFlatL2(d)
|
||||
|
||||
# Add vectors
|
||||
vectors = np.random.random((1000, d)).astype('float32')
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
k = 5
|
||||
query = np.random.random((1, d)).astype('float32')
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Dataset < 10,000 vectors
|
||||
- Need 100% accuracy
|
||||
- Serving as baseline
|
||||
|
||||
### IndexFlatIP - Inner product (cosine similarity)
|
||||
|
||||
```python
|
||||
# For cosine similarity, normalize vectors first
|
||||
import faiss
|
||||
|
||||
d = 128
|
||||
index = faiss.IndexFlatIP(d)
|
||||
|
||||
# Normalize vectors (required for cosine similarity)
|
||||
faiss.normalize_L2(vectors)
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
faiss.normalize_L2(query)
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Need cosine similarity
|
||||
- Recommendation systems
|
||||
- Text embeddings
|
||||
|
||||
## IVF indices (inverted file)
|
||||
|
||||
### IndexIVFFlat - Cluster-based search
|
||||
|
||||
```python
|
||||
# Create quantizer
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
|
||||
# Create IVF index with 100 clusters
|
||||
nlist = 100 # Number of clusters
|
||||
index = faiss.IndexIVFFlat(quantizer, d, nlist)
|
||||
|
||||
# Train on data (required!)
|
||||
index.train(vectors)
|
||||
|
||||
# Add vectors
|
||||
index.add(vectors)
|
||||
|
||||
# Search (nprobe = clusters to search)
|
||||
index.nprobe = 10 # Search 10 closest clusters
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `nlist`: Number of clusters (√N to 4√N recommended)
|
||||
- `nprobe`: Clusters to search (1-nlist, higher = more accurate)
|
||||
|
||||
**Use when:**
|
||||
- Dataset 10K-1M vectors
|
||||
- Need fast approximate search
|
||||
- Can afford training time
|
||||
|
||||
### Tuning nprobe
|
||||
|
||||
```python
|
||||
# Test different nprobe values
|
||||
for nprobe in [1, 5, 10, 20, 50]:
|
||||
index.nprobe = nprobe
|
||||
distances, indices = index.search(query, k)
|
||||
# Measure recall/speed trade-off
|
||||
```
|
||||
|
||||
**Guidelines:**
|
||||
- `nprobe=1`: Fastest, ~50% recall
|
||||
- `nprobe=10`: Good balance, ~95% recall
|
||||
- `nprobe=nlist`: Exact search (same as Flat)
|
||||
|
||||
## HNSW indices (graph-based)
|
||||
|
||||
### IndexHNSWFlat - Hierarchical NSW
|
||||
|
||||
```python
|
||||
# HNSW index
|
||||
M = 32 # Number of connections per layer (16-64)
|
||||
index = faiss.IndexHNSWFlat(d, M)
|
||||
|
||||
# Optional: Set ef_construction (build time parameter)
|
||||
index.hnsw.efConstruction = 40 # Higher = better quality, slower build
|
||||
|
||||
# Add vectors (no training needed!)
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
index.hnsw.efSearch = 16 # Search time parameter
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `M`: Connections per layer (16-64, default 32)
|
||||
- `efConstruction`: Build quality (40-200, higher = better)
|
||||
- `efSearch`: Search quality (16-512, higher = more accurate)
|
||||
|
||||
**Use when:**
|
||||
- Need best quality approximate search
|
||||
- Can afford higher memory (more connections)
|
||||
- Dataset 1M-10M vectors
|
||||
|
||||
## PQ indices (product quantization)
|
||||
|
||||
### IndexPQ - Memory-efficient
|
||||
|
||||
```python
|
||||
# PQ reduces memory by 16-32×
|
||||
m = 8 # Number of subquantizers (divides d)
|
||||
nbits = 8 # Bits per subquantizer
|
||||
|
||||
index = faiss.IndexPQ(d, m, nbits)
|
||||
|
||||
# Train (required!)
|
||||
index.train(vectors)
|
||||
|
||||
# Add vectors
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `m`: Subquantizers (d must be divisible by m)
|
||||
- `nbits`: Bits per code (8 or 16)
|
||||
|
||||
**Memory savings:**
|
||||
- Original: d × 4 bytes (float32)
|
||||
- PQ: m bytes
|
||||
- Compression ratio: 4d/m
|
||||
|
||||
**Use when:**
|
||||
- Limited memory
|
||||
- Large datasets (> 10M vectors)
|
||||
- Can accept ~90-95% accuracy
|
||||
|
||||
### IndexIVFPQ - IVF + PQ combined
|
||||
|
||||
```python
|
||||
# Best for very large datasets
|
||||
nlist = 4096
|
||||
m = 8
|
||||
nbits = 8
|
||||
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, nbits)
|
||||
|
||||
# Train
|
||||
index.train(vectors)
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
index.nprobe = 32
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Dataset > 10M vectors
|
||||
- Need fast search + low memory
|
||||
- Can accept 90-95% accuracy
|
||||
|
||||
## GPU indices
|
||||
|
||||
### Single GPU
|
||||
|
||||
```python
|
||||
import faiss
|
||||
|
||||
# Create CPU index
|
||||
index_cpu = faiss.IndexFlatL2(d)
|
||||
|
||||
# Move to GPU
|
||||
res = faiss.StandardGpuResources() # GPU resources
|
||||
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0
|
||||
|
||||
# Use normally
|
||||
index_gpu.add(vectors)
|
||||
distances, indices = index_gpu.search(query, k)
|
||||
```
|
||||
|
||||
### Multi-GPU
|
||||
|
||||
```python
|
||||
# Use all available GPUs
|
||||
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
|
||||
|
||||
# Or specific GPUs
|
||||
gpus = [0, 1, 2, 3] # Use GPUs 0-3
|
||||
index_gpu = faiss.index_cpu_to_gpus_list(index_cpu, gpus)
|
||||
```
|
||||
|
||||
**Speedup:**
|
||||
- Single GPU: 10-50× faster than CPU
|
||||
- Multi-GPU: Near-linear scaling
|
||||
|
||||
## Index factory
|
||||
|
||||
```python
|
||||
# Easy index creation with string descriptors
|
||||
index = faiss.index_factory(d, "IVF100,Flat")
|
||||
index = faiss.index_factory(d, "HNSW32")
|
||||
index = faiss.index_factory(d, "IVF4096,PQ8")
|
||||
|
||||
# Train and use
|
||||
index.train(vectors)
|
||||
index.add(vectors)
|
||||
```
|
||||
|
||||
**Common descriptors:**
|
||||
- `"Flat"`: Exact search
|
||||
- `"IVF100,Flat"`: IVF with 100 clusters
|
||||
- `"HNSW32"`: HNSW with M=32
|
||||
- `"IVF4096,PQ8"`: IVF + PQ compression
|
||||
|
||||
## Performance comparison
|
||||
|
||||
### Search speed (1M vectors, k=10)
|
||||
|
||||
| Index | Build Time | Search Time | Memory | Recall |
|
||||
|-------|------------|-------------|--------|--------|
|
||||
| Flat | 0s | 50ms | 512 MB | 100% |
|
||||
| IVF100 | 5s | 2ms | 512 MB | 95% |
|
||||
| HNSW32 | 60s | 1ms | 1GB | 99% |
|
||||
| IVF4096+PQ8 | 30s | 3ms | 32 MB | 90% |
|
||||
|
||||
*CPU (16 cores), 128-dim vectors*
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with Flat** - Baseline for comparison
|
||||
2. **Use IVF for medium datasets** - Good balance
|
||||
3. **Use HNSW for best quality** - If memory allows
|
||||
4. **Add PQ for memory savings** - Large datasets
|
||||
5. **GPU for > 100K vectors** - 10-50× speedup
|
||||
6. **Tune nprobe/efSearch** - Trade-off speed/accuracy
|
||||
7. **Train on representative data** - Better clustering
|
||||
8. **Save trained indices** - Avoid retraining
|
||||
|
||||
## Resources
|
||||
|
||||
- **Wiki**: https://github.com/facebookresearch/faiss/wiki
|
||||
- **Paper**: https://arxiv.org/abs/1702.08734
|
||||
367
skills/mlops-knowledge/flash-attention/SKILL.md
Normal file
367
skills/mlops-knowledge/flash-attention/SKILL.md
Normal file
@@ -0,0 +1,367 @@
|
||||
---
|
||||
name: optimizing-attention-flash
|
||||
description: Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Optimization, Flash Attention, Attention Optimization, Memory Efficiency, Speed Optimization, Long Context, PyTorch, SDPA, H100, FP8, Transformers]
|
||||
dependencies: [flash-attn, torch, transformers]
|
||||
---
|
||||
|
||||
# Flash Attention - Fast Memory-Efficient Attention
|
||||
|
||||
## Quick start
|
||||
|
||||
Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
|
||||
|
||||
**PyTorch native (easiest, PyTorch 2.2+)**:
|
||||
```python
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
|
||||
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
|
||||
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
|
||||
|
||||
# Automatically uses Flash Attention if available
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
```
|
||||
|
||||
**flash-attn library (more features)**:
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# q, k, v: [batch, seqlen, nheads, headdim]
|
||||
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Enable in existing PyTorch model
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
Flash Attention Integration:
|
||||
- [ ] Step 1: Check PyTorch version (≥2.2)
|
||||
- [ ] Step 2: Enable Flash Attention backend
|
||||
- [ ] Step 3: Verify speedup with profiling
|
||||
- [ ] Step 4: Test accuracy matches baseline
|
||||
```
|
||||
|
||||
**Step 1: Check PyTorch version**
|
||||
|
||||
```bash
|
||||
python -c "import torch; print(torch.__version__)"
|
||||
# Should be ≥2.2.0
|
||||
```
|
||||
|
||||
If <2.2, upgrade:
|
||||
```bash
|
||||
pip install --upgrade torch
|
||||
```
|
||||
|
||||
**Step 2: Enable Flash Attention backend**
|
||||
|
||||
Replace standard attention:
|
||||
```python
|
||||
# Before (standard attention)
|
||||
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
|
||||
out = attn_weights @ v
|
||||
|
||||
# After (Flash Attention)
|
||||
import torch.nn.functional as F
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
||||
```
|
||||
|
||||
Force Flash Attention backend:
|
||||
```python
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True,
|
||||
enable_math=False,
|
||||
enable_mem_efficient=False
|
||||
):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
```
|
||||
|
||||
**Step 3: Verify speedup with profiling**
|
||||
|
||||
```python
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
def test_attention(use_flash):
|
||||
q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
||||
|
||||
if use_flash:
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True):
|
||||
return F.scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
|
||||
return attn @ v
|
||||
|
||||
# Benchmark
|
||||
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
|
||||
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
|
||||
|
||||
print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
|
||||
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
|
||||
```
|
||||
|
||||
Expected: 2-4x speedup for sequences >512 tokens.
|
||||
|
||||
**Step 4: Test accuracy matches baseline**
|
||||
|
||||
```python
|
||||
# Compare outputs
|
||||
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
||||
|
||||
# Flash Attention
|
||||
out_flash = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
# Standard attention
|
||||
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
|
||||
out_standard = attn_weights @ v
|
||||
|
||||
# Check difference
|
||||
diff = (out_flash - out_standard).abs().max()
|
||||
print(f"Max difference: {diff:.6f}")
|
||||
# Should be <1e-3 for float16
|
||||
```
|
||||
|
||||
### Workflow 2: Use flash-attn library for advanced features
|
||||
|
||||
For multi-query attention, sliding window, or H100 FP8.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
flash-attn Library Setup:
|
||||
- [ ] Step 1: Install flash-attn library
|
||||
- [ ] Step 2: Modify attention code
|
||||
- [ ] Step 3: Enable advanced features
|
||||
- [ ] Step 4: Benchmark performance
|
||||
```
|
||||
|
||||
**Step 1: Install flash-attn library**
|
||||
|
||||
```bash
|
||||
# NVIDIA GPUs (CUDA 12.0+)
|
||||
pip install flash-attn --no-build-isolation
|
||||
|
||||
# Verify installation
|
||||
python -c "from flash_attn import flash_attn_func; print('Success')"
|
||||
```
|
||||
|
||||
**Step 2: Modify attention code**
|
||||
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# Input: [batch_size, seq_len, num_heads, head_dim]
|
||||
# Transpose from [batch, heads, seq, dim] if needed
|
||||
q = q.transpose(1, 2) # [batch, seq, heads, dim]
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
out = flash_attn_func(
|
||||
q, k, v,
|
||||
dropout_p=0.1,
|
||||
causal=True, # For autoregressive models
|
||||
window_size=(-1, -1), # No sliding window
|
||||
softmax_scale=None # Auto-scale
|
||||
)
|
||||
|
||||
out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]
|
||||
```
|
||||
|
||||
**Step 3: Enable advanced features**
|
||||
|
||||
Multi-query attention (shared K/V across heads):
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# q: [batch, seq, num_q_heads, dim]
|
||||
# k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
|
||||
out = flash_attn_func(q, k, v) # Automatically handles MQA
|
||||
```
|
||||
|
||||
Sliding window attention (local attention):
|
||||
```python
|
||||
# Only attend to window of 256 tokens before/after
|
||||
out = flash_attn_func(
|
||||
q, k, v,
|
||||
window_size=(256, 256), # (left, right) window
|
||||
causal=True
|
||||
)
|
||||
```
|
||||
|
||||
**Step 4: Benchmark performance**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from flash_attn import flash_attn_func
|
||||
import time
|
||||
|
||||
q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
_ = flash_attn_func(q, k, v)
|
||||
|
||||
# Benchmark
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
out = flash_attn_func(q, k, v)
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
|
||||
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
|
||||
print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
|
||||
```
|
||||
|
||||
### Workflow 3: H100 FP8 optimization (FlashAttention-3)
|
||||
|
||||
For maximum performance on H100 GPUs.
|
||||
|
||||
```
|
||||
FP8 Setup:
|
||||
- [ ] Step 1: Verify H100 GPU available
|
||||
- [ ] Step 2: Install flash-attn with FP8 support
|
||||
- [ ] Step 3: Convert inputs to FP8
|
||||
- [ ] Step 4: Run with FP8 attention
|
||||
```
|
||||
|
||||
**Step 1: Verify H100 GPU**
|
||||
|
||||
```bash
|
||||
nvidia-smi --query-gpu=name --format=csv
|
||||
# Should show "H100" or "H800"
|
||||
```
|
||||
|
||||
**Step 2: Install flash-attn with FP8 support**
|
||||
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
# FP8 support included for H100
|
||||
```
|
||||
|
||||
**Step 3: Convert inputs to FP8**
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
||||
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
||||
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
||||
|
||||
# Convert to float8_e4m3 (FP8)
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
k_fp8 = k.to(torch.float8_e4m3fn)
|
||||
v_fp8 = v.to(torch.float8_e4m3fn)
|
||||
```
|
||||
|
||||
**Step 4: Run with FP8 attention**
|
||||
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# FlashAttention-3 automatically uses FP8 kernels on H100
|
||||
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
|
||||
# Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use Flash Attention when:**
|
||||
- Training transformers with sequences >512 tokens
|
||||
- Running inference with long context (>2K tokens)
|
||||
- GPU memory constrained (OOM with standard attention)
|
||||
- Need 2-4x speedup without accuracy loss
|
||||
- Using PyTorch 2.2+ or can install flash-attn
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Standard attention**: Sequences <256 tokens (overhead not worth it)
|
||||
- **xFormers**: Need more attention variants (not just speed)
|
||||
- **Memory-efficient attention**: CPU inference (Flash Attention needs GPU)
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: ImportError: cannot import flash_attn**
|
||||
|
||||
Install with no-build-isolation flag:
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Or install CUDA toolkit first:
|
||||
```bash
|
||||
conda install cuda -c nvidia
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
**Issue: Slower than expected (no speedup)**
|
||||
|
||||
Flash Attention benefits increase with sequence length:
|
||||
- <512 tokens: Minimal speedup (10-20%)
|
||||
- 512-2K tokens: 2-3x speedup
|
||||
- >2K tokens: 3-4x speedup
|
||||
|
||||
Check sequence length is sufficient.
|
||||
|
||||
**Issue: RuntimeError: CUDA error**
|
||||
|
||||
Verify GPU supports Flash Attention:
|
||||
```python
|
||||
import torch
|
||||
print(torch.cuda.get_device_capability())
|
||||
# Should be ≥(7, 5) for Turing+
|
||||
```
|
||||
|
||||
Flash Attention requires:
|
||||
- Ampere (A100, A10): ✅ Full support
|
||||
- Turing (T4): ✅ Supported
|
||||
- Volta (V100): ❌ Not supported
|
||||
|
||||
**Issue: Accuracy degradation**
|
||||
|
||||
Check dtype is float16 or bfloat16 (not float32):
|
||||
```python
|
||||
q = q.to(torch.float16) # Or torch.bfloat16
|
||||
```
|
||||
|
||||
Flash Attention uses float16/bfloat16 for speed. Float32 not supported.
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Integration with HuggingFace Transformers**: See [references/transformers-integration.md](references/transformers-integration.md) for enabling Flash Attention in BERT, GPT, Llama models.
|
||||
|
||||
**Performance benchmarks**: See [references/benchmarks.md](references/benchmarks.md) for detailed speed and memory comparisons across GPUs and sequence lengths.
|
||||
|
||||
**Algorithm details**: See [references/algorithm.md](references/algorithm.md) for tiling strategy, recomputation, and IO complexity analysis.
|
||||
|
||||
**Advanced features**: See [references/advanced-features.md](references/advanced-features.md) for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
|
||||
- **VRAM**: Same as standard attention (Flash Attention doesn't increase memory)
|
||||
- **CUDA**: 12.0+ (11.8 minimum)
|
||||
- **PyTorch**: 2.2+ for native support
|
||||
|
||||
**Not supported**: V100 (Volta), CPU inference
|
||||
|
||||
## Resources
|
||||
|
||||
- Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
|
||||
- Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024)
|
||||
- Blog: https://tridao.me/blog/2024/flash3/
|
||||
- GitHub: https://github.com/Dao-AILab/flash-attention
|
||||
- PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
|
||||
|
||||
|
||||
215
skills/mlops-knowledge/flash-attention/references/benchmarks.md
Normal file
215
skills/mlops-knowledge/flash-attention/references/benchmarks.md
Normal file
@@ -0,0 +1,215 @@
|
||||
# Performance Benchmarks
|
||||
|
||||
## Contents
|
||||
- Speed comparisons across GPUs
|
||||
- Memory usage analysis
|
||||
- Scaling with sequence length
|
||||
- Training vs inference performance
|
||||
- Flash Attention versions comparison
|
||||
|
||||
## Speed comparisons across GPUs
|
||||
|
||||
### A100 80GB (Ampere)
|
||||
|
||||
**Forward pass time** (milliseconds, batch=8, heads=32, dim=64):
|
||||
|
||||
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 | Speedup (FA2) |
|
||||
|------------|----------|--------------|--------------|---------------|
|
||||
| 512 | 1.2 | 0.9 | N/A | 1.3x |
|
||||
| 1024 | 3.8 | 1.4 | N/A | 2.7x |
|
||||
| 2048 | 14.2 | 4.8 | N/A | 3.0x |
|
||||
| 4096 | 55.1 | 17.3 | N/A | 3.2x |
|
||||
| 8192 | 218.5 | 66.2 | N/A | 3.3x |
|
||||
|
||||
### H100 80GB (Hopper)
|
||||
|
||||
**Forward pass time** (milliseconds, same config):
|
||||
|
||||
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 (FP16) | Flash Attn 3 (FP8) | Best Speedup |
|
||||
|------------|----------|--------------|---------------------|--------------------|--------------|
|
||||
| 512 | 0.8 | 0.6 | 0.4 | 0.3 | 2.7x |
|
||||
| 1024 | 2.6 | 1.0 | 0.6 | 0.4 | 6.5x |
|
||||
| 2048 | 9.8 | 3.4 | 2.0 | 1.3 | 7.5x |
|
||||
| 4096 | 38.2 | 12.5 | 7.2 | 4.8 | 8.0x |
|
||||
| 8192 | 151.4 | 47.8 | 27.1 | 18.2 | 8.3x |
|
||||
|
||||
**Key insight**: Flash Attention 3 on H100 with FP8 achieves ~1.2 PFLOPS (75% of theoretical max).
|
||||
|
||||
### A10G 24GB (Ampere)
|
||||
|
||||
**Forward pass time** (milliseconds, batch=4):
|
||||
|
||||
| Seq Length | Standard | Flash Attn 2 | Speedup |
|
||||
|------------|----------|--------------|---------|
|
||||
| 512 | 2.1 | 1.6 | 1.3x |
|
||||
| 1024 | 6.8 | 2.8 | 2.4x |
|
||||
| 2048 | 25.9 | 9.4 | 2.8x |
|
||||
| 4096 | 102.1 | 35.2 | 2.9x |
|
||||
|
||||
## Memory usage analysis
|
||||
|
||||
### GPU memory consumption (batch=8, heads=32, dim=64)
|
||||
|
||||
**Standard attention memory**:
|
||||
|
||||
| Seq Length | Attention Matrix | KV Cache | Total | Notes |
|
||||
|------------|------------------|----------|-------|-------|
|
||||
| 512 | 8 MB | 32 MB | 40 MB | Manageable |
|
||||
| 2048 | 128 MB | 128 MB | 256 MB | Growing |
|
||||
| 8192 | 2048 MB (2 GB) | 512 MB | 2.5 GB | Large |
|
||||
| 32768 | 32768 MB (32 GB) | 2048 MB | 34 GB | OOM on 24GB GPUs |
|
||||
|
||||
**Flash Attention 2 memory**:
|
||||
|
||||
| Seq Length | Attention (on-chip) | KV Cache | Total | Reduction |
|
||||
|------------|---------------------|----------|-------|-----------|
|
||||
| 512 | 0 MB (recomputed) | 32 MB | 32 MB | 20% |
|
||||
| 2048 | 0 MB | 128 MB | 128 MB | 50% |
|
||||
| 8192 | 0 MB | 512 MB | 512 MB | 80% |
|
||||
| 32768 | 0 MB | 2048 MB | 2 GB | 94% |
|
||||
|
||||
**Key insight**: Flash Attention doesn't materialize attention matrix, saving O(N²) memory.
|
||||
|
||||
### Memory scaling comparison
|
||||
|
||||
**Llama 2 7B model memory** (float16, batch=1):
|
||||
|
||||
| Context Length | Standard Attention | Flash Attention 2 | Can Fit 24GB GPU? |
|
||||
|----------------|-------------------|-------------------|-------------------|
|
||||
| 2K | 3.2 GB | 2.1 GB | Both: Yes |
|
||||
| 4K | 5.8 GB | 2.8 GB | Both: Yes |
|
||||
| 8K | 12.1 GB | 4.2 GB | Both: Yes |
|
||||
| 16K | 26.3 GB (OOM) | 7.8 GB | Only Flash: Yes |
|
||||
| 32K | OOM | 14.2 GB | Only Flash: Yes |
|
||||
|
||||
### Training memory (Llama 2 7B, batch=4)
|
||||
|
||||
| Context | Standard (GB) | Flash Attn (GB) | Reduction |
|
||||
|---------|---------------|-----------------|-----------|
|
||||
| 2K | 18.2 | 12.4 | 32% |
|
||||
| 4K | 34.8 | 16.8 | 52% |
|
||||
| 8K | OOM (>40GB) | 26.2 | Fits! |
|
||||
|
||||
## Scaling with sequence length
|
||||
|
||||
### Computational complexity
|
||||
|
||||
**Standard attention**:
|
||||
- Time: O(N² × d)
|
||||
- Memory: O(N² + N × d)
|
||||
|
||||
**Flash Attention**:
|
||||
- Time: O(N² × d) (same, but with better constants)
|
||||
- Memory: O(N × d) (linear!)
|
||||
|
||||
### Empirical scaling (A100, batch=1, heads=32, dim=64)
|
||||
|
||||
**Time per token (milliseconds)**:
|
||||
|
||||
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|
||||
|----------|-----|-----|-----|-----|-----|------|
|
||||
| Standard | 0.15 | 0.37 | 1.11 | 3.44 | 13.4 | 52.8 |
|
||||
| Flash Attn 2 | 0.11 | 0.14 | 0.24 | 0.43 | 0.83 | 1.64 |
|
||||
| Speedup | 1.4x | 2.6x | 4.6x | 8.0x | 16.1x | 32.2x |
|
||||
|
||||
**Observation**: Speedup increases quadratically with sequence length!
|
||||
|
||||
### Memory per token (MB)
|
||||
|
||||
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|
||||
|----------|-----|-----|-----|-----|-----|------|
|
||||
| Standard | 0.08 | 0.13 | 0.25 | 0.64 | 2.05 | 8.13 |
|
||||
| Flash Attn 2 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 |
|
||||
|
||||
**Observation**: Flash Attention memory per token is constant!
|
||||
|
||||
## Training vs inference performance
|
||||
|
||||
### Training (forward + backward, Llama 2 7B, A100)
|
||||
|
||||
| Batch × Seq | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|
||||
|-------------|------------------------|--------------------------|---------|
|
||||
| 4 × 2K | 1.2 | 3.1 | 2.6x |
|
||||
| 8 × 2K | 2.1 | 5.8 | 2.8x |
|
||||
| 4 × 4K | 0.4 | 1.3 | 3.3x |
|
||||
| 8 × 4K | OOM | 2.4 | Enabled |
|
||||
| 2 × 8K | 0.1 | 0.4 | 4.0x |
|
||||
|
||||
### Inference (generation, Llama 2 7B, A100)
|
||||
|
||||
| Context Length | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|
||||
|----------------|----------------------|-------------------------|---------|
|
||||
| 512 | 48 | 52 | 1.1x |
|
||||
| 2K | 42 | 62 | 1.5x |
|
||||
| 4K | 31 | 58 | 1.9x |
|
||||
| 8K | 18 | 51 | 2.8x |
|
||||
| 16K | OOM | 42 | Enabled |
|
||||
|
||||
**Note**: Inference speedup less dramatic than training because generation is memory-bound (KV cache accesses).
|
||||
|
||||
## Flash Attention versions comparison
|
||||
|
||||
### Flash Attention 1 vs 2 vs 3 (H100, seq=4096, batch=8)
|
||||
|
||||
| Metric | FA1 | FA2 | FA3 (FP16) | FA3 (FP8) |
|
||||
|--------|-----|-----|------------|-----------|
|
||||
| Forward time (ms) | 28.4 | 12.5 | 7.2 | 4.8 |
|
||||
| Memory (GB) | 4.8 | 4.2 | 4.2 | 2.8 |
|
||||
| TFLOPS | 180 | 420 | 740 | 1150 |
|
||||
| GPU util % | 35% | 55% | 75% | 82% |
|
||||
|
||||
**Key improvements**:
|
||||
- FA2: 2.3x faster than FA1 (better parallelism)
|
||||
- FA3 (FP16): 1.7x faster than FA2 (H100 async optimizations)
|
||||
- FA3 (FP8): 2.6x faster than FA2 (low precision)
|
||||
|
||||
### Features by version
|
||||
|
||||
| Feature | FA1 | FA2 | FA3 |
|
||||
|---------|-----|-----|-----|
|
||||
| Basic attention | ✅ | ✅ | ✅ |
|
||||
| Causal masking | ✅ | ✅ | ✅ |
|
||||
| Multi-query attention | ❌ | ✅ | ✅ |
|
||||
| Sliding window | ❌ | ✅ | ✅ |
|
||||
| Paged KV cache | ❌ | ✅ | ✅ |
|
||||
| FP8 support | ❌ | ❌ | ✅ (H100 only) |
|
||||
| Work partitioning | Basic | Advanced | Optimal |
|
||||
|
||||
## Real-world model benchmarks
|
||||
|
||||
### Llama 2 models (A100 80GB, batch=4, seq=2048)
|
||||
|
||||
| Model | Params | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|
||||
|-------|--------|------------------------|--------------------------|---------|
|
||||
| Llama 2 7B | 7B | 1.2 | 3.1 | 2.6x |
|
||||
| Llama 2 13B | 13B | 0.6 | 1.7 | 2.8x |
|
||||
| Llama 2 70B | 70B | 0.12 | 0.34 | 2.8x |
|
||||
|
||||
### GPT-style models (seq=1024)
|
||||
|
||||
| Model | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|
||||
|-------|----------------------|-------------------------|---------|
|
||||
| GPT-2 (124M) | 520 | 680 | 1.3x |
|
||||
| GPT-J (6B) | 42 | 98 | 2.3x |
|
||||
| GPT-NeoX (20B) | 8 | 22 | 2.75x |
|
||||
|
||||
## Recommendations by use case
|
||||
|
||||
**Training large models (>7B parameters)**:
|
||||
- Use Flash Attention 2 on A100
|
||||
- Use Flash Attention 3 FP8 on H100 for maximum speed
|
||||
- Expected: 2.5-3x speedup
|
||||
|
||||
**Long context inference (>4K tokens)**:
|
||||
- Flash Attention essential (enables contexts standard attention can't handle)
|
||||
- Expected: 2-4x speedup, 5-10x memory reduction
|
||||
|
||||
**Short sequences (<512 tokens)**:
|
||||
- Flash Attention provides 1.2-1.5x speedup
|
||||
- Minimal memory benefit
|
||||
- Still worth enabling (no downside)
|
||||
|
||||
**Multi-user serving**:
|
||||
- Flash Attention reduces per-request memory
|
||||
- Allows higher concurrent batch sizes
|
||||
- Can serve 2-3x more users on same hardware
|
||||
@@ -0,0 +1,293 @@
|
||||
# HuggingFace Transformers Integration
|
||||
|
||||
## Contents
|
||||
- Enabling Flash Attention in Transformers
|
||||
- Supported model architectures
|
||||
- Configuration examples
|
||||
- Performance comparisons
|
||||
- Troubleshooting model-specific issues
|
||||
|
||||
## Enabling Flash Attention in Transformers
|
||||
|
||||
HuggingFace Transformers (v4.36+) supports Flash Attention 2 natively.
|
||||
|
||||
**Simple enable for any supported model**:
|
||||
```python
|
||||
from transformers import AutoModel
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
**Install requirements**:
|
||||
```bash
|
||||
pip install transformers>=4.36
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
## Supported model architectures
|
||||
|
||||
As of Transformers 4.40:
|
||||
|
||||
**Fully supported**:
|
||||
- Llama / Llama 2 / Llama 3
|
||||
- Mistral / Mixtral
|
||||
- Falcon
|
||||
- GPT-NeoX
|
||||
- Phi / Phi-2 / Phi-3
|
||||
- Qwen / Qwen2
|
||||
- Gemma
|
||||
- Starcoder2
|
||||
- GPT-J
|
||||
- OPT
|
||||
- BLOOM
|
||||
|
||||
**Partially supported** (encoder-decoder):
|
||||
- BART
|
||||
- T5 / Flan-T5
|
||||
- Whisper
|
||||
|
||||
**Check support**:
|
||||
```python
|
||||
from transformers import AutoConfig
|
||||
|
||||
config = AutoConfig.from_pretrained("model-name")
|
||||
print(config._attn_implementation_internal)
|
||||
# 'flash_attention_2' if supported
|
||||
```
|
||||
|
||||
## Configuration examples
|
||||
|
||||
### Llama 2 with Flash Attention
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-2-7b-hf"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# Generate
|
||||
inputs = tokenizer("Once upon a time", return_tensors="pt").to("cuda")
|
||||
outputs = model.generate(**inputs, max_length=100)
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
```
|
||||
|
||||
### Mistral with Flash Attention for long context
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16, # Better for long context
|
||||
device_map="auto",
|
||||
max_position_embeddings=32768 # Extended context
|
||||
)
|
||||
|
||||
# Process long document (32K tokens)
|
||||
long_text = "..." * 10000
|
||||
inputs = tokenizer(long_text, return_tensors="pt", truncation=False).to("cuda")
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
```
|
||||
|
||||
### Fine-tuning with Flash Attention
|
||||
|
||||
```python
|
||||
from transformers import Trainer, TrainingArguments
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
num_train_epochs=3,
|
||||
fp16=True, # Must match model dtype
|
||||
optim="adamw_torch_fused" # Fast optimizer
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Multi-GPU training
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
|
||||
# Model parallelism with Flash Attention
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-13b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto", # Automatic multi-GPU placement
|
||||
max_memory={0: "20GB", 1: "20GB"} # Limit per GPU
|
||||
)
|
||||
```
|
||||
|
||||
## Performance comparisons
|
||||
|
||||
### Memory usage (Llama 2 7B, batch=1)
|
||||
|
||||
| Sequence Length | Standard Attention | Flash Attention 2 | Reduction |
|
||||
|-----------------|-------------------|-------------------|-----------|
|
||||
| 512 | 1.2 GB | 0.9 GB | 25% |
|
||||
| 2048 | 3.8 GB | 1.4 GB | 63% |
|
||||
| 8192 | 14.2 GB | 3.2 GB | 77% |
|
||||
| 32768 | OOM (>24GB) | 10.8 GB | Fits! |
|
||||
|
||||
### Speed (tokens/sec, A100 80GB)
|
||||
|
||||
| Model | Standard | Flash Attn 2 | Speedup |
|
||||
|-------|----------|--------------|---------|
|
||||
| Llama 2 7B (seq=2048) | 42 | 118 | 2.8x |
|
||||
| Llama 2 13B (seq=4096) | 18 | 52 | 2.9x |
|
||||
| Llama 2 70B (seq=2048) | 4 | 11 | 2.75x |
|
||||
|
||||
### Training throughput (samples/sec)
|
||||
|
||||
| Model | Batch Size | Standard | Flash Attn 2 | Speedup |
|
||||
|-------|------------|----------|--------------|---------|
|
||||
| Llama 2 7B | 4 | 1.2 | 3.1 | 2.6x |
|
||||
| Llama 2 7B | 8 | 2.1 | 5.8 | 2.8x |
|
||||
| Llama 2 13B | 2 | 0.6 | 1.7 | 2.8x |
|
||||
|
||||
## Troubleshooting model-specific issues
|
||||
|
||||
### Issue: Model doesn't support Flash Attention
|
||||
|
||||
Check support list above. If not supported, use PyTorch SDPA as fallback:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="sdpa", # PyTorch native (still faster)
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: CUDA out of memory during loading
|
||||
|
||||
Reduce memory footprint:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
max_memory={0: "18GB"}, # Reserve memory for KV cache
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Slower inference than expected
|
||||
|
||||
Ensure dtype matches:
|
||||
|
||||
```python
|
||||
# Model and inputs must both be float16/bfloat16
|
||||
model = model.to(torch.float16)
|
||||
inputs = tokenizer(..., return_tensors="pt").to("cuda")
|
||||
inputs = {k: v.to(torch.float16) if v.dtype == torch.float32 else v
|
||||
for k, v in inputs.items()}
|
||||
```
|
||||
|
||||
### Issue: Different outputs vs standard attention
|
||||
|
||||
Flash Attention is numerically equivalent but uses different computation order. Small differences (<1e-3) are normal:
|
||||
|
||||
```python
|
||||
# Compare outputs
|
||||
model_standard = AutoModelForCausalLM.from_pretrained("model-name", torch_dtype=torch.float16)
|
||||
model_flash = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
inputs = tokenizer("Test", return_tensors="pt").to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
out_standard = model_standard(**inputs).logits
|
||||
out_flash = model_flash(**inputs).logits
|
||||
|
||||
diff = (out_standard - out_flash).abs().max()
|
||||
print(f"Max diff: {diff:.6f}") # Should be ~1e-3 to 1e-4
|
||||
```
|
||||
|
||||
### Issue: ImportError during model loading
|
||||
|
||||
Install flash-attn:
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Or disable Flash Attention:
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="eager", # Standard PyTorch
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Always use float16/bfloat16** with Flash Attention (not float32)
|
||||
2. **Set device_map="auto"** for automatic memory management
|
||||
3. **Use bfloat16 for long context** (better numerical stability)
|
||||
4. **Enable gradient checkpointing** for training large models
|
||||
5. **Monitor memory** with `torch.cuda.max_memory_allocated()`
|
||||
|
||||
**Example with all best practices**:
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, TrainingArguments
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16, # Better for training
|
||||
device_map="auto",
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
# Enable gradient checkpointing for memory
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Training with optimizations
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
per_device_train_batch_size=8,
|
||||
gradient_accumulation_steps=2,
|
||||
bf16=True, # Match model dtype
|
||||
optim="adamw_torch_fused",
|
||||
gradient_checkpointing=True
|
||||
)
|
||||
```
|
||||
427
skills/mlops-knowledge/gguf/SKILL.md
Normal file
427
skills/mlops-knowledge/gguf/SKILL.md
Normal file
@@ -0,0 +1,427 @@
|
||||
---
|
||||
name: gguf-quantization
|
||||
description: GGUF format and llama.cpp quantization for efficient CPU/GPU inference. Use when deploying models on consumer hardware, Apple Silicon, or when needing flexible quantization from 2-8 bit without GPU requirements.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [GGUF, Quantization, llama.cpp, CPU Inference, Apple Silicon, Model Compression, Optimization]
|
||||
dependencies: [llama-cpp-python>=0.2.0]
|
||||
---
|
||||
|
||||
# GGUF - Quantization Format for llama.cpp
|
||||
|
||||
The GGUF (GPT-Generated Unified Format) is the standard file format for llama.cpp, enabling efficient inference on CPUs, Apple Silicon, and GPUs with flexible quantization options.
|
||||
|
||||
## When to use GGUF
|
||||
|
||||
**Use GGUF when:**
|
||||
- Deploying on consumer hardware (laptops, desktops)
|
||||
- Running on Apple Silicon (M1/M2/M3) with Metal acceleration
|
||||
- Need CPU inference without GPU requirements
|
||||
- Want flexible quantization (Q2_K to Q8_0)
|
||||
- Using local AI tools (LM Studio, Ollama, text-generation-webui)
|
||||
|
||||
**Key advantages:**
|
||||
- **Universal hardware**: CPU, Apple Silicon, NVIDIA, AMD support
|
||||
- **No Python runtime**: Pure C/C++ inference
|
||||
- **Flexible quantization**: 2-8 bit with various methods (K-quants)
|
||||
- **Ecosystem support**: LM Studio, Ollama, koboldcpp, and more
|
||||
- **imatrix**: Importance matrix for better low-bit quality
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **AWQ/GPTQ**: Maximum accuracy with calibration on NVIDIA GPUs
|
||||
- **HQQ**: Fast calibration-free quantization for HuggingFace
|
||||
- **bitsandbytes**: Simple integration with transformers library
|
||||
- **TensorRT-LLM**: Production NVIDIA deployment with maximum speed
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Clone llama.cpp
|
||||
git clone https://github.com/ggml-org/llama.cpp
|
||||
cd llama.cpp
|
||||
|
||||
# Build (CPU)
|
||||
make
|
||||
|
||||
# Build with CUDA (NVIDIA)
|
||||
make GGML_CUDA=1
|
||||
|
||||
# Build with Metal (Apple Silicon)
|
||||
make GGML_METAL=1
|
||||
|
||||
# Install Python bindings (optional)
|
||||
pip install llama-cpp-python
|
||||
```
|
||||
|
||||
### Convert model to GGUF
|
||||
|
||||
```bash
|
||||
# Install requirements
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Convert HuggingFace model to GGUF (FP16)
|
||||
python convert_hf_to_gguf.py ./path/to/model --outfile model-f16.gguf
|
||||
|
||||
# Or specify output type
|
||||
python convert_hf_to_gguf.py ./path/to/model \
|
||||
--outfile model-f16.gguf \
|
||||
--outtype f16
|
||||
```
|
||||
|
||||
### Quantize model
|
||||
|
||||
```bash
|
||||
# Basic quantization to Q4_K_M
|
||||
./llama-quantize model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
|
||||
# Quantize with importance matrix (better quality)
|
||||
./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix
|
||||
./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
### Run inference
|
||||
|
||||
```bash
|
||||
# CLI inference
|
||||
./llama-cli -m model-q4_k_m.gguf -p "Hello, how are you?"
|
||||
|
||||
# Interactive mode
|
||||
./llama-cli -m model-q4_k_m.gguf --interactive
|
||||
|
||||
# With GPU offload
|
||||
./llama-cli -m model-q4_k_m.gguf -ngl 35 -p "Hello!"
|
||||
```
|
||||
|
||||
## Quantization types
|
||||
|
||||
### K-quant methods (recommended)
|
||||
|
||||
| Type | Bits | Size (7B) | Quality | Use Case |
|
||||
|------|------|-----------|---------|----------|
|
||||
| Q2_K | 2.5 | ~2.8 GB | Low | Extreme compression |
|
||||
| Q3_K_S | 3.0 | ~3.0 GB | Low-Med | Memory constrained |
|
||||
| Q3_K_M | 3.3 | ~3.3 GB | Medium | Balance |
|
||||
| Q4_K_S | 4.0 | ~3.8 GB | Med-High | Good balance |
|
||||
| Q4_K_M | 4.5 | ~4.1 GB | High | **Recommended default** |
|
||||
| Q5_K_S | 5.0 | ~4.6 GB | High | Quality focused |
|
||||
| Q5_K_M | 5.5 | ~4.8 GB | Very High | High quality |
|
||||
| Q6_K | 6.0 | ~5.5 GB | Excellent | Near-original |
|
||||
| Q8_0 | 8.0 | ~7.2 GB | Best | Maximum quality |
|
||||
|
||||
### Legacy methods
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| Q4_0 | 4-bit, basic |
|
||||
| Q4_1 | 4-bit with delta |
|
||||
| Q5_0 | 5-bit, basic |
|
||||
| Q5_1 | 5-bit with delta |
|
||||
|
||||
**Recommendation**: Use K-quant methods (Q4_K_M, Q5_K_M) for best quality/size ratio.
|
||||
|
||||
## Conversion workflows
|
||||
|
||||
### Workflow 1: HuggingFace to GGUF
|
||||
|
||||
```bash
|
||||
# 1. Download model
|
||||
huggingface-cli download meta-llama/Llama-3.1-8B --local-dir ./llama-3.1-8b
|
||||
|
||||
# 2. Convert to GGUF (FP16)
|
||||
python convert_hf_to_gguf.py ./llama-3.1-8b \
|
||||
--outfile llama-3.1-8b-f16.gguf \
|
||||
--outtype f16
|
||||
|
||||
# 3. Quantize
|
||||
./llama-quantize llama-3.1-8b-f16.gguf llama-3.1-8b-q4_k_m.gguf Q4_K_M
|
||||
|
||||
# 4. Test
|
||||
./llama-cli -m llama-3.1-8b-q4_k_m.gguf -p "Hello!" -n 50
|
||||
```
|
||||
|
||||
### Workflow 2: With importance matrix (better quality)
|
||||
|
||||
```bash
|
||||
# 1. Convert to GGUF
|
||||
python convert_hf_to_gguf.py ./model --outfile model-f16.gguf
|
||||
|
||||
# 2. Create calibration text (diverse samples)
|
||||
cat > calibration.txt << 'EOF'
|
||||
The quick brown fox jumps over the lazy dog.
|
||||
Machine learning is a subset of artificial intelligence.
|
||||
Python is a popular programming language.
|
||||
# Add more diverse text samples...
|
||||
EOF
|
||||
|
||||
# 3. Generate importance matrix
|
||||
./llama-imatrix -m model-f16.gguf \
|
||||
-f calibration.txt \
|
||||
--chunk 512 \
|
||||
-o model.imatrix \
|
||||
-ngl 35 # GPU layers if available
|
||||
|
||||
# 4. Quantize with imatrix
|
||||
./llama-quantize --imatrix model.imatrix \
|
||||
model-f16.gguf \
|
||||
model-q4_k_m.gguf \
|
||||
Q4_K_M
|
||||
```
|
||||
|
||||
### Workflow 3: Multiple quantizations
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
MODEL="llama-3.1-8b-f16.gguf"
|
||||
IMATRIX="llama-3.1-8b.imatrix"
|
||||
|
||||
# Generate imatrix once
|
||||
./llama-imatrix -m $MODEL -f wiki.txt -o $IMATRIX -ngl 35
|
||||
|
||||
# Create multiple quantizations
|
||||
for QUANT in Q4_K_M Q5_K_M Q6_K Q8_0; do
|
||||
OUTPUT="llama-3.1-8b-${QUANT,,}.gguf"
|
||||
./llama-quantize --imatrix $IMATRIX $MODEL $OUTPUT $QUANT
|
||||
echo "Created: $OUTPUT ($(du -h $OUTPUT | cut -f1))"
|
||||
done
|
||||
```
|
||||
|
||||
## Python usage
|
||||
|
||||
### llama-cpp-python
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
# Load model
|
||||
llm = Llama(
|
||||
model_path="./model-q4_k_m.gguf",
|
||||
n_ctx=4096, # Context window
|
||||
n_gpu_layers=35, # GPU offload (0 for CPU only)
|
||||
n_threads=8 # CPU threads
|
||||
)
|
||||
|
||||
# Generate
|
||||
output = llm(
|
||||
"What is machine learning?",
|
||||
max_tokens=256,
|
||||
temperature=0.7,
|
||||
stop=["</s>", "\n\n"]
|
||||
)
|
||||
print(output["choices"][0]["text"])
|
||||
```
|
||||
|
||||
### Chat completion
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="./model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35,
|
||||
chat_format="llama-3" # Or "chatml", "mistral", etc.
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is Python?"}
|
||||
]
|
||||
|
||||
response = llm.create_chat_completion(
|
||||
messages=messages,
|
||||
max_tokens=256,
|
||||
temperature=0.7
|
||||
)
|
||||
print(response["choices"][0]["message"]["content"])
|
||||
```
|
||||
|
||||
### Streaming
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(model_path="./model-q4_k_m.gguf", n_gpu_layers=35)
|
||||
|
||||
# Stream tokens
|
||||
for chunk in llm(
|
||||
"Explain quantum computing:",
|
||||
max_tokens=256,
|
||||
stream=True
|
||||
):
|
||||
print(chunk["choices"][0]["text"], end="", flush=True)
|
||||
```
|
||||
|
||||
## Server mode
|
||||
|
||||
### Start OpenAI-compatible server
|
||||
|
||||
```bash
|
||||
# Start server
|
||||
./llama-server -m model-q4_k_m.gguf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
-ngl 35 \
|
||||
-c 4096
|
||||
|
||||
# Or with Python bindings
|
||||
python -m llama_cpp.server \
|
||||
--model model-q4_k_m.gguf \
|
||||
--n_gpu_layers 35 \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080
|
||||
```
|
||||
|
||||
### Use with OpenAI client
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1",
|
||||
api_key="not-needed"
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="local-model",
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
max_tokens=256
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
## Hardware optimization
|
||||
|
||||
### Apple Silicon (Metal)
|
||||
|
||||
```bash
|
||||
# Build with Metal
|
||||
make clean && make GGML_METAL=1
|
||||
|
||||
# Run with Metal acceleration
|
||||
./llama-cli -m model.gguf -ngl 99 -p "Hello"
|
||||
|
||||
# Python with Metal
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=99, # Offload all layers
|
||||
n_threads=1 # Metal handles parallelism
|
||||
)
|
||||
```
|
||||
|
||||
### NVIDIA CUDA
|
||||
|
||||
```bash
|
||||
# Build with CUDA
|
||||
make clean && make GGML_CUDA=1
|
||||
|
||||
# Run with CUDA
|
||||
./llama-cli -m model.gguf -ngl 35 -p "Hello"
|
||||
|
||||
# Specify GPU
|
||||
CUDA_VISIBLE_DEVICES=0 ./llama-cli -m model.gguf -ngl 35
|
||||
```
|
||||
|
||||
### CPU optimization
|
||||
|
||||
```bash
|
||||
# Build with AVX2/AVX512
|
||||
make clean && make
|
||||
|
||||
# Run with optimal threads
|
||||
./llama-cli -m model.gguf -t 8 -p "Hello"
|
||||
|
||||
# Python CPU config
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=0, # CPU only
|
||||
n_threads=8, # Match physical cores
|
||||
n_batch=512 # Batch size for prompt processing
|
||||
)
|
||||
```
|
||||
|
||||
## Integration with tools
|
||||
|
||||
### Ollama
|
||||
|
||||
```bash
|
||||
# Create Modelfile
|
||||
cat > Modelfile << 'EOF'
|
||||
FROM ./model-q4_k_m.gguf
|
||||
TEMPLATE """{{ .System }}
|
||||
{{ .Prompt }}"""
|
||||
PARAMETER temperature 0.7
|
||||
PARAMETER num_ctx 4096
|
||||
EOF
|
||||
|
||||
# Create Ollama model
|
||||
ollama create mymodel -f Modelfile
|
||||
|
||||
# Run
|
||||
ollama run mymodel "Hello!"
|
||||
```
|
||||
|
||||
### LM Studio
|
||||
|
||||
1. Place GGUF file in `~/.cache/lm-studio/models/`
|
||||
2. Open LM Studio and select the model
|
||||
3. Configure context length and GPU offload
|
||||
4. Start inference
|
||||
|
||||
### text-generation-webui
|
||||
|
||||
```bash
|
||||
# Place in models folder
|
||||
cp model-q4_k_m.gguf text-generation-webui/models/
|
||||
|
||||
# Start with llama.cpp loader
|
||||
python server.py --model model-q4_k_m.gguf --loader llama.cpp --n-gpu-layers 35
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use K-quants**: Q4_K_M offers best quality/size balance
|
||||
2. **Use imatrix**: Always use importance matrix for Q4 and below
|
||||
3. **GPU offload**: Offload as many layers as VRAM allows
|
||||
4. **Context length**: Start with 4096, increase if needed
|
||||
5. **Thread count**: Match physical CPU cores, not logical
|
||||
6. **Batch size**: Increase n_batch for faster prompt processing
|
||||
|
||||
## Common issues
|
||||
|
||||
**Model loads slowly:**
|
||||
```bash
|
||||
# Use mmap for faster loading
|
||||
./llama-cli -m model.gguf --mmap
|
||||
```
|
||||
|
||||
**Out of memory:**
|
||||
```bash
|
||||
# Reduce GPU layers
|
||||
./llama-cli -m model.gguf -ngl 20 # Reduce from 35
|
||||
|
||||
# Or use smaller quantization
|
||||
./llama-quantize model-f16.gguf model-q3_k_m.gguf Q3_K_M
|
||||
```
|
||||
|
||||
**Poor quality at low bits:**
|
||||
```bash
|
||||
# Always use imatrix for Q4 and below
|
||||
./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix
|
||||
./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Batching, speculative decoding, custom builds
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues, debugging, benchmarks
|
||||
|
||||
## Resources
|
||||
|
||||
- **Repository**: https://github.com/ggml-org/llama.cpp
|
||||
- **Python Bindings**: https://github.com/abetlen/llama-cpp-python
|
||||
- **Pre-quantized Models**: https://huggingface.co/TheBloke
|
||||
- **GGUF Converter**: https://huggingface.co/spaces/ggml-org/gguf-my-repo
|
||||
- **License**: MIT
|
||||
504
skills/mlops-knowledge/gguf/references/advanced-usage.md
Normal file
504
skills/mlops-knowledge/gguf/references/advanced-usage.md
Normal file
@@ -0,0 +1,504 @@
|
||||
# GGUF Advanced Usage Guide
|
||||
|
||||
## Speculative Decoding
|
||||
|
||||
### Draft Model Approach
|
||||
|
||||
```bash
|
||||
# Use smaller model as draft for faster generation
|
||||
./llama-speculative \
|
||||
-m large-model-q4_k_m.gguf \
|
||||
-md draft-model-q4_k_m.gguf \
|
||||
-p "Write a story about AI" \
|
||||
-n 500 \
|
||||
--draft 8 # Draft tokens before verification
|
||||
```
|
||||
|
||||
### Self-Speculative Decoding
|
||||
|
||||
```bash
|
||||
# Use same model with different context for speculation
|
||||
./llama-cli -m model-q4_k_m.gguf \
|
||||
--lookup-cache-static lookup.bin \
|
||||
--lookup-cache-dynamic lookup-dynamic.bin \
|
||||
-p "Hello world"
|
||||
```
|
||||
|
||||
## Batched Inference
|
||||
|
||||
### Process Multiple Prompts
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35,
|
||||
n_batch=512 # Larger batch for parallel processing
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"What is Python?",
|
||||
"Explain machine learning.",
|
||||
"Describe neural networks."
|
||||
]
|
||||
|
||||
# Process in batch (each prompt gets separate context)
|
||||
for prompt in prompts:
|
||||
output = llm(prompt, max_tokens=100)
|
||||
print(f"Q: {prompt}")
|
||||
print(f"A: {output['choices'][0]['text']}\n")
|
||||
```
|
||||
|
||||
### Server Batching
|
||||
|
||||
```bash
|
||||
# Start server with batching
|
||||
./llama-server -m model-q4_k_m.gguf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
-ngl 35 \
|
||||
-c 4096 \
|
||||
--parallel 4 # Concurrent requests
|
||||
--cont-batching # Continuous batching
|
||||
```
|
||||
|
||||
## Custom Model Conversion
|
||||
|
||||
### Convert with Vocabulary Modifications
|
||||
|
||||
```python
|
||||
# custom_convert.py
|
||||
import sys
|
||||
sys.path.insert(0, './llama.cpp')
|
||||
|
||||
from convert_hf_to_gguf import main
|
||||
from gguf import GGUFWriter
|
||||
|
||||
# Custom conversion with modified vocab
|
||||
def convert_with_custom_vocab(model_path, output_path):
|
||||
# Load and modify tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
# Add special tokens if needed
|
||||
special_tokens = {"additional_special_tokens": ["<|custom|>"]}
|
||||
tokenizer.add_special_tokens(special_tokens)
|
||||
tokenizer.save_pretrained(model_path)
|
||||
|
||||
# Then run standard conversion
|
||||
main([model_path, "--outfile", output_path])
|
||||
```
|
||||
|
||||
### Convert Specific Architecture
|
||||
|
||||
```bash
|
||||
# For Mistral-style models
|
||||
python convert_hf_to_gguf.py ./mistral-model \
|
||||
--outfile mistral-f16.gguf \
|
||||
--outtype f16
|
||||
|
||||
# For Qwen models
|
||||
python convert_hf_to_gguf.py ./qwen-model \
|
||||
--outfile qwen-f16.gguf \
|
||||
--outtype f16
|
||||
|
||||
# For Phi models
|
||||
python convert_hf_to_gguf.py ./phi-model \
|
||||
--outfile phi-f16.gguf \
|
||||
--outtype f16
|
||||
```
|
||||
|
||||
## Advanced Quantization
|
||||
|
||||
### Mixed Quantization
|
||||
|
||||
```bash
|
||||
# Quantize different layer types differently
|
||||
./llama-quantize model-f16.gguf model-mixed.gguf Q4_K_M \
|
||||
--allow-requantize \
|
||||
--leave-output-tensor
|
||||
```
|
||||
|
||||
### Quantization with Token Embeddings
|
||||
|
||||
```bash
|
||||
# Keep embeddings at higher precision
|
||||
./llama-quantize model-f16.gguf model-q4.gguf Q4_K_M \
|
||||
--token-embedding-type f16
|
||||
```
|
||||
|
||||
### IQ Quantization (Importance-aware)
|
||||
|
||||
```bash
|
||||
# Ultra-low bit quantization with importance
|
||||
./llama-quantize --imatrix model.imatrix \
|
||||
model-f16.gguf model-iq2_xxs.gguf IQ2_XXS
|
||||
|
||||
# Available IQ types: IQ2_XXS, IQ2_XS, IQ2_S, IQ3_XXS, IQ3_XS, IQ3_S, IQ4_XS
|
||||
```
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
### Memory Mapping
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
# Use memory mapping for large models
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
use_mmap=True, # Memory map the model
|
||||
use_mlock=False, # Don't lock in RAM
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
### Partial GPU Offload
|
||||
|
||||
```python
|
||||
# Calculate layers to offload based on VRAM
|
||||
import subprocess
|
||||
|
||||
def get_free_vram_gb():
|
||||
result = subprocess.run(
|
||||
['nvidia-smi', '--query-gpu=memory.free', '--format=csv,nounits,noheader'],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
return int(result.stdout.strip()) / 1024
|
||||
|
||||
# Estimate layers based on VRAM (rough: 0.5GB per layer for 7B Q4)
|
||||
free_vram = get_free_vram_gb()
|
||||
layers_to_offload = int(free_vram / 0.5)
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_gpu_layers=min(layers_to_offload, 35) # Cap at total layers
|
||||
)
|
||||
```
|
||||
|
||||
### KV Cache Optimization
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
# Optimize KV cache for long contexts
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_ctx=8192, # Large context
|
||||
n_gpu_layers=35,
|
||||
type_k=1, # Q8_0 for K cache (1)
|
||||
type_v=1, # Q8_0 for V cache (1)
|
||||
# Or use Q4_0 (2) for more compression
|
||||
)
|
||||
```
|
||||
|
||||
## Context Management
|
||||
|
||||
### Context Shifting
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35
|
||||
)
|
||||
|
||||
# Handle long conversations with context shifting
|
||||
conversation = []
|
||||
max_history = 10
|
||||
|
||||
def chat(user_message):
|
||||
conversation.append({"role": "user", "content": user_message})
|
||||
|
||||
# Keep only recent history
|
||||
if len(conversation) > max_history * 2:
|
||||
conversation = conversation[-max_history * 2:]
|
||||
|
||||
response = llm.create_chat_completion(
|
||||
messages=conversation,
|
||||
max_tokens=256
|
||||
)
|
||||
|
||||
assistant_message = response["choices"][0]["message"]["content"]
|
||||
conversation.append({"role": "assistant", "content": assistant_message})
|
||||
return assistant_message
|
||||
```
|
||||
|
||||
### Save and Load State
|
||||
|
||||
```bash
|
||||
# Save state to file
|
||||
./llama-cli -m model.gguf \
|
||||
-p "Once upon a time" \
|
||||
--save-session session.bin \
|
||||
-n 100
|
||||
|
||||
# Load and continue
|
||||
./llama-cli -m model.gguf \
|
||||
--load-session session.bin \
|
||||
-p " and they lived" \
|
||||
-n 100
|
||||
```
|
||||
|
||||
## Grammar Constrained Generation
|
||||
|
||||
### JSON Output
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama, LlamaGrammar
|
||||
|
||||
# Define JSON grammar
|
||||
json_grammar = LlamaGrammar.from_string('''
|
||||
root ::= object
|
||||
object ::= "{" ws pair ("," ws pair)* "}" ws
|
||||
pair ::= string ":" ws value
|
||||
value ::= string | number | object | array | "true" | "false" | "null"
|
||||
array ::= "[" ws value ("," ws value)* "]" ws
|
||||
string ::= "\\"" [^"\\\\]* "\\""
|
||||
number ::= [0-9]+
|
||||
ws ::= [ \\t\\n]*
|
||||
''')
|
||||
|
||||
llm = Llama(model_path="model-q4_k_m.gguf", n_gpu_layers=35)
|
||||
|
||||
output = llm(
|
||||
"Output a JSON object with name and age:",
|
||||
grammar=json_grammar,
|
||||
max_tokens=100
|
||||
)
|
||||
print(output["choices"][0]["text"])
|
||||
```
|
||||
|
||||
### Custom Grammar
|
||||
|
||||
```python
|
||||
# Grammar for specific format
|
||||
answer_grammar = LlamaGrammar.from_string('''
|
||||
root ::= "Answer: " letter "\\n" "Explanation: " explanation
|
||||
letter ::= [A-D]
|
||||
explanation ::= [a-zA-Z0-9 .,!?]+
|
||||
''')
|
||||
|
||||
output = llm(
|
||||
"Q: What is 2+2? A) 3 B) 4 C) 5 D) 6",
|
||||
grammar=answer_grammar,
|
||||
max_tokens=100
|
||||
)
|
||||
```
|
||||
|
||||
## LoRA Integration
|
||||
|
||||
### Load LoRA Adapter
|
||||
|
||||
```bash
|
||||
# Apply LoRA at runtime
|
||||
./llama-cli -m base-model-q4_k_m.gguf \
|
||||
--lora lora-adapter.gguf \
|
||||
--lora-scale 1.0 \
|
||||
-p "Hello!"
|
||||
```
|
||||
|
||||
### Multiple LoRA Adapters
|
||||
|
||||
```bash
|
||||
# Stack multiple adapters
|
||||
./llama-cli -m base-model.gguf \
|
||||
--lora adapter1.gguf --lora-scale 0.5 \
|
||||
--lora adapter2.gguf --lora-scale 0.5 \
|
||||
-p "Hello!"
|
||||
```
|
||||
|
||||
### Python LoRA Usage
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="base-model-q4_k_m.gguf",
|
||||
lora_path="lora-adapter.gguf",
|
||||
lora_scale=1.0,
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
## Embedding Generation
|
||||
|
||||
### Extract Embeddings
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
embedding=True, # Enable embedding mode
|
||||
n_gpu_layers=35
|
||||
)
|
||||
|
||||
# Get embeddings
|
||||
embeddings = llm.embed("This is a test sentence.")
|
||||
print(f"Embedding dimension: {len(embeddings)}")
|
||||
```
|
||||
|
||||
### Batch Embeddings
|
||||
|
||||
```python
|
||||
texts = [
|
||||
"Machine learning is fascinating.",
|
||||
"Deep learning uses neural networks.",
|
||||
"Python is a programming language."
|
||||
]
|
||||
|
||||
embeddings = [llm.embed(text) for text in texts]
|
||||
|
||||
# Calculate similarity
|
||||
import numpy as np
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||
|
||||
sim = cosine_similarity(embeddings[0], embeddings[1])
|
||||
print(f"Similarity: {sim:.4f}")
|
||||
```
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Benchmark Script
|
||||
|
||||
```python
|
||||
import time
|
||||
from llama_cpp import Llama
|
||||
|
||||
def benchmark(model_path, prompt, n_tokens=100, n_runs=5):
|
||||
llm = Llama(
|
||||
model_path=model_path,
|
||||
n_gpu_layers=35,
|
||||
n_ctx=2048,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Warmup
|
||||
llm(prompt, max_tokens=10)
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(n_runs):
|
||||
start = time.time()
|
||||
output = llm(prompt, max_tokens=n_tokens)
|
||||
elapsed = time.time() - start
|
||||
times.append(elapsed)
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
tokens_per_sec = n_tokens / avg_time
|
||||
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Avg time: {avg_time:.2f}s")
|
||||
print(f"Tokens/sec: {tokens_per_sec:.1f}")
|
||||
|
||||
return tokens_per_sec
|
||||
|
||||
# Compare quantizations
|
||||
for quant in ["q4_k_m", "q5_k_m", "q8_0"]:
|
||||
benchmark(f"model-{quant}.gguf", "Explain quantum computing:", 100)
|
||||
```
|
||||
|
||||
### Optimal Configuration Finder
|
||||
|
||||
```python
|
||||
def find_optimal_config(model_path, target_vram_gb=8):
|
||||
"""Find optimal n_gpu_layers and n_batch for target VRAM."""
|
||||
from llama_cpp import Llama
|
||||
import gc
|
||||
|
||||
best_config = None
|
||||
best_speed = 0
|
||||
|
||||
for n_gpu_layers in range(0, 50, 5):
|
||||
for n_batch in [128, 256, 512, 1024]:
|
||||
try:
|
||||
gc.collect()
|
||||
llm = Llama(
|
||||
model_path=model_path,
|
||||
n_gpu_layers=n_gpu_layers,
|
||||
n_batch=n_batch,
|
||||
n_ctx=2048,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Quick benchmark
|
||||
start = time.time()
|
||||
llm("Hello", max_tokens=50)
|
||||
speed = 50 / (time.time() - start)
|
||||
|
||||
if speed > best_speed:
|
||||
best_speed = speed
|
||||
best_config = {
|
||||
"n_gpu_layers": n_gpu_layers,
|
||||
"n_batch": n_batch,
|
||||
"speed": speed
|
||||
}
|
||||
|
||||
del llm
|
||||
gc.collect()
|
||||
|
||||
except Exception as e:
|
||||
print(f"OOM at layers={n_gpu_layers}, batch={n_batch}")
|
||||
break
|
||||
|
||||
return best_config
|
||||
```
|
||||
|
||||
## Multi-GPU Setup
|
||||
|
||||
### Distribute Across GPUs
|
||||
|
||||
```bash
|
||||
# Split model across multiple GPUs
|
||||
./llama-cli -m large-model.gguf \
|
||||
--tensor-split 0.5,0.5 \
|
||||
-ngl 60 \
|
||||
-p "Hello!"
|
||||
```
|
||||
|
||||
### Python Multi-GPU
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="large-model-q4_k_m.gguf",
|
||||
n_gpu_layers=60,
|
||||
tensor_split=[0.5, 0.5] # Split evenly across 2 GPUs
|
||||
)
|
||||
```
|
||||
|
||||
## Custom Builds
|
||||
|
||||
### Build with All Optimizations
|
||||
|
||||
```bash
|
||||
# Clean build with all CPU optimizations
|
||||
make clean
|
||||
LLAMA_OPENBLAS=1 LLAMA_BLAS_VENDOR=OpenBLAS make -j
|
||||
|
||||
# With CUDA and cuBLAS
|
||||
make clean
|
||||
GGML_CUDA=1 LLAMA_CUBLAS=1 make -j
|
||||
|
||||
# With specific CUDA architecture
|
||||
GGML_CUDA=1 CUDA_DOCKER_ARCH=sm_86 make -j
|
||||
```
|
||||
|
||||
### CMake Build
|
||||
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
cmake .. -DGGML_CUDA=ON -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build . --config Release -j
|
||||
```
|
||||
442
skills/mlops-knowledge/gguf/references/troubleshooting.md
Normal file
442
skills/mlops-knowledge/gguf/references/troubleshooting.md
Normal file
@@ -0,0 +1,442 @@
|
||||
# GGUF Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### Build Fails
|
||||
|
||||
**Error**: `make: *** No targets specified and no makefile found`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Ensure you're in llama.cpp directory
|
||||
cd llama.cpp
|
||||
make
|
||||
```
|
||||
|
||||
**Error**: `fatal error: cuda_runtime.h: No such file or directory`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Install CUDA toolkit
|
||||
# Ubuntu
|
||||
sudo apt install nvidia-cuda-toolkit
|
||||
|
||||
# Or set CUDA path
|
||||
export CUDA_PATH=/usr/local/cuda
|
||||
export PATH=$CUDA_PATH/bin:$PATH
|
||||
make GGML_CUDA=1
|
||||
```
|
||||
|
||||
### Python Bindings Issues
|
||||
|
||||
**Error**: `ERROR: Failed building wheel for llama-cpp-python`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Install build dependencies
|
||||
pip install cmake scikit-build-core
|
||||
|
||||
# For CUDA support
|
||||
CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
|
||||
|
||||
# For Metal (macOS)
|
||||
CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
|
||||
```
|
||||
|
||||
**Error**: `ImportError: libcudart.so.XX: cannot open shared object file`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Add CUDA libraries to path
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
# Or reinstall with correct CUDA version
|
||||
pip uninstall llama-cpp-python
|
||||
CUDACXX=/usr/local/cuda/bin/nvcc CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python
|
||||
```
|
||||
|
||||
## Conversion Issues
|
||||
|
||||
### Model Not Supported
|
||||
|
||||
**Error**: `KeyError: 'model.embed_tokens.weight'`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Check model architecture
|
||||
python -c "from transformers import AutoConfig; print(AutoConfig.from_pretrained('./model').architectures)"
|
||||
|
||||
# Use appropriate conversion script
|
||||
# For most models:
|
||||
python convert_hf_to_gguf.py ./model --outfile model.gguf
|
||||
|
||||
# For older models, check if legacy script needed
|
||||
```
|
||||
|
||||
### Vocabulary Mismatch
|
||||
|
||||
**Error**: `RuntimeError: Vocabulary size mismatch`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Ensure tokenizer matches model
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("./model")
|
||||
model = AutoModelForCausalLM.from_pretrained("./model")
|
||||
|
||||
print(f"Tokenizer vocab size: {len(tokenizer)}")
|
||||
print(f"Model vocab size: {model.config.vocab_size}")
|
||||
|
||||
# If mismatch, resize embeddings before conversion
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
model.save_pretrained("./model-fixed")
|
||||
```
|
||||
|
||||
### Out of Memory During Conversion
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError` during conversion
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Use CPU for conversion
|
||||
CUDA_VISIBLE_DEVICES="" python convert_hf_to_gguf.py ./model --outfile model.gguf
|
||||
|
||||
# Or use low memory mode
|
||||
python convert_hf_to_gguf.py ./model --outfile model.gguf --outtype f16
|
||||
```
|
||||
|
||||
## Quantization Issues
|
||||
|
||||
### Wrong Output File Size
|
||||
|
||||
**Problem**: Quantized file is larger than expected
|
||||
|
||||
**Check**:
|
||||
```bash
|
||||
# Verify quantization type
|
||||
./llama-cli -m model.gguf --verbose
|
||||
|
||||
# Expected sizes for 7B model:
|
||||
# Q4_K_M: ~4.1 GB
|
||||
# Q5_K_M: ~4.8 GB
|
||||
# Q8_0: ~7.2 GB
|
||||
# F16: ~13.5 GB
|
||||
```
|
||||
|
||||
### Quantization Crashes
|
||||
|
||||
**Error**: `Segmentation fault` during quantization
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Increase stack size
|
||||
ulimit -s unlimited
|
||||
|
||||
# Or use less threads
|
||||
./llama-quantize -t 4 model-f16.gguf model-q4.gguf Q4_K_M
|
||||
```
|
||||
|
||||
### Poor Quality After Quantization
|
||||
|
||||
**Problem**: Model outputs gibberish after quantization
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Use importance matrix**:
|
||||
```bash
|
||||
# Generate imatrix with good calibration data
|
||||
./llama-imatrix -m model-f16.gguf \
|
||||
-f wiki_sample.txt \
|
||||
--chunk 512 \
|
||||
-o model.imatrix
|
||||
|
||||
# Quantize with imatrix
|
||||
./llama-quantize --imatrix model.imatrix \
|
||||
model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
2. **Try higher precision**:
|
||||
```bash
|
||||
# Use Q5_K_M or Q6_K instead of Q4
|
||||
./llama-quantize model-f16.gguf model-q5_k_m.gguf Q5_K_M
|
||||
```
|
||||
|
||||
3. **Check original model**:
|
||||
```bash
|
||||
# Test FP16 version first
|
||||
./llama-cli -m model-f16.gguf -p "Hello, how are you?" -n 50
|
||||
```
|
||||
|
||||
## Inference Issues
|
||||
|
||||
### Slow Generation
|
||||
|
||||
**Problem**: Generation is slower than expected
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable GPU offload**:
|
||||
```bash
|
||||
./llama-cli -m model.gguf -ngl 35 -p "Hello"
|
||||
```
|
||||
|
||||
2. **Optimize batch size**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_batch=512, # Increase for faster prompt processing
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
3. **Use appropriate threads**:
|
||||
```bash
|
||||
# Match physical cores, not logical
|
||||
./llama-cli -m model.gguf -t 8 -p "Hello"
|
||||
```
|
||||
|
||||
4. **Enable Flash Attention** (if supported):
|
||||
```bash
|
||||
./llama-cli -m model.gguf -ngl 35 --flash-attn -p "Hello"
|
||||
```
|
||||
|
||||
### Out of Memory
|
||||
|
||||
**Error**: `CUDA out of memory` or system freeze
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Reduce GPU layers**:
|
||||
```python
|
||||
# Start low and increase
|
||||
llm = Llama(model_path="model.gguf", n_gpu_layers=10)
|
||||
```
|
||||
|
||||
2. **Use smaller quantization**:
|
||||
```bash
|
||||
./llama-quantize model-f16.gguf model-q3_k_m.gguf Q3_K_M
|
||||
```
|
||||
|
||||
3. **Reduce context length**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_ctx=2048, # Reduce from 4096
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
4. **Quantize KV cache**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
type_k=2, # Q4_0 for K cache
|
||||
type_v=2, # Q4_0 for V cache
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
### Garbage Output
|
||||
|
||||
**Problem**: Model outputs random characters or nonsense
|
||||
|
||||
**Diagnose**:
|
||||
```python
|
||||
# Check model loading
|
||||
llm = Llama(model_path="model.gguf", verbose=True)
|
||||
|
||||
# Test with simple prompt
|
||||
output = llm("1+1=", max_tokens=5, temperature=0)
|
||||
print(output)
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Check model integrity**:
|
||||
```bash
|
||||
# Verify GGUF file
|
||||
./llama-cli -m model.gguf --verbose 2>&1 | head -50
|
||||
```
|
||||
|
||||
2. **Use correct chat format**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
chat_format="llama-3" # Match your model: chatml, mistral, etc.
|
||||
)
|
||||
```
|
||||
|
||||
3. **Check temperature**:
|
||||
```python
|
||||
# Use lower temperature for deterministic output
|
||||
output = llm("Hello", max_tokens=50, temperature=0.1)
|
||||
```
|
||||
|
||||
### Token Issues
|
||||
|
||||
**Error**: `RuntimeError: unknown token` or encoding errors
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Ensure UTF-8 encoding
|
||||
prompt = "Hello, world!".encode('utf-8').decode('utf-8')
|
||||
output = llm(prompt, max_tokens=50)
|
||||
```
|
||||
|
||||
## Server Issues
|
||||
|
||||
### Connection Refused
|
||||
|
||||
**Error**: `Connection refused` when accessing server
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Bind to all interfaces
|
||||
./llama-server -m model.gguf --host 0.0.0.0 --port 8080
|
||||
|
||||
# Check if port is in use
|
||||
lsof -i :8080
|
||||
```
|
||||
|
||||
### Server Crashes Under Load
|
||||
|
||||
**Problem**: Server crashes with multiple concurrent requests
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Limit parallelism**:
|
||||
```bash
|
||||
./llama-server -m model.gguf \
|
||||
--parallel 2 \
|
||||
-c 4096 \
|
||||
--cont-batching
|
||||
```
|
||||
|
||||
2. **Add request timeout**:
|
||||
```bash
|
||||
./llama-server -m model.gguf --timeout 300
|
||||
```
|
||||
|
||||
3. **Monitor memory**:
|
||||
```bash
|
||||
watch -n 1 nvidia-smi # For GPU
|
||||
watch -n 1 free -h # For RAM
|
||||
```
|
||||
|
||||
### API Compatibility Issues
|
||||
|
||||
**Problem**: OpenAI client not working with server
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# Use correct base URL format
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1", # Include /v1
|
||||
api_key="not-needed"
|
||||
)
|
||||
|
||||
# Use correct model name
|
||||
response = client.chat.completions.create(
|
||||
model="local", # Or the actual model name
|
||||
messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
```
|
||||
|
||||
## Apple Silicon Issues
|
||||
|
||||
### Metal Not Working
|
||||
|
||||
**Problem**: Metal acceleration not enabled
|
||||
|
||||
**Check**:
|
||||
```bash
|
||||
# Verify Metal support
|
||||
./llama-cli -m model.gguf --verbose 2>&1 | grep -i metal
|
||||
```
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Rebuild with Metal
|
||||
make clean
|
||||
make GGML_METAL=1
|
||||
|
||||
# Python bindings
|
||||
CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall
|
||||
```
|
||||
|
||||
### Incorrect Memory Usage on M1/M2
|
||||
|
||||
**Problem**: Model uses too much unified memory
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Offload all layers for Metal
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=99, # Offload everything
|
||||
n_threads=1 # Metal handles parallelism
|
||||
)
|
||||
```
|
||||
|
||||
## Debugging
|
||||
|
||||
### Enable Verbose Output
|
||||
|
||||
```bash
|
||||
# CLI verbose mode
|
||||
./llama-cli -m model.gguf --verbose -p "Hello" -n 50
|
||||
|
||||
# Python verbose
|
||||
llm = Llama(model_path="model.gguf", verbose=True)
|
||||
```
|
||||
|
||||
### Check Model Metadata
|
||||
|
||||
```bash
|
||||
# View GGUF metadata
|
||||
./llama-cli -m model.gguf --verbose 2>&1 | head -100
|
||||
```
|
||||
|
||||
### Validate GGUF File
|
||||
|
||||
```python
|
||||
import struct
|
||||
|
||||
def validate_gguf(filepath):
|
||||
with open(filepath, 'rb') as f:
|
||||
magic = f.read(4)
|
||||
if magic != b'GGUF':
|
||||
print(f"Invalid magic: {magic}")
|
||||
return False
|
||||
|
||||
version = struct.unpack('<I', f.read(4))[0]
|
||||
print(f"GGUF version: {version}")
|
||||
|
||||
tensor_count = struct.unpack('<Q', f.read(8))[0]
|
||||
metadata_count = struct.unpack('<Q', f.read(8))[0]
|
||||
print(f"Tensors: {tensor_count}, Metadata: {metadata_count}")
|
||||
|
||||
return True
|
||||
|
||||
validate_gguf("model.gguf")
|
||||
```
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **GitHub Issues**: https://github.com/ggml-org/llama.cpp/issues
|
||||
2. **Discussions**: https://github.com/ggml-org/llama.cpp/discussions
|
||||
3. **Reddit**: r/LocalLLaMA
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- llama.cpp version/commit hash
|
||||
- Build command used
|
||||
- Model name and quantization
|
||||
- Full error message/stack trace
|
||||
- Hardware: CPU/GPU model, RAM, VRAM
|
||||
- OS version
|
||||
- Minimal reproduction steps
|
||||
97
skills/mlops-knowledge/grpo-rl-training/README.md
Normal file
97
skills/mlops-knowledge/grpo-rl-training/README.md
Normal file
@@ -0,0 +1,97 @@
|
||||
# GRPO/RL Training Skill
|
||||
|
||||
**Expert-level guidance for Group Relative Policy Optimization with TRL**
|
||||
|
||||
## 📁 Skill Structure
|
||||
|
||||
```
|
||||
grpo-rl-training/
|
||||
├── SKILL.md # Main skill documentation (READ THIS FIRST)
|
||||
├── README.md # This file
|
||||
├── templates/
|
||||
│ └── basic_grpo_training.py # Production-ready training template
|
||||
└── examples/
|
||||
└── reward_functions_library.py # 20+ reward function examples
|
||||
```
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
1. **Read SKILL.md** - Comprehensive guide with all concepts and patterns
|
||||
2. **Copy `templates/basic_grpo_training.py`** - Start with working code
|
||||
3. **Browse `examples/reward_functions_library.py`** - Pick reward functions for your task
|
||||
4. **Modify for your use case** - Adapt dataset, rewards, and config
|
||||
|
||||
## 💡 What's Inside
|
||||
|
||||
### SKILL.md (Main Documentation)
|
||||
- Core GRPO concepts and algorithm fundamentals
|
||||
- Complete implementation workflow (dataset → rewards → training → deployment)
|
||||
- 10+ reward function examples with code
|
||||
- Hyperparameter tuning guide
|
||||
- Training insights (loss behavior, metrics, debugging)
|
||||
- Troubleshooting guide
|
||||
- Production best practices
|
||||
|
||||
### Templates
|
||||
- **basic_grpo_training.py**: Minimal, production-ready training script
|
||||
- Uses Qwen 2.5 1.5B Instruct
|
||||
- 3 reward functions (format + correctness)
|
||||
- LoRA for efficient training
|
||||
- Fully documented and ready to run
|
||||
|
||||
### Examples
|
||||
- **reward_functions_library.py**: 20+ battle-tested reward functions
|
||||
- Correctness rewards (exact match, fuzzy match, numeric, code execution)
|
||||
- Format rewards (XML, JSON, strict/soft)
|
||||
- Length rewards (ideal length, min/max)
|
||||
- Style rewards (reasoning quality, citations, repetition penalty)
|
||||
- Combined rewards (multi-objective optimization)
|
||||
- Preset collections for common tasks
|
||||
|
||||
## 📖 Usage for Agents
|
||||
|
||||
When this skill is loaded in your agent's context:
|
||||
|
||||
1. **Always read SKILL.md first** before implementing
|
||||
2. **Start simple** - Use length-based reward to validate setup
|
||||
3. **Build incrementally** - Add one reward function at a time
|
||||
4. **Reference examples** - Copy patterns from reward_functions_library.py
|
||||
5. **Monitor training** - Watch reward metrics (not loss!)
|
||||
|
||||
## 🎯 Common Use Cases
|
||||
|
||||
| Task Type | Recommended Rewards | Template |
|
||||
|-----------|---------------------|----------|
|
||||
| Math reasoning | `MATH_REASONING_REWARDS` preset | basic_grpo_training.py |
|
||||
| Code generation | `CODE_GENERATION_REWARDS` preset | Modify dataset in template |
|
||||
| Summarization | `SUMMARIZATION_REWARDS` preset | Adjust prompts + rewards |
|
||||
| Q&A | `QA_REWARDS` preset | Use fuzzy match + citations |
|
||||
|
||||
## ⚠️ Critical Reminders
|
||||
|
||||
- **Loss goes UP during training** - This is normal (it's KL divergence)
|
||||
- **Use 3-5 reward functions** - Single rewards often fail
|
||||
- **Test rewards before training** - Debug each function independently
|
||||
- **Monitor reward_std** - Should stay > 0.1 (avoid mode collapse)
|
||||
- **Start with num_generations=4-8** - Scale up if GPU allows
|
||||
|
||||
## 🔗 External Resources
|
||||
|
||||
- [TRL Documentation](https://huggingface.co/docs/trl)
|
||||
- [DeepSeek R1 Paper](https://arxiv.org/abs/2501.12948)
|
||||
- [Open R1 Implementation](https://github.com/huggingface/open-r1)
|
||||
- [Unsloth (2-3x faster)](https://docs.unsloth.ai/)
|
||||
|
||||
## 📝 Version
|
||||
|
||||
**v1.0.0** - Initial release (January 2025)
|
||||
|
||||
## 👨💻 Maintained By
|
||||
|
||||
Orchestra Research
|
||||
For questions or improvements, see https://orchestra.com
|
||||
|
||||
---
|
||||
|
||||
**License:** MIT
|
||||
**Last Updated:** January 2025
|
||||
572
skills/mlops-knowledge/grpo-rl-training/SKILL.md
Normal file
572
skills/mlops-knowledge/grpo-rl-training/SKILL.md
Normal file
@@ -0,0 +1,572 @@
|
||||
---
|
||||
name: grpo-rl-training
|
||||
description: Expert guidance for GRPO/RL fine-tuning with TRL for reasoning and task-specific model training
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Post-Training, Reinforcement Learning, GRPO, TRL, RLHF, Reward Modeling, Reasoning, DPO, PPO, Structured Output]
|
||||
dependencies: [transformers>=4.47.0, trl>=0.14.0, datasets>=3.2.0, peft>=0.14.0, torch]
|
||||
---
|
||||
|
||||
# GRPO/RL Training with TRL
|
||||
|
||||
Expert-level guidance for implementing Group Relative Policy Optimization (GRPO) using the Transformer Reinforcement Learning (TRL) library. This skill provides battle-tested patterns, critical insights, and production-ready workflows for fine-tuning language models with custom reward functions.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use GRPO training when you need to:
|
||||
- **Enforce specific output formats** (e.g., XML tags, JSON, structured reasoning)
|
||||
- **Teach verifiable tasks** with objective correctness metrics (math, coding, fact-checking)
|
||||
- **Improve reasoning capabilities** by rewarding chain-of-thought patterns
|
||||
- **Align models to domain-specific behaviors** without labeled preference data
|
||||
- **Optimize for multiple objectives** simultaneously (format + correctness + style)
|
||||
|
||||
**Do NOT use GRPO for:**
|
||||
- Simple supervised fine-tuning tasks (use SFT instead)
|
||||
- Tasks without clear reward signals
|
||||
- When you already have high-quality preference pairs (use DPO/PPO instead)
|
||||
|
||||
---
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. GRPO Algorithm Fundamentals
|
||||
|
||||
**Key Mechanism:**
|
||||
- Generates **multiple completions** for each prompt (group size: 4-16)
|
||||
- Compares completions within each group using reward functions
|
||||
- Updates policy to favor higher-rewarded responses relative to the group
|
||||
|
||||
**Critical Difference from PPO:**
|
||||
- No separate reward model needed
|
||||
- More sample-efficient (learns from within-group comparisons)
|
||||
- Simpler to implement and debug
|
||||
|
||||
**Mathematical Intuition:**
|
||||
```
|
||||
For each prompt p:
|
||||
1. Generate N completions: {c₁, c₂, ..., cₙ}
|
||||
2. Compute rewards: {r₁, r₂, ..., rₙ}
|
||||
3. Learn to increase probability of high-reward completions
|
||||
relative to low-reward ones in the same group
|
||||
```
|
||||
|
||||
### 2. Reward Function Design Philosophy
|
||||
|
||||
**Golden Rules:**
|
||||
1. **Compose multiple reward functions** - Each handles one aspect (format, correctness, style)
|
||||
2. **Scale rewards appropriately** - Higher weight = stronger signal
|
||||
3. **Use incremental rewards** - Partial credit for partial compliance
|
||||
4. **Test rewards independently** - Debug each reward function in isolation
|
||||
|
||||
**Reward Function Types:**
|
||||
|
||||
| Type | Use Case | Example Weight |
|
||||
|------|----------|----------------|
|
||||
| **Correctness** | Verifiable tasks (math, code) | 2.0 (highest) |
|
||||
| **Format** | Strict structure enforcement | 0.5-1.0 |
|
||||
| **Length** | Encourage verbosity/conciseness | 0.1-0.5 |
|
||||
| **Style** | Penalize unwanted patterns | -0.5 to 0.5 |
|
||||
|
||||
---
|
||||
|
||||
## Implementation Workflow
|
||||
|
||||
### Step 1: Dataset Preparation
|
||||
|
||||
**Critical Requirements:**
|
||||
- Prompts in chat format (list of dicts with 'role' and 'content')
|
||||
- Include system prompts to set expectations
|
||||
- For verifiable tasks, include ground truth answers as additional columns
|
||||
|
||||
**Example Structure:**
|
||||
```python
|
||||
from datasets import load_dataset, Dataset
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
Respond in the following format:
|
||||
<reasoning>
|
||||
[Your step-by-step thinking]
|
||||
</reasoning>
|
||||
<answer>
|
||||
[Final answer]
|
||||
</answer>
|
||||
"""
|
||||
|
||||
def prepare_dataset(raw_data):
|
||||
"""
|
||||
Transform raw data into GRPO-compatible format.
|
||||
|
||||
Returns: Dataset with columns:
|
||||
- 'prompt': List[Dict] with role/content (system + user messages)
|
||||
- 'answer': str (ground truth, optional but recommended)
|
||||
"""
|
||||
return raw_data.map(lambda x: {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': extract_answer(x['raw_answer'])
|
||||
})
|
||||
```
|
||||
|
||||
**Pro Tips:**
|
||||
- Use one-shot or few-shot examples in system prompt for complex formats
|
||||
- Keep prompts concise (max_prompt_length: 256-512 tokens)
|
||||
- Validate data quality before training (garbage in = garbage out)
|
||||
|
||||
### Step 2: Reward Function Implementation
|
||||
|
||||
**Template Structure:**
|
||||
```python
|
||||
def reward_function_name(
|
||||
prompts, # List[List[Dict]]: Original prompts
|
||||
completions, # List[List[Dict]]: Model generations
|
||||
answer=None, # Optional: Ground truth from dataset
|
||||
**kwargs # Additional dataset columns
|
||||
) -> list[float]:
|
||||
"""
|
||||
Evaluate completions and return rewards.
|
||||
|
||||
Returns: List of floats (one per completion)
|
||||
"""
|
||||
# Extract completion text
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
|
||||
# Compute rewards
|
||||
rewards = []
|
||||
for response in responses:
|
||||
score = compute_score(response)
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Example 1: Correctness Reward (Math/Coding)**
|
||||
```python
|
||||
def correctness_reward(prompts, completions, answer, **kwargs):
|
||||
"""Reward correct answers with high score."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted = [extract_final_answer(r) for r in responses]
|
||||
return [2.0 if ans == gt else 0.0
|
||||
for ans, gt in zip(extracted, answer)]
|
||||
```
|
||||
|
||||
**Example 2: Format Reward (Structured Output)**
|
||||
```python
|
||||
import re
|
||||
|
||||
def format_reward(completions, **kwargs):
|
||||
"""Reward XML-like structured format."""
|
||||
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
return [1.0 if re.search(pattern, r, re.DOTALL) else 0.0
|
||||
for r in responses]
|
||||
```
|
||||
|
||||
**Example 3: Incremental Format Reward (Partial Credit)**
|
||||
```python
|
||||
def incremental_format_reward(completions, **kwargs):
|
||||
"""Award partial credit for format compliance."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
score = 0.0
|
||||
if '<reasoning>' in r:
|
||||
score += 0.25
|
||||
if '</reasoning>' in r:
|
||||
score += 0.25
|
||||
if '<answer>' in r:
|
||||
score += 0.25
|
||||
if '</answer>' in r:
|
||||
score += 0.25
|
||||
# Penalize extra text after closing tag
|
||||
if r.count('</answer>') == 1:
|
||||
extra_text = r.split('</answer>')[-1].strip()
|
||||
score -= len(extra_text) * 0.001
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Critical Insight:**
|
||||
Combine 3-5 reward functions for robust training. Order matters less than diversity of signals.
|
||||
|
||||
### Step 3: Training Configuration
|
||||
|
||||
**Memory-Optimized Config (Small GPU)**
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir="outputs/grpo-model",
|
||||
|
||||
# Learning rate
|
||||
learning_rate=5e-6, # Lower = more stable
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type='cosine',
|
||||
|
||||
# Batch settings
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4, # Effective batch = 4
|
||||
|
||||
# GRPO-specific
|
||||
num_generations=8, # Group size: 8-16 recommended
|
||||
max_prompt_length=256,
|
||||
max_completion_length=512,
|
||||
|
||||
# Training duration
|
||||
num_train_epochs=1,
|
||||
max_steps=None, # Or set fixed steps (e.g., 500)
|
||||
|
||||
# Optimization
|
||||
bf16=True, # Faster on A100/H100
|
||||
optim="adamw_8bit", # Memory-efficient optimizer
|
||||
max_grad_norm=0.1,
|
||||
|
||||
# Logging
|
||||
logging_steps=1,
|
||||
save_steps=100,
|
||||
report_to="wandb", # Or "none" for no logging
|
||||
)
|
||||
```
|
||||
|
||||
**High-Performance Config (Large GPU)**
|
||||
```python
|
||||
training_args = GRPOConfig(
|
||||
output_dir="outputs/grpo-model",
|
||||
learning_rate=1e-5,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=2,
|
||||
num_generations=16, # Larger groups = better signal
|
||||
max_prompt_length=512,
|
||||
max_completion_length=1024,
|
||||
num_train_epochs=1,
|
||||
bf16=True,
|
||||
use_vllm=True, # Fast generation with vLLM
|
||||
logging_steps=10,
|
||||
)
|
||||
```
|
||||
|
||||
**Critical Hyperparameters:**
|
||||
|
||||
| Parameter | Impact | Tuning Advice |
|
||||
|-----------|--------|---------------|
|
||||
| `num_generations` | Group size for comparison | Start with 8, increase to 16 if GPU allows |
|
||||
| `learning_rate` | Convergence speed/stability | 5e-6 (safe), 1e-5 (faster, riskier) |
|
||||
| `max_completion_length` | Output verbosity | Match your task (512 for reasoning, 256 for short answers) |
|
||||
| `gradient_accumulation_steps` | Effective batch size | Increase if GPU memory limited |
|
||||
|
||||
### Step 4: Model Setup and Training
|
||||
|
||||
**Standard Setup (Transformers)**
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOTrainer
|
||||
|
||||
# Load model
|
||||
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2", # 2-3x faster
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Optional: LoRA for parameter-efficient training
|
||||
peft_config = LoraConfig(
|
||||
r=16, # Rank (higher = more capacity)
|
||||
lora_alpha=32, # Scaling factor (typically 2*r)
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
incremental_format_reward,
|
||||
format_reward,
|
||||
correctness_reward,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=peft_config, # Remove for full fine-tuning
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer.train()
|
||||
|
||||
# Save
|
||||
trainer.save_model("final_model")
|
||||
```
|
||||
|
||||
**Unsloth Setup (2-3x Faster)**
|
||||
```python
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="google/gemma-3-1b-it",
|
||||
max_seq_length=1024,
|
||||
load_in_4bit=True,
|
||||
fast_inference=True,
|
||||
max_lora_rank=32,
|
||||
)
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=32,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"],
|
||||
lora_alpha=32,
|
||||
use_gradient_checkpointing="unsloth",
|
||||
)
|
||||
|
||||
# Rest is identical to standard setup
|
||||
trainer = GRPOTrainer(model=model, ...)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Critical Training Insights
|
||||
|
||||
### 1. Loss Behavior (EXPECTED PATTERN)
|
||||
- **Loss starts near 0 and INCREASES during training**
|
||||
- This is CORRECT - loss measures KL divergence from initial policy
|
||||
- Model is learning (diverging from original behavior to optimize rewards)
|
||||
- Monitor reward metrics instead of loss for progress
|
||||
|
||||
### 2. Reward Tracking
|
||||
Key metrics to watch:
|
||||
- `reward`: Average across all completions
|
||||
- `reward_std`: Diversity within groups (should remain > 0)
|
||||
- `kl`: KL divergence from reference (should grow moderately)
|
||||
|
||||
**Healthy Training Pattern:**
|
||||
```
|
||||
Step Reward Reward_Std KL
|
||||
100 0.5 0.3 0.02
|
||||
200 0.8 0.25 0.05
|
||||
300 1.2 0.2 0.08 ← Good progression
|
||||
400 1.5 0.15 0.12
|
||||
```
|
||||
|
||||
**Warning Signs:**
|
||||
- Reward std → 0 (model collapsing to single response)
|
||||
- KL exploding (> 0.5) (diverging too much, reduce LR)
|
||||
- Reward stuck (reward functions too harsh or model capacity issue)
|
||||
|
||||
### 3. Common Pitfalls and Solutions
|
||||
|
||||
| Problem | Symptom | Solution |
|
||||
|---------|---------|----------|
|
||||
| **Mode collapse** | All completions identical | Increase `num_generations`, add diversity penalty |
|
||||
| **No learning** | Flat rewards | Check reward function logic, increase LR |
|
||||
| **OOM errors** | GPU memory exceeded | Reduce `num_generations`, enable gradient checkpointing |
|
||||
| **Slow training** | < 1 it/s | Enable `use_vllm=True`, use Unsloth, reduce seq length |
|
||||
| **Format ignored** | Model doesn't follow structure | Increase format reward weight, add incremental rewards |
|
||||
|
||||
---
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### 1. Multi-Stage Training
|
||||
For complex tasks, train in stages:
|
||||
|
||||
```python
|
||||
# Stage 1: Format compliance (epochs=1)
|
||||
trainer_stage1 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[incremental_format_reward, format_reward],
|
||||
...
|
||||
)
|
||||
trainer_stage1.train()
|
||||
|
||||
# Stage 2: Correctness (epochs=1)
|
||||
trainer_stage2 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[format_reward, correctness_reward],
|
||||
...
|
||||
)
|
||||
trainer_stage2.train()
|
||||
```
|
||||
|
||||
### 2. Adaptive Reward Scaling
|
||||
```python
|
||||
class AdaptiveReward:
|
||||
def __init__(self, base_reward_func, initial_weight=1.0):
|
||||
self.func = base_reward_func
|
||||
self.weight = initial_weight
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
rewards = self.func(*args, **kwargs)
|
||||
return [r * self.weight for r in rewards]
|
||||
|
||||
def adjust_weight(self, success_rate):
|
||||
"""Increase weight if model struggling, decrease if succeeding."""
|
||||
if success_rate < 0.3:
|
||||
self.weight *= 1.2
|
||||
elif success_rate > 0.8:
|
||||
self.weight *= 0.9
|
||||
```
|
||||
|
||||
### 3. Custom Dataset Integration
|
||||
```python
|
||||
def load_custom_knowledge_base(csv_path):
|
||||
"""Example: School communication platform docs."""
|
||||
import pandas as pd
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
dataset = Dataset.from_pandas(df).map(lambda x: {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': CUSTOM_SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': x['expert_answer']
|
||||
})
|
||||
return dataset
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Deployment and Inference
|
||||
|
||||
### Save and Merge LoRA
|
||||
```python
|
||||
# Merge LoRA adapters into base model
|
||||
if hasattr(trainer.model, 'merge_and_unload'):
|
||||
merged_model = trainer.model.merge_and_unload()
|
||||
merged_model.save_pretrained("production_model")
|
||||
tokenizer.save_pretrained("production_model")
|
||||
```
|
||||
|
||||
### Inference Example
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
generator = pipeline(
|
||||
"text-generation",
|
||||
model="production_model",
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
result = generator(
|
||||
[
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': "What is 15 + 27?"}
|
||||
],
|
||||
max_new_tokens=256,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9
|
||||
)
|
||||
print(result[0]['generated_text'])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Best Practices Checklist
|
||||
|
||||
**Before Training:**
|
||||
- [ ] Validate dataset format (prompts as List[Dict])
|
||||
- [ ] Test reward functions on sample data
|
||||
- [ ] Calculate expected max_prompt_length from data
|
||||
- [ ] Choose appropriate num_generations based on GPU memory
|
||||
- [ ] Set up logging (wandb recommended)
|
||||
|
||||
**During Training:**
|
||||
- [ ] Monitor reward progression (should increase)
|
||||
- [ ] Check reward_std (should stay > 0.1)
|
||||
- [ ] Watch for OOM errors (reduce batch size if needed)
|
||||
- [ ] Sample generations every 50-100 steps
|
||||
- [ ] Validate format compliance on holdout set
|
||||
|
||||
**After Training:**
|
||||
- [ ] Merge LoRA weights if using PEFT
|
||||
- [ ] Test on diverse prompts
|
||||
- [ ] Compare to baseline model
|
||||
- [ ] Document reward weights and hyperparameters
|
||||
- [ ] Save reproducibility config
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting Guide
|
||||
|
||||
### Debugging Workflow
|
||||
1. **Isolate reward functions** - Test each independently
|
||||
2. **Check data distribution** - Ensure diversity in prompts
|
||||
3. **Reduce complexity** - Start with single reward, add gradually
|
||||
4. **Monitor generations** - Print samples every N steps
|
||||
5. **Validate extraction logic** - Ensure answer parsing works
|
||||
|
||||
### Quick Fixes
|
||||
```python
|
||||
# Debug reward function
|
||||
def debug_reward(completions, **kwargs):
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
for i, r in enumerate(responses[:2]): # Print first 2
|
||||
print(f"Response {i}: {r[:200]}...")
|
||||
return [1.0] * len(responses) # Dummy rewards
|
||||
|
||||
# Test without training
|
||||
trainer = GRPOTrainer(..., reward_funcs=[debug_reward])
|
||||
trainer.generate_completions(dataset[:1]) # Generate without updating
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## References and Resources
|
||||
|
||||
**Official Documentation:**
|
||||
- TRL GRPO Trainer: https://huggingface.co/docs/trl/grpo_trainer
|
||||
- DeepSeek R1 Paper: https://arxiv.org/abs/2501.12948
|
||||
- Unsloth Docs: https://docs.unsloth.ai/
|
||||
|
||||
**Example Repositories:**
|
||||
- Open R1 Implementation: https://github.com/huggingface/open-r1
|
||||
- TRL Examples: https://github.com/huggingface/trl/tree/main/examples
|
||||
|
||||
**Recommended Reading:**
|
||||
- Progressive Disclosure Pattern for agent instructions
|
||||
- Reward shaping in RL (Ng et al.)
|
||||
- LoRA paper (Hu et al., 2021)
|
||||
|
||||
---
|
||||
|
||||
## Usage Instructions for Agents
|
||||
|
||||
When this skill is loaded:
|
||||
|
||||
1. **Read this entire file** before implementing GRPO training
|
||||
2. **Start with the simplest reward function** (e.g., length-based) to validate setup
|
||||
3. **Use the templates** in `templates/` directory as starting points
|
||||
4. **Reference examples** in `examples/` for task-specific implementations
|
||||
5. **Follow the workflow** sequentially (don't skip steps)
|
||||
6. **Debug incrementally** - add one reward function at a time
|
||||
|
||||
**Critical Reminders:**
|
||||
- Always use multiple reward functions (3-5 is optimal)
|
||||
- Monitor reward metrics, not loss
|
||||
- Test reward functions before training
|
||||
- Start small (num_generations=4), scale up gradually
|
||||
- Save checkpoints frequently (every 100 steps)
|
||||
|
||||
This skill is designed for **expert-level implementation**. Beginners should start with supervised fine-tuning before attempting GRPO.
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Basic GRPO Training Template
|
||||
=============================
|
||||
|
||||
A minimal, production-ready template for GRPO training with TRL.
|
||||
Adapt this for your specific task by modifying:
|
||||
1. Dataset loading (get_dataset function)
|
||||
2. Reward functions (reward_*_func)
|
||||
3. System prompt (SYSTEM_PROMPT)
|
||||
4. Hyperparameters (GRPOConfig)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import re
|
||||
from datasets import load_dataset, Dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
# ==================== CONFIGURATION ====================
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
OUTPUT_DIR = "outputs/grpo-model"
|
||||
MAX_PROMPT_LENGTH = 256
|
||||
MAX_COMPLETION_LENGTH = 512
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
Respond in the following format:
|
||||
<reasoning>
|
||||
[Your step-by-step thinking]
|
||||
</reasoning>
|
||||
<answer>
|
||||
[Final answer]
|
||||
</answer>
|
||||
"""
|
||||
|
||||
# ==================== DATASET ====================
|
||||
|
||||
def get_dataset(split="train"):
|
||||
"""
|
||||
Load and prepare your dataset.
|
||||
|
||||
Returns: Dataset with columns:
|
||||
- 'prompt': List[Dict] with role/content
|
||||
- 'answer': str (ground truth, optional)
|
||||
"""
|
||||
# Example: GSM8K math dataset
|
||||
data = load_dataset('openai/gsm8k', 'main')[split]
|
||||
|
||||
def process_example(x):
|
||||
# Extract ground truth answer
|
||||
answer = x['answer'].split('####')[1].strip() if '####' in x['answer'] else None
|
||||
|
||||
return {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': answer
|
||||
}
|
||||
|
||||
return data.map(process_example)
|
||||
|
||||
# ==================== HELPER FUNCTIONS ====================
|
||||
|
||||
def extract_xml_tag(text: str, tag: str) -> str:
|
||||
"""Extract content between XML tags."""
|
||||
pattern = f'<{tag}>(.*?)</{tag}>'
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
return match.group(1).strip() if match else ""
|
||||
|
||||
def extract_answer(text: str) -> str:
|
||||
"""Extract the final answer from structured output."""
|
||||
return extract_xml_tag(text, 'answer')
|
||||
|
||||
# ==================== REWARD FUNCTIONS ====================
|
||||
|
||||
def correctness_reward_func(prompts, completions, answer, **kwargs):
|
||||
"""
|
||||
Reward correct answers.
|
||||
Weight: 2.0 (highest priority)
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted = [extract_answer(r) for r in responses]
|
||||
return [2.0 if ans == gt else 0.0 for ans, gt in zip(extracted, answer)]
|
||||
|
||||
def format_reward_func(completions, **kwargs):
|
||||
"""
|
||||
Reward proper XML format.
|
||||
Weight: 0.5
|
||||
"""
|
||||
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
return [0.5 if re.search(pattern, r, re.DOTALL) else 0.0 for r in responses]
|
||||
|
||||
def incremental_format_reward_func(completions, **kwargs):
|
||||
"""
|
||||
Incremental reward for partial format compliance.
|
||||
Weight: up to 0.5
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
score = 0.0
|
||||
if '<reasoning>' in r:
|
||||
score += 0.125
|
||||
if '</reasoning>' in r:
|
||||
score += 0.125
|
||||
if '<answer>' in r:
|
||||
score += 0.125
|
||||
if '</answer>' in r:
|
||||
score += 0.125
|
||||
|
||||
# Penalize extra content after closing tag
|
||||
if '</answer>' in r:
|
||||
extra = r.split('</answer>')[-1].strip()
|
||||
score -= len(extra) * 0.001
|
||||
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
|
||||
# ==================== MODEL SETUP ====================
|
||||
|
||||
def setup_model_and_tokenizer():
|
||||
"""Load model and tokenizer with optimizations."""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_NAME,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
def get_peft_config():
|
||||
"""LoRA configuration for parameter-efficient training."""
|
||||
return LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
|
||||
# ==================== TRAINING ====================
|
||||
|
||||
def main():
|
||||
"""Main training function."""
|
||||
|
||||
# Load data
|
||||
print("Loading dataset...")
|
||||
dataset = get_dataset()
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
# Setup model
|
||||
print("Loading model...")
|
||||
model, tokenizer = setup_model_and_tokenizer()
|
||||
|
||||
# Training configuration
|
||||
training_args = GRPOConfig(
|
||||
output_dir=OUTPUT_DIR,
|
||||
run_name="grpo-training",
|
||||
|
||||
# Learning rate
|
||||
learning_rate=5e-6,
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type='cosine',
|
||||
|
||||
# Batch settings
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
|
||||
# GRPO specific
|
||||
num_generations=8,
|
||||
max_prompt_length=MAX_PROMPT_LENGTH,
|
||||
max_completion_length=MAX_COMPLETION_LENGTH,
|
||||
|
||||
# Training duration
|
||||
num_train_epochs=1,
|
||||
|
||||
# Optimization
|
||||
bf16=True,
|
||||
optim="adamw_8bit",
|
||||
max_grad_norm=0.1,
|
||||
|
||||
# Logging
|
||||
logging_steps=1,
|
||||
save_steps=100,
|
||||
report_to="wandb", # Change to "none" to disable logging
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
incremental_format_reward_func,
|
||||
format_reward_func,
|
||||
correctness_reward_func,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=get_peft_config(),
|
||||
)
|
||||
|
||||
# Train
|
||||
print("Starting training...")
|
||||
trainer.train()
|
||||
|
||||
# Save final model
|
||||
print(f"Saving model to {OUTPUT_DIR}/final")
|
||||
trainer.save_model(f"{OUTPUT_DIR}/final")
|
||||
|
||||
print("Training complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
572
skills/mlops-knowledge/guidance/SKILL.md
Normal file
572
skills/mlops-knowledge/guidance/SKILL.md
Normal file
@@ -0,0 +1,572 @@
|
||||
---
|
||||
name: guidance
|
||||
description: Control LLM output with regex and grammars, guarantee valid JSON/XML/code generation, enforce structured formats, and build multi-step workflows with Guidance - Microsoft Research's constrained generation framework
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Prompt Engineering, Guidance, Constrained Generation, Structured Output, JSON Validation, Grammar, Microsoft Research, Format Enforcement, Multi-Step Workflows]
|
||||
dependencies: [guidance, transformers]
|
||||
---
|
||||
|
||||
# Guidance: Constrained LLM Generation
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use Guidance when you need to:
|
||||
- **Control LLM output syntax** with regex or grammars
|
||||
- **Guarantee valid JSON/XML/code** generation
|
||||
- **Reduce latency** vs traditional prompting approaches
|
||||
- **Enforce structured formats** (dates, emails, IDs, etc.)
|
||||
- **Build multi-step workflows** with Pythonic control flow
|
||||
- **Prevent invalid outputs** through grammatical constraints
|
||||
|
||||
**GitHub Stars**: 18,000+ | **From**: Microsoft Research
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Base installation
|
||||
pip install guidance
|
||||
|
||||
# With specific backends
|
||||
pip install guidance[transformers] # Hugging Face models
|
||||
pip install guidance[llama_cpp] # llama.cpp models
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Example: Structured Generation
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
# Load model (supports OpenAI, Transformers, llama.cpp)
|
||||
lm = models.OpenAI("gpt-4")
|
||||
|
||||
# Generate with constraints
|
||||
result = lm + "The capital of France is " + gen("capital", max_tokens=5)
|
||||
|
||||
print(result["capital"]) # "Paris"
|
||||
```
|
||||
|
||||
### With Anthropic Claude
|
||||
|
||||
```python
|
||||
from guidance import models, gen, system, user, assistant
|
||||
|
||||
# Configure Claude
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Use context managers for chat format
|
||||
with system():
|
||||
lm += "You are a helpful assistant."
|
||||
|
||||
with user():
|
||||
lm += "What is the capital of France?"
|
||||
|
||||
with assistant():
|
||||
lm += gen(max_tokens=20)
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. Context Managers
|
||||
|
||||
Guidance uses Pythonic context managers for chat-style interactions.
|
||||
|
||||
```python
|
||||
from guidance import system, user, assistant, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# System message
|
||||
with system():
|
||||
lm += "You are a JSON generation expert."
|
||||
|
||||
# User message
|
||||
with user():
|
||||
lm += "Generate a person object with name and age."
|
||||
|
||||
# Assistant response
|
||||
with assistant():
|
||||
lm += gen("response", max_tokens=100)
|
||||
|
||||
print(lm["response"])
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Natural chat flow
|
||||
- Clear role separation
|
||||
- Easy to read and maintain
|
||||
|
||||
### 2. Constrained Generation
|
||||
|
||||
Guidance ensures outputs match specified patterns using regex or grammars.
|
||||
|
||||
#### Regex Constraints
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Constrain to valid email format
|
||||
lm += "Email: " + gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
|
||||
|
||||
# Constrain to date format (YYYY-MM-DD)
|
||||
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}")
|
||||
|
||||
# Constrain to phone number
|
||||
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}")
|
||||
|
||||
print(lm["email"]) # Guaranteed valid email
|
||||
print(lm["date"]) # Guaranteed YYYY-MM-DD format
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
- Regex converted to grammar at token level
|
||||
- Invalid tokens filtered during generation
|
||||
- Model can only produce matching outputs
|
||||
|
||||
#### Selection Constraints
|
||||
|
||||
```python
|
||||
from guidance import models, gen, select
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Constrain to specific choices
|
||||
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
|
||||
|
||||
# Multiple-choice selection
|
||||
lm += "Best answer: " + select(
|
||||
["A) Paris", "B) London", "C) Berlin", "D) Madrid"],
|
||||
name="answer"
|
||||
)
|
||||
|
||||
print(lm["sentiment"]) # One of: positive, negative, neutral
|
||||
print(lm["answer"]) # One of: A, B, C, or D
|
||||
```
|
||||
|
||||
### 3. Token Healing
|
||||
|
||||
Guidance automatically "heals" token boundaries between prompt and generation.
|
||||
|
||||
**Problem:** Tokenization creates unnatural boundaries.
|
||||
|
||||
```python
|
||||
# Without token healing
|
||||
prompt = "The capital of France is "
|
||||
# Last token: " is "
|
||||
# First generated token might be " Par" (with leading space)
|
||||
# Result: "The capital of France is Paris" (double space!)
|
||||
```
|
||||
|
||||
**Solution:** Guidance backs up one token and regenerates.
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Token healing enabled by default
|
||||
lm += "The capital of France is " + gen("capital", max_tokens=5)
|
||||
# Result: "The capital of France is Paris" (correct spacing)
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Natural text boundaries
|
||||
- No awkward spacing issues
|
||||
- Better model performance (sees natural token sequences)
|
||||
|
||||
### 4. Grammar-Based Generation
|
||||
|
||||
Define complex structures using context-free grammars.
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# JSON grammar (simplified)
|
||||
json_grammar = """
|
||||
{
|
||||
"name": <gen name regex="[A-Za-z ]+" max_tokens=20>,
|
||||
"age": <gen age regex="[0-9]+" max_tokens=3>,
|
||||
"email": <gen email regex="[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}" max_tokens=50>
|
||||
}
|
||||
"""
|
||||
|
||||
# Generate valid JSON
|
||||
lm += gen("person", grammar=json_grammar)
|
||||
|
||||
print(lm["person"]) # Guaranteed valid JSON structure
|
||||
```
|
||||
|
||||
**Use cases:**
|
||||
- Complex structured outputs
|
||||
- Nested data structures
|
||||
- Programming language syntax
|
||||
- Domain-specific languages
|
||||
|
||||
### 5. Guidance Functions
|
||||
|
||||
Create reusable generation patterns with the `@guidance` decorator.
|
||||
|
||||
```python
|
||||
from guidance import guidance, gen, models
|
||||
|
||||
@guidance
|
||||
def generate_person(lm):
|
||||
"""Generate a person with name and age."""
|
||||
lm += "Name: " + gen("name", max_tokens=20, stop="\n")
|
||||
lm += "\nAge: " + gen("age", regex=r"[0-9]+", max_tokens=3)
|
||||
return lm
|
||||
|
||||
# Use the function
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_person(lm)
|
||||
|
||||
print(lm["name"])
|
||||
print(lm["age"])
|
||||
```
|
||||
|
||||
**Stateful Functions:**
|
||||
|
||||
```python
|
||||
@guidance(stateless=False)
|
||||
def react_agent(lm, question, tools, max_rounds=5):
|
||||
"""ReAct agent with tool use."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
for i in range(max_rounds):
|
||||
# Thought
|
||||
lm += f"Thought {i+1}: " + gen("thought", stop="\n")
|
||||
|
||||
# Action
|
||||
lm += "\nAction: " + select(list(tools.keys()), name="action")
|
||||
|
||||
# Execute tool
|
||||
tool_result = tools[lm["action"]]()
|
||||
lm += f"\nObservation: {tool_result}\n\n"
|
||||
|
||||
# Check if done
|
||||
lm += "Done? " + select(["Yes", "No"], name="done")
|
||||
if lm["done"] == "Yes":
|
||||
break
|
||||
|
||||
# Final answer
|
||||
lm += "\nFinal Answer: " + gen("answer", max_tokens=100)
|
||||
return lm
|
||||
```
|
||||
|
||||
## Backend Configuration
|
||||
|
||||
### Anthropic Claude
|
||||
|
||||
```python
|
||||
from guidance import models
|
||||
|
||||
lm = models.Anthropic(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
api_key="your-api-key" # Or set ANTHROPIC_API_KEY env var
|
||||
)
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
```python
|
||||
lm = models.OpenAI(
|
||||
model="gpt-4o-mini",
|
||||
api_key="your-api-key" # Or set OPENAI_API_KEY env var
|
||||
)
|
||||
```
|
||||
|
||||
### Local Models (Transformers)
|
||||
|
||||
```python
|
||||
from guidance.models import Transformers
|
||||
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda" # Or "cpu"
|
||||
)
|
||||
```
|
||||
|
||||
### Local Models (llama.cpp)
|
||||
|
||||
```python
|
||||
from guidance.models import LlamaCpp
|
||||
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Pattern 1: JSON Generation
|
||||
|
||||
```python
|
||||
from guidance import models, gen, system, user, assistant
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
with system():
|
||||
lm += "You generate valid JSON."
|
||||
|
||||
with user():
|
||||
lm += "Generate a user profile with name, age, and email."
|
||||
|
||||
with assistant():
|
||||
lm += """{
|
||||
"name": """ + gen("name", regex=r'"[A-Za-z ]+"', max_tokens=30) + """,
|
||||
"age": """ + gen("age", regex=r"[0-9]+", max_tokens=3) + """,
|
||||
"email": """ + gen("email", regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"', max_tokens=50) + """
|
||||
}"""
|
||||
|
||||
print(lm) # Valid JSON guaranteed
|
||||
```
|
||||
|
||||
### Pattern 2: Classification
|
||||
|
||||
```python
|
||||
from guidance import models, gen, select
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
text = "This product is amazing! I love it."
|
||||
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
|
||||
lm += "\nConfidence: " + gen("confidence", regex=r"[0-9]+", max_tokens=3) + "%"
|
||||
|
||||
print(f"Sentiment: {lm['sentiment']}")
|
||||
print(f"Confidence: {lm['confidence']}%")
|
||||
```
|
||||
|
||||
### Pattern 3: Multi-Step Reasoning
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def chain_of_thought(lm, question):
|
||||
"""Generate answer with step-by-step reasoning."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
# Generate multiple reasoning steps
|
||||
for i in range(3):
|
||||
lm += f"Step {i+1}: " + gen(f"step_{i+1}", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Final answer
|
||||
lm += "\nTherefore, the answer is: " + gen("answer", max_tokens=50)
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = chain_of_thought(lm, "What is 15% of 200?")
|
||||
|
||||
print(lm["answer"])
|
||||
```
|
||||
|
||||
### Pattern 4: ReAct Agent
|
||||
|
||||
```python
|
||||
from guidance import models, gen, select, guidance
|
||||
|
||||
@guidance(stateless=False)
|
||||
def react_agent(lm, question):
|
||||
"""ReAct agent with tool use."""
|
||||
tools = {
|
||||
"calculator": lambda expr: eval(expr),
|
||||
"search": lambda query: f"Search results for: {query}",
|
||||
}
|
||||
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
for round in range(5):
|
||||
# Thought
|
||||
lm += f"Thought: " + gen("thought", stop="\n") + "\n"
|
||||
|
||||
# Action selection
|
||||
lm += "Action: " + select(["calculator", "search", "answer"], name="action")
|
||||
|
||||
if lm["action"] == "answer":
|
||||
lm += "\nFinal Answer: " + gen("answer", max_tokens=100)
|
||||
break
|
||||
|
||||
# Action input
|
||||
lm += "\nAction Input: " + gen("action_input", stop="\n") + "\n"
|
||||
|
||||
# Execute tool
|
||||
if lm["action"] in tools:
|
||||
result = tools[lm["action"]](lm["action_input"])
|
||||
lm += f"Observation: {result}\n\n"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = react_agent(lm, "What is 25 * 4 + 10?")
|
||||
print(lm["answer"])
|
||||
```
|
||||
|
||||
### Pattern 5: Data Extraction
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def extract_entities(lm, text):
|
||||
"""Extract structured entities from text."""
|
||||
lm += f"Text: {text}\n\n"
|
||||
|
||||
# Extract person
|
||||
lm += "Person: " + gen("person", stop="\n", max_tokens=30) + "\n"
|
||||
|
||||
# Extract organization
|
||||
lm += "Organization: " + gen("organization", stop="\n", max_tokens=30) + "\n"
|
||||
|
||||
# Extract date
|
||||
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}", max_tokens=10) + "\n"
|
||||
|
||||
# Extract location
|
||||
lm += "Location: " + gen("location", stop="\n", max_tokens=30) + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
text = "Tim Cook announced at Apple Park on 2024-09-15 in Cupertino."
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = extract_entities(lm, text)
|
||||
|
||||
print(f"Person: {lm['person']}")
|
||||
print(f"Organization: {lm['organization']}")
|
||||
print(f"Date: {lm['date']}")
|
||||
print(f"Location: {lm['location']}")
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Use Regex for Format Validation
|
||||
|
||||
```python
|
||||
# ✅ Good: Regex ensures valid format
|
||||
lm += "Email: " + gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
|
||||
|
||||
# ❌ Bad: Free generation may produce invalid emails
|
||||
lm += "Email: " + gen("email", max_tokens=50)
|
||||
```
|
||||
|
||||
### 2. Use select() for Fixed Categories
|
||||
|
||||
```python
|
||||
# ✅ Good: Guaranteed valid category
|
||||
lm += "Status: " + select(["pending", "approved", "rejected"], name="status")
|
||||
|
||||
# ❌ Bad: May generate typos or invalid values
|
||||
lm += "Status: " + gen("status", max_tokens=20)
|
||||
```
|
||||
|
||||
### 3. Leverage Token Healing
|
||||
|
||||
```python
|
||||
# Token healing is enabled by default
|
||||
# No special action needed - just concatenate naturally
|
||||
lm += "The capital is " + gen("capital") # Automatic healing
|
||||
```
|
||||
|
||||
### 4. Use stop Sequences
|
||||
|
||||
```python
|
||||
# ✅ Good: Stop at newline for single-line outputs
|
||||
lm += "Name: " + gen("name", stop="\n")
|
||||
|
||||
# ❌ Bad: May generate multiple lines
|
||||
lm += "Name: " + gen("name", max_tokens=50)
|
||||
```
|
||||
|
||||
### 5. Create Reusable Functions
|
||||
|
||||
```python
|
||||
# ✅ Good: Reusable pattern
|
||||
@guidance
|
||||
def generate_person(lm):
|
||||
lm += "Name: " + gen("name", stop="\n")
|
||||
lm += "\nAge: " + gen("age", regex=r"[0-9]+")
|
||||
return lm
|
||||
|
||||
# Use multiple times
|
||||
lm = generate_person(lm)
|
||||
lm += "\n\n"
|
||||
lm = generate_person(lm)
|
||||
```
|
||||
|
||||
### 6. Balance Constraints
|
||||
|
||||
```python
|
||||
# ✅ Good: Reasonable constraints
|
||||
lm += gen("name", regex=r"[A-Za-z ]+", max_tokens=30)
|
||||
|
||||
# ❌ Too strict: May fail or be very slow
|
||||
lm += gen("name", regex=r"^(John|Jane)$", max_tokens=10)
|
||||
```
|
||||
|
||||
## Comparison to Alternatives
|
||||
|
||||
| Feature | Guidance | Instructor | Outlines | LMQL |
|
||||
|---------|----------|------------|----------|------|
|
||||
| Regex Constraints | ✅ Yes | ❌ No | ✅ Yes | ✅ Yes |
|
||||
| Grammar Support | ✅ CFG | ❌ No | ✅ CFG | ✅ CFG |
|
||||
| Pydantic Validation | ❌ No | ✅ Yes | ✅ Yes | ❌ No |
|
||||
| Token Healing | ✅ Yes | ❌ No | ✅ Yes | ❌ No |
|
||||
| Local Models | ✅ Yes | ⚠️ Limited | ✅ Yes | ✅ Yes |
|
||||
| API Models | ✅ Yes | ✅ Yes | ⚠️ Limited | ✅ Yes |
|
||||
| Pythonic Syntax | ✅ Yes | ✅ Yes | ✅ Yes | ❌ SQL-like |
|
||||
| Learning Curve | Low | Low | Medium | High |
|
||||
|
||||
**When to choose Guidance:**
|
||||
- Need regex/grammar constraints
|
||||
- Want token healing
|
||||
- Building complex workflows with control flow
|
||||
- Using local models (Transformers, llama.cpp)
|
||||
- Prefer Pythonic syntax
|
||||
|
||||
**When to choose alternatives:**
|
||||
- Instructor: Need Pydantic validation with automatic retrying
|
||||
- Outlines: Need JSON schema validation
|
||||
- LMQL: Prefer declarative query syntax
|
||||
|
||||
## Performance Characteristics
|
||||
|
||||
**Latency Reduction:**
|
||||
- 30-50% faster than traditional prompting for constrained outputs
|
||||
- Token healing reduces unnecessary regeneration
|
||||
- Grammar constraints prevent invalid token generation
|
||||
|
||||
**Memory Usage:**
|
||||
- Minimal overhead vs unconstrained generation
|
||||
- Grammar compilation cached after first use
|
||||
- Efficient token filtering at inference time
|
||||
|
||||
**Token Efficiency:**
|
||||
- Prevents wasted tokens on invalid outputs
|
||||
- No need for retry loops
|
||||
- Direct path to valid outputs
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://guidance.readthedocs.io
|
||||
- **GitHub**: https://github.com/guidance-ai/guidance (18k+ stars)
|
||||
- **Notebooks**: https://github.com/guidance-ai/guidance/tree/main/notebooks
|
||||
- **Discord**: Community support available
|
||||
|
||||
## See Also
|
||||
|
||||
- `references/constraints.md` - Comprehensive regex and grammar patterns
|
||||
- `references/backends.md` - Backend-specific configuration
|
||||
- `references/examples.md` - Production-ready examples
|
||||
|
||||
|
||||
554
skills/mlops-knowledge/guidance/references/backends.md
Normal file
554
skills/mlops-knowledge/guidance/references/backends.md
Normal file
@@ -0,0 +1,554 @@
|
||||
# Backend Configuration Guide
|
||||
|
||||
Complete guide to configuring Guidance with different LLM backends.
|
||||
|
||||
## Table of Contents
|
||||
- API-Based Models (Anthropic, OpenAI)
|
||||
- Local Models (Transformers, llama.cpp)
|
||||
- Backend Comparison
|
||||
- Performance Tuning
|
||||
- Advanced Configuration
|
||||
|
||||
## API-Based Models
|
||||
|
||||
### Anthropic Claude
|
||||
|
||||
#### Basic Setup
|
||||
|
||||
```python
|
||||
from guidance import models
|
||||
|
||||
# Using environment variable
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
# Reads ANTHROPIC_API_KEY from environment
|
||||
|
||||
# Explicit API key
|
||||
lm = models.Anthropic(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
api_key="your-api-key-here"
|
||||
)
|
||||
```
|
||||
|
||||
#### Available Models
|
||||
|
||||
```python
|
||||
# Claude 3.5 Sonnet (Latest, recommended)
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Claude 3.7 Sonnet (Fast, cost-effective)
|
||||
lm = models.Anthropic("claude-sonnet-3.7-20250219")
|
||||
|
||||
# Claude 3 Opus (Most capable)
|
||||
lm = models.Anthropic("claude-3-opus-20240229")
|
||||
|
||||
# Claude 3.5 Haiku (Fastest, cheapest)
|
||||
lm = models.Anthropic("claude-3-5-haiku-20241022")
|
||||
```
|
||||
|
||||
#### Configuration Options
|
||||
|
||||
```python
|
||||
lm = models.Anthropic(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
api_key="your-api-key",
|
||||
max_tokens=4096, # Max tokens to generate
|
||||
temperature=0.7, # Sampling temperature (0-1)
|
||||
top_p=0.9, # Nucleus sampling
|
||||
timeout=30, # Request timeout (seconds)
|
||||
max_retries=3 # Retry failed requests
|
||||
)
|
||||
```
|
||||
|
||||
#### With Context Managers
|
||||
|
||||
```python
|
||||
from guidance import models, system, user, assistant, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
with system():
|
||||
lm += "You are a helpful assistant."
|
||||
|
||||
with user():
|
||||
lm += "What is the capital of France?"
|
||||
|
||||
with assistant():
|
||||
lm += gen(max_tokens=50)
|
||||
|
||||
print(lm)
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
#### Basic Setup
|
||||
|
||||
```python
|
||||
from guidance import models
|
||||
|
||||
# Using environment variable
|
||||
lm = models.OpenAI("gpt-4o")
|
||||
# Reads OPENAI_API_KEY from environment
|
||||
|
||||
# Explicit API key
|
||||
lm = models.OpenAI(
|
||||
model="gpt-4o",
|
||||
api_key="your-api-key-here"
|
||||
)
|
||||
```
|
||||
|
||||
#### Available Models
|
||||
|
||||
```python
|
||||
# GPT-4o (Latest, multimodal)
|
||||
lm = models.OpenAI("gpt-4o")
|
||||
|
||||
# GPT-4o Mini (Fast, cost-effective)
|
||||
lm = models.OpenAI("gpt-4o-mini")
|
||||
|
||||
# GPT-4 Turbo
|
||||
lm = models.OpenAI("gpt-4-turbo")
|
||||
|
||||
# GPT-3.5 Turbo (Cheapest)
|
||||
lm = models.OpenAI("gpt-3.5-turbo")
|
||||
```
|
||||
|
||||
#### Configuration Options
|
||||
|
||||
```python
|
||||
lm = models.OpenAI(
|
||||
model="gpt-4o-mini",
|
||||
api_key="your-api-key",
|
||||
max_tokens=2048,
|
||||
temperature=0.7,
|
||||
top_p=1.0,
|
||||
frequency_penalty=0.0,
|
||||
presence_penalty=0.0,
|
||||
timeout=30
|
||||
)
|
||||
```
|
||||
|
||||
#### Chat Format
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.OpenAI("gpt-4o-mini")
|
||||
|
||||
# OpenAI uses chat format
|
||||
lm += [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is 2+2?"}
|
||||
]
|
||||
|
||||
# Generate response
|
||||
lm += gen(max_tokens=50)
|
||||
```
|
||||
|
||||
### Azure OpenAI
|
||||
|
||||
```python
|
||||
from guidance import models
|
||||
|
||||
lm = models.AzureOpenAI(
|
||||
model="gpt-4o",
|
||||
azure_endpoint="https://your-resource.openai.azure.com/",
|
||||
api_key="your-azure-api-key",
|
||||
api_version="2024-02-15-preview",
|
||||
deployment_name="your-deployment-name"
|
||||
)
|
||||
```
|
||||
|
||||
## Local Models
|
||||
|
||||
### Transformers (Hugging Face)
|
||||
|
||||
#### Basic Setup
|
||||
|
||||
```python
|
||||
from guidance.models import Transformers
|
||||
|
||||
# Load model from Hugging Face
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct")
|
||||
```
|
||||
|
||||
#### GPU Configuration
|
||||
|
||||
```python
|
||||
# Use GPU
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Use specific GPU
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda:0" # GPU 0
|
||||
)
|
||||
|
||||
# Use CPU
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cpu"
|
||||
)
|
||||
```
|
||||
|
||||
#### Advanced Configuration
|
||||
|
||||
```python
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda",
|
||||
torch_dtype="float16", # Use FP16 (faster, less memory)
|
||||
load_in_8bit=True, # 8-bit quantization
|
||||
max_memory={0: "20GB"}, # GPU memory limit
|
||||
offload_folder="./offload" # Offload to disk if needed
|
||||
)
|
||||
```
|
||||
|
||||
#### Popular Models
|
||||
|
||||
```python
|
||||
# Phi-4 (Microsoft)
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct")
|
||||
lm = Transformers("microsoft/Phi-3-medium-4k-instruct")
|
||||
|
||||
# Llama 3 (Meta)
|
||||
lm = Transformers("meta-llama/Llama-3.1-8B-Instruct")
|
||||
lm = Transformers("meta-llama/Llama-3.1-70B-Instruct")
|
||||
|
||||
# Mistral (Mistral AI)
|
||||
lm = Transformers("mistralai/Mistral-7B-Instruct-v0.3")
|
||||
lm = Transformers("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||
|
||||
# Qwen (Alibaba)
|
||||
lm = Transformers("Qwen/Qwen2.5-7B-Instruct")
|
||||
|
||||
# Gemma (Google)
|
||||
lm = Transformers("google/gemma-2-9b-it")
|
||||
```
|
||||
|
||||
#### Generation Configuration
|
||||
|
||||
```python
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Configure generation
|
||||
from guidance import gen
|
||||
|
||||
result = lm + gen(
|
||||
max_tokens=100,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
top_k=50,
|
||||
repetition_penalty=1.1
|
||||
)
|
||||
```
|
||||
|
||||
### llama.cpp
|
||||
|
||||
#### Basic Setup
|
||||
|
||||
```python
|
||||
from guidance.models import LlamaCpp
|
||||
|
||||
# Load GGUF model
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.gguf",
|
||||
n_ctx=4096 # Context window
|
||||
)
|
||||
```
|
||||
|
||||
#### GPU Configuration
|
||||
|
||||
```python
|
||||
# Use GPU acceleration
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35, # Offload 35 layers to GPU
|
||||
n_threads=8 # CPU threads for remaining layers
|
||||
)
|
||||
|
||||
# Full GPU offload
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=-1 # Offload all layers
|
||||
)
|
||||
```
|
||||
|
||||
#### Advanced Configuration
|
||||
|
||||
```python
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/llama-3.1-8b-instruct.Q4_K_M.gguf",
|
||||
n_ctx=8192, # Context window (tokens)
|
||||
n_gpu_layers=35, # GPU layers
|
||||
n_threads=8, # CPU threads
|
||||
n_batch=512, # Batch size for prompt processing
|
||||
use_mmap=True, # Memory-map the model file
|
||||
use_mlock=False, # Lock model in RAM
|
||||
seed=42, # Random seed
|
||||
verbose=False # Suppress verbose output
|
||||
)
|
||||
```
|
||||
|
||||
#### Quantized Models
|
||||
|
||||
```python
|
||||
# Q4_K_M (4-bit, recommended for most cases)
|
||||
lm = LlamaCpp("/path/to/model.Q4_K_M.gguf")
|
||||
|
||||
# Q5_K_M (5-bit, better quality)
|
||||
lm = LlamaCpp("/path/to/model.Q5_K_M.gguf")
|
||||
|
||||
# Q8_0 (8-bit, high quality)
|
||||
lm = LlamaCpp("/path/to/model.Q8_0.gguf")
|
||||
|
||||
# F16 (16-bit float, highest quality)
|
||||
lm = LlamaCpp("/path/to/model.F16.gguf")
|
||||
```
|
||||
|
||||
#### Popular GGUF Models
|
||||
|
||||
```python
|
||||
# Llama 3.1
|
||||
lm = LlamaCpp("llama-3.1-8b-instruct.Q4_K_M.gguf")
|
||||
|
||||
# Mistral
|
||||
lm = LlamaCpp("mistral-7b-instruct-v0.3.Q4_K_M.gguf")
|
||||
|
||||
# Phi-4
|
||||
lm = LlamaCpp("phi-4-mini-instruct.Q4_K_M.gguf")
|
||||
```
|
||||
|
||||
## Backend Comparison
|
||||
|
||||
### Feature Matrix
|
||||
|
||||
| Feature | Anthropic | OpenAI | Transformers | llama.cpp |
|
||||
|---------|-----------|--------|--------------|-----------|
|
||||
| Constrained Generation | ✅ Full | ✅ Full | ✅ Full | ✅ Full |
|
||||
| Token Healing | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes |
|
||||
| Streaming | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes |
|
||||
| GPU Support | N/A | N/A | ✅ Yes | ✅ Yes |
|
||||
| Quantization | N/A | N/A | ✅ Yes | ✅ Yes |
|
||||
| Cost | $$$ | $$$ | Free | Free |
|
||||
| Latency | Low | Low | Medium | Low |
|
||||
| Setup Difficulty | Easy | Easy | Medium | Medium |
|
||||
|
||||
### Performance Characteristics
|
||||
|
||||
**Anthropic Claude:**
|
||||
- **Latency**: 200-500ms (API call)
|
||||
- **Throughput**: Limited by API rate limits
|
||||
- **Cost**: $3-15 per 1M input tokens
|
||||
- **Best for**: Production systems, high-quality outputs
|
||||
|
||||
**OpenAI:**
|
||||
- **Latency**: 200-400ms (API call)
|
||||
- **Throughput**: Limited by API rate limits
|
||||
- **Cost**: $0.15-30 per 1M input tokens
|
||||
- **Best for**: Cost-sensitive production, gpt-4o-mini
|
||||
|
||||
**Transformers:**
|
||||
- **Latency**: 50-200ms (local inference)
|
||||
- **Throughput**: GPU-dependent (10-100 tokens/sec)
|
||||
- **Cost**: Hardware cost only
|
||||
- **Best for**: Privacy-sensitive, high-volume, experimentation
|
||||
|
||||
**llama.cpp:**
|
||||
- **Latency**: 30-150ms (local inference)
|
||||
- **Throughput**: Hardware-dependent (20-150 tokens/sec)
|
||||
- **Cost**: Hardware cost only
|
||||
- **Best for**: Edge deployment, Apple Silicon, CPU inference
|
||||
|
||||
### Memory Requirements
|
||||
|
||||
**Transformers (FP16):**
|
||||
- 7B model: ~14GB GPU VRAM
|
||||
- 13B model: ~26GB GPU VRAM
|
||||
- 70B model: ~140GB GPU VRAM (multi-GPU)
|
||||
|
||||
**llama.cpp (Q4_K_M):**
|
||||
- 7B model: ~4.5GB RAM
|
||||
- 13B model: ~8GB RAM
|
||||
- 70B model: ~40GB RAM
|
||||
|
||||
**Optimization Tips:**
|
||||
- Use quantized models (Q4_K_M) for lower memory
|
||||
- Use GPU offloading for faster inference
|
||||
- Use CPU inference for smaller models (<7B)
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### API Models (Anthropic, OpenAI)
|
||||
|
||||
#### Reduce Latency
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Use lower max_tokens (faster response)
|
||||
lm += gen(max_tokens=100) # Instead of 1000
|
||||
|
||||
# Use streaming (perceived latency reduction)
|
||||
for chunk in lm.stream(gen(max_tokens=500)):
|
||||
print(chunk, end="", flush=True)
|
||||
```
|
||||
|
||||
#### Reduce Cost
|
||||
|
||||
```python
|
||||
# Use cheaper models
|
||||
lm = models.Anthropic("claude-3-5-haiku-20241022") # vs Sonnet
|
||||
lm = models.OpenAI("gpt-4o-mini") # vs gpt-4o
|
||||
|
||||
# Reduce context size
|
||||
# - Keep prompts concise
|
||||
# - Avoid large few-shot examples
|
||||
# - Use max_tokens limits
|
||||
```
|
||||
|
||||
### Local Models (Transformers, llama.cpp)
|
||||
|
||||
#### Optimize GPU Usage
|
||||
|
||||
```python
|
||||
from guidance.models import Transformers
|
||||
|
||||
# Use FP16 for 2x speedup
|
||||
lm = Transformers(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
device="cuda",
|
||||
torch_dtype="float16"
|
||||
)
|
||||
|
||||
# Use 8-bit quantization for 4x memory reduction
|
||||
lm = Transformers(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
device="cuda",
|
||||
load_in_8bit=True
|
||||
)
|
||||
|
||||
# Use flash attention (requires flash-attn package)
|
||||
lm = Transformers(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
device="cuda",
|
||||
use_flash_attention_2=True
|
||||
)
|
||||
```
|
||||
|
||||
#### Optimize llama.cpp
|
||||
|
||||
```python
|
||||
from guidance.models import LlamaCpp
|
||||
|
||||
# Maximize GPU layers
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.Q4_K_M.gguf",
|
||||
n_gpu_layers=-1 # All layers on GPU
|
||||
)
|
||||
|
||||
# Optimize batch size
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.Q4_K_M.gguf",
|
||||
n_batch=512, # Larger batch = faster prompt processing
|
||||
n_gpu_layers=-1
|
||||
)
|
||||
|
||||
# Use Metal (Apple Silicon)
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.Q4_K_M.gguf",
|
||||
n_gpu_layers=-1, # Use Metal GPU acceleration
|
||||
use_mmap=True
|
||||
)
|
||||
```
|
||||
|
||||
#### Batch Processing
|
||||
|
||||
```python
|
||||
# Process multiple requests efficiently
|
||||
requests = [
|
||||
"What is 2+2?",
|
||||
"What is the capital of France?",
|
||||
"What is photosynthesis?"
|
||||
]
|
||||
|
||||
# Bad: Sequential processing
|
||||
for req in requests:
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct")
|
||||
lm += req + gen(max_tokens=50)
|
||||
|
||||
# Good: Reuse loaded model
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct")
|
||||
for req in requests:
|
||||
lm += req + gen(max_tokens=50)
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### Custom Model Configurations
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from guidance.models import Transformers
|
||||
|
||||
# Load custom model
|
||||
tokenizer = AutoTokenizer.from_pretrained("your-model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"your-model",
|
||||
device_map="auto",
|
||||
torch_dtype="float16"
|
||||
)
|
||||
|
||||
# Use with Guidance
|
||||
lm = Transformers(model=model, tokenizer=tokenizer)
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# API keys
|
||||
export ANTHROPIC_API_KEY="sk-ant-..."
|
||||
export OPENAI_API_KEY="sk-..."
|
||||
|
||||
# Transformers cache
|
||||
export HF_HOME="/path/to/cache"
|
||||
export TRANSFORMERS_CACHE="/path/to/cache"
|
||||
|
||||
# GPU selection
|
||||
export CUDA_VISIBLE_DEVICES=0,1 # Use GPU 0 and 1
|
||||
```
|
||||
|
||||
### Debugging
|
||||
|
||||
```python
|
||||
# Enable verbose logging
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
# Check backend info
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
print(f"Model: {lm.model_name}")
|
||||
print(f"Backend: {lm.backend}")
|
||||
|
||||
# Check GPU usage (Transformers)
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct", device="cuda")
|
||||
print(f"Device: {lm.device}")
|
||||
print(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Anthropic Docs**: https://docs.anthropic.com
|
||||
- **OpenAI Docs**: https://platform.openai.com/docs
|
||||
- **Hugging Face Models**: https://huggingface.co/models
|
||||
- **llama.cpp**: https://github.com/ggerganov/llama.cpp
|
||||
- **GGUF Models**: https://huggingface.co/models?library=gguf
|
||||
674
skills/mlops-knowledge/guidance/references/constraints.md
Normal file
674
skills/mlops-knowledge/guidance/references/constraints.md
Normal file
@@ -0,0 +1,674 @@
|
||||
# Comprehensive Constraint Patterns
|
||||
|
||||
Guide to regex constraints, grammar-based generation, and token healing in Guidance.
|
||||
|
||||
## Table of Contents
|
||||
- Regex Constraints
|
||||
- Grammar-Based Generation
|
||||
- Token Healing
|
||||
- Selection Constraints
|
||||
- Complex Patterns
|
||||
- Performance Optimization
|
||||
|
||||
## Regex Constraints
|
||||
|
||||
### Basic Patterns
|
||||
|
||||
#### Numeric Constraints
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Integer (positive)
|
||||
lm += "Age: " + gen("age", regex=r"[0-9]+")
|
||||
|
||||
# Integer (with negatives)
|
||||
lm += "Temperature: " + gen("temp", regex=r"-?[0-9]+")
|
||||
|
||||
# Float (positive)
|
||||
lm += "Price: $" + gen("price", regex=r"[0-9]+\.[0-9]{2}")
|
||||
|
||||
# Float (with negatives and optional decimals)
|
||||
lm += "Value: " + gen("value", regex=r"-?[0-9]+(\.[0-9]+)?")
|
||||
|
||||
# Percentage (0-100)
|
||||
lm += "Progress: " + gen("progress", regex=r"(100|[0-9]{1,2})")
|
||||
|
||||
# Range (1-5 stars)
|
||||
lm += "Rating: " + gen("rating", regex=r"[1-5]") + " stars"
|
||||
```
|
||||
|
||||
#### Text Constraints
|
||||
|
||||
```python
|
||||
# Alphabetic only
|
||||
lm += "Name: " + gen("name", regex=r"[A-Za-z]+")
|
||||
|
||||
# Alphabetic with spaces
|
||||
lm += "Full Name: " + gen("full_name", regex=r"[A-Za-z ]+")
|
||||
|
||||
# Alphanumeric
|
||||
lm += "Username: " + gen("username", regex=r"[A-Za-z0-9_]+")
|
||||
|
||||
# Capitalized words
|
||||
lm += "Title: " + gen("title", regex=r"[A-Z][a-z]+( [A-Z][a-z]+)*")
|
||||
|
||||
# Lowercase only
|
||||
lm += "Code: " + gen("code", regex=r"[a-z0-9-]+")
|
||||
|
||||
# Specific length
|
||||
lm += "ID: " + gen("id", regex=r"[A-Z]{3}-[0-9]{6}") # e.g., "ABC-123456"
|
||||
```
|
||||
|
||||
#### Date and Time Constraints
|
||||
|
||||
```python
|
||||
# Date (YYYY-MM-DD)
|
||||
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}")
|
||||
|
||||
# Date (MM/DD/YYYY)
|
||||
lm += "Date: " + gen("date_us", regex=r"\d{2}/\d{2}/\d{4}")
|
||||
|
||||
# Time (HH:MM)
|
||||
lm += "Time: " + gen("time", regex=r"\d{2}:\d{2}")
|
||||
|
||||
# Time (HH:MM:SS)
|
||||
lm += "Time: " + gen("time_full", regex=r"\d{2}:\d{2}:\d{2}")
|
||||
|
||||
# ISO 8601 datetime
|
||||
lm += "Timestamp: " + gen(
|
||||
"timestamp",
|
||||
regex=r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z"
|
||||
)
|
||||
|
||||
# Year (YYYY)
|
||||
lm += "Year: " + gen("year", regex=r"(19|20)\d{2}")
|
||||
|
||||
# Month name
|
||||
lm += "Month: " + gen(
|
||||
"month",
|
||||
regex=r"(January|February|March|April|May|June|July|August|September|October|November|December)"
|
||||
)
|
||||
```
|
||||
|
||||
#### Contact Information
|
||||
|
||||
```python
|
||||
# Email
|
||||
lm += "Email: " + gen(
|
||||
"email",
|
||||
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"
|
||||
)
|
||||
|
||||
# Phone (US format)
|
||||
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}")
|
||||
|
||||
# Phone (international format)
|
||||
lm += "Phone: " + gen("phone_intl", regex=r"\+[0-9]{1,3}-[0-9]{1,14}")
|
||||
|
||||
# ZIP code (US)
|
||||
lm += "ZIP: " + gen("zip", regex=r"\d{5}(-\d{4})?")
|
||||
|
||||
# Postal code (Canada)
|
||||
lm += "Postal: " + gen("postal", regex=r"[A-Z]\d[A-Z] \d[A-Z]\d")
|
||||
|
||||
# URL
|
||||
lm += "URL: " + gen(
|
||||
"url",
|
||||
regex=r"https?://[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}(/[a-zA-Z0-9._~:/?#\[\]@!$&'()*+,;=-]*)?"
|
||||
)
|
||||
```
|
||||
|
||||
### Advanced Patterns
|
||||
|
||||
#### JSON Field Constraints
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# String field with quotes
|
||||
lm += '"name": ' + gen("name", regex=r'"[A-Za-z ]+"')
|
||||
|
||||
# Numeric field (no quotes)
|
||||
lm += '"age": ' + gen("age", regex=r"[0-9]+")
|
||||
|
||||
# Boolean field
|
||||
lm += '"active": ' + gen("active", regex=r"(true|false)")
|
||||
|
||||
# Null field
|
||||
lm += '"optional": ' + gen("optional", regex=r"(null|[0-9]+)")
|
||||
|
||||
# Array of strings
|
||||
lm += '"tags": [' + gen(
|
||||
"tags",
|
||||
regex=r'"[a-z]+"(, "[a-z]+")*'
|
||||
) + ']'
|
||||
|
||||
# Complete JSON object
|
||||
lm += """{
|
||||
"name": """ + gen("name", regex=r'"[A-Za-z ]+"') + """,
|
||||
"age": """ + gen("age", regex=r"[0-9]+") + """,
|
||||
"email": """ + gen(
|
||||
"email",
|
||||
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
|
||||
) + """
|
||||
}"""
|
||||
```
|
||||
|
||||
#### Code Patterns
|
||||
|
||||
```python
|
||||
# Python variable name
|
||||
lm += "Variable: " + gen("var", regex=r"[a-z_][a-z0-9_]*")
|
||||
|
||||
# Python function name
|
||||
lm += "Function: " + gen("func", regex=r"[a-z_][a-z0-9_]*")
|
||||
|
||||
# Hex color code
|
||||
lm += "Color: #" + gen("color", regex=r"[0-9A-Fa-f]{6}")
|
||||
|
||||
# UUID
|
||||
lm += "UUID: " + gen(
|
||||
"uuid",
|
||||
regex=r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||
)
|
||||
|
||||
# Git commit hash (short)
|
||||
lm += "Commit: " + gen("commit", regex=r"[0-9a-f]{7}")
|
||||
|
||||
# Semantic version
|
||||
lm += "Version: " + gen("version", regex=r"[0-9]+\.[0-9]+\.[0-9]+")
|
||||
|
||||
# IP address (IPv4)
|
||||
lm += "IP: " + gen(
|
||||
"ip",
|
||||
regex=r"((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"
|
||||
)
|
||||
```
|
||||
|
||||
#### Domain-Specific Patterns
|
||||
|
||||
```python
|
||||
# Credit card number
|
||||
lm += "Card: " + gen("card", regex=r"\d{4}-\d{4}-\d{4}-\d{4}")
|
||||
|
||||
# Social Security Number (US)
|
||||
lm += "SSN: " + gen("ssn", regex=r"\d{3}-\d{2}-\d{4}")
|
||||
|
||||
# ISBN-13
|
||||
lm += "ISBN: " + gen("isbn", regex=r"978-\d{1,5}-\d{1,7}-\d{1,7}-\d")
|
||||
|
||||
# License plate (US)
|
||||
lm += "Plate: " + gen("plate", regex=r"[A-Z]{3}-\d{4}")
|
||||
|
||||
# Currency amount
|
||||
lm += "Amount: $" + gen("amount", regex=r"[0-9]{1,3}(,[0-9]{3})*\.[0-9]{2}")
|
||||
|
||||
# Percentage with decimal
|
||||
lm += "Rate: " + gen("rate", regex=r"[0-9]+\.[0-9]{1,2}%")
|
||||
```
|
||||
|
||||
## Grammar-Based Generation
|
||||
|
||||
### JSON Grammar
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def json_object(lm):
|
||||
"""Generate valid JSON object."""
|
||||
lm += "{\n"
|
||||
|
||||
# Name field (required)
|
||||
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
|
||||
# Age field (required)
|
||||
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + ",\n"
|
||||
|
||||
# Email field (required)
|
||||
lm += ' "email": ' + gen(
|
||||
"email",
|
||||
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
|
||||
) + ",\n"
|
||||
|
||||
# Active field (required, boolean)
|
||||
lm += ' "active": ' + gen("active", regex=r"(true|false)") + "\n"
|
||||
|
||||
lm += "}"
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = json_object(lm)
|
||||
print(lm) # Valid JSON guaranteed
|
||||
```
|
||||
|
||||
### Nested JSON Grammar
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def nested_json(lm):
|
||||
"""Generate nested JSON structure."""
|
||||
lm += "{\n"
|
||||
|
||||
# User object
|
||||
lm += ' "user": {\n'
|
||||
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + "\n"
|
||||
lm += " },\n"
|
||||
|
||||
# Address object
|
||||
lm += ' "address": {\n'
|
||||
lm += ' "street": ' + gen("street", regex=r'"[A-Za-z0-9 ]+"') + ",\n"
|
||||
lm += ' "city": ' + gen("city", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "zip": ' + gen("zip", regex=r'"\d{5}"') + "\n"
|
||||
lm += " }\n"
|
||||
|
||||
lm += "}"
|
||||
return lm
|
||||
```
|
||||
|
||||
### Array Grammar
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def json_array(lm, count=3):
|
||||
"""Generate JSON array with fixed count."""
|
||||
lm += "[\n"
|
||||
|
||||
for i in range(count):
|
||||
lm += " {\n"
|
||||
lm += ' "id": ' + gen(f"id_{i}", regex=r"[0-9]+") + ",\n"
|
||||
lm += ' "name": ' + gen(f"name_{i}", regex=r'"[A-Za-z ]+"') + "\n"
|
||||
lm += " }"
|
||||
if i < count - 1:
|
||||
lm += ","
|
||||
lm += "\n"
|
||||
|
||||
lm += "]"
|
||||
return lm
|
||||
```
|
||||
|
||||
### XML Grammar
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def xml_document(lm):
|
||||
"""Generate valid XML document."""
|
||||
lm += '<?xml version="1.0"?>\n'
|
||||
lm += "<person>\n"
|
||||
|
||||
# Name element
|
||||
lm += " <name>" + gen("name", regex=r"[A-Za-z ]+") + "</name>\n"
|
||||
|
||||
# Age element
|
||||
lm += " <age>" + gen("age", regex=r"[0-9]+") + "</age>\n"
|
||||
|
||||
# Email element
|
||||
lm += " <email>" + gen(
|
||||
"email",
|
||||
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"
|
||||
) + "</email>\n"
|
||||
|
||||
lm += "</person>"
|
||||
return lm
|
||||
```
|
||||
|
||||
### CSV Grammar
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def csv_row(lm):
|
||||
"""Generate CSV row."""
|
||||
lm += gen("name", regex=r"[A-Za-z ]+") + ","
|
||||
lm += gen("age", regex=r"[0-9]+") + ","
|
||||
lm += gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
|
||||
return lm
|
||||
|
||||
@guidance
|
||||
def csv_document(lm, rows=5):
|
||||
"""Generate complete CSV."""
|
||||
# Header
|
||||
lm += "Name,Age,Email\n"
|
||||
|
||||
# Rows
|
||||
for i in range(rows):
|
||||
lm = csv_row(lm)
|
||||
if i < rows - 1:
|
||||
lm += "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
## Token Healing
|
||||
|
||||
### How Token Healing Works
|
||||
|
||||
**Problem:** Tokenization creates unnatural boundaries.
|
||||
|
||||
```python
|
||||
# Example without token healing
|
||||
prompt = "The capital of France is "
|
||||
# Tokenization: ["The", " capital", " of", " France", " is", " "]
|
||||
# Model sees last token: " "
|
||||
# First generated token might include leading space: " Paris"
|
||||
# Result: "The capital of France is Paris" (double space)
|
||||
```
|
||||
|
||||
**Solution:** Guidance backs up and regenerates the last token.
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Token healing enabled by default
|
||||
lm += "The capital of France is " + gen("capital", max_tokens=5)
|
||||
|
||||
# Process:
|
||||
# 1. Back up to token before " is "
|
||||
# 2. Regenerate " is" + "capital" together
|
||||
# 3. Result: "The capital of France is Paris" (correct)
|
||||
```
|
||||
|
||||
### Token Healing Examples
|
||||
|
||||
#### Natural Continuations
|
||||
|
||||
```python
|
||||
# Before token healing
|
||||
lm += "The function name is get" + gen("rest")
|
||||
# Might generate: "The function name is get User" (space before User)
|
||||
|
||||
# With token healing
|
||||
lm += "The function name is get" + gen("rest")
|
||||
# Generates: "The function name is getUser" (correct camelCase)
|
||||
```
|
||||
|
||||
#### Code Generation
|
||||
|
||||
```python
|
||||
# Function name completion
|
||||
lm += "def calculate_" + gen("rest", stop="(")
|
||||
# Token healing ensures smooth connection: "calculate_total"
|
||||
|
||||
# Variable name completion
|
||||
lm += "my_" + gen("var_name", regex=r"[a-z_]+")
|
||||
# Token healing ensures: "my_variable_name" (not "my_ variable_name")
|
||||
```
|
||||
|
||||
#### Domain-Specific Terms
|
||||
|
||||
```python
|
||||
# Medical terms
|
||||
lm += "The patient has hyper" + gen("condition")
|
||||
# Token healing helps: "hypertension" (not "hyper tension")
|
||||
|
||||
# Technical terms
|
||||
lm += "Using micro" + gen("tech")
|
||||
# Token healing helps: "microservices" (not "micro services")
|
||||
```
|
||||
|
||||
### Disabling Token Healing
|
||||
|
||||
```python
|
||||
# Disable token healing if needed (rare)
|
||||
lm += gen("text", token_healing=False)
|
||||
```
|
||||
|
||||
## Selection Constraints
|
||||
|
||||
### Basic Selection
|
||||
|
||||
```python
|
||||
from guidance import models, select
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Simple selection
|
||||
lm += "Status: " + select(["active", "inactive", "pending"], name="status")
|
||||
|
||||
# Boolean selection
|
||||
lm += "Approved: " + select(["Yes", "No"], name="approved")
|
||||
|
||||
# Multiple choice
|
||||
lm += "Answer: " + select(
|
||||
["A) Paris", "B) London", "C) Berlin", "D) Madrid"],
|
||||
name="answer"
|
||||
)
|
||||
```
|
||||
|
||||
### Conditional Selection
|
||||
|
||||
```python
|
||||
from guidance import models, select, gen, guidance
|
||||
|
||||
@guidance
|
||||
def conditional_fields(lm):
|
||||
"""Generate fields conditionally based on type."""
|
||||
lm += "Type: " + select(["person", "company"], name="type")
|
||||
|
||||
if lm["type"] == "person":
|
||||
lm += "\nName: " + gen("name", regex=r"[A-Za-z ]+")
|
||||
lm += "\nAge: " + gen("age", regex=r"[0-9]+")
|
||||
else:
|
||||
lm += "\nCompany Name: " + gen("company", regex=r"[A-Za-z ]+")
|
||||
lm += "\nEmployees: " + gen("employees", regex=r"[0-9]+")
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Repeated Selection
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def multiple_selections(lm):
|
||||
"""Select multiple items."""
|
||||
lm += "Select 3 colors:\n"
|
||||
|
||||
colors = ["red", "blue", "green", "yellow", "purple"]
|
||||
|
||||
for i in range(3):
|
||||
lm += f"{i+1}. " + select(colors, name=f"color_{i}") + "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
## Complex Patterns
|
||||
|
||||
### Pattern 1: Structured Forms
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def user_form(lm):
|
||||
"""Generate structured user form."""
|
||||
lm += "=== User Registration ===\n\n"
|
||||
|
||||
# Name (alphabetic only)
|
||||
lm += "Full Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
# Age (numeric)
|
||||
lm += "Age: " + gen("age", regex=r"[0-9]+", max_tokens=3) + "\n"
|
||||
|
||||
# Email (validated format)
|
||||
lm += "Email: " + gen(
|
||||
"email",
|
||||
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
|
||||
stop="\n"
|
||||
) + "\n"
|
||||
|
||||
# Phone (US format)
|
||||
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}") + "\n"
|
||||
|
||||
# Account type (selection)
|
||||
lm += "Account Type: " + select(
|
||||
["Standard", "Premium", "Enterprise"],
|
||||
name="account_type"
|
||||
) + "\n"
|
||||
|
||||
# Active status (boolean)
|
||||
lm += "Active: " + select(["Yes", "No"], name="active") + "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Pattern 2: Multi-Entity Extraction
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def extract_entities(lm, text):
|
||||
"""Extract multiple entities with constraints."""
|
||||
lm += f"Text: {text}\n\n"
|
||||
|
||||
# Person name (alphabetic)
|
||||
lm += "Person: " + gen("person", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
# Organization (alphanumeric with spaces)
|
||||
lm += "Organization: " + gen(
|
||||
"organization",
|
||||
regex=r"[A-Za-z0-9 ]+",
|
||||
stop="\n"
|
||||
) + "\n"
|
||||
|
||||
# Date (YYYY-MM-DD format)
|
||||
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}") + "\n"
|
||||
|
||||
# Location (alphabetic with spaces)
|
||||
lm += "Location: " + gen("location", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
# Amount (currency)
|
||||
lm += "Amount: $" + gen("amount", regex=r"[0-9,]+\.[0-9]{2}") + "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Pattern 3: Code Generation
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_python_function(lm):
|
||||
"""Generate Python function with constraints."""
|
||||
# Function name (valid Python identifier)
|
||||
lm += "def " + gen("func_name", regex=r"[a-z_][a-z0-9_]*") + "("
|
||||
|
||||
# Parameter name
|
||||
lm += gen("param", regex=r"[a-z_][a-z0-9_]*") + "):\n"
|
||||
|
||||
# Docstring
|
||||
lm += ' """' + gen("docstring", stop='"""', max_tokens=50) + '"""\n'
|
||||
|
||||
# Function body (constrained to valid Python)
|
||||
lm += " return " + gen("return_value", stop="\n") + "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Pattern 4: Hierarchical Data
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def org_chart(lm):
|
||||
"""Generate organizational chart."""
|
||||
lm += "Company: " + gen("company", regex=r"[A-Za-z ]+") + "\n\n"
|
||||
|
||||
# CEO
|
||||
lm += "CEO: " + gen("ceo", regex=r"[A-Za-z ]+") + "\n"
|
||||
|
||||
# Departments
|
||||
for dept in ["Engineering", "Sales", "Marketing"]:
|
||||
lm += f"\n{dept} Department:\n"
|
||||
lm += " Head: " + gen(f"{dept.lower()}_head", regex=r"[A-Za-z ]+") + "\n"
|
||||
lm += " Size: " + gen(f"{dept.lower()}_size", regex=r"[0-9]+") + " employees\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Best Practices
|
||||
|
||||
#### 1. Use Specific Patterns
|
||||
|
||||
```python
|
||||
# ✅ Good: Specific pattern
|
||||
lm += gen("age", regex=r"[0-9]{1,3}") # Fast
|
||||
|
||||
# ❌ Bad: Overly broad pattern
|
||||
lm += gen("age", regex=r"[0-9]+") # Slower
|
||||
```
|
||||
|
||||
#### 2. Limit Max Tokens
|
||||
|
||||
```python
|
||||
# ✅ Good: Reasonable limit
|
||||
lm += gen("name", max_tokens=30)
|
||||
|
||||
# ❌ Bad: No limit
|
||||
lm += gen("name") # May generate forever
|
||||
```
|
||||
|
||||
#### 3. Use stop Sequences
|
||||
|
||||
```python
|
||||
# ✅ Good: Stop at newline
|
||||
lm += gen("line", stop="\n")
|
||||
|
||||
# ❌ Bad: Rely on max_tokens
|
||||
lm += gen("line", max_tokens=100)
|
||||
```
|
||||
|
||||
#### 4. Cache Compiled Grammars
|
||||
|
||||
```python
|
||||
# Grammars are cached automatically after first use
|
||||
# No manual caching needed
|
||||
@guidance
|
||||
def reusable_pattern(lm):
|
||||
"""This grammar is compiled once and cached."""
|
||||
lm += gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
|
||||
return lm
|
||||
|
||||
# First call: compiles grammar
|
||||
lm = reusable_pattern(lm)
|
||||
|
||||
# Subsequent calls: uses cached grammar (fast)
|
||||
lm = reusable_pattern(lm)
|
||||
```
|
||||
|
||||
#### 5. Avoid Overlapping Constraints
|
||||
|
||||
```python
|
||||
# ✅ Good: Clear constraints
|
||||
lm += gen("age", regex=r"[0-9]+", max_tokens=3)
|
||||
|
||||
# ❌ Bad: Conflicting constraints
|
||||
lm += gen("age", regex=r"[0-9]{2}", max_tokens=10) # max_tokens unnecessary
|
||||
```
|
||||
|
||||
### Performance Benchmarks
|
||||
|
||||
**Regex vs Free Generation:**
|
||||
- Simple regex (digits): ~1.2x slower than free gen
|
||||
- Complex regex (email): ~1.5x slower than free gen
|
||||
- Grammar-based: ~2x slower than free gen
|
||||
|
||||
**But:**
|
||||
- 100% valid outputs (vs ~70% with free gen + validation)
|
||||
- No retry loops needed
|
||||
- Overall faster end-to-end for structured outputs
|
||||
|
||||
**Optimization Tips:**
|
||||
- Use regex for critical fields only
|
||||
- Use `select()` for small fixed sets (fastest)
|
||||
- Use `stop` sequences when possible (faster than max_tokens)
|
||||
- Cache compiled grammars by reusing functions
|
||||
|
||||
## Resources
|
||||
|
||||
- **Token Healing Paper**: https://arxiv.org/abs/2306.17648
|
||||
- **Guidance Docs**: https://guidance.readthedocs.io
|
||||
- **GitHub**: https://github.com/guidance-ai/guidance
|
||||
767
skills/mlops-knowledge/guidance/references/examples.md
Normal file
767
skills/mlops-knowledge/guidance/references/examples.md
Normal file
@@ -0,0 +1,767 @@
|
||||
# Production-Ready Examples
|
||||
|
||||
Real-world examples of using Guidance for structured generation, agents, and workflows.
|
||||
|
||||
## Table of Contents
|
||||
- JSON Generation
|
||||
- Data Extraction
|
||||
- Classification Systems
|
||||
- Agent Systems
|
||||
- Multi-Step Workflows
|
||||
- Code Generation
|
||||
- Production Tips
|
||||
|
||||
## JSON Generation
|
||||
|
||||
### Basic JSON
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def generate_user(lm):
|
||||
"""Generate valid user JSON."""
|
||||
lm += "{\n"
|
||||
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + ",\n"
|
||||
lm += ' "email": ' + gen(
|
||||
"email",
|
||||
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
|
||||
) + "\n"
|
||||
lm += "}"
|
||||
return lm
|
||||
|
||||
# Use it
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm += "Generate a user profile:\n"
|
||||
lm = generate_user(lm)
|
||||
|
||||
print(lm)
|
||||
# Output: Valid JSON guaranteed
|
||||
```
|
||||
|
||||
### Nested JSON
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_order(lm):
|
||||
"""Generate nested order JSON."""
|
||||
lm += "{\n"
|
||||
|
||||
# Customer info
|
||||
lm += ' "customer": {\n'
|
||||
lm += ' "name": ' + gen("customer_name", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "email": ' + gen(
|
||||
"customer_email",
|
||||
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
|
||||
) + "\n"
|
||||
lm += " },\n"
|
||||
|
||||
# Order details
|
||||
lm += ' "order": {\n'
|
||||
lm += ' "id": ' + gen("order_id", regex=r'"ORD-[0-9]{6}"') + ",\n"
|
||||
lm += ' "date": ' + gen("order_date", regex=r'"\d{4}-\d{2}-\d{2}"') + ",\n"
|
||||
lm += ' "total": ' + gen("order_total", regex=r"[0-9]+\.[0-9]{2}") + "\n"
|
||||
lm += " },\n"
|
||||
|
||||
# Status
|
||||
lm += ' "status": ' + gen(
|
||||
"status",
|
||||
regex=r'"(pending|processing|shipped|delivered)"'
|
||||
) + "\n"
|
||||
|
||||
lm += "}"
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_order(lm)
|
||||
```
|
||||
|
||||
### JSON Array
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_user_list(lm, count=3):
|
||||
"""Generate JSON array of users."""
|
||||
lm += "[\n"
|
||||
|
||||
for i in range(count):
|
||||
lm += " {\n"
|
||||
lm += ' "id": ' + gen(f"id_{i}", regex=r"[0-9]+") + ",\n"
|
||||
lm += ' "name": ' + gen(f"name_{i}", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "active": ' + gen(f"active_{i}", regex=r"(true|false)") + "\n"
|
||||
lm += " }"
|
||||
if i < count - 1:
|
||||
lm += ","
|
||||
lm += "\n"
|
||||
|
||||
lm += "]"
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_user_list(lm, count=5)
|
||||
```
|
||||
|
||||
### Dynamic JSON Schema
|
||||
|
||||
```python
|
||||
import json
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def json_from_schema(lm, schema):
|
||||
"""Generate JSON matching a schema."""
|
||||
lm += "{\n"
|
||||
|
||||
fields = list(schema["properties"].items())
|
||||
for i, (field_name, field_schema) in enumerate(fields):
|
||||
lm += f' "{field_name}": '
|
||||
|
||||
# Handle different types
|
||||
if field_schema["type"] == "string":
|
||||
if "pattern" in field_schema:
|
||||
lm += gen(field_name, regex=f'"{field_schema["pattern"]}"')
|
||||
else:
|
||||
lm += gen(field_name, regex=r'"[^"]+"')
|
||||
elif field_schema["type"] == "number":
|
||||
lm += gen(field_name, regex=r"[0-9]+(\.[0-9]+)?")
|
||||
elif field_schema["type"] == "integer":
|
||||
lm += gen(field_name, regex=r"[0-9]+")
|
||||
elif field_schema["type"] == "boolean":
|
||||
lm += gen(field_name, regex=r"(true|false)")
|
||||
|
||||
if i < len(fields) - 1:
|
||||
lm += ","
|
||||
lm += "\n"
|
||||
|
||||
lm += "}"
|
||||
return lm
|
||||
|
||||
# Define schema
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
"score": {"type": "number"},
|
||||
"active": {"type": "boolean"}
|
||||
}
|
||||
}
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = json_from_schema(lm, schema)
|
||||
```
|
||||
|
||||
## Data Extraction
|
||||
|
||||
### Extract from Text
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance, system, user, assistant
|
||||
|
||||
@guidance
|
||||
def extract_person_info(lm, text):
|
||||
"""Extract structured info from text."""
|
||||
lm += f"Text: {text}\n\n"
|
||||
|
||||
with assistant():
|
||||
lm += "Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
lm += "Age: " + gen("age", regex=r"[0-9]+", max_tokens=3) + "\n"
|
||||
lm += "Occupation: " + gen("occupation", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
lm += "Email: " + gen(
|
||||
"email",
|
||||
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
|
||||
stop="\n"
|
||||
) + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
text = "John Smith is a 35-year-old software engineer. Contact: john@example.com"
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
with system():
|
||||
lm += "You extract structured information from text."
|
||||
|
||||
with user():
|
||||
lm = extract_person_info(lm, text)
|
||||
|
||||
print(f"Name: {lm['name']}")
|
||||
print(f"Age: {lm['age']}")
|
||||
print(f"Occupation: {lm['occupation']}")
|
||||
print(f"Email: {lm['email']}")
|
||||
```
|
||||
|
||||
### Multi-Entity Extraction
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def extract_entities(lm, text):
|
||||
"""Extract multiple entity types."""
|
||||
lm += f"Analyze: {text}\n\n"
|
||||
|
||||
# Person entities
|
||||
lm += "People:\n"
|
||||
for i in range(3): # Up to 3 people
|
||||
lm += f"- " + gen(f"person_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
# Organization entities
|
||||
lm += "\nOrganizations:\n"
|
||||
for i in range(2): # Up to 2 orgs
|
||||
lm += f"- " + gen(f"org_{i}", regex=r"[A-Za-z0-9 ]+", stop="\n") + "\n"
|
||||
|
||||
# Dates
|
||||
lm += "\nDates:\n"
|
||||
for i in range(2): # Up to 2 dates
|
||||
lm += f"- " + gen(f"date_{i}", regex=r"\d{4}-\d{2}-\d{2}", stop="\n") + "\n"
|
||||
|
||||
# Locations
|
||||
lm += "\nLocations:\n"
|
||||
for i in range(2): # Up to 2 locations
|
||||
lm += f"- " + gen(f"location_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
text = """
|
||||
Tim Cook and Satya Nadella met at Microsoft headquarters in Redmond on 2024-09-15
|
||||
to discuss the collaboration between Apple and Microsoft. The meeting continued
|
||||
in Cupertino on 2024-09-20.
|
||||
"""
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = extract_entities(lm, text)
|
||||
```
|
||||
|
||||
### Batch Extraction
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def batch_extract(lm, texts):
|
||||
"""Extract from multiple texts."""
|
||||
lm += "Batch Extraction Results:\n\n"
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
lm += f"=== Item {i+1} ===\n"
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Name: " + gen(f"name_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
lm += "Sentiment: " + gen(
|
||||
f"sentiment_{i}",
|
||||
regex=r"(positive|negative|neutral)",
|
||||
stop="\n"
|
||||
) + "\n\n"
|
||||
|
||||
return lm
|
||||
|
||||
texts = [
|
||||
"Alice is happy with the product",
|
||||
"Bob is disappointed with the service",
|
||||
"Carol has no strong feelings either way"
|
||||
]
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = batch_extract(lm, texts)
|
||||
```
|
||||
|
||||
## Classification Systems
|
||||
|
||||
### Sentiment Analysis
|
||||
|
||||
```python
|
||||
from guidance import models, select, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
text = "This product is absolutely amazing! Best purchase ever."
|
||||
|
||||
lm += f"Text: {text}\n\n"
|
||||
lm += "Sentiment: " + select(
|
||||
["positive", "negative", "neutral"],
|
||||
name="sentiment"
|
||||
)
|
||||
lm += "\nConfidence: " + gen("confidence", regex=r"[0-9]{1,3}") + "%\n"
|
||||
lm += "Reasoning: " + gen("reasoning", stop="\n", max_tokens=50)
|
||||
|
||||
print(f"Sentiment: {lm['sentiment']}")
|
||||
print(f"Confidence: {lm['confidence']}%")
|
||||
print(f"Reasoning: {lm['reasoning']}")
|
||||
```
|
||||
|
||||
### Multi-Label Classification
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def classify_article(lm, text):
|
||||
"""Classify article with multiple labels."""
|
||||
lm += f"Article: {text}\n\n"
|
||||
|
||||
# Primary category
|
||||
lm += "Primary Category: " + select(
|
||||
["Technology", "Business", "Science", "Politics", "Entertainment"],
|
||||
name="primary_category"
|
||||
) + "\n"
|
||||
|
||||
# Secondary categories (up to 3)
|
||||
lm += "\nSecondary Categories:\n"
|
||||
categories = ["Technology", "Business", "Science", "Politics", "Entertainment"]
|
||||
for i in range(3):
|
||||
lm += f"{i+1}. " + select(categories, name=f"secondary_{i}") + "\n"
|
||||
|
||||
# Tags
|
||||
lm += "\nTags: " + gen("tags", stop="\n", max_tokens=50) + "\n"
|
||||
|
||||
# Target audience
|
||||
lm += "Target Audience: " + select(
|
||||
["General", "Expert", "Beginner"],
|
||||
name="audience"
|
||||
)
|
||||
|
||||
return lm
|
||||
|
||||
article = """
|
||||
Apple announced new AI features in iOS 18, leveraging machine learning to improve
|
||||
battery life and performance. The company's stock rose 5% following the announcement.
|
||||
"""
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = classify_article(lm, article)
|
||||
```
|
||||
|
||||
### Intent Classification
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def classify_intent(lm, message):
|
||||
"""Classify user intent."""
|
||||
lm += f"User Message: {message}\n\n"
|
||||
|
||||
# Intent
|
||||
lm += "Intent: " + select(
|
||||
["question", "complaint", "request", "feedback", "other"],
|
||||
name="intent"
|
||||
) + "\n"
|
||||
|
||||
# Urgency
|
||||
lm += "Urgency: " + select(
|
||||
["low", "medium", "high", "critical"],
|
||||
name="urgency"
|
||||
) + "\n"
|
||||
|
||||
# Department
|
||||
lm += "Route To: " + select(
|
||||
["support", "sales", "billing", "technical"],
|
||||
name="department"
|
||||
) + "\n"
|
||||
|
||||
# Sentiment
|
||||
lm += "Sentiment: " + select(
|
||||
["positive", "neutral", "negative"],
|
||||
name="sentiment"
|
||||
)
|
||||
|
||||
return lm
|
||||
|
||||
message = "My account was charged twice for the same order. Need help ASAP!"
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = classify_intent(lm, message)
|
||||
|
||||
print(f"Intent: {lm['intent']}")
|
||||
print(f"Urgency: {lm['urgency']}")
|
||||
print(f"Department: {lm['department']}")
|
||||
```
|
||||
|
||||
## Agent Systems
|
||||
|
||||
### ReAct Agent
|
||||
|
||||
```python
|
||||
from guidance import models, gen, select, guidance
|
||||
|
||||
@guidance(stateless=False)
|
||||
def react_agent(lm, question, tools, max_rounds=5):
|
||||
"""ReAct agent with tool use."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
for round in range(max_rounds):
|
||||
# Thought
|
||||
lm += f"Thought {round+1}: " + gen("thought", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Action selection
|
||||
lm += "Action: " + select(
|
||||
list(tools.keys()) + ["answer"],
|
||||
name="action"
|
||||
)
|
||||
|
||||
if lm["action"] == "answer":
|
||||
lm += "\n\nFinal Answer: " + gen("answer", max_tokens=200)
|
||||
break
|
||||
|
||||
# Action input
|
||||
lm += "\nAction Input: " + gen("action_input", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Execute tool
|
||||
if lm["action"] in tools:
|
||||
try:
|
||||
result = tools[lm["action"]](lm["action_input"])
|
||||
lm += f"Observation: {result}\n\n"
|
||||
except Exception as e:
|
||||
lm += f"Observation: Error - {str(e)}\n\n"
|
||||
|
||||
return lm
|
||||
|
||||
# Define tools
|
||||
tools = {
|
||||
"calculator": lambda expr: eval(expr),
|
||||
"search": lambda query: f"Search results for '{query}': [Mock results]",
|
||||
"weather": lambda city: f"Weather in {city}: Sunny, 72°F"
|
||||
}
|
||||
|
||||
# Use agent
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = react_agent(lm, "What is (25 * 4) + 10?", tools)
|
||||
|
||||
print(lm["answer"])
|
||||
```
|
||||
|
||||
### Multi-Agent System
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def coordinator_agent(lm, task):
|
||||
"""Coordinator that delegates to specialists."""
|
||||
lm += f"Task: {task}\n\n"
|
||||
|
||||
# Determine which specialist to use
|
||||
lm += "Specialist: " + select(
|
||||
["researcher", "writer", "coder", "analyst"],
|
||||
name="specialist"
|
||||
) + "\n"
|
||||
|
||||
lm += "Reasoning: " + gen("reasoning", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
@guidance
|
||||
def researcher_agent(lm, query):
|
||||
"""Research specialist."""
|
||||
lm += f"Research Query: {query}\n\n"
|
||||
lm += "Findings:\n"
|
||||
for i in range(3):
|
||||
lm += f"{i+1}. " + gen(f"finding_{i}", stop="\n", max_tokens=100) + "\n"
|
||||
return lm
|
||||
|
||||
@guidance
|
||||
def writer_agent(lm, topic):
|
||||
"""Writing specialist."""
|
||||
lm += f"Topic: {topic}\n\n"
|
||||
lm += "Title: " + gen("title", stop="\n", max_tokens=50) + "\n"
|
||||
lm += "Content:\n" + gen("content", max_tokens=500)
|
||||
return lm
|
||||
|
||||
# Coordination workflow
|
||||
task = "Write an article about AI safety"
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = coordinator_agent(lm, task)
|
||||
|
||||
specialist = lm["specialist"]
|
||||
if specialist == "researcher":
|
||||
lm = researcher_agent(lm, task)
|
||||
elif specialist == "writer":
|
||||
lm = writer_agent(lm, task)
|
||||
```
|
||||
|
||||
### Tool Use with Validation
|
||||
|
||||
```python
|
||||
@guidance(stateless=False)
|
||||
def validated_tool_agent(lm, question):
|
||||
"""Agent with validated tool calls."""
|
||||
tools = {
|
||||
"add": lambda a, b: float(a) + float(b),
|
||||
"multiply": lambda a, b: float(a) * float(b),
|
||||
"divide": lambda a, b: float(a) / float(b) if float(b) != 0 else "Error: Division by zero"
|
||||
}
|
||||
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
for i in range(5):
|
||||
# Select tool
|
||||
lm += "Tool: " + select(list(tools.keys()) + ["done"], name="tool")
|
||||
|
||||
if lm["tool"] == "done":
|
||||
lm += "\nAnswer: " + gen("answer", max_tokens=100)
|
||||
break
|
||||
|
||||
# Get validated numeric arguments
|
||||
lm += "\nArg1: " + gen("arg1", regex=r"-?[0-9]+(\.[0-9]+)?") + "\n"
|
||||
lm += "Arg2: " + gen("arg2", regex=r"-?[0-9]+(\.[0-9]+)?") + "\n"
|
||||
|
||||
# Execute
|
||||
result = tools[lm["tool"]](lm["arg1"], lm["arg2"])
|
||||
lm += f"Result: {result}\n\n"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = validated_tool_agent(lm, "What is (10 + 5) * 3?")
|
||||
```
|
||||
|
||||
## Multi-Step Workflows
|
||||
|
||||
### Chain of Thought
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def chain_of_thought(lm, question):
|
||||
"""Multi-step reasoning with CoT."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
# Generate reasoning steps
|
||||
lm += "Let me think step by step:\n\n"
|
||||
for i in range(4):
|
||||
lm += f"Step {i+1}: " + gen(f"step_{i+1}", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Final answer
|
||||
lm += "\nTherefore, the answer is: " + gen("answer", stop="\n", max_tokens=50)
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = chain_of_thought(lm, "If a train travels 60 mph for 2.5 hours, how far does it go?")
|
||||
|
||||
print(lm["answer"])
|
||||
```
|
||||
|
||||
### Self-Consistency
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def self_consistency(lm, question, num_samples=3):
|
||||
"""Generate multiple reasoning paths and aggregate."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
answers = []
|
||||
for i in range(num_samples):
|
||||
lm += f"=== Attempt {i+1} ===\n"
|
||||
lm += "Reasoning: " + gen(f"reasoning_{i}", stop="\n", max_tokens=100) + "\n"
|
||||
lm += "Answer: " + gen(f"answer_{i}", stop="\n", max_tokens=50) + "\n\n"
|
||||
answers.append(lm[f"answer_{i}"])
|
||||
|
||||
# Aggregate (simple majority vote)
|
||||
from collections import Counter
|
||||
most_common = Counter(answers).most_common(1)[0][0]
|
||||
|
||||
lm += f"Final Answer (by majority): {most_common}\n"
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = self_consistency(lm, "What is 15% of 200?")
|
||||
```
|
||||
|
||||
### Planning and Execution
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def plan_and_execute(lm, goal):
|
||||
"""Plan tasks then execute them."""
|
||||
lm += f"Goal: {goal}\n\n"
|
||||
|
||||
# Planning phase
|
||||
lm += "Plan:\n"
|
||||
num_steps = 4
|
||||
for i in range(num_steps):
|
||||
lm += f"{i+1}. " + gen(f"plan_step_{i}", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Execution phase
|
||||
lm += "\nExecution:\n\n"
|
||||
for i in range(num_steps):
|
||||
lm += f"Step {i+1}: {lm[f'plan_step_{i}']}\n"
|
||||
lm += "Status: " + select(["completed", "in-progress", "blocked"], name=f"status_{i}") + "\n"
|
||||
lm += "Result: " + gen(f"result_{i}", stop="\n", max_tokens=150) + "\n\n"
|
||||
|
||||
# Summary
|
||||
lm += "Summary: " + gen("summary", max_tokens=200)
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = plan_and_execute(lm, "Build a REST API for a blog platform")
|
||||
```
|
||||
|
||||
## Code Generation
|
||||
|
||||
### Python Function
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_python_function(lm, description):
|
||||
"""Generate Python function from description."""
|
||||
lm += f"Description: {description}\n\n"
|
||||
|
||||
# Function signature
|
||||
lm += "def " + gen("func_name", regex=r"[a-z_][a-z0-9_]*") + "("
|
||||
lm += gen("params", regex=r"[a-z_][a-z0-9_]*(, [a-z_][a-z0-9_]*)*") + "):\n"
|
||||
|
||||
# Docstring
|
||||
lm += ' """' + gen("docstring", stop='"""', max_tokens=100) + '"""\n'
|
||||
|
||||
# Function body
|
||||
lm += " " + gen("body", stop="\n", max_tokens=200) + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_python_function(lm, "Check if a number is prime")
|
||||
|
||||
print(lm)
|
||||
```
|
||||
|
||||
### SQL Query
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_sql(lm, description):
|
||||
"""Generate SQL query from description."""
|
||||
lm += f"Description: {description}\n\n"
|
||||
lm += "SQL Query:\n"
|
||||
|
||||
# SELECT clause
|
||||
lm += "SELECT " + gen("select_clause", stop=" FROM", max_tokens=100)
|
||||
|
||||
# FROM clause
|
||||
lm += " FROM " + gen("from_clause", stop=" WHERE", max_tokens=50)
|
||||
|
||||
# WHERE clause (optional)
|
||||
lm += " WHERE " + gen("where_clause", stop=";", max_tokens=100) + ";"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_sql(lm, "Get all users who signed up in the last 30 days")
|
||||
```
|
||||
|
||||
### API Endpoint
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_api_endpoint(lm, description):
|
||||
"""Generate REST API endpoint."""
|
||||
lm += f"Description: {description}\n\n"
|
||||
|
||||
# HTTP method
|
||||
lm += "Method: " + select(["GET", "POST", "PUT", "DELETE"], name="method") + "\n"
|
||||
|
||||
# Path
|
||||
lm += "Path: /" + gen("path", regex=r"[a-z0-9/-]+", stop="\n") + "\n"
|
||||
|
||||
# Request body (if POST/PUT)
|
||||
if lm["method"] in ["POST", "PUT"]:
|
||||
lm += "\nRequest Body:\n"
|
||||
lm += "{\n"
|
||||
lm += ' "field1": ' + gen("field1", regex=r'"[a-z_]+"') + ",\n"
|
||||
lm += ' "field2": ' + gen("field2", regex=r'"[a-z_]+"') + "\n"
|
||||
lm += "}\n"
|
||||
|
||||
# Response
|
||||
lm += "\nResponse (200 OK):\n"
|
||||
lm += "{\n"
|
||||
lm += ' "status": "success",\n'
|
||||
lm += ' "data": ' + gen("response_data", max_tokens=100) + "\n"
|
||||
lm += "}\n"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_api_endpoint(lm, "Create a new blog post")
|
||||
```
|
||||
|
||||
## Production Tips
|
||||
|
||||
### Error Handling
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def safe_extraction(lm, text):
|
||||
"""Extract with fallback handling."""
|
||||
try:
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n", max_tokens=30)
|
||||
return lm
|
||||
except Exception as e:
|
||||
# Fallback to less strict extraction
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Name: " + gen("name", stop="\n", max_tokens=30)
|
||||
return lm
|
||||
```
|
||||
|
||||
### Caching
|
||||
|
||||
```python
|
||||
from functools import lru_cache
|
||||
|
||||
@lru_cache(maxsize=100)
|
||||
def cached_generation(text):
|
||||
"""Cache LLM generations."""
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm += f"Analyze: {text}\n"
|
||||
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
|
||||
return lm["sentiment"]
|
||||
|
||||
# First call: hits LLM
|
||||
result1 = cached_generation("This is great!")
|
||||
|
||||
# Second call: returns cached result
|
||||
result2 = cached_generation("This is great!") # Instant!
|
||||
```
|
||||
|
||||
### Monitoring
|
||||
|
||||
```python
|
||||
import time
|
||||
|
||||
@guidance
|
||||
def monitored_generation(lm, text):
|
||||
"""Track generation metrics."""
|
||||
start_time = time.time()
|
||||
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Analysis: " + gen("analysis", max_tokens=100)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Log metrics
|
||||
print(f"Generation time: {elapsed:.2f}s")
|
||||
print(f"Output length: {len(lm['analysis'])} chars")
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Batch Processing
|
||||
|
||||
```python
|
||||
def batch_process(texts, batch_size=10):
|
||||
"""Process texts in batches."""
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
results = []
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i+batch_size]
|
||||
|
||||
for text in batch:
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Sentiment: " + select(
|
||||
["positive", "negative", "neutral"],
|
||||
name=f"sentiment_{i}"
|
||||
) + "\n\n"
|
||||
|
||||
results.extend([lm[f"sentiment_{i}"] for i in range(len(batch))])
|
||||
|
||||
return results
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Guidance Notebooks**: https://github.com/guidance-ai/guidance/tree/main/notebooks
|
||||
- **Guidance Docs**: https://guidance.readthedocs.io
|
||||
- **Community Examples**: https://github.com/guidance-ai/guidance/discussions
|
||||
516
skills/mlops-knowledge/huggingface-tokenizers/SKILL.md
Normal file
516
skills/mlops-knowledge/huggingface-tokenizers/SKILL.md
Normal file
@@ -0,0 +1,516 @@
|
||||
---
|
||||
name: huggingface-tokenizers
|
||||
description: Fast tokenizers optimized for research and production. Rust-based implementation tokenizes 1GB in <20 seconds. Supports BPE, WordPiece, and Unigram algorithms. Train custom vocabularies, track alignments, handle padding/truncation. Integrates seamlessly with transformers. Use when you need high-performance tokenization or custom tokenizer training.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Tokenization, HuggingFace, BPE, WordPiece, Unigram, Fast Tokenization, Rust, Custom Tokenizer, Alignment Tracking, Production]
|
||||
dependencies: [tokenizers, transformers, datasets]
|
||||
---
|
||||
|
||||
# HuggingFace Tokenizers - Fast Tokenization for NLP
|
||||
|
||||
Fast, production-ready tokenizers with Rust performance and Python ease-of-use.
|
||||
|
||||
## When to use HuggingFace Tokenizers
|
||||
|
||||
**Use HuggingFace Tokenizers when:**
|
||||
- Need extremely fast tokenization (<20s per GB of text)
|
||||
- Training custom tokenizers from scratch
|
||||
- Want alignment tracking (token → original text position)
|
||||
- Building production NLP pipelines
|
||||
- Need to tokenize large corpora efficiently
|
||||
|
||||
**Performance**:
|
||||
- **Speed**: <20 seconds to tokenize 1GB on CPU
|
||||
- **Implementation**: Rust core with Python/Node.js bindings
|
||||
- **Efficiency**: 10-100× faster than pure Python implementations
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **SentencePiece**: Language-independent, used by T5/ALBERT
|
||||
- **tiktoken**: OpenAI's BPE tokenizer for GPT models
|
||||
- **transformers AutoTokenizer**: Loading pretrained only (uses this library internally)
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Install tokenizers
|
||||
pip install tokenizers
|
||||
|
||||
# With transformers integration
|
||||
pip install tokenizers transformers
|
||||
```
|
||||
|
||||
### Load pretrained tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
# Load from HuggingFace Hub
|
||||
tokenizer = Tokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Encode text
|
||||
output = tokenizer.encode("Hello, how are you?")
|
||||
print(output.tokens) # ['hello', ',', 'how', 'are', 'you', '?']
|
||||
print(output.ids) # [7592, 1010, 2129, 2024, 2017, 1029]
|
||||
|
||||
# Decode back
|
||||
text = tokenizer.decode(output.ids)
|
||||
print(text) # "hello, how are you?"
|
||||
```
|
||||
|
||||
### Train custom BPE tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import Whitespace
|
||||
|
||||
# Initialize tokenizer with BPE model
|
||||
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
|
||||
tokenizer.pre_tokenizer = Whitespace()
|
||||
|
||||
# Configure trainer
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=30000,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
min_frequency=2
|
||||
)
|
||||
|
||||
# Train on files
|
||||
files = ["train.txt", "validation.txt"]
|
||||
tokenizer.train(files, trainer)
|
||||
|
||||
# Save
|
||||
tokenizer.save("my-tokenizer.json")
|
||||
```
|
||||
|
||||
**Training time**: ~1-2 minutes for 100MB corpus, ~10-20 minutes for 1GB
|
||||
|
||||
### Batch encoding with padding
|
||||
|
||||
```python
|
||||
# Enable padding
|
||||
tokenizer.enable_padding(pad_id=3, pad_token="[PAD]")
|
||||
|
||||
# Encode batch
|
||||
texts = ["Hello world", "This is a longer sentence"]
|
||||
encodings = tokenizer.encode_batch(texts)
|
||||
|
||||
for encoding in encodings:
|
||||
print(encoding.ids)
|
||||
# [101, 7592, 2088, 102, 3, 3, 3]
|
||||
# [101, 2023, 2003, 1037, 2936, 6251, 102]
|
||||
```
|
||||
|
||||
## Tokenization algorithms
|
||||
|
||||
### BPE (Byte-Pair Encoding)
|
||||
|
||||
**How it works**:
|
||||
1. Start with character-level vocabulary
|
||||
2. Find most frequent character pair
|
||||
3. Merge into new token, add to vocabulary
|
||||
4. Repeat until vocabulary size reached
|
||||
|
||||
**Used by**: GPT-2, GPT-3, RoBERTa, BART, DeBERTa
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
|
||||
tokenizer = Tokenizer(BPE(unk_token="<|endoftext|>"))
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50257,
|
||||
special_tokens=["<|endoftext|>"],
|
||||
min_frequency=2
|
||||
)
|
||||
|
||||
tokenizer.train(files=["data.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- Handles OOV words well (breaks into subwords)
|
||||
- Flexible vocabulary size
|
||||
- Good for morphologically rich languages
|
||||
|
||||
**Trade-offs**:
|
||||
- Tokenization depends on merge order
|
||||
- May split common words unexpectedly
|
||||
|
||||
### WordPiece
|
||||
|
||||
**How it works**:
|
||||
1. Start with character vocabulary
|
||||
2. Score merge pairs: `frequency(pair) / (frequency(first) × frequency(second))`
|
||||
3. Merge highest scoring pair
|
||||
4. Repeat until vocabulary size reached
|
||||
|
||||
**Used by**: BERT, DistilBERT, MobileBERT
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import WordPiece
|
||||
from tokenizers.trainers import WordPieceTrainer
|
||||
from tokenizers.pre_tokenizers import Whitespace
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
|
||||
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
|
||||
tokenizer.normalizer = BertNormalizer(lowercase=True)
|
||||
tokenizer.pre_tokenizer = Whitespace()
|
||||
|
||||
trainer = WordPieceTrainer(
|
||||
vocab_size=30522,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
continuing_subword_prefix="##"
|
||||
)
|
||||
|
||||
tokenizer.train(files=["corpus.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- Prioritizes meaningful merges (high score = semantically related)
|
||||
- Used successfully in BERT (state-of-the-art results)
|
||||
|
||||
**Trade-offs**:
|
||||
- Unknown words become `[UNK]` if no subword match
|
||||
- Saves vocabulary, not merge rules (larger files)
|
||||
|
||||
### Unigram
|
||||
|
||||
**How it works**:
|
||||
1. Start with large vocabulary (all substrings)
|
||||
2. Compute loss for corpus with current vocabulary
|
||||
3. Remove tokens with minimal impact on loss
|
||||
4. Repeat until vocabulary size reached
|
||||
|
||||
**Used by**: ALBERT, T5, mBART, XLNet (via SentencePiece)
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import Unigram
|
||||
from tokenizers.trainers import UnigramTrainer
|
||||
|
||||
tokenizer = Tokenizer(Unigram())
|
||||
|
||||
trainer = UnigramTrainer(
|
||||
vocab_size=8000,
|
||||
special_tokens=["<unk>", "<s>", "</s>"],
|
||||
unk_token="<unk>"
|
||||
)
|
||||
|
||||
tokenizer.train(files=["data.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- Probabilistic (finds most likely tokenization)
|
||||
- Works well for languages without word boundaries
|
||||
- Handles diverse linguistic contexts
|
||||
|
||||
**Trade-offs**:
|
||||
- Computationally expensive to train
|
||||
- More hyperparameters to tune
|
||||
|
||||
## Tokenization pipeline
|
||||
|
||||
Complete pipeline: **Normalization → Pre-tokenization → Model → Post-processing**
|
||||
|
||||
### Normalization
|
||||
|
||||
Clean and standardize text:
|
||||
|
||||
```python
|
||||
from tokenizers.normalizers import NFD, StripAccents, Lowercase, Sequence
|
||||
|
||||
tokenizer.normalizer = Sequence([
|
||||
NFD(), # Unicode normalization (decompose)
|
||||
Lowercase(), # Convert to lowercase
|
||||
StripAccents() # Remove accents
|
||||
])
|
||||
|
||||
# Input: "Héllo WORLD"
|
||||
# After normalization: "hello world"
|
||||
```
|
||||
|
||||
**Common normalizers**:
|
||||
- `NFD`, `NFC`, `NFKD`, `NFKC` - Unicode normalization forms
|
||||
- `Lowercase()` - Convert to lowercase
|
||||
- `StripAccents()` - Remove accents (é → e)
|
||||
- `Strip()` - Remove whitespace
|
||||
- `Replace(pattern, content)` - Regex replacement
|
||||
|
||||
### Pre-tokenization
|
||||
|
||||
Split text into word-like units:
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Sequence, ByteLevel
|
||||
|
||||
# Split on whitespace and punctuation
|
||||
tokenizer.pre_tokenizer = Sequence([
|
||||
Whitespace(),
|
||||
Punctuation()
|
||||
])
|
||||
|
||||
# Input: "Hello, world!"
|
||||
# After pre-tokenization: ["Hello", ",", "world", "!"]
|
||||
```
|
||||
|
||||
**Common pre-tokenizers**:
|
||||
- `Whitespace()` - Split on spaces, tabs, newlines
|
||||
- `ByteLevel()` - GPT-2 style byte-level splitting
|
||||
- `Punctuation()` - Isolate punctuation
|
||||
- `Digits(individual_digits=True)` - Split digits individually
|
||||
- `Metaspace()` - Replace spaces with ▁ (SentencePiece style)
|
||||
|
||||
### Post-processing
|
||||
|
||||
Add special tokens for model input:
|
||||
|
||||
```python
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
|
||||
# BERT-style: [CLS] sentence [SEP]
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
pair="[CLS] $A [SEP] $B [SEP]",
|
||||
special_tokens=[
|
||||
("[CLS]", 1),
|
||||
("[SEP]", 2),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
**Common patterns**:
|
||||
```python
|
||||
# GPT-2: sentence <|endoftext|>
|
||||
TemplateProcessing(
|
||||
single="$A <|endoftext|>",
|
||||
special_tokens=[("<|endoftext|>", 50256)]
|
||||
)
|
||||
|
||||
# RoBERTa: <s> sentence </s>
|
||||
TemplateProcessing(
|
||||
single="<s> $A </s>",
|
||||
pair="<s> $A </s> </s> $B </s>",
|
||||
special_tokens=[("<s>", 0), ("</s>", 2)]
|
||||
)
|
||||
```
|
||||
|
||||
## Alignment tracking
|
||||
|
||||
Track token positions in original text:
|
||||
|
||||
```python
|
||||
output = tokenizer.encode("Hello, world!")
|
||||
|
||||
# Get token offsets
|
||||
for token, offset in zip(output.tokens, output.offsets):
|
||||
start, end = offset
|
||||
print(f"{token:10} → [{start:2}, {end:2}): {text[start:end]!r}")
|
||||
|
||||
# Output:
|
||||
# hello → [ 0, 5): 'Hello'
|
||||
# , → [ 5, 6): ','
|
||||
# world → [ 7, 12): 'world'
|
||||
# ! → [12, 13): '!'
|
||||
```
|
||||
|
||||
**Use cases**:
|
||||
- Named entity recognition (map predictions back to text)
|
||||
- Question answering (extract answer spans)
|
||||
- Token classification (align labels to original positions)
|
||||
|
||||
## Integration with transformers
|
||||
|
||||
### Load with AutoTokenizer
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# AutoTokenizer automatically uses fast tokenizers
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Check if using fast tokenizer
|
||||
print(tokenizer.is_fast) # True
|
||||
|
||||
# Access underlying tokenizers.Tokenizer
|
||||
fast_tokenizer = tokenizer.backend_tokenizer
|
||||
print(type(fast_tokenizer)) # <class 'tokenizers.Tokenizer'>
|
||||
```
|
||||
|
||||
### Convert custom tokenizer to transformers
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
# Train custom tokenizer
|
||||
tokenizer = Tokenizer(BPE())
|
||||
# ... train tokenizer ...
|
||||
tokenizer.save("my-tokenizer.json")
|
||||
|
||||
# Wrap for transformers
|
||||
transformers_tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_file="my-tokenizer.json",
|
||||
unk_token="[UNK]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
sep_token="[SEP]",
|
||||
mask_token="[MASK]"
|
||||
)
|
||||
|
||||
# Use like any transformers tokenizer
|
||||
outputs = transformers_tokenizer(
|
||||
"Hello world",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
return_tensors="pt"
|
||||
)
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Train from iterator (large datasets)
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
|
||||
|
||||
# Create batch iterator
|
||||
def batch_iterator(batch_size=1000):
|
||||
for i in range(0, len(dataset), batch_size):
|
||||
yield dataset[i:i + batch_size]["text"]
|
||||
|
||||
# Train tokenizer
|
||||
tokenizer.train_from_iterator(
|
||||
batch_iterator(),
|
||||
trainer=trainer,
|
||||
length=len(dataset) # For progress bar
|
||||
)
|
||||
```
|
||||
|
||||
**Performance**: Processes 1GB in ~10-20 minutes
|
||||
|
||||
### Enable truncation and padding
|
||||
|
||||
```python
|
||||
# Enable truncation
|
||||
tokenizer.enable_truncation(max_length=512)
|
||||
|
||||
# Enable padding
|
||||
tokenizer.enable_padding(
|
||||
pad_id=tokenizer.token_to_id("[PAD]"),
|
||||
pad_token="[PAD]",
|
||||
length=512 # Fixed length, or None for batch max
|
||||
)
|
||||
|
||||
# Encode with both
|
||||
output = tokenizer.encode("This is a long sentence that will be truncated...")
|
||||
print(len(output.ids)) # 512
|
||||
```
|
||||
|
||||
### Multi-processing
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from multiprocessing import Pool
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = Tokenizer.from_file("tokenizer.json")
|
||||
|
||||
def encode_batch(texts):
|
||||
return tokenizer.encode_batch(texts)
|
||||
|
||||
# Process large corpus in parallel
|
||||
with Pool(8) as pool:
|
||||
# Split corpus into chunks
|
||||
chunk_size = 1000
|
||||
chunks = [corpus[i:i+chunk_size] for i in range(0, len(corpus), chunk_size)]
|
||||
|
||||
# Encode in parallel
|
||||
results = pool.map(encode_batch, chunks)
|
||||
```
|
||||
|
||||
**Speedup**: 5-8× with 8 cores
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
### Training speed
|
||||
|
||||
| Corpus Size | BPE (30k vocab) | WordPiece (30k) | Unigram (8k) |
|
||||
|-------------|-----------------|-----------------|--------------|
|
||||
| 10 MB | 15 sec | 18 sec | 25 sec |
|
||||
| 100 MB | 1.5 min | 2 min | 4 min |
|
||||
| 1 GB | 15 min | 20 min | 40 min |
|
||||
|
||||
**Hardware**: 16-core CPU, tested on English Wikipedia
|
||||
|
||||
### Tokenization speed
|
||||
|
||||
| Implementation | 1 GB corpus | Throughput |
|
||||
|----------------|-------------|---------------|
|
||||
| Pure Python | ~20 minutes | ~50 MB/min |
|
||||
| HF Tokenizers | ~15 seconds | ~4 GB/min |
|
||||
| **Speedup** | **80×** | **80×** |
|
||||
|
||||
**Test**: English text, average sentence length 20 words
|
||||
|
||||
### Memory usage
|
||||
|
||||
| Task | Memory |
|
||||
|-------------------------|---------|
|
||||
| Load tokenizer | ~10 MB |
|
||||
| Train BPE (30k vocab) | ~200 MB |
|
||||
| Encode 1M sentences | ~500 MB |
|
||||
|
||||
## Supported models
|
||||
|
||||
Pre-trained tokenizers available via `from_pretrained()`:
|
||||
|
||||
**BERT family**:
|
||||
- `bert-base-uncased`, `bert-large-cased`
|
||||
- `distilbert-base-uncased`
|
||||
- `roberta-base`, `roberta-large`
|
||||
|
||||
**GPT family**:
|
||||
- `gpt2`, `gpt2-medium`, `gpt2-large`
|
||||
- `distilgpt2`
|
||||
|
||||
**T5 family**:
|
||||
- `t5-small`, `t5-base`, `t5-large`
|
||||
- `google/flan-t5-xxl`
|
||||
|
||||
**Other**:
|
||||
- `facebook/bart-base`, `facebook/mbart-large-cc25`
|
||||
- `albert-base-v2`, `albert-xlarge-v2`
|
||||
- `xlm-roberta-base`, `xlm-roberta-large`
|
||||
|
||||
Browse all: https://huggingface.co/models?library=tokenizers
|
||||
|
||||
## References
|
||||
|
||||
- **[Training Guide](references/training.md)** - Train custom tokenizers, configure trainers, handle large datasets
|
||||
- **[Algorithms Deep Dive](references/algorithms.md)** - BPE, WordPiece, Unigram explained in detail
|
||||
- **[Pipeline Components](references/pipeline.md)** - Normalizers, pre-tokenizers, post-processors, decoders
|
||||
- **[Transformers Integration](references/integration.md)** - AutoTokenizer, PreTrainedTokenizerFast, special tokens
|
||||
|
||||
## Resources
|
||||
|
||||
- **Docs**: https://huggingface.co/docs/tokenizers
|
||||
- **GitHub**: https://github.com/huggingface/tokenizers ⭐ 9,000+
|
||||
- **Version**: 0.20.0+
|
||||
- **Course**: https://huggingface.co/learn/nlp-course/chapter6/1
|
||||
- **Paper**: BPE (Sennrich et al., 2016), WordPiece (Schuster & Nakajima, 2012)
|
||||
|
||||
|
||||
@@ -0,0 +1,653 @@
|
||||
# Tokenization Algorithms Deep Dive
|
||||
|
||||
Comprehensive explanation of BPE, WordPiece, and Unigram algorithms.
|
||||
|
||||
## Byte-Pair Encoding (BPE)
|
||||
|
||||
### Algorithm overview
|
||||
|
||||
BPE iteratively merges the most frequent pair of tokens in a corpus.
|
||||
|
||||
**Training process**:
|
||||
1. Initialize vocabulary with all characters
|
||||
2. Count frequency of all adjacent token pairs
|
||||
3. Merge most frequent pair into new token
|
||||
4. Add new token to vocabulary
|
||||
5. Update corpus with new token
|
||||
6. Repeat until vocabulary size reached
|
||||
|
||||
### Step-by-step example
|
||||
|
||||
**Corpus**:
|
||||
```
|
||||
low: 5
|
||||
lower: 2
|
||||
newest: 6
|
||||
widest: 3
|
||||
```
|
||||
|
||||
**Iteration 1**:
|
||||
```
|
||||
Count pairs:
|
||||
'e' + 's': 9 (newest: 6, widest: 3) ← most frequent
|
||||
'l' + 'o': 7
|
||||
'o' + 'w': 7
|
||||
...
|
||||
|
||||
Merge: 'e' + 's' → 'es'
|
||||
|
||||
Updated corpus:
|
||||
low: 5
|
||||
lower: 2
|
||||
newest: 6 → newes|t: 6
|
||||
widest: 3 → wides|t: 3
|
||||
|
||||
Vocabulary: [a-z] + ['es']
|
||||
```
|
||||
|
||||
**Iteration 2**:
|
||||
```
|
||||
Count pairs:
|
||||
'es' + 't': 9 ← most frequent
|
||||
'l' + 'o': 7
|
||||
...
|
||||
|
||||
Merge: 'es' + 't' → 'est'
|
||||
|
||||
Updated corpus:
|
||||
low: 5
|
||||
lower: 2
|
||||
newest: 6 → new|est: 6
|
||||
widest: 3 → wid|est: 3
|
||||
|
||||
Vocabulary: [a-z] + ['es', 'est']
|
||||
```
|
||||
|
||||
**Continue until desired vocabulary size...**
|
||||
|
||||
### Tokenization with trained BPE
|
||||
|
||||
Given vocabulary: `['l', 'o', 'w', 'e', 'r', 'n', 's', 't', 'i', 'd', 'es', 'est', 'lo', 'low', 'ne', 'new', 'newest', 'wi', 'wid', 'widest']`
|
||||
|
||||
Tokenize "lowest":
|
||||
```
|
||||
Step 1: Split into characters
|
||||
['l', 'o', 'w', 'e', 's', 't']
|
||||
|
||||
Step 2: Apply merges in order learned during training
|
||||
- Merge 'l' + 'o' → 'lo' (if this merge was learned)
|
||||
- Merge 'lo' + 'w' → 'low' (if learned)
|
||||
- Merge 'e' + 's' → 'es' (learned)
|
||||
- Merge 'es' + 't' → 'est' (learned)
|
||||
|
||||
Final: ['low', 'est']
|
||||
```
|
||||
|
||||
### Implementation
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import Whitespace
|
||||
|
||||
# Initialize
|
||||
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
|
||||
tokenizer.pre_tokenizer = Whitespace()
|
||||
|
||||
# Configure trainer
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=1000,
|
||||
min_frequency=2,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
||||
)
|
||||
|
||||
# Train
|
||||
corpus = [
|
||||
"This is a sample corpus for BPE training.",
|
||||
"BPE learns subword units from the training data.",
|
||||
# ... more sentences
|
||||
]
|
||||
|
||||
tokenizer.train_from_iterator(corpus, trainer=trainer)
|
||||
|
||||
# Use
|
||||
output = tokenizer.encode("This is tokenization")
|
||||
print(output.tokens) # ['This', 'is', 'token', 'ization']
|
||||
```
|
||||
|
||||
### Byte-level BPE (GPT-2 variant)
|
||||
|
||||
**Problem**: Standard BPE has limited character coverage (256+ Unicode chars)
|
||||
|
||||
**Solution**: Operate on byte level (256 bytes)
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
|
||||
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Byte-level pre-tokenization
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
tokenizer.decoder = ByteLevelDecoder()
|
||||
|
||||
# This handles ALL possible characters, including emojis
|
||||
text = "Hello 🌍 世界"
|
||||
tokens = tokenizer.encode(text).tokens
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- Handles any Unicode character (256 byte coverage)
|
||||
- No unknown tokens (worst case: bytes)
|
||||
- Used by GPT-2, GPT-3, BART
|
||||
|
||||
**Trade-offs**:
|
||||
- Slightly worse compression (bytes vs characters)
|
||||
- More tokens for non-ASCII text
|
||||
|
||||
### BPE variants
|
||||
|
||||
**SentencePiece BPE**:
|
||||
- Language-independent (no pre-tokenization)
|
||||
- Treats input as raw byte stream
|
||||
- Used by T5, ALBERT, XLNet
|
||||
|
||||
**Robust BPE**:
|
||||
- Dropout during training (randomly skip merges)
|
||||
- More robust tokenization at inference
|
||||
- Reduces overfitting to training data
|
||||
|
||||
## WordPiece
|
||||
|
||||
### Algorithm overview
|
||||
|
||||
WordPiece is similar to BPE but uses a different merge selection criterion.
|
||||
|
||||
**Training process**:
|
||||
1. Initialize vocabulary with all characters
|
||||
2. Count frequency of all token pairs
|
||||
3. Score each pair: `score = freq(pair) / (freq(first) × freq(second))`
|
||||
4. Merge pair with highest score
|
||||
5. Repeat until vocabulary size reached
|
||||
|
||||
### Why different scoring?
|
||||
|
||||
**BPE**: Merges most frequent pairs
|
||||
- "aa" appears 100 times → high priority
|
||||
- Even if 'a' appears 1000 times alone
|
||||
|
||||
**WordPiece**: Merges pairs that are semantically related
|
||||
- "aa" appears 100 times, 'a' appears 1000 times → low score (100 / (1000 × 1000))
|
||||
- "th" appears 50 times, 't' appears 60 times, 'h' appears 55 times → high score (50 / (60 × 55))
|
||||
- Prioritizes pairs that appear together more than expected
|
||||
|
||||
### Step-by-step example
|
||||
|
||||
**Corpus**:
|
||||
```
|
||||
low: 5
|
||||
lower: 2
|
||||
newest: 6
|
||||
widest: 3
|
||||
```
|
||||
|
||||
**Iteration 1**:
|
||||
```
|
||||
Count frequencies:
|
||||
'e': 11 (lower: 2, newest: 6, widest: 3)
|
||||
's': 9
|
||||
't': 9
|
||||
...
|
||||
|
||||
Count pairs:
|
||||
'e' + 's': 9 (newest: 6, widest: 3)
|
||||
'es' + 't': 9 (newest: 6, widest: 3)
|
||||
...
|
||||
|
||||
Compute scores:
|
||||
score('e' + 's') = 9 / (11 × 9) = 0.091
|
||||
score('es' + 't') = 9 / (9 × 9) = 0.111 ← highest score
|
||||
score('l' + 'o') = 7 / (7 × 9) = 0.111 ← tied
|
||||
|
||||
Choose: 'es' + 't' → 'est' (or 'lo' if tied)
|
||||
```
|
||||
|
||||
**Key difference**: WordPiece prioritizes rare combinations over frequent ones.
|
||||
|
||||
### Tokenization with WordPiece
|
||||
|
||||
Given vocabulary: `['##e', '##s', '##t', 'l', 'o', 'w', 'new', 'est', 'low']`
|
||||
|
||||
Tokenize "lowest":
|
||||
```
|
||||
Step 1: Find longest matching prefix
|
||||
'lowest' → 'low' (matches)
|
||||
|
||||
Step 2: Find longest match for remainder
|
||||
'est' → 'est' (matches)
|
||||
|
||||
Final: ['low', 'est']
|
||||
```
|
||||
|
||||
**If no match**:
|
||||
```
|
||||
Tokenize "unknownword":
|
||||
'unknownword' → no match
|
||||
'unknown' → no match
|
||||
'unkn' → no match
|
||||
'un' → no match
|
||||
'u' → no match
|
||||
→ [UNK]
|
||||
```
|
||||
|
||||
### Implementation
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import WordPiece
|
||||
from tokenizers.trainers import WordPieceTrainer
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
from tokenizers.pre_tokenizers import BertPreTokenizer
|
||||
|
||||
# Initialize BERT-style tokenizer
|
||||
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
|
||||
|
||||
# Normalization (lowercase, accent stripping)
|
||||
tokenizer.normalizer = BertNormalizer(lowercase=True)
|
||||
|
||||
# Pre-tokenization (whitespace + punctuation)
|
||||
tokenizer.pre_tokenizer = BertPreTokenizer()
|
||||
|
||||
# Configure trainer
|
||||
trainer = WordPieceTrainer(
|
||||
vocab_size=30522, # BERT vocab size
|
||||
min_frequency=2,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
continuing_subword_prefix="##" # BERT uses ##
|
||||
)
|
||||
|
||||
# Train
|
||||
tokenizer.train_from_iterator(corpus, trainer=trainer)
|
||||
|
||||
# Use
|
||||
output = tokenizer.encode("Tokenization works great!")
|
||||
print(output.tokens) # ['token', '##ization', 'works', 'great', '!']
|
||||
```
|
||||
|
||||
### Subword prefix
|
||||
|
||||
**BERT uses `##` prefix**:
|
||||
```
|
||||
"unbelievable" → ['un', '##believ', '##able']
|
||||
```
|
||||
|
||||
**Why?**
|
||||
- Indicates token is a continuation
|
||||
- Allows reconstruction: remove ##, concatenate
|
||||
- Helps model distinguish word boundaries
|
||||
|
||||
### WordPiece advantages
|
||||
|
||||
**Semantic merges**:
|
||||
- Prioritizes meaningful combinations
|
||||
- "qu" has high score (always together)
|
||||
- "qx" has low score (rare combination)
|
||||
|
||||
**Better for morphology**:
|
||||
- Captures affixes: un-, -ing, -ed
|
||||
- Preserves word stems
|
||||
|
||||
**Trade-offs**:
|
||||
- Slower training than BPE
|
||||
- More memory (stores vocabulary, not merges)
|
||||
- Original implementation not open-source (HF reimplementation)
|
||||
|
||||
## Unigram
|
||||
|
||||
### Algorithm overview
|
||||
|
||||
Unigram works backward: start with large vocabulary, remove tokens.
|
||||
|
||||
**Training process**:
|
||||
1. Initialize with large vocabulary (all substrings)
|
||||
2. Estimate probability of each token (frequency-based)
|
||||
3. For each token, compute loss increase if removed
|
||||
4. Remove 10-20% of tokens with lowest loss impact
|
||||
5. Re-estimate probabilities
|
||||
6. Repeat until desired vocabulary size
|
||||
|
||||
### Probabilistic tokenization
|
||||
|
||||
**Unigram assumption**: Each token is independent.
|
||||
|
||||
Given vocabulary with probabilities:
|
||||
```
|
||||
P('low') = 0.02
|
||||
P('l') = 0.01
|
||||
P('o') = 0.015
|
||||
P('w') = 0.01
|
||||
P('est') = 0.03
|
||||
P('e') = 0.02
|
||||
P('s') = 0.015
|
||||
P('t') = 0.015
|
||||
```
|
||||
|
||||
Tokenize "lowest":
|
||||
```
|
||||
Option 1: ['low', 'est']
|
||||
P = P('low') × P('est') = 0.02 × 0.03 = 0.0006
|
||||
|
||||
Option 2: ['l', 'o', 'w', 'est']
|
||||
P = 0.01 × 0.015 × 0.01 × 0.03 = 0.000000045
|
||||
|
||||
Option 3: ['low', 'e', 's', 't']
|
||||
P = 0.02 × 0.02 × 0.015 × 0.015 = 0.0000009
|
||||
|
||||
Choose option 1 (highest probability)
|
||||
```
|
||||
|
||||
### Viterbi algorithm
|
||||
|
||||
Finding best tokenization is expensive (exponential possibilities).
|
||||
|
||||
**Viterbi algorithm** (dynamic programming):
|
||||
```python
|
||||
def tokenize_viterbi(word, vocab, probs):
|
||||
n = len(word)
|
||||
# dp[i] = (best_prob, best_tokens) for word[:i]
|
||||
dp = [{} for _ in range(n + 1)]
|
||||
dp[0] = (0.0, []) # log probability
|
||||
|
||||
for i in range(1, n + 1):
|
||||
best_prob = float('-inf')
|
||||
best_tokens = []
|
||||
|
||||
# Try all possible last tokens
|
||||
for j in range(i):
|
||||
token = word[j:i]
|
||||
if token in vocab:
|
||||
prob = dp[j][0] + log(probs[token])
|
||||
if prob > best_prob:
|
||||
best_prob = prob
|
||||
best_tokens = dp[j][1] + [token]
|
||||
|
||||
dp[i] = (best_prob, best_tokens)
|
||||
|
||||
return dp[n][1]
|
||||
```
|
||||
|
||||
**Time complexity**: O(n² × vocab_size) vs O(2^n) brute force
|
||||
|
||||
### Implementation
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import Unigram
|
||||
from tokenizers.trainers import UnigramTrainer
|
||||
|
||||
# Initialize
|
||||
tokenizer = Tokenizer(Unigram())
|
||||
|
||||
# Configure trainer
|
||||
trainer = UnigramTrainer(
|
||||
vocab_size=8000,
|
||||
special_tokens=["<unk>", "<s>", "</s>"],
|
||||
unk_token="<unk>",
|
||||
max_piece_length=16, # Max token length
|
||||
n_sub_iterations=2, # EM iterations
|
||||
shrinking_factor=0.75 # Remove 25% each iteration
|
||||
)
|
||||
|
||||
# Train
|
||||
tokenizer.train_from_iterator(corpus, trainer=trainer)
|
||||
|
||||
# Use
|
||||
output = tokenizer.encode("Tokenization with Unigram")
|
||||
print(output.tokens) # ['▁Token', 'ization', '▁with', '▁Un', 'igram']
|
||||
```
|
||||
|
||||
### Unigram advantages
|
||||
|
||||
**Probabilistic**:
|
||||
- Multiple valid tokenizations
|
||||
- Can sample different tokenizations (data augmentation)
|
||||
|
||||
**Subword regularization**:
|
||||
```python
|
||||
# Sample different tokenizations
|
||||
for _ in range(3):
|
||||
tokens = tokenizer.encode("tokenization", is_pretokenized=False).tokens
|
||||
print(tokens)
|
||||
|
||||
# Output (different each time):
|
||||
# ['token', 'ization']
|
||||
# ['tok', 'en', 'ization']
|
||||
# ['token', 'iz', 'ation']
|
||||
```
|
||||
|
||||
**Language-independent**:
|
||||
- No word boundaries needed
|
||||
- Works for CJK languages (Chinese, Japanese, Korean)
|
||||
- Treats input as character stream
|
||||
|
||||
**Trade-offs**:
|
||||
- Slower training (EM algorithm)
|
||||
- More hyperparameters
|
||||
- Larger model (stores probabilities)
|
||||
|
||||
## Algorithm comparison
|
||||
|
||||
### Training speed
|
||||
|
||||
| Algorithm | Small (10MB) | Medium (100MB) | Large (1GB) |
|
||||
|------------|--------------|----------------|-------------|
|
||||
| BPE | 10-15 sec | 1-2 min | 10-20 min |
|
||||
| WordPiece | 15-20 sec | 2-3 min | 15-30 min |
|
||||
| Unigram | 20-30 sec | 3-5 min | 30-60 min |
|
||||
|
||||
**Tested on**: 16-core CPU, 30k vocab
|
||||
|
||||
### Tokenization quality
|
||||
|
||||
Tested on English Wikipedia (perplexity measurement):
|
||||
|
||||
| Algorithm | Vocab Size | Tokens/Word | Unknown Rate |
|
||||
|------------|------------|-------------|--------------|
|
||||
| BPE | 30k | 1.3 | 0.5% |
|
||||
| WordPiece | 30k | 1.2 | 1.2% |
|
||||
| Unigram | 8k | 1.5 | 0.3% |
|
||||
|
||||
**Key observations**:
|
||||
- WordPiece: Slightly better compression
|
||||
- BPE: Lower unknown rate
|
||||
- Unigram: Smallest vocab, good coverage
|
||||
|
||||
### Compression ratio
|
||||
|
||||
Characters per token (higher = better compression):
|
||||
|
||||
| Language | BPE (30k) | WordPiece (30k) | Unigram (8k) |
|
||||
|----------|-----------|-----------------|--------------|
|
||||
| English | 4.2 | 4.5 | 3.8 |
|
||||
| Chinese | 2.1 | 2.3 | 2.5 |
|
||||
| Arabic | 3.5 | 3.8 | 3.2 |
|
||||
|
||||
**Best for each**:
|
||||
- English: WordPiece
|
||||
- Chinese: Unigram (language-independent)
|
||||
- Arabic: WordPiece
|
||||
|
||||
### Use case recommendations
|
||||
|
||||
**BPE** - Best for:
|
||||
- English language models
|
||||
- Code (handles symbols well)
|
||||
- Fast training needed
|
||||
- **Models**: GPT-2, GPT-3, RoBERTa, BART
|
||||
|
||||
**WordPiece** - Best for:
|
||||
- Masked language modeling (BERT-style)
|
||||
- Morphologically rich languages
|
||||
- Semantic understanding tasks
|
||||
- **Models**: BERT, DistilBERT, ELECTRA
|
||||
|
||||
**Unigram** - Best for:
|
||||
- Multilingual models
|
||||
- Languages without word boundaries (CJK)
|
||||
- Data augmentation via subword regularization
|
||||
- **Models**: T5, ALBERT, XLNet (via SentencePiece)
|
||||
|
||||
## Advanced topics
|
||||
|
||||
### Handling rare words
|
||||
|
||||
**BPE approach**:
|
||||
```
|
||||
"antidisestablishmentarianism"
|
||||
→ ['anti', 'dis', 'establish', 'ment', 'arian', 'ism']
|
||||
```
|
||||
|
||||
**WordPiece approach**:
|
||||
```
|
||||
"antidisestablishmentarianism"
|
||||
→ ['anti', '##dis', '##establish', '##ment', '##arian', '##ism']
|
||||
```
|
||||
|
||||
**Unigram approach**:
|
||||
```
|
||||
"antidisestablishmentarianism"
|
||||
→ ['▁anti', 'dis', 'establish', 'ment', 'arian', 'ism']
|
||||
```
|
||||
|
||||
### Handling numbers
|
||||
|
||||
**Challenge**: Infinite number combinations
|
||||
|
||||
**BPE solution**: Byte-level (handles any digit sequence)
|
||||
```python
|
||||
tokenizer = Tokenizer(BPE())
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
# Handles any number
|
||||
"123456789" → byte-level tokens
|
||||
```
|
||||
|
||||
**WordPiece solution**: Digit pre-tokenization
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Digits
|
||||
|
||||
# Split digits individually or as groups
|
||||
tokenizer.pre_tokenizer = Digits(individual_digits=True)
|
||||
|
||||
"123" → ['1', '2', '3']
|
||||
```
|
||||
|
||||
**Unigram solution**: Learns common number patterns
|
||||
```python
|
||||
# Learns patterns during training
|
||||
"2023" → ['202', '3'] or ['20', '23']
|
||||
```
|
||||
|
||||
### Handling case sensitivity
|
||||
|
||||
**Lowercase (BERT)**:
|
||||
```python
|
||||
from tokenizers.normalizers import Lowercase
|
||||
|
||||
tokenizer.normalizer = Lowercase()
|
||||
|
||||
"Hello WORLD" → "hello world" → ['hello', 'world']
|
||||
```
|
||||
|
||||
**Preserve case (GPT-2)**:
|
||||
```python
|
||||
# No case normalization
|
||||
tokenizer.normalizer = None
|
||||
|
||||
"Hello WORLD" → ['Hello', 'WORLD']
|
||||
```
|
||||
|
||||
**Cased tokens (RoBERTa)**:
|
||||
```python
|
||||
# Learns separate tokens for different cases
|
||||
Vocabulary: ['Hello', 'hello', 'HELLO', 'world', 'WORLD']
|
||||
```
|
||||
|
||||
### Handling emojis and special characters
|
||||
|
||||
**Byte-level (GPT-2)**:
|
||||
```python
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
"Hello 🌍 👋" → byte-level representation (always works)
|
||||
```
|
||||
|
||||
**Unicode normalization**:
|
||||
```python
|
||||
from tokenizers.normalizers import NFKC
|
||||
|
||||
tokenizer.normalizer = NFKC()
|
||||
|
||||
"é" (composed) ↔ "é" (decomposed) → normalized to one form
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Poor subword splitting
|
||||
|
||||
**Symptom**:
|
||||
```
|
||||
"running" → ['r', 'u', 'n', 'n', 'i', 'n', 'g'] (too granular)
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Increase vocabulary size
|
||||
2. Train longer (more merge iterations)
|
||||
3. Lower `min_frequency` threshold
|
||||
|
||||
### Issue: Too many unknown tokens
|
||||
|
||||
**Symptom**:
|
||||
```
|
||||
5% of tokens are [UNK]
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Increase vocabulary size
|
||||
2. Use byte-level BPE (no UNK possible)
|
||||
3. Verify training corpus is representative
|
||||
|
||||
### Issue: Inconsistent tokenization
|
||||
|
||||
**Symptom**:
|
||||
```
|
||||
"running" → ['run', 'ning']
|
||||
"runner" → ['r', 'u', 'n', 'n', 'e', 'r']
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Check normalization consistency
|
||||
2. Ensure pre-tokenization is deterministic
|
||||
3. Use Unigram for probabilistic variance
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Match algorithm to model architecture**:
|
||||
- BERT-style → WordPiece
|
||||
- GPT-style → BPE
|
||||
- T5-style → Unigram
|
||||
|
||||
2. **Use byte-level for multilingual**:
|
||||
- Handles any Unicode
|
||||
- No unknown tokens
|
||||
|
||||
3. **Test on representative data**:
|
||||
- Measure compression ratio
|
||||
- Check unknown token rate
|
||||
- Inspect sample tokenizations
|
||||
|
||||
4. **Version control tokenizers**:
|
||||
- Save with model
|
||||
- Document special tokens
|
||||
- Track vocabulary changes
|
||||
@@ -0,0 +1,637 @@
|
||||
# Transformers Integration
|
||||
|
||||
Complete guide to using HuggingFace Tokenizers with the Transformers library.
|
||||
|
||||
## AutoTokenizer
|
||||
|
||||
The easiest way to load tokenizers.
|
||||
|
||||
### Loading pretrained tokenizers
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Load from HuggingFace Hub
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Check if using fast tokenizer (Rust-based)
|
||||
print(tokenizer.is_fast) # True
|
||||
|
||||
# Access underlying tokenizers.Tokenizer
|
||||
if tokenizer.is_fast:
|
||||
fast_tokenizer = tokenizer.backend_tokenizer
|
||||
print(type(fast_tokenizer)) # <class 'tokenizers.Tokenizer'>
|
||||
```
|
||||
|
||||
### Fast vs slow tokenizers
|
||||
|
||||
| Feature | Fast (Rust) | Slow (Python) |
|
||||
|--------------------------|----------------|---------------|
|
||||
| Speed | 5-10× faster | Baseline |
|
||||
| Alignment tracking | ✅ Full support | ❌ Limited |
|
||||
| Batch processing | ✅ Optimized | ⚠️ Slower |
|
||||
| Offset mapping | ✅ Yes | ❌ No |
|
||||
| Installation | `tokenizers` | Built-in |
|
||||
|
||||
**Always use fast tokenizers when available.**
|
||||
|
||||
### Check available tokenizers
|
||||
|
||||
```python
|
||||
from transformers import TOKENIZER_MAPPING
|
||||
|
||||
# List all fast tokenizers
|
||||
for config_class, (slow, fast) in TOKENIZER_MAPPING.items():
|
||||
if fast is not None:
|
||||
print(f"{config_class.__name__}: {fast.__name__}")
|
||||
```
|
||||
|
||||
## PreTrainedTokenizerFast
|
||||
|
||||
Wrap custom tokenizers for transformers.
|
||||
|
||||
### Convert custom tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
# Train custom tokenizer
|
||||
tokenizer = Tokenizer(BPE())
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=30000,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
||||
)
|
||||
tokenizer.train(files=["corpus.txt"], trainer=trainer)
|
||||
|
||||
# Save tokenizer
|
||||
tokenizer.save("my-tokenizer.json")
|
||||
|
||||
# Wrap for transformers
|
||||
transformers_tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_file="my-tokenizer.json",
|
||||
unk_token="[UNK]",
|
||||
sep_token="[SEP]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
mask_token="[MASK]"
|
||||
)
|
||||
|
||||
# Save in transformers format
|
||||
transformers_tokenizer.save_pretrained("my-tokenizer")
|
||||
```
|
||||
|
||||
**Result**: Directory with `tokenizer.json` + `tokenizer_config.json` + `special_tokens_map.json`
|
||||
|
||||
### Use like any transformers tokenizer
|
||||
|
||||
```python
|
||||
# Load
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("my-tokenizer")
|
||||
|
||||
# Encode with all transformers features
|
||||
outputs = tokenizer(
|
||||
"Hello world",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=128,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
print(outputs.keys())
|
||||
# dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
|
||||
```
|
||||
|
||||
## Special tokens
|
||||
|
||||
### Default special tokens
|
||||
|
||||
| Model Family | CLS/BOS | SEP/EOS | PAD | UNK | MASK |
|
||||
|--------------|---------|---------------|---------|---------|---------|
|
||||
| BERT | [CLS] | [SEP] | [PAD] | [UNK] | [MASK] |
|
||||
| GPT-2 | - | <\|endoftext\|> | <\|endoftext\|> | <\|endoftext\|> | - |
|
||||
| RoBERTa | <s> | </s> | <pad> | <unk> | <mask> |
|
||||
| T5 | - | </s> | <pad> | <unk> | - |
|
||||
|
||||
### Adding special tokens
|
||||
|
||||
```python
|
||||
# Add new special tokens
|
||||
special_tokens_dict = {
|
||||
"additional_special_tokens": ["<|image|>", "<|video|>", "<|audio|>"]
|
||||
}
|
||||
|
||||
num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
print(f"Added {num_added_tokens} tokens")
|
||||
|
||||
# Resize model embeddings
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Use new tokens
|
||||
text = "This is an image: <|image|>"
|
||||
tokens = tokenizer.encode(text)
|
||||
```
|
||||
|
||||
### Adding regular tokens
|
||||
|
||||
```python
|
||||
# Add domain-specific tokens
|
||||
new_tokens = ["COVID-19", "mRNA", "vaccine"]
|
||||
num_added = tokenizer.add_tokens(new_tokens)
|
||||
|
||||
# These are NOT special tokens (can be split if needed)
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=False)
|
||||
|
||||
# These ARE special tokens (never split)
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=True)
|
||||
```
|
||||
|
||||
## Encoding and decoding
|
||||
|
||||
### Basic encoding
|
||||
|
||||
```python
|
||||
# Single sentence
|
||||
text = "Hello, how are you?"
|
||||
encoded = tokenizer(text)
|
||||
|
||||
print(encoded)
|
||||
# {'input_ids': [101, 7592, 1010, 2129, 2024, 2017, 1029, 102],
|
||||
# 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0],
|
||||
# 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
|
||||
```
|
||||
|
||||
### Batch encoding
|
||||
|
||||
```python
|
||||
# Multiple sentences
|
||||
texts = ["Hello world", "How are you?", "I am fine"]
|
||||
encoded = tokenizer(texts, padding=True, truncation=True, max_length=10)
|
||||
|
||||
print(encoded['input_ids'])
|
||||
# [[101, 7592, 2088, 102, 0, 0, 0, 0, 0, 0],
|
||||
# [101, 2129, 2024, 2017, 1029, 102, 0, 0, 0, 0],
|
||||
# [101, 1045, 2572, 2986, 102, 0, 0, 0, 0, 0]]
|
||||
```
|
||||
|
||||
### Return tensors
|
||||
|
||||
```python
|
||||
# Return PyTorch tensors
|
||||
outputs = tokenizer("Hello world", return_tensors="pt")
|
||||
print(outputs['input_ids'].shape) # torch.Size([1, 5])
|
||||
|
||||
# Return TensorFlow tensors
|
||||
outputs = tokenizer("Hello world", return_tensors="tf")
|
||||
|
||||
# Return NumPy arrays
|
||||
outputs = tokenizer("Hello world", return_tensors="np")
|
||||
|
||||
# Return lists (default)
|
||||
outputs = tokenizer("Hello world", return_tensors=None)
|
||||
```
|
||||
|
||||
### Decoding
|
||||
|
||||
```python
|
||||
# Decode token IDs
|
||||
ids = [101, 7592, 2088, 102]
|
||||
text = tokenizer.decode(ids)
|
||||
print(text) # "[CLS] hello world [SEP]"
|
||||
|
||||
# Skip special tokens
|
||||
text = tokenizer.decode(ids, skip_special_tokens=True)
|
||||
print(text) # "hello world"
|
||||
|
||||
# Batch decode
|
||||
batch_ids = [[101, 7592, 102], [101, 2088, 102]]
|
||||
texts = tokenizer.batch_decode(batch_ids, skip_special_tokens=True)
|
||||
print(texts) # ["hello", "world"]
|
||||
```
|
||||
|
||||
## Padding and truncation
|
||||
|
||||
### Padding strategies
|
||||
|
||||
```python
|
||||
# Pad to max length in batch
|
||||
tokenizer(texts, padding="longest")
|
||||
|
||||
# Pad to model max length
|
||||
tokenizer(texts, padding="max_length", max_length=128)
|
||||
|
||||
# No padding
|
||||
tokenizer(texts, padding=False)
|
||||
|
||||
# Pad to multiple of value (for efficient computation)
|
||||
tokenizer(texts, padding="max_length", max_length=128, pad_to_multiple_of=8)
|
||||
# Result: length will be 128 (already multiple of 8)
|
||||
```
|
||||
|
||||
### Truncation strategies
|
||||
|
||||
```python
|
||||
# Truncate to max length
|
||||
tokenizer(text, truncation=True, max_length=10)
|
||||
|
||||
# Only truncate first sequence (for pairs)
|
||||
tokenizer(text1, text2, truncation="only_first", max_length=20)
|
||||
|
||||
# Only truncate second sequence
|
||||
tokenizer(text1, text2, truncation="only_second", max_length=20)
|
||||
|
||||
# Truncate longest first (default for pairs)
|
||||
tokenizer(text1, text2, truncation="longest_first", max_length=20)
|
||||
|
||||
# No truncation (error if too long)
|
||||
tokenizer(text, truncation=False)
|
||||
```
|
||||
|
||||
### Stride for long documents
|
||||
|
||||
```python
|
||||
# For documents longer than max_length
|
||||
text = "Very long document " * 1000
|
||||
|
||||
# Encode with overlap
|
||||
encodings = tokenizer(
|
||||
text,
|
||||
max_length=512,
|
||||
stride=128, # Overlap between chunks
|
||||
truncation=True,
|
||||
return_overflowing_tokens=True,
|
||||
return_offsets_mapping=True
|
||||
)
|
||||
|
||||
# Get all chunks
|
||||
num_chunks = len(encodings['input_ids'])
|
||||
print(f"Split into {num_chunks} chunks")
|
||||
|
||||
# Each chunk overlaps by stride tokens
|
||||
for i, chunk in enumerate(encodings['input_ids']):
|
||||
print(f"Chunk {i}: {len(chunk)} tokens")
|
||||
```
|
||||
|
||||
**Use case**: Long document QA, sliding window inference
|
||||
|
||||
## Alignment and offsets
|
||||
|
||||
### Offset mapping
|
||||
|
||||
```python
|
||||
# Get character offsets for each token
|
||||
encoded = tokenizer("Hello, world!", return_offsets_mapping=True)
|
||||
|
||||
for token, (start, end) in zip(
|
||||
encoded.tokens(),
|
||||
encoded['offset_mapping'][0]
|
||||
):
|
||||
print(f"{token:10s} → [{start:2d}, {end:2d})")
|
||||
|
||||
# Output:
|
||||
# [CLS] → [ 0, 0)
|
||||
# Hello → [ 0, 5)
|
||||
# , → [ 5, 6)
|
||||
# world → [ 7, 12)
|
||||
# ! → [12, 13)
|
||||
# [SEP] → [ 0, 0)
|
||||
```
|
||||
|
||||
### Word IDs
|
||||
|
||||
```python
|
||||
# Get word index for each token
|
||||
encoded = tokenizer("Hello world", return_offsets_mapping=True)
|
||||
word_ids = encoded.word_ids()
|
||||
|
||||
print(word_ids)
|
||||
# [None, 0, 1, None]
|
||||
# None = special token, 0 = first word, 1 = second word
|
||||
```
|
||||
|
||||
**Use case**: Token classification (NER, POS tagging)
|
||||
|
||||
### Character to token mapping
|
||||
|
||||
```python
|
||||
text = "Machine learning is awesome"
|
||||
encoded = tokenizer(text, return_offsets_mapping=True)
|
||||
|
||||
# Find token for character position
|
||||
char_pos = 8 # "l" in "learning"
|
||||
token_idx = encoded.char_to_token(char_pos)
|
||||
|
||||
print(f"Character {char_pos} is in token {token_idx}: {encoded.tokens()[token_idx]}")
|
||||
# Character 8 is in token 2: learning
|
||||
```
|
||||
|
||||
**Use case**: Question answering (map answer character span to tokens)
|
||||
|
||||
### Sequence pairs
|
||||
|
||||
```python
|
||||
# Encode sentence pair
|
||||
encoded = tokenizer("Question here", "Answer here", return_offsets_mapping=True)
|
||||
|
||||
# Get sequence IDs (which sequence each token belongs to)
|
||||
sequence_ids = encoded.sequence_ids()
|
||||
print(sequence_ids)
|
||||
# [None, 0, 0, 0, None, 1, 1, 1, None]
|
||||
# None = special token, 0 = question, 1 = answer
|
||||
```
|
||||
|
||||
## Model integration
|
||||
|
||||
### Use with transformers models
|
||||
|
||||
```python
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
import torch
|
||||
|
||||
# Load model and tokenizer
|
||||
model = AutoModel.from_pretrained("bert-base-uncased")
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Tokenize
|
||||
text = "Hello world"
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
|
||||
# Forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Get embeddings
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
print(last_hidden_state.shape) # [1, seq_len, hidden_size]
|
||||
```
|
||||
|
||||
### Custom model with custom tokenizer
|
||||
|
||||
```python
|
||||
from transformers import BertConfig, BertModel
|
||||
|
||||
# Train custom tokenizer
|
||||
from tokenizers import Tokenizer, models, trainers
|
||||
tokenizer = Tokenizer(models.BPE())
|
||||
trainer = trainers.BpeTrainer(vocab_size=30000)
|
||||
tokenizer.train(files=["data.txt"], trainer=trainer)
|
||||
|
||||
# Wrap for transformers
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
fast_tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=tokenizer,
|
||||
unk_token="[UNK]",
|
||||
pad_token="[PAD]"
|
||||
)
|
||||
|
||||
# Create model with custom vocab size
|
||||
config = BertConfig(vocab_size=30000)
|
||||
model = BertModel(config)
|
||||
|
||||
# Use together
|
||||
inputs = fast_tokenizer("Hello world", return_tensors="pt")
|
||||
outputs = model(**inputs)
|
||||
```
|
||||
|
||||
### Save and load together
|
||||
|
||||
```python
|
||||
# Save both
|
||||
model.save_pretrained("my-model")
|
||||
tokenizer.save_pretrained("my-model")
|
||||
|
||||
# Directory structure:
|
||||
# my-model/
|
||||
# ├── config.json
|
||||
# ├── pytorch_model.bin
|
||||
# ├── tokenizer.json
|
||||
# ├── tokenizer_config.json
|
||||
# └── special_tokens_map.json
|
||||
|
||||
# Load both
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
model = AutoModel.from_pretrained("my-model")
|
||||
tokenizer = AutoTokenizer.from_pretrained("my-model")
|
||||
```
|
||||
|
||||
## Advanced features
|
||||
|
||||
### Multimodal tokenization
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# LLaVA-style (image + text)
|
||||
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
# Add image placeholder token
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
|
||||
|
||||
# Use in prompt
|
||||
text = "Describe this image: <image>"
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
```
|
||||
|
||||
### Template formatting
|
||||
|
||||
```python
|
||||
# Chat template
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi! How can I help?"},
|
||||
{"role": "user", "content": "What's the weather?"}
|
||||
]
|
||||
|
||||
# Apply chat template (if tokenizer has one)
|
||||
if hasattr(tokenizer, "apply_chat_template"):
|
||||
text = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
```
|
||||
|
||||
### Custom template
|
||||
|
||||
```python
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
|
||||
|
||||
# Define chat template
|
||||
tokenizer.chat_template = """
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'system' %}
|
||||
System: {{ message['content'] }}\\n
|
||||
{%- elif message['role'] == 'user' %}
|
||||
User: {{ message['content'] }}\\n
|
||||
{%- elif message['role'] == 'assistant' %}
|
||||
Assistant: {{ message['content'] }}\\n
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
Assistant:
|
||||
"""
|
||||
|
||||
# Use template
|
||||
text = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
```
|
||||
|
||||
## Performance optimization
|
||||
|
||||
### Batch processing
|
||||
|
||||
```python
|
||||
# Process large datasets efficiently
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("imdb", split="train[:1000]")
|
||||
|
||||
# Tokenize in batches
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(
|
||||
examples["text"],
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=512
|
||||
)
|
||||
|
||||
# Map over dataset (batched)
|
||||
tokenized_dataset = dataset.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
batch_size=1000,
|
||||
num_proc=4 # Parallel processing
|
||||
)
|
||||
```
|
||||
|
||||
### Caching
|
||||
|
||||
```python
|
||||
# Enable caching for repeated tokenization
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"bert-base-uncased",
|
||||
use_fast=True,
|
||||
cache_dir="./cache" # Cache tokenizer files
|
||||
)
|
||||
|
||||
# Tokenize with caching
|
||||
from functools import lru_cache
|
||||
|
||||
@lru_cache(maxsize=10000)
|
||||
def cached_tokenize(text):
|
||||
return tuple(tokenizer.encode(text))
|
||||
|
||||
# Reuses cached results for repeated inputs
|
||||
```
|
||||
|
||||
### Memory efficiency
|
||||
|
||||
```python
|
||||
# For very large datasets, use streaming
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("pile", split="train", streaming=True)
|
||||
|
||||
def process_batch(batch):
|
||||
# Tokenize
|
||||
tokens = tokenizer(batch["text"], truncation=True, max_length=512)
|
||||
|
||||
# Process tokens...
|
||||
|
||||
return tokens
|
||||
|
||||
# Process in chunks (memory efficient)
|
||||
for batch in dataset.batch(batch_size=1000):
|
||||
processed = process_batch(batch)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Tokenizer not fast
|
||||
|
||||
**Symptom**:
|
||||
```python
|
||||
tokenizer.is_fast # False
|
||||
```
|
||||
|
||||
**Solution**: Install tokenizers library
|
||||
```bash
|
||||
pip install tokenizers
|
||||
```
|
||||
|
||||
### Issue: Special tokens not working
|
||||
|
||||
**Symptom**: Special tokens are split into subwords
|
||||
|
||||
**Solution**: Add as special tokens, not regular tokens
|
||||
```python
|
||||
# Wrong
|
||||
tokenizer.add_tokens(["<|image|>"])
|
||||
|
||||
# Correct
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": ["<|image|>"]})
|
||||
```
|
||||
|
||||
### Issue: Offset mapping not available
|
||||
|
||||
**Symptom**:
|
||||
```python
|
||||
tokenizer("text", return_offsets_mapping=True)
|
||||
# Error: return_offsets_mapping not supported
|
||||
```
|
||||
|
||||
**Solution**: Use fast tokenizer
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Load fast version
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)
|
||||
```
|
||||
|
||||
### Issue: Padding inconsistent
|
||||
|
||||
**Symptom**: Some sequences padded, others not
|
||||
|
||||
**Solution**: Specify padding strategy
|
||||
```python
|
||||
# Explicit padding
|
||||
tokenizer(
|
||||
texts,
|
||||
padding="max_length", # or "longest"
|
||||
max_length=128
|
||||
)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Always use fast tokenizers**:
|
||||
- 5-10× faster
|
||||
- Full alignment tracking
|
||||
- Better batch processing
|
||||
|
||||
2. **Save tokenizer with model**:
|
||||
- Ensures reproducibility
|
||||
- Prevents version mismatches
|
||||
|
||||
3. **Use batch processing for datasets**:
|
||||
- Tokenize with `.map(batched=True)`
|
||||
- Set `num_proc` for parallelism
|
||||
|
||||
4. **Enable caching for repeated inputs**:
|
||||
- Use `lru_cache` for inference
|
||||
- Cache tokenizer files with `cache_dir`
|
||||
|
||||
5. **Handle special tokens properly**:
|
||||
- Use `add_special_tokens()` for never-split tokens
|
||||
- Resize embeddings after adding tokens
|
||||
|
||||
6. **Test alignment for downstream tasks**:
|
||||
- Verify `offset_mapping` is correct
|
||||
- Test `char_to_token()` on samples
|
||||
|
||||
7. **Version control tokenizer config**:
|
||||
- Save `tokenizer_config.json`
|
||||
- Document custom templates
|
||||
- Track vocabulary changes
|
||||
@@ -0,0 +1,723 @@
|
||||
# Tokenization Pipeline Components
|
||||
|
||||
Complete guide to normalizers, pre-tokenizers, models, post-processors, and decoders.
|
||||
|
||||
## Pipeline overview
|
||||
|
||||
**Full tokenization pipeline**:
|
||||
```
|
||||
Raw Text
|
||||
↓
|
||||
Normalization (cleaning, lowercasing)
|
||||
↓
|
||||
Pre-tokenization (split into words)
|
||||
↓
|
||||
Model (apply BPE/WordPiece/Unigram)
|
||||
↓
|
||||
Post-processing (add special tokens)
|
||||
↓
|
||||
Token IDs
|
||||
```
|
||||
|
||||
**Decoding reverses the process**:
|
||||
```
|
||||
Token IDs
|
||||
↓
|
||||
Decoder (handle special encodings)
|
||||
↓
|
||||
Raw Text
|
||||
```
|
||||
|
||||
## Normalizers
|
||||
|
||||
Clean and standardize input text.
|
||||
|
||||
### Common normalizers
|
||||
|
||||
**Lowercase**:
|
||||
```python
|
||||
from tokenizers.normalizers import Lowercase
|
||||
|
||||
tokenizer.normalizer = Lowercase()
|
||||
|
||||
# Input: "Hello WORLD"
|
||||
# Output: "hello world"
|
||||
```
|
||||
|
||||
**Unicode normalization**:
|
||||
```python
|
||||
from tokenizers.normalizers import NFD, NFC, NFKD, NFKC
|
||||
|
||||
# NFD: Canonical decomposition
|
||||
tokenizer.normalizer = NFD()
|
||||
# "é" → "e" + "́" (separate characters)
|
||||
|
||||
# NFC: Canonical composition (default)
|
||||
tokenizer.normalizer = NFC()
|
||||
# "e" + "́" → "é" (composed)
|
||||
|
||||
# NFKD: Compatibility decomposition
|
||||
tokenizer.normalizer = NFKD()
|
||||
# "fi" → "f" + "i"
|
||||
|
||||
# NFKC: Compatibility composition
|
||||
tokenizer.normalizer = NFKC()
|
||||
# Most aggressive normalization
|
||||
```
|
||||
|
||||
**Strip accents**:
|
||||
```python
|
||||
from tokenizers.normalizers import StripAccents
|
||||
|
||||
tokenizer.normalizer = StripAccents()
|
||||
|
||||
# Input: "café"
|
||||
# Output: "cafe"
|
||||
```
|
||||
|
||||
**Whitespace handling**:
|
||||
```python
|
||||
from tokenizers.normalizers import Strip, StripAccents
|
||||
|
||||
# Remove leading/trailing whitespace
|
||||
tokenizer.normalizer = Strip()
|
||||
|
||||
# Input: " hello "
|
||||
# Output: "hello"
|
||||
```
|
||||
|
||||
**Replace patterns**:
|
||||
```python
|
||||
from tokenizers.normalizers import Replace
|
||||
|
||||
# Replace newlines with spaces
|
||||
tokenizer.normalizer = Replace("\\n", " ")
|
||||
|
||||
# Input: "hello\\nworld"
|
||||
# Output: "hello world"
|
||||
```
|
||||
|
||||
### Combining normalizers
|
||||
|
||||
```python
|
||||
from tokenizers.normalizers import Sequence, NFD, Lowercase, StripAccents
|
||||
|
||||
# BERT-style normalization
|
||||
tokenizer.normalizer = Sequence([
|
||||
NFD(), # Unicode decomposition
|
||||
Lowercase(), # Convert to lowercase
|
||||
StripAccents() # Remove accents
|
||||
])
|
||||
|
||||
# Input: "Café au Lait"
|
||||
# After NFD: "Café au Lait" (e + ́)
|
||||
# After Lowercase: "café au lait"
|
||||
# After StripAccents: "cafe au lait"
|
||||
```
|
||||
|
||||
### Use case examples
|
||||
|
||||
**Case-insensitive model (BERT)**:
|
||||
```python
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
|
||||
# All-in-one BERT normalization
|
||||
tokenizer.normalizer = BertNormalizer(
|
||||
clean_text=True, # Remove control characters
|
||||
handle_chinese_chars=True, # Add spaces around Chinese
|
||||
strip_accents=True, # Remove accents
|
||||
lowercase=True # Lowercase
|
||||
)
|
||||
```
|
||||
|
||||
**Case-sensitive model (GPT-2)**:
|
||||
```python
|
||||
# Minimal normalization
|
||||
tokenizer.normalizer = NFC() # Only normalize Unicode
|
||||
```
|
||||
|
||||
**Multilingual (mBERT)**:
|
||||
```python
|
||||
# Preserve scripts, normalize form
|
||||
tokenizer.normalizer = NFKC()
|
||||
```
|
||||
|
||||
## Pre-tokenizers
|
||||
|
||||
Split text into word-like units before tokenization.
|
||||
|
||||
### Whitespace splitting
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Whitespace
|
||||
|
||||
tokenizer.pre_tokenizer = Whitespace()
|
||||
|
||||
# Input: "Hello world! How are you?"
|
||||
# Output: [("Hello", (0, 5)), ("world!", (6, 12)), ("How", (13, 16)), ("are", (17, 20)), ("you?", (21, 25))]
|
||||
```
|
||||
|
||||
### Punctuation isolation
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Punctuation
|
||||
|
||||
tokenizer.pre_tokenizer = Punctuation()
|
||||
|
||||
# Input: "Hello, world!"
|
||||
# Output: [("Hello", ...), (",", ...), ("world", ...), ("!", ...)]
|
||||
```
|
||||
|
||||
### Byte-level (GPT-2)
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
|
||||
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True)
|
||||
|
||||
# Input: "Hello world"
|
||||
# Output: Byte-level tokens with Ġ prefix for spaces
|
||||
# [("ĠHello", ...), ("Ġworld", ...)]
|
||||
```
|
||||
|
||||
**Key feature**: Handles ALL Unicode characters (256 byte combinations)
|
||||
|
||||
### Metaspace (SentencePiece)
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Metaspace
|
||||
|
||||
tokenizer.pre_tokenizer = Metaspace(replacement="▁", add_prefix_space=True)
|
||||
|
||||
# Input: "Hello world"
|
||||
# Output: [("▁Hello", ...), ("▁world", ...)]
|
||||
```
|
||||
|
||||
**Used by**: T5, ALBERT (via SentencePiece)
|
||||
|
||||
### Digits splitting
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Digits
|
||||
|
||||
# Split digits individually
|
||||
tokenizer.pre_tokenizer = Digits(individual_digits=True)
|
||||
|
||||
# Input: "Room 123"
|
||||
# Output: [("Room", ...), ("1", ...), ("2", ...), ("3", ...)]
|
||||
|
||||
# Keep digits together
|
||||
tokenizer.pre_tokenizer = Digits(individual_digits=False)
|
||||
|
||||
# Input: "Room 123"
|
||||
# Output: [("Room", ...), ("123", ...)]
|
||||
```
|
||||
|
||||
### BERT pre-tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import BertPreTokenizer
|
||||
|
||||
tokenizer.pre_tokenizer = BertPreTokenizer()
|
||||
|
||||
# Splits on whitespace and punctuation, preserves CJK
|
||||
# Input: "Hello, 世界!"
|
||||
# Output: [("Hello", ...), (",", ...), ("世", ...), ("界", ...), ("!", ...)]
|
||||
```
|
||||
|
||||
### Combining pre-tokenizers
|
||||
|
||||
```python
|
||||
from tokenizers.pre_tokenizers import Sequence, Whitespace, Punctuation
|
||||
|
||||
tokenizer.pre_tokenizer = Sequence([
|
||||
Whitespace(), # Split on whitespace first
|
||||
Punctuation() # Then isolate punctuation
|
||||
])
|
||||
|
||||
# Input: "Hello, world!"
|
||||
# After Whitespace: [("Hello,", ...), ("world!", ...)]
|
||||
# After Punctuation: [("Hello", ...), (",", ...), ("world", ...), ("!", ...)]
|
||||
```
|
||||
|
||||
### Pre-tokenizer comparison
|
||||
|
||||
| Pre-tokenizer | Use Case | Example |
|
||||
|-------------------|---------------------------------|--------------------------------------------|
|
||||
| Whitespace | Simple English | "Hello world" → ["Hello", "world"] |
|
||||
| Punctuation | Isolate symbols | "world!" → ["world", "!"] |
|
||||
| ByteLevel | Multilingual, emojis | "🌍" → byte tokens |
|
||||
| Metaspace | SentencePiece-style | "Hello" → ["▁Hello"] |
|
||||
| BertPreTokenizer | BERT-style (CJK aware) | "世界" → ["世", "界"] |
|
||||
| Digits | Handle numbers | "123" → ["1", "2", "3"] or ["123"] |
|
||||
|
||||
## Models
|
||||
|
||||
Core tokenization algorithms.
|
||||
|
||||
### BPE Model
|
||||
|
||||
```python
|
||||
from tokenizers.models import BPE
|
||||
|
||||
model = BPE(
|
||||
vocab=None, # Or provide pre-built vocab
|
||||
merges=None, # Or provide merge rules
|
||||
unk_token="[UNK]", # Unknown token
|
||||
continuing_subword_prefix="",
|
||||
end_of_word_suffix="",
|
||||
fuse_unk=False # Keep unknown tokens separate
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(model)
|
||||
```
|
||||
|
||||
**Parameters**:
|
||||
- `vocab`: Dict of token → id
|
||||
- `merges`: List of merge rules `["a b", "ab c"]`
|
||||
- `unk_token`: Token for unknown words
|
||||
- `continuing_subword_prefix`: Prefix for subwords (empty for GPT-2)
|
||||
- `end_of_word_suffix`: Suffix for last subword (empty for GPT-2)
|
||||
|
||||
### WordPiece Model
|
||||
|
||||
```python
|
||||
from tokenizers.models import WordPiece
|
||||
|
||||
model = WordPiece(
|
||||
vocab=None,
|
||||
unk_token="[UNK]",
|
||||
max_input_chars_per_word=100, # Max word length
|
||||
continuing_subword_prefix="##" # BERT-style prefix
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(model)
|
||||
```
|
||||
|
||||
**Key difference**: Uses `##` prefix for continuing subwords.
|
||||
|
||||
### Unigram Model
|
||||
|
||||
```python
|
||||
from tokenizers.models import Unigram
|
||||
|
||||
model = Unigram(
|
||||
vocab=None, # List of (token, score) tuples
|
||||
unk_id=0, # ID for unknown token
|
||||
byte_fallback=False # Fall back to bytes if no match
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(model)
|
||||
```
|
||||
|
||||
**Probabilistic**: Selects tokenization with highest probability.
|
||||
|
||||
### WordLevel Model
|
||||
|
||||
```python
|
||||
from tokenizers.models import WordLevel
|
||||
|
||||
# Simple word-to-ID mapping (no subwords)
|
||||
model = WordLevel(
|
||||
vocab=None,
|
||||
unk_token="[UNK]"
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(model)
|
||||
```
|
||||
|
||||
**Warning**: Requires huge vocabulary (one token per word).
|
||||
|
||||
## Post-processors
|
||||
|
||||
Add special tokens and format output.
|
||||
|
||||
### Template processing
|
||||
|
||||
**BERT-style** (`[CLS] sentence [SEP]`):
|
||||
```python
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
pair="[CLS] $A [SEP] $B [SEP]",
|
||||
special_tokens=[
|
||||
("[CLS]", 101),
|
||||
("[SEP]", 102),
|
||||
],
|
||||
)
|
||||
|
||||
# Single sentence
|
||||
output = tokenizer.encode("Hello world")
|
||||
# [101, ..., 102] ([CLS] hello world [SEP])
|
||||
|
||||
# Sentence pair
|
||||
output = tokenizer.encode("Hello", "world")
|
||||
# [101, ..., 102, ..., 102] ([CLS] hello [SEP] world [SEP])
|
||||
```
|
||||
|
||||
**GPT-2 style** (`sentence <|endoftext|>`):
|
||||
```python
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="$A <|endoftext|>",
|
||||
special_tokens=[
|
||||
("<|endoftext|>", 50256),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
**RoBERTa style** (`<s> sentence </s>`):
|
||||
```python
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="<s> $A </s>",
|
||||
pair="<s> $A </s> </s> $B </s>",
|
||||
special_tokens=[
|
||||
("<s>", 0),
|
||||
("</s>", 2),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
**T5 style** (no special tokens):
|
||||
```python
|
||||
# T5 doesn't add special tokens via post-processor
|
||||
tokenizer.post_processor = None
|
||||
```
|
||||
|
||||
### RobertaProcessing
|
||||
|
||||
```python
|
||||
from tokenizers.processors import RobertaProcessing
|
||||
|
||||
tokenizer.post_processor = RobertaProcessing(
|
||||
sep=("</s>", 2),
|
||||
cls=("<s>", 0),
|
||||
add_prefix_space=True, # Add space before first token
|
||||
trim_offsets=True # Trim leading space from offsets
|
||||
)
|
||||
```
|
||||
|
||||
### ByteLevelProcessing
|
||||
|
||||
```python
|
||||
from tokenizers.processors import ByteLevel as ByteLevelProcessing
|
||||
|
||||
tokenizer.post_processor = ByteLevelProcessing(
|
||||
trim_offsets=True # Remove Ġ from offsets
|
||||
)
|
||||
```
|
||||
|
||||
## Decoders
|
||||
|
||||
Convert token IDs back to text.
|
||||
|
||||
### ByteLevel decoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import ByteLevel
|
||||
|
||||
tokenizer.decoder = ByteLevel()
|
||||
|
||||
# Handles byte-level tokens
|
||||
# ["ĠHello", "Ġworld"] → "Hello world"
|
||||
```
|
||||
|
||||
### WordPiece decoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import WordPiece
|
||||
|
||||
tokenizer.decoder = WordPiece(prefix="##")
|
||||
|
||||
# Removes ## prefix and concatenates
|
||||
# ["token", "##ization"] → "tokenization"
|
||||
```
|
||||
|
||||
### Metaspace decoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import Metaspace
|
||||
|
||||
tokenizer.decoder = Metaspace(replacement="▁", add_prefix_space=True)
|
||||
|
||||
# Converts ▁ back to spaces
|
||||
# ["▁Hello", "▁world"] → "Hello world"
|
||||
```
|
||||
|
||||
### BPEDecoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import BPEDecoder
|
||||
|
||||
tokenizer.decoder = BPEDecoder(suffix="</w>")
|
||||
|
||||
# Removes suffix and concatenates
|
||||
# ["token", "ization</w>"] → "tokenization"
|
||||
```
|
||||
|
||||
### Sequence decoder
|
||||
|
||||
```python
|
||||
from tokenizers.decoders import Sequence, ByteLevel, Strip
|
||||
|
||||
tokenizer.decoder = Sequence([
|
||||
ByteLevel(), # Decode byte-level first
|
||||
Strip(' ', 1, 1) # Strip leading/trailing spaces
|
||||
])
|
||||
```
|
||||
|
||||
## Complete pipeline examples
|
||||
|
||||
### BERT tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import WordPiece
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
from tokenizers.pre_tokenizers import BertPreTokenizer
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
from tokenizers.decoders import WordPiece as WordPieceDecoder
|
||||
|
||||
# Model
|
||||
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
|
||||
|
||||
# Normalization
|
||||
tokenizer.normalizer = BertNormalizer(lowercase=True)
|
||||
|
||||
# Pre-tokenization
|
||||
tokenizer.pre_tokenizer = BertPreTokenizer()
|
||||
|
||||
# Post-processing
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
pair="[CLS] $A [SEP] $B [SEP]",
|
||||
special_tokens=[("[CLS]", 101), ("[SEP]", 102)],
|
||||
)
|
||||
|
||||
# Decoder
|
||||
tokenizer.decoder = WordPieceDecoder(prefix="##")
|
||||
|
||||
# Enable padding
|
||||
tokenizer.enable_padding(pad_id=0, pad_token="[PAD]")
|
||||
|
||||
# Enable truncation
|
||||
tokenizer.enable_truncation(max_length=512)
|
||||
```
|
||||
|
||||
### GPT-2 tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.normalizers import NFC
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
|
||||
# Model
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Normalization (minimal)
|
||||
tokenizer.normalizer = NFC()
|
||||
|
||||
# Byte-level pre-tokenization
|
||||
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
|
||||
|
||||
# Post-processing
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="$A <|endoftext|>",
|
||||
special_tokens=[("<|endoftext|>", 50256)],
|
||||
)
|
||||
|
||||
# Byte-level decoder
|
||||
tokenizer.decoder = ByteLevelDecoder()
|
||||
```
|
||||
|
||||
### T5 tokenizer (SentencePiece-style)
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import Unigram
|
||||
from tokenizers.normalizers import NFKC
|
||||
from tokenizers.pre_tokenizers import Metaspace
|
||||
from tokenizers.decoders import Metaspace as MetaspaceDecoder
|
||||
|
||||
# Model
|
||||
tokenizer = Tokenizer(Unigram())
|
||||
|
||||
# Normalization
|
||||
tokenizer.normalizer = NFKC()
|
||||
|
||||
# Metaspace pre-tokenization
|
||||
tokenizer.pre_tokenizer = Metaspace(replacement="▁", add_prefix_space=True)
|
||||
|
||||
# No post-processing (T5 doesn't add CLS/SEP)
|
||||
tokenizer.post_processor = None
|
||||
|
||||
# Metaspace decoder
|
||||
tokenizer.decoder = MetaspaceDecoder(replacement="▁", add_prefix_space=True)
|
||||
```
|
||||
|
||||
## Alignment tracking
|
||||
|
||||
Track token positions in original text.
|
||||
|
||||
### Basic alignment
|
||||
|
||||
```python
|
||||
text = "Hello, world!"
|
||||
output = tokenizer.encode(text)
|
||||
|
||||
for token, (start, end) in zip(output.tokens, output.offsets):
|
||||
print(f"{token:10s} → [{start:2d}, {end:2d}): {text[start:end]!r}")
|
||||
|
||||
# Output:
|
||||
# [CLS] → [ 0, 0): ''
|
||||
# hello → [ 0, 5): 'Hello'
|
||||
# , → [ 5, 6): ','
|
||||
# world → [ 7, 12): 'world'
|
||||
# ! → [12, 13): '!'
|
||||
# [SEP] → [ 0, 0): ''
|
||||
```
|
||||
|
||||
### Word-level alignment
|
||||
|
||||
```python
|
||||
# Get word_ids (which word each token belongs to)
|
||||
encoding = tokenizer.encode("Hello world")
|
||||
word_ids = encoding.word_ids
|
||||
|
||||
print(word_ids)
|
||||
# [None, 0, 0, 1, None]
|
||||
# None = special token, 0 = first word, 1 = second word
|
||||
```
|
||||
|
||||
**Use case**: Token classification (NER)
|
||||
```python
|
||||
# Align predictions to words
|
||||
predictions = ["O", "B-PER", "I-PER", "O", "O"]
|
||||
word_predictions = {}
|
||||
|
||||
for token_idx, word_idx in enumerate(encoding.word_ids):
|
||||
if word_idx is not None and word_idx not in word_predictions:
|
||||
word_predictions[word_idx] = predictions[token_idx]
|
||||
|
||||
print(word_predictions)
|
||||
# {0: "B-PER", 1: "O"} # First word is PERSON, second is OTHER
|
||||
```
|
||||
|
||||
### Span alignment
|
||||
|
||||
```python
|
||||
# Find token span for character span
|
||||
text = "Machine learning is awesome"
|
||||
char_start, char_end = 8, 16 # "learning"
|
||||
|
||||
encoding = tokenizer.encode(text)
|
||||
|
||||
# Find token span
|
||||
token_start = encoding.char_to_token(char_start)
|
||||
token_end = encoding.char_to_token(char_end - 1) + 1
|
||||
|
||||
print(f"Tokens {token_start}:{token_end} = {encoding.tokens[token_start:token_end]}")
|
||||
# Tokens 2:3 = ['learning']
|
||||
```
|
||||
|
||||
**Use case**: Question answering (extract answer span)
|
||||
|
||||
## Custom components
|
||||
|
||||
### Custom normalizer
|
||||
|
||||
```python
|
||||
from tokenizers import NormalizedString, Normalizer
|
||||
|
||||
class CustomNormalizer:
|
||||
def normalize(self, normalized: NormalizedString):
|
||||
# Custom normalization logic
|
||||
normalized.lowercase()
|
||||
normalized.replace(" ", " ") # Replace double spaces
|
||||
|
||||
# Use custom normalizer
|
||||
tokenizer.normalizer = CustomNormalizer()
|
||||
```
|
||||
|
||||
### Custom pre-tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import PreTokenizedString
|
||||
|
||||
class CustomPreTokenizer:
|
||||
def pre_tokenize(self, pretok: PreTokenizedString):
|
||||
# Custom pre-tokenization logic
|
||||
pretok.split(lambda i, char: char.isspace())
|
||||
|
||||
tokenizer.pre_tokenizer = CustomPreTokenizer()
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Misaligned offsets
|
||||
|
||||
**Symptom**: Offsets don't match original text
|
||||
```python
|
||||
text = " hello" # Leading spaces
|
||||
offsets = [(0, 5)] # Expects " hel"
|
||||
```
|
||||
|
||||
**Solution**: Check normalization strips spaces
|
||||
```python
|
||||
# Preserve offsets
|
||||
tokenizer.normalizer = Sequence([
|
||||
Strip(), # This changes offsets!
|
||||
])
|
||||
|
||||
# Use trim_offsets in post-processor instead
|
||||
tokenizer.post_processor = ByteLevelProcessing(trim_offsets=True)
|
||||
```
|
||||
|
||||
### Issue: Special tokens not added
|
||||
|
||||
**Symptom**: No [CLS] or [SEP] in output
|
||||
|
||||
**Solution**: Check post-processor is set
|
||||
```python
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
special_tokens=[("[CLS]", 101), ("[SEP]", 102)],
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Incorrect decoding
|
||||
|
||||
**Symptom**: Decoded text has ## or ▁
|
||||
|
||||
**Solution**: Set correct decoder
|
||||
```python
|
||||
# For WordPiece
|
||||
tokenizer.decoder = WordPieceDecoder(prefix="##")
|
||||
|
||||
# For SentencePiece
|
||||
tokenizer.decoder = MetaspaceDecoder(replacement="▁")
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Match pipeline to model architecture**:
|
||||
- BERT → BertNormalizer + BertPreTokenizer + WordPiece
|
||||
- GPT-2 → NFC + ByteLevel + BPE
|
||||
- T5 → NFKC + Metaspace + Unigram
|
||||
|
||||
2. **Test pipeline on sample inputs**:
|
||||
- Check normalization doesn't over-normalize
|
||||
- Verify pre-tokenization splits correctly
|
||||
- Ensure decoding reconstructs text
|
||||
|
||||
3. **Preserve alignment for downstream tasks**:
|
||||
- Use `trim_offsets` instead of stripping in normalizer
|
||||
- Test `char_to_token()` on sample spans
|
||||
|
||||
4. **Document your pipeline**:
|
||||
- Save complete tokenizer config
|
||||
- Document special tokens
|
||||
- Note any custom components
|
||||
@@ -0,0 +1,565 @@
|
||||
# Training Custom Tokenizers
|
||||
|
||||
Complete guide to training tokenizers from scratch.
|
||||
|
||||
## Training workflow
|
||||
|
||||
### Step 1: Choose tokenization algorithm
|
||||
|
||||
**Decision tree**:
|
||||
- **GPT-style model** → BPE
|
||||
- **BERT-style model** → WordPiece
|
||||
- **Multilingual/No word boundaries** → Unigram
|
||||
|
||||
### Step 2: Prepare training data
|
||||
|
||||
```python
|
||||
# Option 1: From files
|
||||
files = ["train.txt", "validation.txt"]
|
||||
|
||||
# Option 2: From Python list
|
||||
texts = [
|
||||
"This is the first sentence.",
|
||||
"This is the second sentence.",
|
||||
# ... more texts
|
||||
]
|
||||
|
||||
# Option 3: From dataset iterator
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
|
||||
|
||||
def batch_iterator(batch_size=1000):
|
||||
for i in range(0, len(dataset), batch_size):
|
||||
yield dataset[i:i + batch_size]["text"]
|
||||
```
|
||||
|
||||
### Step 3: Initialize tokenizer
|
||||
|
||||
**BPE example**:
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
|
||||
|
||||
tokenizer = Tokenizer(BPE())
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
tokenizer.decoder = ByteLevelDecoder()
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50000,
|
||||
min_frequency=2,
|
||||
special_tokens=["<|endoftext|>", "<|padding|>"],
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
**WordPiece example**:
|
||||
```python
|
||||
from tokenizers.models import WordPiece
|
||||
from tokenizers.trainers import WordPieceTrainer
|
||||
from tokenizers.normalizers import BertNormalizer
|
||||
from tokenizers.pre_tokenizers import BertPreTokenizer
|
||||
|
||||
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
|
||||
tokenizer.normalizer = BertNormalizer(lowercase=True)
|
||||
tokenizer.pre_tokenizer = BertPreTokenizer()
|
||||
|
||||
trainer = WordPieceTrainer(
|
||||
vocab_size=30522,
|
||||
min_frequency=2,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
continuing_subword_prefix="##",
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
**Unigram example**:
|
||||
```python
|
||||
from tokenizers.models import Unigram
|
||||
from tokenizers.trainers import UnigramTrainer
|
||||
|
||||
tokenizer = Tokenizer(Unigram())
|
||||
|
||||
trainer = UnigramTrainer(
|
||||
vocab_size=8000,
|
||||
special_tokens=["<unk>", "<s>", "</s>", "<pad>"],
|
||||
unk_token="<unk>",
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
### Step 4: Train
|
||||
|
||||
```python
|
||||
# From files
|
||||
tokenizer.train(files=files, trainer=trainer)
|
||||
|
||||
# From iterator (recommended for large datasets)
|
||||
tokenizer.train_from_iterator(
|
||||
batch_iterator(),
|
||||
trainer=trainer,
|
||||
length=len(dataset) # Optional, for progress bar
|
||||
)
|
||||
```
|
||||
|
||||
**Training time** (30k vocab on 16-core CPU):
|
||||
- 10 MB: 15-30 seconds
|
||||
- 100 MB: 1-3 minutes
|
||||
- 1 GB: 15-30 minutes
|
||||
- 10 GB: 2-4 hours
|
||||
|
||||
### Step 5: Add post-processing
|
||||
|
||||
```python
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
|
||||
# BERT-style
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="[CLS] $A [SEP]",
|
||||
pair="[CLS] $A [SEP] $B [SEP]",
|
||||
special_tokens=[
|
||||
("[CLS]", tokenizer.token_to_id("[CLS]")),
|
||||
("[SEP]", tokenizer.token_to_id("[SEP]")),
|
||||
],
|
||||
)
|
||||
|
||||
# GPT-2 style
|
||||
tokenizer.post_processor = TemplateProcessing(
|
||||
single="$A <|endoftext|>",
|
||||
special_tokens=[
|
||||
("<|endoftext|>", tokenizer.token_to_id("<|endoftext|>")),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
### Step 6: Save
|
||||
|
||||
```python
|
||||
# Save to JSON
|
||||
tokenizer.save("my-tokenizer.json")
|
||||
|
||||
# Save to directory (for transformers)
|
||||
tokenizer.save("my-tokenizer-dir/tokenizer.json")
|
||||
|
||||
# Convert to transformers format
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
transformers_tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=tokenizer,
|
||||
unk_token="[UNK]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
sep_token="[SEP]",
|
||||
mask_token="[MASK]"
|
||||
)
|
||||
|
||||
transformers_tokenizer.save_pretrained("my-tokenizer-dir")
|
||||
```
|
||||
|
||||
## Trainer configuration
|
||||
|
||||
### BpeTrainer parameters
|
||||
|
||||
```python
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=30000, # Target vocabulary size
|
||||
min_frequency=2, # Minimum frequency for merges
|
||||
special_tokens=["[UNK]"], # Special tokens (added first)
|
||||
limit_alphabet=1000, # Limit initial alphabet size
|
||||
initial_alphabet=[], # Pre-defined initial characters
|
||||
show_progress=True, # Show progress bar
|
||||
continuing_subword_prefix="", # Prefix for continuing subwords
|
||||
end_of_word_suffix="" # Suffix for end of words
|
||||
)
|
||||
```
|
||||
|
||||
**Parameter tuning**:
|
||||
- **vocab_size**: Start with 30k for English, 50k for multilingual
|
||||
- **min_frequency**: 2-5 for large corpora, 1 for small
|
||||
- **limit_alphabet**: Reduce for non-English (CJK languages)
|
||||
|
||||
### WordPieceTrainer parameters
|
||||
|
||||
```python
|
||||
from tokenizers.trainers import WordPieceTrainer
|
||||
|
||||
trainer = WordPieceTrainer(
|
||||
vocab_size=30522, # BERT uses 30,522
|
||||
min_frequency=2,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
|
||||
limit_alphabet=1000,
|
||||
continuing_subword_prefix="##", # BERT-style prefix
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
### UnigramTrainer parameters
|
||||
|
||||
```python
|
||||
from tokenizers.trainers import UnigramTrainer
|
||||
|
||||
trainer = UnigramTrainer(
|
||||
vocab_size=8000, # Typically smaller than BPE/WordPiece
|
||||
special_tokens=["<unk>", "<s>", "</s>"],
|
||||
unk_token="<unk>",
|
||||
max_piece_length=16, # Maximum token length
|
||||
n_sub_iterations=2, # EM algorithm iterations
|
||||
shrinking_factor=0.75, # Vocabulary reduction rate
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
## Training from large datasets
|
||||
|
||||
### Memory-efficient training
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("wikipedia", "20220301.en", split="train", streaming=True)
|
||||
|
||||
# Create iterator (yields batches)
|
||||
def batch_iterator(batch_size=1000):
|
||||
batch = []
|
||||
for sample in dataset:
|
||||
batch.append(sample["text"])
|
||||
if len(batch) >= batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
# Initialize tokenizer
|
||||
tokenizer = Tokenizer(BPE())
|
||||
trainer = BpeTrainer(vocab_size=50000, special_tokens=["<|endoftext|>"])
|
||||
|
||||
# Train (memory efficient - streams data)
|
||||
tokenizer.train_from_iterator(
|
||||
batch_iterator(),
|
||||
trainer=trainer
|
||||
)
|
||||
```
|
||||
|
||||
**Memory usage**: ~200 MB (vs 10+ GB loading full dataset)
|
||||
|
||||
### Multi-file training
|
||||
|
||||
```python
|
||||
import glob
|
||||
|
||||
# Find all training files
|
||||
files = glob.glob("data/train/*.txt")
|
||||
print(f"Training on {len(files)} files")
|
||||
|
||||
# Train on all files
|
||||
tokenizer.train(files=files, trainer=trainer)
|
||||
```
|
||||
|
||||
### Parallel training (multi-processing)
|
||||
|
||||
```python
|
||||
from multiprocessing import Pool, cpu_count
|
||||
import os
|
||||
|
||||
def train_shard(shard_files):
|
||||
"""Train tokenizer on a shard of files."""
|
||||
tokenizer = Tokenizer(BPE())
|
||||
trainer = BpeTrainer(vocab_size=50000)
|
||||
tokenizer.train(files=shard_files, trainer=trainer)
|
||||
return tokenizer.get_vocab()
|
||||
|
||||
# Split files into shards
|
||||
num_shards = cpu_count()
|
||||
file_shards = [files[i::num_shards] for i in range(num_shards)]
|
||||
|
||||
# Train shards in parallel
|
||||
with Pool(num_shards) as pool:
|
||||
vocab_shards = pool.map(train_shard, file_shards)
|
||||
|
||||
# Merge vocabularies (custom logic needed)
|
||||
# This is a simplified example - real implementation would merge intelligently
|
||||
final_vocab = {}
|
||||
for vocab in vocab_shards:
|
||||
final_vocab.update(vocab)
|
||||
```
|
||||
|
||||
## Domain-specific tokenizers
|
||||
|
||||
### Code tokenizer
|
||||
|
||||
```python
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
from tokenizers.normalizers import Sequence, NFC
|
||||
|
||||
# Code-optimized configuration
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Minimal normalization (preserve case, whitespace)
|
||||
tokenizer.normalizer = NFC() # Only normalize Unicode
|
||||
|
||||
# Byte-level pre-tokenization (handles all characters)
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
# Train on code corpus
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50000,
|
||||
special_tokens=["<|endoftext|>", "<|pad|>"],
|
||||
min_frequency=2
|
||||
)
|
||||
|
||||
tokenizer.train(files=["code_corpus.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
### Medical/scientific tokenizer
|
||||
|
||||
```python
|
||||
# Preserve case and special characters
|
||||
from tokenizers.normalizers import NFKC
|
||||
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Sequence
|
||||
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Minimal normalization
|
||||
tokenizer.normalizer = NFKC()
|
||||
|
||||
# Preserve medical terms
|
||||
tokenizer.pre_tokenizer = Sequence([
|
||||
Whitespace(),
|
||||
Punctuation(behavior="isolated") # Keep punctuation separate
|
||||
])
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50000,
|
||||
special_tokens=["[UNK]", "[CLS]", "[SEP]"],
|
||||
min_frequency=3 # Higher threshold for rare medical terms
|
||||
)
|
||||
|
||||
tokenizer.train(files=["pubmed_corpus.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
### Multilingual tokenizer
|
||||
|
||||
```python
|
||||
# Handle multiple scripts
|
||||
from tokenizers.normalizers import NFKC, Lowercase, Sequence
|
||||
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
||||
# Normalize but don't lowercase (preserves script differences)
|
||||
tokenizer.normalizer = NFKC()
|
||||
|
||||
# Byte-level handles all Unicode
|
||||
from tokenizers.pre_tokenizers import ByteLevel
|
||||
tokenizer.pre_tokenizer = ByteLevel()
|
||||
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=100000, # Larger vocab for multiple languages
|
||||
special_tokens=["<unk>", "<s>", "</s>"],
|
||||
limit_alphabet=None # No limit (handles all scripts)
|
||||
)
|
||||
|
||||
# Train on multilingual corpus
|
||||
tokenizer.train(files=["multilingual_corpus.txt"], trainer=trainer)
|
||||
```
|
||||
|
||||
## Vocabulary size selection
|
||||
|
||||
### Guidelines by task
|
||||
|
||||
| Task | Recommended Vocab Size | Rationale |
|
||||
|-----------------------|------------------------|-----------|
|
||||
| English (monolingual) | 30,000 - 50,000 | Balanced coverage |
|
||||
| Multilingual | 50,000 - 250,000 | More languages = more tokens |
|
||||
| Code | 30,000 - 50,000 | Similar to English |
|
||||
| Domain-specific | 10,000 - 30,000 | Smaller, focused vocabulary |
|
||||
| Character-level tasks | 1,000 - 5,000 | Only characters + subwords |
|
||||
|
||||
### Vocabulary size impact
|
||||
|
||||
**Small vocab (10k)**:
|
||||
- Pros: Faster training, smaller model, less memory
|
||||
- Cons: More tokens per sentence, worse OOV handling
|
||||
|
||||
**Medium vocab (30k-50k)**:
|
||||
- Pros: Good balance, standard choice
|
||||
- Cons: None (recommended default)
|
||||
|
||||
**Large vocab (100k+)**:
|
||||
- Pros: Fewer tokens per sentence, better OOV
|
||||
- Cons: Slower training, larger embedding table
|
||||
|
||||
### Empirical testing
|
||||
|
||||
```python
|
||||
# Train multiple tokenizers with different vocab sizes
|
||||
vocab_sizes = [10000, 30000, 50000, 100000]
|
||||
|
||||
for vocab_size in vocab_sizes:
|
||||
tokenizer = Tokenizer(BPE())
|
||||
trainer = BpeTrainer(vocab_size=vocab_size)
|
||||
tokenizer.train(files=["sample.txt"], trainer=trainer)
|
||||
|
||||
# Evaluate on test set
|
||||
test_text = "Test sentence for evaluation..."
|
||||
tokens = tokenizer.encode(test_text).ids
|
||||
|
||||
print(f"Vocab: {vocab_size:6d} | Tokens: {len(tokens):3d} | Avg: {len(test_text)/len(tokens):.2f} chars/token")
|
||||
|
||||
# Example output:
|
||||
# Vocab: 10000 | Tokens: 12 | Avg: 2.33 chars/token
|
||||
# Vocab: 30000 | Tokens: 8 | Avg: 3.50 chars/token
|
||||
# Vocab: 50000 | Tokens: 7 | Avg: 4.00 chars/token
|
||||
# Vocab: 100000 | Tokens: 6 | Avg: 4.67 chars/token
|
||||
```
|
||||
|
||||
## Testing tokenizer quality
|
||||
|
||||
### Coverage test
|
||||
|
||||
```python
|
||||
# Test on held-out data
|
||||
test_corpus = load_dataset("wikitext", "wikitext-103-raw-v1", split="test")
|
||||
|
||||
total_tokens = 0
|
||||
unk_tokens = 0
|
||||
unk_id = tokenizer.token_to_id("[UNK]")
|
||||
|
||||
for text in test_corpus["text"]:
|
||||
if text.strip():
|
||||
encoding = tokenizer.encode(text)
|
||||
total_tokens += len(encoding.ids)
|
||||
unk_tokens += encoding.ids.count(unk_id)
|
||||
|
||||
unk_rate = unk_tokens / total_tokens
|
||||
print(f"Unknown token rate: {unk_rate:.2%}")
|
||||
|
||||
# Good quality: <1% unknown tokens
|
||||
# Acceptable: 1-5%
|
||||
# Poor: >5%
|
||||
```
|
||||
|
||||
### Compression test
|
||||
|
||||
```python
|
||||
# Measure tokenization efficiency
|
||||
import numpy as np
|
||||
|
||||
token_lengths = []
|
||||
|
||||
for text in test_corpus["text"][:1000]:
|
||||
if text.strip():
|
||||
encoding = tokenizer.encode(text)
|
||||
chars_per_token = len(text) / len(encoding.ids)
|
||||
token_lengths.append(chars_per_token)
|
||||
|
||||
avg_chars_per_token = np.mean(token_lengths)
|
||||
print(f"Average characters per token: {avg_chars_per_token:.2f}")
|
||||
|
||||
# Good: 4-6 chars/token (English)
|
||||
# Acceptable: 3-4 chars/token
|
||||
# Poor: <3 chars/token (under-compression)
|
||||
```
|
||||
|
||||
### Semantic test
|
||||
|
||||
```python
|
||||
# Manually inspect tokenization of common words/phrases
|
||||
test_phrases = [
|
||||
"tokenization",
|
||||
"machine learning",
|
||||
"artificial intelligence",
|
||||
"preprocessing",
|
||||
"hello world"
|
||||
]
|
||||
|
||||
for phrase in test_phrases:
|
||||
tokens = tokenizer.encode(phrase).tokens
|
||||
print(f"{phrase:25s} → {tokens}")
|
||||
|
||||
# Good tokenization:
|
||||
# tokenization → ['token', 'ization']
|
||||
# machine learning → ['machine', 'learning']
|
||||
# artificial intelligence → ['artificial', 'intelligence']
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Training too slow
|
||||
|
||||
**Solutions**:
|
||||
1. Reduce vocabulary size
|
||||
2. Increase `min_frequency`
|
||||
3. Use `limit_alphabet` to reduce initial alphabet
|
||||
4. Train on subset first
|
||||
|
||||
```python
|
||||
# Fast training configuration
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=20000, # Smaller vocab
|
||||
min_frequency=5, # Higher threshold
|
||||
limit_alphabet=500, # Limit alphabet
|
||||
show_progress=True
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: High unknown token rate
|
||||
|
||||
**Solutions**:
|
||||
1. Increase vocabulary size
|
||||
2. Decrease `min_frequency`
|
||||
3. Check normalization (might be too aggressive)
|
||||
|
||||
```python
|
||||
# Better coverage configuration
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=50000, # Larger vocab
|
||||
min_frequency=1, # Lower threshold
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Poor quality tokenization
|
||||
|
||||
**Solutions**:
|
||||
1. Verify normalization matches your use case
|
||||
2. Check pre-tokenization splits correctly
|
||||
3. Ensure training data is representative
|
||||
4. Try different algorithm (BPE vs WordPiece vs Unigram)
|
||||
|
||||
```python
|
||||
# Debug tokenization pipeline
|
||||
text = "Sample text to debug"
|
||||
|
||||
# Check normalization
|
||||
normalized = tokenizer.normalizer.normalize_str(text)
|
||||
print(f"Normalized: {normalized}")
|
||||
|
||||
# Check pre-tokenization
|
||||
pre_tokens = tokenizer.pre_tokenizer.pre_tokenize_str(text)
|
||||
print(f"Pre-tokens: {pre_tokens}")
|
||||
|
||||
# Check final tokenization
|
||||
tokens = tokenizer.encode(text).tokens
|
||||
print(f"Tokens: {tokens}")
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use representative training data** - Match your target domain
|
||||
2. **Start with standard configs** - BERT WordPiece or GPT-2 BPE
|
||||
3. **Test on held-out data** - Measure unknown token rate
|
||||
4. **Iterate on vocabulary size** - Test 30k, 50k, 100k
|
||||
5. **Save tokenizer with model** - Ensure reproducibility
|
||||
6. **Version your tokenizers** - Track changes for reproducibility
|
||||
7. **Document special tokens** - Critical for model training
|
||||
740
skills/mlops-knowledge/instructor/SKILL.md
Normal file
740
skills/mlops-knowledge/instructor/SKILL.md
Normal file
@@ -0,0 +1,740 @@
|
||||
---
|
||||
name: instructor
|
||||
description: Extract structured data from LLM responses with Pydantic validation, retry failed extractions automatically, parse complex JSON with type safety, and stream partial results with Instructor - battle-tested structured output library
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Prompt Engineering, Instructor, Structured Output, Pydantic, Data Extraction, JSON Parsing, Type Safety, Validation, Streaming, OpenAI, Anthropic]
|
||||
dependencies: [instructor, pydantic, openai, anthropic]
|
||||
---
|
||||
|
||||
# Instructor: Structured LLM Outputs
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use Instructor when you need to:
|
||||
- **Extract structured data** from LLM responses reliably
|
||||
- **Validate outputs** against Pydantic schemas automatically
|
||||
- **Retry failed extractions** with automatic error handling
|
||||
- **Parse complex JSON** with type safety and validation
|
||||
- **Stream partial results** for real-time processing
|
||||
- **Support multiple LLM providers** with consistent API
|
||||
|
||||
**GitHub Stars**: 15,000+ | **Battle-tested**: 100,000+ developers
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Base installation
|
||||
pip install instructor
|
||||
|
||||
# With specific providers
|
||||
pip install "instructor[anthropic]" # Anthropic Claude
|
||||
pip install "instructor[openai]" # OpenAI
|
||||
pip install "instructor[all]" # All providers
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Example: Extract User Data
|
||||
|
||||
```python
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
from anthropic import Anthropic
|
||||
|
||||
# Define output structure
|
||||
class User(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
email: str
|
||||
|
||||
# Create instructor client
|
||||
client = instructor.from_anthropic(Anthropic())
|
||||
|
||||
# Extract structured data
|
||||
user = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "John Doe is 30 years old. His email is john@example.com"
|
||||
}],
|
||||
response_model=User
|
||||
)
|
||||
|
||||
print(user.name) # "John Doe"
|
||||
print(user.age) # 30
|
||||
print(user.email) # "john@example.com"
|
||||
```
|
||||
|
||||
### With OpenAI
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = instructor.from_openai(OpenAI())
|
||||
|
||||
user = client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
response_model=User,
|
||||
messages=[{"role": "user", "content": "Extract: Alice, 25, alice@email.com"}]
|
||||
)
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. Response Models (Pydantic)
|
||||
|
||||
Response models define the structure and validation rules for LLM outputs.
|
||||
|
||||
#### Basic Model
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Article(BaseModel):
|
||||
title: str = Field(description="Article title")
|
||||
author: str = Field(description="Author name")
|
||||
word_count: int = Field(description="Number of words", gt=0)
|
||||
tags: list[str] = Field(description="List of relevant tags")
|
||||
|
||||
article = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Analyze this article: [article text]"
|
||||
}],
|
||||
response_model=Article
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Type safety with Python type hints
|
||||
- Automatic validation (word_count > 0)
|
||||
- Self-documenting with Field descriptions
|
||||
- IDE autocomplete support
|
||||
|
||||
#### Nested Models
|
||||
|
||||
```python
|
||||
class Address(BaseModel):
|
||||
street: str
|
||||
city: str
|
||||
country: str
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
address: Address # Nested model
|
||||
|
||||
person = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "John lives at 123 Main St, Boston, USA"
|
||||
}],
|
||||
response_model=Person
|
||||
)
|
||||
|
||||
print(person.address.city) # "Boston"
|
||||
```
|
||||
|
||||
#### Optional Fields
|
||||
|
||||
```python
|
||||
from typing import Optional
|
||||
|
||||
class Product(BaseModel):
|
||||
name: str
|
||||
price: float
|
||||
discount: Optional[float] = None # Optional
|
||||
description: str = Field(default="No description") # Default value
|
||||
|
||||
# LLM doesn't need to provide discount or description
|
||||
```
|
||||
|
||||
#### Enums for Constraints
|
||||
|
||||
```python
|
||||
from enum import Enum
|
||||
|
||||
class Sentiment(str, Enum):
|
||||
POSITIVE = "positive"
|
||||
NEGATIVE = "negative"
|
||||
NEUTRAL = "neutral"
|
||||
|
||||
class Review(BaseModel):
|
||||
text: str
|
||||
sentiment: Sentiment # Only these 3 values allowed
|
||||
|
||||
review = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "This product is amazing!"
|
||||
}],
|
||||
response_model=Review
|
||||
)
|
||||
|
||||
print(review.sentiment) # Sentiment.POSITIVE
|
||||
```
|
||||
|
||||
### 2. Validation
|
||||
|
||||
Pydantic validates LLM outputs automatically. If validation fails, Instructor retries.
|
||||
|
||||
#### Built-in Validators
|
||||
|
||||
```python
|
||||
from pydantic import Field, EmailStr, HttpUrl
|
||||
|
||||
class Contact(BaseModel):
|
||||
name: str = Field(min_length=2, max_length=100)
|
||||
age: int = Field(ge=0, le=120) # 0 <= age <= 120
|
||||
email: EmailStr # Validates email format
|
||||
website: HttpUrl # Validates URL format
|
||||
|
||||
# If LLM provides invalid data, Instructor retries automatically
|
||||
```
|
||||
|
||||
#### Custom Validators
|
||||
|
||||
```python
|
||||
from pydantic import field_validator
|
||||
|
||||
class Event(BaseModel):
|
||||
name: str
|
||||
date: str
|
||||
attendees: int
|
||||
|
||||
@field_validator('date')
|
||||
def validate_date(cls, v):
|
||||
"""Ensure date is in YYYY-MM-DD format."""
|
||||
import re
|
||||
if not re.match(r'\d{4}-\d{2}-\d{2}', v):
|
||||
raise ValueError('Date must be YYYY-MM-DD format')
|
||||
return v
|
||||
|
||||
@field_validator('attendees')
|
||||
def validate_attendees(cls, v):
|
||||
"""Ensure positive attendees."""
|
||||
if v < 1:
|
||||
raise ValueError('Must have at least 1 attendee')
|
||||
return v
|
||||
```
|
||||
|
||||
#### Model-Level Validation
|
||||
|
||||
```python
|
||||
from pydantic import model_validator
|
||||
|
||||
class DateRange(BaseModel):
|
||||
start_date: str
|
||||
end_date: str
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_dates(self):
|
||||
"""Ensure end_date is after start_date."""
|
||||
from datetime import datetime
|
||||
start = datetime.strptime(self.start_date, '%Y-%m-%d')
|
||||
end = datetime.strptime(self.end_date, '%Y-%m-%d')
|
||||
|
||||
if end < start:
|
||||
raise ValueError('end_date must be after start_date')
|
||||
return self
|
||||
```
|
||||
|
||||
### 3. Automatic Retrying
|
||||
|
||||
Instructor retries automatically when validation fails, providing error feedback to the LLM.
|
||||
|
||||
```python
|
||||
# Retries up to 3 times if validation fails
|
||||
user = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Extract user from: John, age unknown"
|
||||
}],
|
||||
response_model=User,
|
||||
max_retries=3 # Default is 3
|
||||
)
|
||||
|
||||
# If age can't be extracted, Instructor tells the LLM:
|
||||
# "Validation error: age - field required"
|
||||
# LLM tries again with better extraction
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
1. LLM generates output
|
||||
2. Pydantic validates
|
||||
3. If invalid: Error message sent back to LLM
|
||||
4. LLM tries again with error feedback
|
||||
5. Repeats up to max_retries
|
||||
|
||||
### 4. Streaming
|
||||
|
||||
Stream partial results for real-time processing.
|
||||
|
||||
#### Streaming Partial Objects
|
||||
|
||||
```python
|
||||
from instructor import Partial
|
||||
|
||||
class Story(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
tags: list[str]
|
||||
|
||||
# Stream partial updates as LLM generates
|
||||
for partial_story in client.messages.create_partial(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Write a short sci-fi story"
|
||||
}],
|
||||
response_model=Story
|
||||
):
|
||||
print(f"Title: {partial_story.title}")
|
||||
print(f"Content so far: {partial_story.content[:100]}...")
|
||||
# Update UI in real-time
|
||||
```
|
||||
|
||||
#### Streaming Iterables
|
||||
|
||||
```python
|
||||
class Task(BaseModel):
|
||||
title: str
|
||||
priority: str
|
||||
|
||||
# Stream list items as they're generated
|
||||
tasks = client.messages.create_iterable(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Generate 10 project tasks"
|
||||
}],
|
||||
response_model=Task
|
||||
)
|
||||
|
||||
for task in tasks:
|
||||
print(f"- {task.title} ({task.priority})")
|
||||
# Process each task as it arrives
|
||||
```
|
||||
|
||||
## Provider Configuration
|
||||
|
||||
### Anthropic Claude
|
||||
|
||||
```python
|
||||
import instructor
|
||||
from anthropic import Anthropic
|
||||
|
||||
client = instructor.from_anthropic(
|
||||
Anthropic(api_key="your-api-key")
|
||||
)
|
||||
|
||||
# Use with Claude models
|
||||
response = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[...],
|
||||
response_model=YourModel
|
||||
)
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = instructor.from_openai(
|
||||
OpenAI(api_key="your-api-key")
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
response_model=YourModel,
|
||||
messages=[...]
|
||||
)
|
||||
```
|
||||
|
||||
### Local Models (Ollama)
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# Point to local Ollama server
|
||||
client = instructor.from_openai(
|
||||
OpenAI(
|
||||
base_url="http://localhost:11434/v1",
|
||||
api_key="ollama" # Required but ignored
|
||||
),
|
||||
mode=instructor.Mode.JSON
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="llama3.1",
|
||||
response_model=YourModel,
|
||||
messages=[...]
|
||||
)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Pattern 1: Data Extraction from Text
|
||||
|
||||
```python
|
||||
class CompanyInfo(BaseModel):
|
||||
name: str
|
||||
founded_year: int
|
||||
industry: str
|
||||
employees: int
|
||||
headquarters: str
|
||||
|
||||
text = """
|
||||
Tesla, Inc. was founded in 2003. It operates in the automotive and energy
|
||||
industry with approximately 140,000 employees. The company is headquartered
|
||||
in Austin, Texas.
|
||||
"""
|
||||
|
||||
company = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Extract company information from: {text}"
|
||||
}],
|
||||
response_model=CompanyInfo
|
||||
)
|
||||
```
|
||||
|
||||
### Pattern 2: Classification
|
||||
|
||||
```python
|
||||
class Category(str, Enum):
|
||||
TECHNOLOGY = "technology"
|
||||
FINANCE = "finance"
|
||||
HEALTHCARE = "healthcare"
|
||||
EDUCATION = "education"
|
||||
OTHER = "other"
|
||||
|
||||
class ArticleClassification(BaseModel):
|
||||
category: Category
|
||||
confidence: float = Field(ge=0.0, le=1.0)
|
||||
keywords: list[str]
|
||||
|
||||
classification = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Classify this article: [article text]"
|
||||
}],
|
||||
response_model=ArticleClassification
|
||||
)
|
||||
```
|
||||
|
||||
### Pattern 3: Multi-Entity Extraction
|
||||
|
||||
```python
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
role: str
|
||||
|
||||
class Organization(BaseModel):
|
||||
name: str
|
||||
industry: str
|
||||
|
||||
class Entities(BaseModel):
|
||||
people: list[Person]
|
||||
organizations: list[Organization]
|
||||
locations: list[str]
|
||||
|
||||
text = "Tim Cook, CEO of Apple, announced at the event in Cupertino..."
|
||||
|
||||
entities = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Extract all entities from: {text}"
|
||||
}],
|
||||
response_model=Entities
|
||||
)
|
||||
|
||||
for person in entities.people:
|
||||
print(f"{person.name} - {person.role}")
|
||||
```
|
||||
|
||||
### Pattern 4: Structured Analysis
|
||||
|
||||
```python
|
||||
class SentimentAnalysis(BaseModel):
|
||||
overall_sentiment: Sentiment
|
||||
positive_aspects: list[str]
|
||||
negative_aspects: list[str]
|
||||
suggestions: list[str]
|
||||
score: float = Field(ge=-1.0, le=1.0)
|
||||
|
||||
review = "The product works well but setup was confusing..."
|
||||
|
||||
analysis = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Analyze this review: {review}"
|
||||
}],
|
||||
response_model=SentimentAnalysis
|
||||
)
|
||||
```
|
||||
|
||||
### Pattern 5: Batch Processing
|
||||
|
||||
```python
|
||||
def extract_person(text: str) -> Person:
|
||||
return client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Extract person from: {text}"
|
||||
}],
|
||||
response_model=Person
|
||||
)
|
||||
|
||||
texts = [
|
||||
"John Doe is a 30-year-old engineer",
|
||||
"Jane Smith, 25, works in marketing",
|
||||
"Bob Johnson, age 40, software developer"
|
||||
]
|
||||
|
||||
people = [extract_person(text) for text in texts]
|
||||
```
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Union Types
|
||||
|
||||
```python
|
||||
from typing import Union
|
||||
|
||||
class TextContent(BaseModel):
|
||||
type: str = "text"
|
||||
content: str
|
||||
|
||||
class ImageContent(BaseModel):
|
||||
type: str = "image"
|
||||
url: HttpUrl
|
||||
caption: str
|
||||
|
||||
class Post(BaseModel):
|
||||
title: str
|
||||
content: Union[TextContent, ImageContent] # Either type
|
||||
|
||||
# LLM chooses appropriate type based on content
|
||||
```
|
||||
|
||||
### Dynamic Models
|
||||
|
||||
```python
|
||||
from pydantic import create_model
|
||||
|
||||
# Create model at runtime
|
||||
DynamicUser = create_model(
|
||||
'User',
|
||||
name=(str, ...),
|
||||
age=(int, Field(ge=0)),
|
||||
email=(EmailStr, ...)
|
||||
)
|
||||
|
||||
user = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[...],
|
||||
response_model=DynamicUser
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Modes
|
||||
|
||||
```python
|
||||
# For providers without native structured outputs
|
||||
client = instructor.from_anthropic(
|
||||
Anthropic(),
|
||||
mode=instructor.Mode.JSON # JSON mode
|
||||
)
|
||||
|
||||
# Available modes:
|
||||
# - Mode.ANTHROPIC_TOOLS (recommended for Claude)
|
||||
# - Mode.JSON (fallback)
|
||||
# - Mode.TOOLS (OpenAI tools)
|
||||
```
|
||||
|
||||
### Context Management
|
||||
|
||||
```python
|
||||
# Single-use client
|
||||
with instructor.from_anthropic(Anthropic()) as client:
|
||||
result = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[...],
|
||||
response_model=YourModel
|
||||
)
|
||||
# Client closed automatically
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Handling Validation Errors
|
||||
|
||||
```python
|
||||
from pydantic import ValidationError
|
||||
|
||||
try:
|
||||
user = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[...],
|
||||
response_model=User,
|
||||
max_retries=3
|
||||
)
|
||||
except ValidationError as e:
|
||||
print(f"Failed after retries: {e}")
|
||||
# Handle gracefully
|
||||
|
||||
except Exception as e:
|
||||
print(f"API error: {e}")
|
||||
```
|
||||
|
||||
### Custom Error Messages
|
||||
|
||||
```python
|
||||
class ValidatedUser(BaseModel):
|
||||
name: str = Field(description="Full name, 2-100 characters")
|
||||
age: int = Field(description="Age between 0 and 120", ge=0, le=120)
|
||||
email: EmailStr = Field(description="Valid email address")
|
||||
|
||||
class Config:
|
||||
# Custom error messages
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{
|
||||
"name": "John Doe",
|
||||
"age": 30,
|
||||
"email": "john@example.com"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Clear Field Descriptions
|
||||
|
||||
```python
|
||||
# ❌ Bad: Vague
|
||||
class Product(BaseModel):
|
||||
name: str
|
||||
price: float
|
||||
|
||||
# ✅ Good: Descriptive
|
||||
class Product(BaseModel):
|
||||
name: str = Field(description="Product name from the text")
|
||||
price: float = Field(description="Price in USD, without currency symbol")
|
||||
```
|
||||
|
||||
### 2. Use Appropriate Validation
|
||||
|
||||
```python
|
||||
# ✅ Good: Constrain values
|
||||
class Rating(BaseModel):
|
||||
score: int = Field(ge=1, le=5, description="Rating from 1 to 5 stars")
|
||||
review: str = Field(min_length=10, description="Review text, at least 10 chars")
|
||||
```
|
||||
|
||||
### 3. Provide Examples in Prompts
|
||||
|
||||
```python
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": """Extract person info from: "John, 30, engineer"
|
||||
|
||||
Example format:
|
||||
{
|
||||
"name": "John Doe",
|
||||
"age": 30,
|
||||
"occupation": "engineer"
|
||||
}"""
|
||||
}]
|
||||
```
|
||||
|
||||
### 4. Use Enums for Fixed Categories
|
||||
|
||||
```python
|
||||
# ✅ Good: Enum ensures valid values
|
||||
class Status(str, Enum):
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
REJECTED = "rejected"
|
||||
|
||||
class Application(BaseModel):
|
||||
status: Status # LLM must choose from enum
|
||||
```
|
||||
|
||||
### 5. Handle Missing Data Gracefully
|
||||
|
||||
```python
|
||||
class PartialData(BaseModel):
|
||||
required_field: str
|
||||
optional_field: Optional[str] = None
|
||||
default_field: str = "default_value"
|
||||
|
||||
# LLM only needs to provide required_field
|
||||
```
|
||||
|
||||
## Comparison to Alternatives
|
||||
|
||||
| Feature | Instructor | Manual JSON | LangChain | DSPy |
|
||||
|---------|------------|-------------|-----------|------|
|
||||
| Type Safety | ✅ Yes | ❌ No | ⚠️ Partial | ✅ Yes |
|
||||
| Auto Validation | ✅ Yes | ❌ No | ❌ No | ⚠️ Limited |
|
||||
| Auto Retry | ✅ Yes | ❌ No | ❌ No | ✅ Yes |
|
||||
| Streaming | ✅ Yes | ❌ No | ✅ Yes | ❌ No |
|
||||
| Multi-Provider | ✅ Yes | ⚠️ Manual | ✅ Yes | ✅ Yes |
|
||||
| Learning Curve | Low | Low | Medium | High |
|
||||
|
||||
**When to choose Instructor:**
|
||||
- Need structured, validated outputs
|
||||
- Want type safety and IDE support
|
||||
- Require automatic retries
|
||||
- Building data extraction systems
|
||||
|
||||
**When to choose alternatives:**
|
||||
- DSPy: Need prompt optimization
|
||||
- LangChain: Building complex chains
|
||||
- Manual: Simple, one-off extractions
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://python.useinstructor.com
|
||||
- **GitHub**: https://github.com/jxnl/instructor (15k+ stars)
|
||||
- **Cookbook**: https://python.useinstructor.com/examples
|
||||
- **Discord**: Community support available
|
||||
|
||||
## See Also
|
||||
|
||||
- `references/validation.md` - Advanced validation patterns
|
||||
- `references/providers.md` - Provider-specific configuration
|
||||
- `references/examples.md` - Real-world use cases
|
||||
|
||||
|
||||
107
skills/mlops-knowledge/instructor/references/examples.md
Normal file
107
skills/mlops-knowledge/instructor/references/examples.md
Normal file
@@ -0,0 +1,107 @@
|
||||
# Real-World Examples
|
||||
|
||||
Practical examples of using Instructor for structured data extraction.
|
||||
|
||||
## Data Extraction
|
||||
|
||||
```python
|
||||
class CompanyInfo(BaseModel):
|
||||
name: str
|
||||
founded: int
|
||||
industry: str
|
||||
employees: int
|
||||
|
||||
text = "Apple was founded in 1976 in the technology industry with 164,000 employees."
|
||||
|
||||
company = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": f"Extract: {text}"}],
|
||||
response_model=CompanyInfo
|
||||
)
|
||||
```
|
||||
|
||||
## Classification
|
||||
|
||||
```python
|
||||
class Sentiment(str, Enum):
|
||||
POSITIVE = "positive"
|
||||
NEGATIVE = "negative"
|
||||
NEUTRAL = "neutral"
|
||||
|
||||
class Review(BaseModel):
|
||||
sentiment: Sentiment
|
||||
confidence: float = Field(ge=0.0, le=1.0)
|
||||
|
||||
review = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "This product is amazing!"}],
|
||||
response_model=Review
|
||||
)
|
||||
```
|
||||
|
||||
## Multi-Entity Extraction
|
||||
|
||||
```python
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
role: str
|
||||
|
||||
class Entities(BaseModel):
|
||||
people: list[Person]
|
||||
organizations: list[str]
|
||||
locations: list[str]
|
||||
|
||||
entities = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "Tim Cook, CEO of Apple, spoke in Cupertino..."}],
|
||||
response_model=Entities
|
||||
)
|
||||
```
|
||||
|
||||
## Structured Analysis
|
||||
|
||||
```python
|
||||
class Analysis(BaseModel):
|
||||
summary: str
|
||||
key_points: list[str]
|
||||
sentiment: Sentiment
|
||||
actionable_items: list[str]
|
||||
|
||||
analysis = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "Analyze: [long text]"}],
|
||||
response_model=Analysis
|
||||
)
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
```python
|
||||
texts = ["text1", "text2", "text3"]
|
||||
results = [
|
||||
client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": text}],
|
||||
response_model=YourModel
|
||||
)
|
||||
for text in texts
|
||||
]
|
||||
```
|
||||
|
||||
## Streaming
|
||||
|
||||
```python
|
||||
for partial in client.messages.create_partial(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "Generate report..."}],
|
||||
response_model=Report
|
||||
):
|
||||
print(f"Progress: {partial.title}")
|
||||
# Update UI in real-time
|
||||
```
|
||||
70
skills/mlops-knowledge/instructor/references/providers.md
Normal file
70
skills/mlops-knowledge/instructor/references/providers.md
Normal file
@@ -0,0 +1,70 @@
|
||||
# Provider Configuration
|
||||
|
||||
Guide to using Instructor with different LLM providers.
|
||||
|
||||
## Anthropic Claude
|
||||
|
||||
```python
|
||||
import instructor
|
||||
from anthropic import Anthropic
|
||||
|
||||
# Basic setup
|
||||
client = instructor.from_anthropic(Anthropic())
|
||||
|
||||
# With API key
|
||||
client = instructor.from_anthropic(
|
||||
Anthropic(api_key="your-api-key")
|
||||
)
|
||||
|
||||
# Recommended mode
|
||||
client = instructor.from_anthropic(
|
||||
Anthropic(),
|
||||
mode=instructor.Mode.ANTHROPIC_TOOLS
|
||||
)
|
||||
|
||||
# Usage
|
||||
result = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "..."}],
|
||||
response_model=YourModel
|
||||
)
|
||||
```
|
||||
|
||||
## OpenAI
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = instructor.from_openai(OpenAI())
|
||||
|
||||
result = client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
response_model=YourModel,
|
||||
messages=[{"role": "user", "content": "..."}]
|
||||
)
|
||||
```
|
||||
|
||||
## Local Models (Ollama)
|
||||
|
||||
```python
|
||||
client = instructor.from_openai(
|
||||
OpenAI(
|
||||
base_url="http://localhost:11434/v1",
|
||||
api_key="ollama"
|
||||
),
|
||||
mode=instructor.Mode.JSON
|
||||
)
|
||||
|
||||
result = client.chat.completions.create(
|
||||
model="llama3.1",
|
||||
response_model=YourModel,
|
||||
messages=[...]
|
||||
)
|
||||
```
|
||||
|
||||
## Modes
|
||||
|
||||
- `Mode.ANTHROPIC_TOOLS`: Recommended for Claude
|
||||
- `Mode.TOOLS`: OpenAI function calling
|
||||
- `Mode.JSON`: Fallback for unsupported providers
|
||||
606
skills/mlops-knowledge/instructor/references/validation.md
Normal file
606
skills/mlops-knowledge/instructor/references/validation.md
Normal file
@@ -0,0 +1,606 @@
|
||||
# Advanced Validation Patterns
|
||||
|
||||
Complete guide to validation in Instructor using Pydantic.
|
||||
|
||||
## Table of Contents
|
||||
- Built-in Validators
|
||||
- Custom Field Validators
|
||||
- Model-Level Validation
|
||||
- Complex Validation Patterns
|
||||
- Error Handling
|
||||
|
||||
## Built-in Validators
|
||||
|
||||
### Numeric Constraints
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Product(BaseModel):
|
||||
price: float = Field(gt=0, description="Price must be positive")
|
||||
discount: float = Field(ge=0, le=100, description="Discount 0-100%")
|
||||
quantity: int = Field(ge=1, description="At least 1 item")
|
||||
rating: float = Field(ge=0.0, le=5.0, description="Rating 0-5 stars")
|
||||
|
||||
# If LLM provides invalid values, automatic retry with error feedback
|
||||
```
|
||||
|
||||
**Available constraints:**
|
||||
- `gt`: Greater than
|
||||
- `ge`: Greater than or equal
|
||||
- `lt`: Less than
|
||||
- `le`: Less than or equal
|
||||
- `multiple_of`: Must be multiple of this number
|
||||
|
||||
### String Constraints
|
||||
|
||||
```python
|
||||
class User(BaseModel):
|
||||
username: str = Field(
|
||||
min_length=3,
|
||||
max_length=20,
|
||||
pattern=r'^[a-zA-Z0-9_]+$',
|
||||
description="3-20 alphanumeric characters"
|
||||
)
|
||||
bio: str = Field(max_length=500, description="Bio up to 500 chars")
|
||||
status: str = Field(pattern=r'^(active|inactive|pending)$')
|
||||
|
||||
# pattern validates against regex
|
||||
```
|
||||
|
||||
### Email and URL Validation
|
||||
|
||||
```python
|
||||
from pydantic import EmailStr, HttpUrl, AnyUrl
|
||||
|
||||
class Contact(BaseModel):
|
||||
email: EmailStr # Validates email format
|
||||
website: HttpUrl # Validates HTTP/HTTPS URLs
|
||||
portfolio: AnyUrl # Any valid URL scheme
|
||||
|
||||
contact = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Extract: john@example.com, https://example.com"
|
||||
}],
|
||||
response_model=Contact
|
||||
)
|
||||
```
|
||||
|
||||
### Date and DateTime Validation
|
||||
|
||||
```python
|
||||
from datetime import date, datetime
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
class Event(BaseModel):
|
||||
event_date: date # Validates date format
|
||||
created_at: datetime # Validates datetime format
|
||||
year: int = Field(ge=1900, le=2100)
|
||||
|
||||
@field_validator('event_date')
|
||||
def future_date(cls, v):
|
||||
"""Ensure event is in the future."""
|
||||
if v < date.today():
|
||||
raise ValueError('Event must be in the future')
|
||||
return v
|
||||
```
|
||||
|
||||
### List and Dict Validation
|
||||
|
||||
```python
|
||||
class Document(BaseModel):
|
||||
tags: list[str] = Field(min_length=1, max_length=10)
|
||||
keywords: list[str] = Field(min_length=3, description="At least 3 keywords")
|
||||
metadata: dict[str, str] = Field(description="String key-value pairs")
|
||||
|
||||
@field_validator('tags')
|
||||
def unique_tags(cls, v):
|
||||
"""Ensure tags are unique."""
|
||||
if len(v) != len(set(v)):
|
||||
raise ValueError('Tags must be unique')
|
||||
return v
|
||||
```
|
||||
|
||||
## Custom Field Validators
|
||||
|
||||
### Basic Field Validator
|
||||
|
||||
```python
|
||||
from pydantic import field_validator
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
@field_validator('name')
|
||||
def name_must_not_be_empty(cls, v):
|
||||
"""Validate name is not empty or just whitespace."""
|
||||
if not v or not v.strip():
|
||||
raise ValueError('Name cannot be empty')
|
||||
return v.strip()
|
||||
|
||||
@field_validator('age')
|
||||
def age_must_be_reasonable(cls, v):
|
||||
"""Validate age is between 0 and 120."""
|
||||
if v < 0 or v > 120:
|
||||
raise ValueError('Age must be between 0 and 120')
|
||||
return v
|
||||
```
|
||||
|
||||
### Validator with Field Info
|
||||
|
||||
```python
|
||||
from pydantic import ValidationInfo
|
||||
|
||||
class Article(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
|
||||
@field_validator('content')
|
||||
def content_length(cls, v, info: ValidationInfo):
|
||||
"""Validate content is longer than title."""
|
||||
if 'title' in info.data:
|
||||
title_len = len(info.data['title'])
|
||||
if len(v) < title_len * 2:
|
||||
raise ValueError('Content should be at least 2x title length')
|
||||
return v
|
||||
```
|
||||
|
||||
### Multiple Fields Validation
|
||||
|
||||
```python
|
||||
class TimeRange(BaseModel):
|
||||
start_time: str
|
||||
end_time: str
|
||||
|
||||
@field_validator('start_time', 'end_time')
|
||||
def valid_time_format(cls, v):
|
||||
"""Validate both times are in HH:MM format."""
|
||||
import re
|
||||
if not re.match(r'^\d{2}:\d{2}$', v):
|
||||
raise ValueError('Time must be in HH:MM format')
|
||||
return v
|
||||
```
|
||||
|
||||
### Transform and Validate
|
||||
|
||||
```python
|
||||
class URL(BaseModel):
|
||||
url: str
|
||||
|
||||
@field_validator('url')
|
||||
def normalize_url(cls, v):
|
||||
"""Add https:// if missing."""
|
||||
if not v.startswith(('http://', 'https://')):
|
||||
v = f'https://{v}'
|
||||
return v
|
||||
```
|
||||
|
||||
## Model-Level Validation
|
||||
|
||||
### Cross-Field Validation
|
||||
|
||||
```python
|
||||
from pydantic import model_validator
|
||||
|
||||
class DateRange(BaseModel):
|
||||
start_date: str
|
||||
end_date: str
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_dates(self):
|
||||
"""Ensure end_date is after start_date."""
|
||||
from datetime import datetime
|
||||
start = datetime.strptime(self.start_date, '%Y-%m-%d')
|
||||
end = datetime.strptime(self.end_date, '%Y-%m-%d')
|
||||
|
||||
if end < start:
|
||||
raise ValueError('end_date must be after start_date')
|
||||
return self
|
||||
|
||||
class PriceRange(BaseModel):
|
||||
min_price: float
|
||||
max_price: float
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_price_range(self):
|
||||
"""Ensure max > min."""
|
||||
if self.max_price <= self.min_price:
|
||||
raise ValueError('max_price must be greater than min_price')
|
||||
return self
|
||||
```
|
||||
|
||||
### Conditional Validation
|
||||
|
||||
```python
|
||||
class Order(BaseModel):
|
||||
order_type: str # "standard" or "express"
|
||||
delivery_date: str
|
||||
delivery_time: Optional[str] = None
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_delivery_time(self):
|
||||
"""Express orders need delivery time."""
|
||||
if self.order_type == "express" and not self.delivery_time:
|
||||
raise ValueError('Express orders require delivery_time')
|
||||
return self
|
||||
```
|
||||
|
||||
### Complex Business Logic
|
||||
|
||||
```python
|
||||
class Discount(BaseModel):
|
||||
code: str
|
||||
percentage: float = Field(ge=0, le=100)
|
||||
min_purchase: float = Field(ge=0)
|
||||
max_discount: float = Field(ge=0)
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_discount(self):
|
||||
"""Ensure discount logic is sound."""
|
||||
# Max discount can't exceed percentage of min_purchase
|
||||
theoretical_max = (self.percentage / 100) * self.min_purchase
|
||||
if self.max_discount > theoretical_max:
|
||||
self.max_discount = theoretical_max
|
||||
return self
|
||||
```
|
||||
|
||||
## Complex Validation Patterns
|
||||
|
||||
### Nested Model Validation
|
||||
|
||||
```python
|
||||
class Address(BaseModel):
|
||||
street: str
|
||||
city: str
|
||||
country: str
|
||||
postal_code: str
|
||||
|
||||
@field_validator('postal_code')
|
||||
def validate_postal_code(cls, v, info: ValidationInfo):
|
||||
"""Validate postal code format based on country."""
|
||||
if 'country' in info.data:
|
||||
country = info.data['country']
|
||||
if country == "USA":
|
||||
import re
|
||||
if not re.match(r'^\d{5}(-\d{4})?$', v):
|
||||
raise ValueError('Invalid US postal code')
|
||||
elif country == "Canada":
|
||||
if not re.match(r'^[A-Z]\d[A-Z] \d[A-Z]\d$', v):
|
||||
raise ValueError('Invalid Canadian postal code')
|
||||
return v
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
address: Address
|
||||
|
||||
# Nested validation runs automatically
|
||||
```
|
||||
|
||||
### List of Models
|
||||
|
||||
```python
|
||||
class Task(BaseModel):
|
||||
title: str = Field(min_length=1)
|
||||
priority: int = Field(ge=1, le=5)
|
||||
|
||||
class Project(BaseModel):
|
||||
name: str
|
||||
tasks: list[Task] = Field(min_length=1, description="At least 1 task")
|
||||
|
||||
@field_validator('tasks')
|
||||
def at_least_one_high_priority(cls, v):
|
||||
"""Ensure at least one task has priority >= 4."""
|
||||
if not any(task.priority >= 4 for task in v):
|
||||
raise ValueError('Project needs at least one high-priority task')
|
||||
return v
|
||||
```
|
||||
|
||||
### Union Type Validation
|
||||
|
||||
```python
|
||||
from typing import Union
|
||||
|
||||
class TextBlock(BaseModel):
|
||||
type: str = "text"
|
||||
content: str = Field(min_length=1)
|
||||
|
||||
class ImageBlock(BaseModel):
|
||||
type: str = "image"
|
||||
url: HttpUrl
|
||||
alt_text: str
|
||||
|
||||
class Page(BaseModel):
|
||||
title: str
|
||||
blocks: list[Union[TextBlock, ImageBlock]]
|
||||
|
||||
@field_validator('blocks')
|
||||
def validate_block_types(cls, v):
|
||||
"""Ensure first block is TextBlock."""
|
||||
if v and not isinstance(v[0], TextBlock):
|
||||
raise ValueError('First block must be text')
|
||||
return v
|
||||
```
|
||||
|
||||
### Dependent Fields
|
||||
|
||||
```python
|
||||
class Subscription(BaseModel):
|
||||
plan: str # "free", "pro", "enterprise"
|
||||
max_users: int
|
||||
features: list[str]
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_plan_limits(self):
|
||||
"""Enforce plan-specific limits."""
|
||||
limits = {
|
||||
"free": {"max_users": 1, "required_features": ["basic"]},
|
||||
"pro": {"max_users": 10, "required_features": ["basic", "advanced"]},
|
||||
"enterprise": {"max_users": 999, "required_features": ["basic", "advanced", "premium"]}
|
||||
}
|
||||
|
||||
if self.plan in limits:
|
||||
limit = limits[self.plan]
|
||||
|
||||
if self.max_users > limit["max_users"]:
|
||||
raise ValueError(f'{self.plan} plan limited to {limit["max_users"]} users')
|
||||
|
||||
for feature in limit["required_features"]:
|
||||
if feature not in self.features:
|
||||
raise ValueError(f'{self.plan} plan requires {feature} feature')
|
||||
|
||||
return self
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Graceful Degradation
|
||||
|
||||
```python
|
||||
class OptionalExtraction(BaseModel):
|
||||
# Required fields
|
||||
title: str
|
||||
|
||||
# Optional fields with defaults
|
||||
author: Optional[str] = None
|
||||
date: Optional[str] = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
|
||||
# LLM can succeed even if it can't extract everything
|
||||
```
|
||||
|
||||
### Partial Validation
|
||||
|
||||
```python
|
||||
from pydantic import ValidationError
|
||||
|
||||
def extract_with_fallback(text: str):
|
||||
"""Try full extraction, fall back to partial."""
|
||||
try:
|
||||
# Try full extraction
|
||||
return client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": text}],
|
||||
response_model=FullModel
|
||||
)
|
||||
except ValidationError:
|
||||
# Fall back to partial model
|
||||
return client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": text}],
|
||||
response_model=PartialModel
|
||||
)
|
||||
```
|
||||
|
||||
### Validation Error Inspection
|
||||
|
||||
```python
|
||||
from pydantic import ValidationError
|
||||
|
||||
try:
|
||||
result = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[...],
|
||||
response_model=MyModel,
|
||||
max_retries=3
|
||||
)
|
||||
except ValidationError as e:
|
||||
# Inspect specific errors
|
||||
for error in e.errors():
|
||||
field = error['loc'][0]
|
||||
message = error['msg']
|
||||
print(f"Field '{field}' failed: {message}")
|
||||
|
||||
# Custom handling per field
|
||||
if field == 'email':
|
||||
# Handle email validation failure
|
||||
pass
|
||||
```
|
||||
|
||||
### Custom Error Messages
|
||||
|
||||
```python
|
||||
class DetailedModel(BaseModel):
|
||||
name: str = Field(
|
||||
min_length=2,
|
||||
max_length=100,
|
||||
description="Name between 2-100 characters"
|
||||
)
|
||||
age: int = Field(
|
||||
ge=0,
|
||||
le=120,
|
||||
description="Age between 0 and 120 years"
|
||||
)
|
||||
|
||||
@field_validator('name')
|
||||
def validate_name(cls, v):
|
||||
"""Provide helpful error message."""
|
||||
if not v.strip():
|
||||
raise ValueError(
|
||||
'Name cannot be empty. '
|
||||
'Please provide a valid name from the text.'
|
||||
)
|
||||
return v
|
||||
|
||||
# When validation fails, LLM sees these helpful messages
|
||||
```
|
||||
|
||||
## Validation Best Practices
|
||||
|
||||
### 1. Be Specific
|
||||
|
||||
```python
|
||||
# ❌ Bad: Vague validation
|
||||
class Item(BaseModel):
|
||||
name: str
|
||||
|
||||
# ✅ Good: Specific constraints
|
||||
class Item(BaseModel):
|
||||
name: str = Field(
|
||||
min_length=1,
|
||||
max_length=200,
|
||||
description="Item name, 1-200 characters"
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Provide Context
|
||||
|
||||
```python
|
||||
# ✅ Good: Explain why validation failed
|
||||
@field_validator('price')
|
||||
def validate_price(cls, v):
|
||||
if v <= 0:
|
||||
raise ValueError(
|
||||
'Price must be positive. '
|
||||
'Extract numeric price from text without currency symbols.'
|
||||
)
|
||||
return v
|
||||
```
|
||||
|
||||
### 3. Use Enums for Fixed Sets
|
||||
|
||||
```python
|
||||
# ❌ Bad: String validation
|
||||
status: str
|
||||
|
||||
@field_validator('status')
|
||||
def validate_status(cls, v):
|
||||
if v not in ['active', 'inactive', 'pending']:
|
||||
raise ValueError('Invalid status')
|
||||
return v
|
||||
|
||||
# ✅ Good: Enum
|
||||
class Status(str, Enum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
PENDING = "pending"
|
||||
|
||||
status: Status # Validation automatic
|
||||
```
|
||||
|
||||
### 4. Balance Strictness
|
||||
|
||||
```python
|
||||
# Too strict: May fail unnecessarily
|
||||
class StrictModel(BaseModel):
|
||||
date: str = Field(pattern=r'^\d{4}-\d{2}-\d{2}$')
|
||||
# Fails if LLM uses "2024-1-5" instead of "2024-01-05"
|
||||
|
||||
# Better: Normalize in validator
|
||||
class FlexibleModel(BaseModel):
|
||||
date: str
|
||||
|
||||
@field_validator('date')
|
||||
def normalize_date(cls, v):
|
||||
from datetime import datetime
|
||||
# Parse flexible formats
|
||||
for fmt in ['%Y-%m-%d', '%Y/%m/%d', '%m/%d/%Y']:
|
||||
try:
|
||||
dt = datetime.strptime(v, fmt)
|
||||
return dt.strftime('%Y-%m-%d') # Normalize
|
||||
except ValueError:
|
||||
continue
|
||||
raise ValueError('Invalid date format')
|
||||
```
|
||||
|
||||
### 5. Test Validation
|
||||
|
||||
```python
|
||||
# Test your validators with edge cases
|
||||
def test_validation():
|
||||
# Should succeed
|
||||
valid = MyModel(field="valid_value")
|
||||
|
||||
# Should fail
|
||||
try:
|
||||
invalid = MyModel(field="invalid")
|
||||
assert False, "Should have raised ValidationError"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
|
||||
# Run tests before using in production
|
||||
```
|
||||
|
||||
## Advanced Techniques
|
||||
|
||||
### Conditional Required Fields
|
||||
|
||||
```python
|
||||
from typing import Optional
|
||||
|
||||
class ConditionalModel(BaseModel):
|
||||
type: str
|
||||
detail_a: Optional[str] = None
|
||||
detail_b: Optional[str] = None
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_required_details(self):
|
||||
"""Require different fields based on type."""
|
||||
if self.type == "type_a" and not self.detail_a:
|
||||
raise ValueError('type_a requires detail_a')
|
||||
if self.type == "type_b" and not self.detail_b:
|
||||
raise ValueError('type_b requires detail_b')
|
||||
return self
|
||||
```
|
||||
|
||||
### Validation with External Data
|
||||
|
||||
```python
|
||||
class Product(BaseModel):
|
||||
sku: str
|
||||
name: str
|
||||
|
||||
@field_validator('sku')
|
||||
def validate_sku(cls, v):
|
||||
"""Check SKU exists in database."""
|
||||
# Query database or API
|
||||
if not database.sku_exists(v):
|
||||
raise ValueError(f'SKU {v} not found in catalog')
|
||||
return v
|
||||
```
|
||||
|
||||
### Progressive Validation
|
||||
|
||||
```python
|
||||
# Start with loose validation
|
||||
class Stage1(BaseModel):
|
||||
data: str # Any string
|
||||
|
||||
# Then strict validation
|
||||
class Stage2(BaseModel):
|
||||
data: str = Field(pattern=r'^[A-Z]{3}-\d{6}$')
|
||||
|
||||
# Use Stage1 for initial extraction
|
||||
# Use Stage2 for final validation
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Pydantic Docs**: https://docs.pydantic.dev/latest/concepts/validators/
|
||||
- **Instructor Examples**: https://python.useinstructor.com/examples
|
||||
545
skills/mlops-knowledge/lambda-labs/SKILL.md
Normal file
545
skills/mlops-knowledge/lambda-labs/SKILL.md
Normal file
@@ -0,0 +1,545 @@
|
||||
---
|
||||
name: lambda-labs-gpu-cloud
|
||||
description: Reserved and on-demand GPU cloud instances for ML training and inference. Use when you need dedicated GPU instances with simple SSH access, persistent filesystems, or high-performance multi-node clusters for large-scale training.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Infrastructure, GPU Cloud, Training, Inference, Lambda Labs]
|
||||
dependencies: [lambda-cloud-client>=1.0.0]
|
||||
---
|
||||
|
||||
# Lambda Labs GPU Cloud
|
||||
|
||||
Comprehensive guide to running ML workloads on Lambda Labs GPU cloud with on-demand instances and 1-Click Clusters.
|
||||
|
||||
## When to use Lambda Labs
|
||||
|
||||
**Use Lambda Labs when:**
|
||||
- Need dedicated GPU instances with full SSH access
|
||||
- Running long training jobs (hours to days)
|
||||
- Want simple pricing with no egress fees
|
||||
- Need persistent storage across sessions
|
||||
- Require high-performance multi-node clusters (16-512 GPUs)
|
||||
- Want pre-installed ML stack (Lambda Stack with PyTorch, CUDA, NCCL)
|
||||
|
||||
**Key features:**
|
||||
- **GPU variety**: B200, H100, GH200, A100, A10, A6000, V100
|
||||
- **Lambda Stack**: Pre-installed PyTorch, TensorFlow, CUDA, cuDNN, NCCL
|
||||
- **Persistent filesystems**: Keep data across instance restarts
|
||||
- **1-Click Clusters**: 16-512 GPU Slurm clusters with InfiniBand
|
||||
- **Simple pricing**: Pay-per-minute, no egress fees
|
||||
- **Global regions**: 12+ regions worldwide
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Modal**: For serverless, auto-scaling workloads
|
||||
- **SkyPilot**: For multi-cloud orchestration and cost optimization
|
||||
- **RunPod**: For cheaper spot instances and serverless endpoints
|
||||
- **Vast.ai**: For GPU marketplace with lowest prices
|
||||
|
||||
## Quick start
|
||||
|
||||
### Account setup
|
||||
|
||||
1. Create account at https://lambda.ai
|
||||
2. Add payment method
|
||||
3. Generate API key from dashboard
|
||||
4. Add SSH key (required before launching instances)
|
||||
|
||||
### Launch via console
|
||||
|
||||
1. Go to https://cloud.lambda.ai/instances
|
||||
2. Click "Launch instance"
|
||||
3. Select GPU type and region
|
||||
4. Choose SSH key
|
||||
5. Optionally attach filesystem
|
||||
6. Launch and wait 3-15 minutes
|
||||
|
||||
### Connect via SSH
|
||||
|
||||
```bash
|
||||
# Get instance IP from console
|
||||
ssh ubuntu@<INSTANCE-IP>
|
||||
|
||||
# Or with specific key
|
||||
ssh -i ~/.ssh/lambda_key ubuntu@<INSTANCE-IP>
|
||||
```
|
||||
|
||||
## GPU instances
|
||||
|
||||
### Available GPUs
|
||||
|
||||
| GPU | VRAM | Price/GPU/hr | Best For |
|
||||
|-----|------|--------------|----------|
|
||||
| B200 SXM6 | 180 GB | $4.99 | Largest models, fastest training |
|
||||
| H100 SXM | 80 GB | $2.99-3.29 | Large model training |
|
||||
| H100 PCIe | 80 GB | $2.49 | Cost-effective H100 |
|
||||
| GH200 | 96 GB | $1.49 | Single-GPU large models |
|
||||
| A100 80GB | 80 GB | $1.79 | Production training |
|
||||
| A100 40GB | 40 GB | $1.29 | Standard training |
|
||||
| A10 | 24 GB | $0.75 | Inference, fine-tuning |
|
||||
| A6000 | 48 GB | $0.80 | Good VRAM/price ratio |
|
||||
| V100 | 16 GB | $0.55 | Budget training |
|
||||
|
||||
### Instance configurations
|
||||
|
||||
```
|
||||
8x GPU: Best for distributed training (DDP, FSDP)
|
||||
4x GPU: Large models, multi-GPU training
|
||||
2x GPU: Medium workloads
|
||||
1x GPU: Fine-tuning, inference, development
|
||||
```
|
||||
|
||||
### Launch times
|
||||
|
||||
- Single-GPU: 3-5 minutes
|
||||
- Multi-GPU: 10-15 minutes
|
||||
|
||||
## Lambda Stack
|
||||
|
||||
All instances come with Lambda Stack pre-installed:
|
||||
|
||||
```bash
|
||||
# Included software
|
||||
- Ubuntu 22.04 LTS
|
||||
- NVIDIA drivers (latest)
|
||||
- CUDA 12.x
|
||||
- cuDNN 8.x
|
||||
- NCCL (for multi-GPU)
|
||||
- PyTorch (latest)
|
||||
- TensorFlow (latest)
|
||||
- JAX
|
||||
- JupyterLab
|
||||
```
|
||||
|
||||
### Verify installation
|
||||
|
||||
```bash
|
||||
# Check GPU
|
||||
nvidia-smi
|
||||
|
||||
# Check PyTorch
|
||||
python -c "import torch; print(torch.cuda.is_available())"
|
||||
|
||||
# Check CUDA version
|
||||
nvcc --version
|
||||
```
|
||||
|
||||
## Python API
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip install lambda-cloud-client
|
||||
```
|
||||
|
||||
### Authentication
|
||||
|
||||
```python
|
||||
import os
|
||||
import lambda_cloud_client
|
||||
|
||||
# Configure with API key
|
||||
configuration = lambda_cloud_client.Configuration(
|
||||
host="https://cloud.lambdalabs.com/api/v1",
|
||||
access_token=os.environ["LAMBDA_API_KEY"]
|
||||
)
|
||||
```
|
||||
|
||||
### List available instances
|
||||
|
||||
```python
|
||||
with lambda_cloud_client.ApiClient(configuration) as api_client:
|
||||
api = lambda_cloud_client.DefaultApi(api_client)
|
||||
|
||||
# Get available instance types
|
||||
types = api.instance_types()
|
||||
for name, info in types.data.items():
|
||||
print(f"{name}: {info.instance_type.description}")
|
||||
```
|
||||
|
||||
### Launch instance
|
||||
|
||||
```python
|
||||
from lambda_cloud_client.models import LaunchInstanceRequest
|
||||
|
||||
request = LaunchInstanceRequest(
|
||||
region_name="us-west-1",
|
||||
instance_type_name="gpu_1x_h100_sxm5",
|
||||
ssh_key_names=["my-ssh-key"],
|
||||
file_system_names=["my-filesystem"], # Optional
|
||||
name="training-job"
|
||||
)
|
||||
|
||||
response = api.launch_instance(request)
|
||||
instance_id = response.data.instance_ids[0]
|
||||
print(f"Launched: {instance_id}")
|
||||
```
|
||||
|
||||
### List running instances
|
||||
|
||||
```python
|
||||
instances = api.list_instances()
|
||||
for instance in instances.data:
|
||||
print(f"{instance.name}: {instance.ip} ({instance.status})")
|
||||
```
|
||||
|
||||
### Terminate instance
|
||||
|
||||
```python
|
||||
from lambda_cloud_client.models import TerminateInstanceRequest
|
||||
|
||||
request = TerminateInstanceRequest(
|
||||
instance_ids=[instance_id]
|
||||
)
|
||||
api.terminate_instance(request)
|
||||
```
|
||||
|
||||
### SSH key management
|
||||
|
||||
```python
|
||||
from lambda_cloud_client.models import AddSshKeyRequest
|
||||
|
||||
# Add SSH key
|
||||
request = AddSshKeyRequest(
|
||||
name="my-key",
|
||||
public_key="ssh-rsa AAAA..."
|
||||
)
|
||||
api.add_ssh_key(request)
|
||||
|
||||
# List keys
|
||||
keys = api.list_ssh_keys()
|
||||
|
||||
# Delete key
|
||||
api.delete_ssh_key(key_id)
|
||||
```
|
||||
|
||||
## CLI with curl
|
||||
|
||||
### List instance types
|
||||
|
||||
```bash
|
||||
curl -u $LAMBDA_API_KEY: \
|
||||
https://cloud.lambdalabs.com/api/v1/instance-types | jq
|
||||
```
|
||||
|
||||
### Launch instance
|
||||
|
||||
```bash
|
||||
curl -u $LAMBDA_API_KEY: \
|
||||
-X POST https://cloud.lambdalabs.com/api/v1/instance-operations/launch \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"region_name": "us-west-1",
|
||||
"instance_type_name": "gpu_1x_h100_sxm5",
|
||||
"ssh_key_names": ["my-key"]
|
||||
}' | jq
|
||||
```
|
||||
|
||||
### Terminate instance
|
||||
|
||||
```bash
|
||||
curl -u $LAMBDA_API_KEY: \
|
||||
-X POST https://cloud.lambdalabs.com/api/v1/instance-operations/terminate \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"instance_ids": ["<INSTANCE-ID>"]}' | jq
|
||||
```
|
||||
|
||||
## Persistent storage
|
||||
|
||||
### Filesystems
|
||||
|
||||
Filesystems persist data across instance restarts:
|
||||
|
||||
```bash
|
||||
# Mount location
|
||||
/lambda/nfs/<FILESYSTEM_NAME>
|
||||
|
||||
# Example: save checkpoints
|
||||
python train.py --checkpoint-dir /lambda/nfs/my-storage/checkpoints
|
||||
```
|
||||
|
||||
### Create filesystem
|
||||
|
||||
1. Go to Storage in Lambda console
|
||||
2. Click "Create filesystem"
|
||||
3. Select region (must match instance region)
|
||||
4. Name and create
|
||||
|
||||
### Attach to instance
|
||||
|
||||
Filesystems must be attached at instance launch time:
|
||||
- Via console: Select filesystem when launching
|
||||
- Via API: Include `file_system_names` in launch request
|
||||
|
||||
### Best practices
|
||||
|
||||
```bash
|
||||
# Store on filesystem (persists)
|
||||
/lambda/nfs/storage/
|
||||
├── datasets/
|
||||
├── checkpoints/
|
||||
├── models/
|
||||
└── outputs/
|
||||
|
||||
# Local SSD (faster, ephemeral)
|
||||
/home/ubuntu/
|
||||
└── working/ # Temporary files
|
||||
```
|
||||
|
||||
## SSH configuration
|
||||
|
||||
### Add SSH key
|
||||
|
||||
```bash
|
||||
# Generate key locally
|
||||
ssh-keygen -t ed25519 -f ~/.ssh/lambda_key
|
||||
|
||||
# Add public key to Lambda console
|
||||
# Or via API
|
||||
```
|
||||
|
||||
### Multiple keys
|
||||
|
||||
```bash
|
||||
# On instance, add more keys
|
||||
echo 'ssh-rsa AAAA...' >> ~/.ssh/authorized_keys
|
||||
```
|
||||
|
||||
### Import from GitHub
|
||||
|
||||
```bash
|
||||
# On instance
|
||||
ssh-import-id gh:username
|
||||
```
|
||||
|
||||
### SSH tunneling
|
||||
|
||||
```bash
|
||||
# Forward Jupyter
|
||||
ssh -L 8888:localhost:8888 ubuntu@<IP>
|
||||
|
||||
# Forward TensorBoard
|
||||
ssh -L 6006:localhost:6006 ubuntu@<IP>
|
||||
|
||||
# Multiple ports
|
||||
ssh -L 8888:localhost:8888 -L 6006:localhost:6006 ubuntu@<IP>
|
||||
```
|
||||
|
||||
## JupyterLab
|
||||
|
||||
### Launch from console
|
||||
|
||||
1. Go to Instances page
|
||||
2. Click "Launch" in Cloud IDE column
|
||||
3. JupyterLab opens in browser
|
||||
|
||||
### Manual access
|
||||
|
||||
```bash
|
||||
# On instance
|
||||
jupyter lab --ip=0.0.0.0 --port=8888
|
||||
|
||||
# From local machine with tunnel
|
||||
ssh -L 8888:localhost:8888 ubuntu@<IP>
|
||||
# Open http://localhost:8888
|
||||
```
|
||||
|
||||
## Training workflows
|
||||
|
||||
### Single-GPU training
|
||||
|
||||
```bash
|
||||
# SSH to instance
|
||||
ssh ubuntu@<IP>
|
||||
|
||||
# Clone repo
|
||||
git clone https://github.com/user/project
|
||||
cd project
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Train
|
||||
python train.py --epochs 100 --checkpoint-dir /lambda/nfs/storage/checkpoints
|
||||
```
|
||||
|
||||
### Multi-GPU training (single node)
|
||||
|
||||
```python
|
||||
# train_ddp.py
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
def main():
|
||||
dist.init_process_group("nccl")
|
||||
rank = dist.get_rank()
|
||||
device = rank % torch.cuda.device_count()
|
||||
|
||||
model = MyModel().to(device)
|
||||
model = DDP(model, device_ids=[device])
|
||||
|
||||
# Training loop...
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
```bash
|
||||
# Launch with torchrun (8 GPUs)
|
||||
torchrun --nproc_per_node=8 train_ddp.py
|
||||
```
|
||||
|
||||
### Checkpoint to filesystem
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
checkpoint_dir = "/lambda/nfs/my-storage/checkpoints"
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Save checkpoint
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': loss,
|
||||
}, f"{checkpoint_dir}/checkpoint_{epoch}.pt")
|
||||
```
|
||||
|
||||
## 1-Click Clusters
|
||||
|
||||
### Overview
|
||||
|
||||
High-performance Slurm clusters with:
|
||||
- 16-512 NVIDIA H100 or B200 GPUs
|
||||
- NVIDIA Quantum-2 400 Gb/s InfiniBand
|
||||
- GPUDirect RDMA at 3200 Gb/s
|
||||
- Pre-installed distributed ML stack
|
||||
|
||||
### Included software
|
||||
|
||||
- Ubuntu 22.04 LTS + Lambda Stack
|
||||
- NCCL, Open MPI
|
||||
- PyTorch with DDP and FSDP
|
||||
- TensorFlow
|
||||
- OFED drivers
|
||||
|
||||
### Storage
|
||||
|
||||
- 24 TB NVMe per compute node (ephemeral)
|
||||
- Lambda filesystems for persistent data
|
||||
|
||||
### Multi-node training
|
||||
|
||||
```bash
|
||||
# On Slurm cluster
|
||||
srun --nodes=4 --ntasks-per-node=8 --gpus-per-node=8 \
|
||||
torchrun --nnodes=4 --nproc_per_node=8 \
|
||||
--rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:29500 \
|
||||
train.py
|
||||
```
|
||||
|
||||
## Networking
|
||||
|
||||
### Bandwidth
|
||||
|
||||
- Inter-instance (same region): up to 200 Gbps
|
||||
- Internet outbound: 20 Gbps max
|
||||
|
||||
### Firewall
|
||||
|
||||
- Default: Only port 22 (SSH) open
|
||||
- Configure additional ports in Lambda console
|
||||
- ICMP traffic allowed by default
|
||||
|
||||
### Private IPs
|
||||
|
||||
```bash
|
||||
# Find private IP
|
||||
ip addr show | grep 'inet '
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Fine-tuning LLM
|
||||
|
||||
```bash
|
||||
# 1. Launch 8x H100 instance with filesystem
|
||||
|
||||
# 2. SSH and setup
|
||||
ssh ubuntu@<IP>
|
||||
pip install transformers accelerate peft
|
||||
|
||||
# 3. Download model to filesystem
|
||||
python -c "
|
||||
from transformers import AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf')
|
||||
model.save_pretrained('/lambda/nfs/storage/models/llama-2-7b')
|
||||
"
|
||||
|
||||
# 4. Fine-tune with checkpoints on filesystem
|
||||
accelerate launch --num_processes 8 train.py \
|
||||
--model_path /lambda/nfs/storage/models/llama-2-7b \
|
||||
--output_dir /lambda/nfs/storage/outputs \
|
||||
--checkpoint_dir /lambda/nfs/storage/checkpoints
|
||||
```
|
||||
|
||||
### Workflow 2: Batch inference
|
||||
|
||||
```bash
|
||||
# 1. Launch A10 instance (cost-effective for inference)
|
||||
|
||||
# 2. Run inference
|
||||
python inference.py \
|
||||
--model /lambda/nfs/storage/models/fine-tuned \
|
||||
--input /lambda/nfs/storage/data/inputs.jsonl \
|
||||
--output /lambda/nfs/storage/data/outputs.jsonl
|
||||
```
|
||||
|
||||
## Cost optimization
|
||||
|
||||
### Choose right GPU
|
||||
|
||||
| Task | Recommended GPU |
|
||||
|------|-----------------|
|
||||
| LLM fine-tuning (7B) | A100 40GB |
|
||||
| LLM fine-tuning (70B) | 8x H100 |
|
||||
| Inference | A10, A6000 |
|
||||
| Development | V100, A10 |
|
||||
| Maximum performance | B200 |
|
||||
|
||||
### Reduce costs
|
||||
|
||||
1. **Use filesystems**: Avoid re-downloading data
|
||||
2. **Checkpoint frequently**: Resume interrupted training
|
||||
3. **Right-size**: Don't over-provision GPUs
|
||||
4. **Terminate idle**: No auto-stop, manually terminate
|
||||
|
||||
### Monitor usage
|
||||
|
||||
- Dashboard shows real-time GPU utilization
|
||||
- API for programmatic monitoring
|
||||
|
||||
## Common issues
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| Instance won't launch | Check region availability, try different GPU |
|
||||
| SSH connection refused | Wait for instance to initialize (3-15 min) |
|
||||
| Data lost after terminate | Use persistent filesystems |
|
||||
| Slow data transfer | Use filesystem in same region |
|
||||
| GPU not detected | Reboot instance, check drivers |
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Multi-node training, API automation
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://docs.lambda.ai
|
||||
- **Console**: https://cloud.lambda.ai
|
||||
- **Pricing**: https://lambda.ai/instances
|
||||
- **Support**: https://support.lambdalabs.com
|
||||
- **Blog**: https://lambda.ai/blog
|
||||
611
skills/mlops-knowledge/lambda-labs/references/advanced-usage.md
Normal file
611
skills/mlops-knowledge/lambda-labs/references/advanced-usage.md
Normal file
@@ -0,0 +1,611 @@
|
||||
# Lambda Labs Advanced Usage Guide
|
||||
|
||||
## Multi-Node Distributed Training
|
||||
|
||||
### PyTorch DDP across nodes
|
||||
|
||||
```python
|
||||
# train_multi_node.py
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
def setup_distributed():
|
||||
# Environment variables set by launcher
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
rank=rank,
|
||||
world_size=world_size
|
||||
)
|
||||
|
||||
torch.cuda.set_device(local_rank)
|
||||
return rank, world_size, local_rank
|
||||
|
||||
def main():
|
||||
rank, world_size, local_rank = setup_distributed()
|
||||
|
||||
model = MyModel().cuda(local_rank)
|
||||
model = DDP(model, device_ids=[local_rank])
|
||||
|
||||
# Training loop with synchronized gradients
|
||||
for epoch in range(num_epochs):
|
||||
train_one_epoch(model, dataloader)
|
||||
|
||||
# Save checkpoint on rank 0 only
|
||||
if rank == 0:
|
||||
torch.save(model.module.state_dict(), f"checkpoint_{epoch}.pt")
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
### Launch on multiple instances
|
||||
|
||||
```bash
|
||||
# On Node 0 (master)
|
||||
export MASTER_ADDR=<NODE0_PRIVATE_IP>
|
||||
export MASTER_PORT=29500
|
||||
|
||||
torchrun \
|
||||
--nnodes=2 \
|
||||
--nproc_per_node=8 \
|
||||
--node_rank=0 \
|
||||
--master_addr=$MASTER_ADDR \
|
||||
--master_port=$MASTER_PORT \
|
||||
train_multi_node.py
|
||||
|
||||
# On Node 1
|
||||
export MASTER_ADDR=<NODE0_PRIVATE_IP>
|
||||
export MASTER_PORT=29500
|
||||
|
||||
torchrun \
|
||||
--nnodes=2 \
|
||||
--nproc_per_node=8 \
|
||||
--node_rank=1 \
|
||||
--master_addr=$MASTER_ADDR \
|
||||
--master_port=$MASTER_PORT \
|
||||
train_multi_node.py
|
||||
```
|
||||
|
||||
### FSDP for large models
|
||||
|
||||
```python
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
||||
|
||||
# Wrap policy for transformer models
|
||||
auto_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={LlamaDecoderLayer}
|
||||
)
|
||||
|
||||
model = FSDP(
|
||||
model,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.bfloat16,
|
||||
reduce_dtype=torch.bfloat16,
|
||||
buffer_dtype=torch.bfloat16,
|
||||
),
|
||||
device_id=local_rank,
|
||||
)
|
||||
```
|
||||
|
||||
### DeepSpeed ZeRO
|
||||
|
||||
```python
|
||||
# ds_config.json
|
||||
{
|
||||
"train_batch_size": 64,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"fp16": {"enabled": true},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {"device": "cpu"},
|
||||
"offload_param": {"device": "cpu"}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
# Launch with DeepSpeed
|
||||
deepspeed --num_nodes=2 \
|
||||
--num_gpus=8 \
|
||||
--hostfile=hostfile.txt \
|
||||
train.py --deepspeed ds_config.json
|
||||
```
|
||||
|
||||
### Hostfile for multi-node
|
||||
|
||||
```bash
|
||||
# hostfile.txt
|
||||
node0_ip slots=8
|
||||
node1_ip slots=8
|
||||
```
|
||||
|
||||
## API Automation
|
||||
|
||||
### Auto-launch training jobs
|
||||
|
||||
```python
|
||||
import os
|
||||
import time
|
||||
import lambda_cloud_client
|
||||
from lambda_cloud_client.models import LaunchInstanceRequest
|
||||
|
||||
class LambdaJobManager:
|
||||
def __init__(self, api_key: str):
|
||||
self.config = lambda_cloud_client.Configuration(
|
||||
host="https://cloud.lambdalabs.com/api/v1",
|
||||
access_token=api_key
|
||||
)
|
||||
|
||||
def find_available_gpu(self, gpu_types: list[str], regions: list[str] = None):
|
||||
"""Find first available GPU type across regions."""
|
||||
with lambda_cloud_client.ApiClient(self.config) as client:
|
||||
api = lambda_cloud_client.DefaultApi(client)
|
||||
types = api.instance_types()
|
||||
|
||||
for gpu_type in gpu_types:
|
||||
if gpu_type in types.data:
|
||||
info = types.data[gpu_type]
|
||||
for region in info.regions_with_capacity_available:
|
||||
if regions is None or region.name in regions:
|
||||
return gpu_type, region.name
|
||||
|
||||
return None, None
|
||||
|
||||
def launch_and_wait(self, instance_type: str, region: str,
|
||||
ssh_key: str, filesystem: str = None,
|
||||
timeout: int = 900) -> dict:
|
||||
"""Launch instance and wait for it to be ready."""
|
||||
with lambda_cloud_client.ApiClient(self.config) as client:
|
||||
api = lambda_cloud_client.DefaultApi(client)
|
||||
|
||||
request = LaunchInstanceRequest(
|
||||
region_name=region,
|
||||
instance_type_name=instance_type,
|
||||
ssh_key_names=[ssh_key],
|
||||
file_system_names=[filesystem] if filesystem else [],
|
||||
)
|
||||
|
||||
response = api.launch_instance(request)
|
||||
instance_id = response.data.instance_ids[0]
|
||||
|
||||
# Poll until ready
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
instance = api.get_instance(instance_id)
|
||||
if instance.data.status == "active":
|
||||
return {
|
||||
"id": instance_id,
|
||||
"ip": instance.data.ip,
|
||||
"status": "active"
|
||||
}
|
||||
time.sleep(30)
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} not ready after {timeout}s")
|
||||
|
||||
def terminate(self, instance_ids: list[str]):
|
||||
"""Terminate instances."""
|
||||
from lambda_cloud_client.models import TerminateInstanceRequest
|
||||
|
||||
with lambda_cloud_client.ApiClient(self.config) as client:
|
||||
api = lambda_cloud_client.DefaultApi(client)
|
||||
request = TerminateInstanceRequest(instance_ids=instance_ids)
|
||||
api.terminate_instance(request)
|
||||
|
||||
|
||||
# Usage
|
||||
manager = LambdaJobManager(os.environ["LAMBDA_API_KEY"])
|
||||
|
||||
# Find available H100 or A100
|
||||
gpu_type, region = manager.find_available_gpu(
|
||||
["gpu_8x_h100_sxm5", "gpu_8x_a100_80gb_sxm4"],
|
||||
regions=["us-west-1", "us-east-1"]
|
||||
)
|
||||
|
||||
if gpu_type:
|
||||
instance = manager.launch_and_wait(
|
||||
gpu_type, region,
|
||||
ssh_key="my-key",
|
||||
filesystem="training-data"
|
||||
)
|
||||
print(f"Ready: ssh ubuntu@{instance['ip']}")
|
||||
```
|
||||
|
||||
### Batch job submission
|
||||
|
||||
```python
|
||||
import subprocess
|
||||
import paramiko
|
||||
|
||||
def run_remote_job(ip: str, ssh_key_path: str, commands: list[str]):
|
||||
"""Execute commands on remote instance."""
|
||||
client = paramiko.SSHClient()
|
||||
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
client.connect(ip, username="ubuntu", key_filename=ssh_key_path)
|
||||
|
||||
for cmd in commands:
|
||||
stdin, stdout, stderr = client.exec_command(cmd)
|
||||
print(stdout.read().decode())
|
||||
if stderr.read():
|
||||
print(f"Error: {stderr.read().decode()}")
|
||||
|
||||
client.close()
|
||||
|
||||
# Submit training job
|
||||
commands = [
|
||||
"cd /lambda/nfs/storage/project",
|
||||
"git pull",
|
||||
"pip install -r requirements.txt",
|
||||
"nohup torchrun --nproc_per_node=8 train.py > train.log 2>&1 &"
|
||||
]
|
||||
|
||||
run_remote_job(instance["ip"], "~/.ssh/lambda_key", commands)
|
||||
```
|
||||
|
||||
### Monitor training progress
|
||||
|
||||
```python
|
||||
def monitor_job(ip: str, ssh_key_path: str, log_file: str = "train.log"):
|
||||
"""Stream training logs from remote instance."""
|
||||
import time
|
||||
|
||||
client = paramiko.SSHClient()
|
||||
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
client.connect(ip, username="ubuntu", key_filename=ssh_key_path)
|
||||
|
||||
# Tail log file
|
||||
stdin, stdout, stderr = client.exec_command(f"tail -f {log_file}")
|
||||
|
||||
try:
|
||||
for line in stdout:
|
||||
print(line.strip())
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
client.close()
|
||||
```
|
||||
|
||||
## 1-Click Cluster Workflows
|
||||
|
||||
### Slurm job submission
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=llm-training
|
||||
#SBATCH --nodes=4
|
||||
#SBATCH --ntasks-per-node=8
|
||||
#SBATCH --gpus-per-node=8
|
||||
#SBATCH --time=24:00:00
|
||||
#SBATCH --output=logs/%j.out
|
||||
#SBATCH --error=logs/%j.err
|
||||
|
||||
# Set up distributed environment
|
||||
export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
|
||||
export MASTER_PORT=29500
|
||||
|
||||
# Launch training
|
||||
srun torchrun \
|
||||
--nnodes=$SLURM_NNODES \
|
||||
--nproc_per_node=$SLURM_GPUS_PER_NODE \
|
||||
--rdzv_backend=c10d \
|
||||
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
|
||||
train.py \
|
||||
--config config.yaml
|
||||
```
|
||||
|
||||
### Interactive cluster session
|
||||
|
||||
```bash
|
||||
# Request interactive session
|
||||
srun --nodes=1 --ntasks=1 --gpus=8 --time=4:00:00 --pty bash
|
||||
|
||||
# Now on compute node with 8 GPUs
|
||||
nvidia-smi
|
||||
python train.py
|
||||
```
|
||||
|
||||
### Monitoring cluster jobs
|
||||
|
||||
```bash
|
||||
# View job queue
|
||||
squeue
|
||||
|
||||
# View job details
|
||||
scontrol show job <JOB_ID>
|
||||
|
||||
# Cancel job
|
||||
scancel <JOB_ID>
|
||||
|
||||
# View node status
|
||||
sinfo
|
||||
|
||||
# View GPU usage across cluster
|
||||
srun --nodes=4 nvidia-smi --query-gpu=name,utilization.gpu --format=csv
|
||||
```
|
||||
|
||||
## Advanced Filesystem Usage
|
||||
|
||||
### Data staging workflow
|
||||
|
||||
```bash
|
||||
# Stage data from S3 to filesystem (one-time)
|
||||
aws s3 sync s3://my-bucket/dataset /lambda/nfs/storage/datasets/
|
||||
|
||||
# Or use rclone
|
||||
rclone sync s3:my-bucket/dataset /lambda/nfs/storage/datasets/
|
||||
```
|
||||
|
||||
### Shared filesystem across instances
|
||||
|
||||
```python
|
||||
# Instance 1: Write checkpoints
|
||||
checkpoint_path = "/lambda/nfs/shared/checkpoints/model_step_1000.pt"
|
||||
torch.save(model.state_dict(), checkpoint_path)
|
||||
|
||||
# Instance 2: Read checkpoints
|
||||
model.load_state_dict(torch.load(checkpoint_path))
|
||||
```
|
||||
|
||||
### Filesystem best practices
|
||||
|
||||
```bash
|
||||
# Organize for ML workflows
|
||||
/lambda/nfs/storage/
|
||||
├── datasets/
|
||||
│ ├── raw/ # Original data
|
||||
│ └── processed/ # Preprocessed data
|
||||
├── models/
|
||||
│ ├── pretrained/ # Base models
|
||||
│ └── fine-tuned/ # Your trained models
|
||||
├── checkpoints/
|
||||
│ └── experiment_1/ # Per-experiment checkpoints
|
||||
├── logs/
|
||||
│ └── tensorboard/ # Training logs
|
||||
└── outputs/
|
||||
└── inference/ # Inference results
|
||||
```
|
||||
|
||||
## Environment Management
|
||||
|
||||
### Custom Python environments
|
||||
|
||||
```bash
|
||||
# Don't modify system Python, create venv
|
||||
python -m venv ~/myenv
|
||||
source ~/myenv/bin/activate
|
||||
|
||||
# Install packages
|
||||
pip install torch transformers accelerate
|
||||
|
||||
# Save to filesystem for reuse
|
||||
cp -r ~/myenv /lambda/nfs/storage/envs/myenv
|
||||
```
|
||||
|
||||
### Conda environments
|
||||
|
||||
```bash
|
||||
# Install miniconda (if not present)
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
|
||||
bash Miniconda3-latest-Linux-x86_64.sh -b -p ~/miniconda3
|
||||
|
||||
# Create environment
|
||||
~/miniconda3/bin/conda create -n ml python=3.10 pytorch pytorch-cuda=12.1 -c pytorch -c nvidia -y
|
||||
|
||||
# Activate
|
||||
source ~/miniconda3/bin/activate ml
|
||||
```
|
||||
|
||||
### Docker containers
|
||||
|
||||
```bash
|
||||
# Pull and run NVIDIA container
|
||||
docker run --gpus all -it --rm \
|
||||
-v /lambda/nfs/storage:/data \
|
||||
nvcr.io/nvidia/pytorch:24.01-py3
|
||||
|
||||
# Run training in container
|
||||
docker run --gpus all -d \
|
||||
-v /lambda/nfs/storage:/data \
|
||||
-v $(pwd):/workspace \
|
||||
nvcr.io/nvidia/pytorch:24.01-py3 \
|
||||
python /workspace/train.py
|
||||
```
|
||||
|
||||
## Monitoring and Observability
|
||||
|
||||
### GPU monitoring
|
||||
|
||||
```bash
|
||||
# Real-time GPU stats
|
||||
watch -n 1 nvidia-smi
|
||||
|
||||
# GPU utilization over time
|
||||
nvidia-smi dmon -s u -d 1
|
||||
|
||||
# Detailed GPU info
|
||||
nvidia-smi -q
|
||||
```
|
||||
|
||||
### System monitoring
|
||||
|
||||
```bash
|
||||
# CPU and memory
|
||||
htop
|
||||
|
||||
# Disk I/O
|
||||
iostat -x 1
|
||||
|
||||
# Network
|
||||
iftop
|
||||
|
||||
# All resources
|
||||
glances
|
||||
```
|
||||
|
||||
### TensorBoard integration
|
||||
|
||||
```bash
|
||||
# Start TensorBoard
|
||||
tensorboard --logdir /lambda/nfs/storage/logs --port 6006 --bind_all
|
||||
|
||||
# SSH tunnel from local machine
|
||||
ssh -L 6006:localhost:6006 ubuntu@<IP>
|
||||
|
||||
# Access at http://localhost:6006
|
||||
```
|
||||
|
||||
### Weights & Biases integration
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
# Initialize with API key
|
||||
wandb.login(key=os.environ["WANDB_API_KEY"])
|
||||
|
||||
# Start run
|
||||
wandb.init(
|
||||
project="lambda-training",
|
||||
config={"learning_rate": 1e-4, "epochs": 100}
|
||||
)
|
||||
|
||||
# Log metrics
|
||||
wandb.log({"loss": loss, "accuracy": acc})
|
||||
|
||||
# Save artifacts to filesystem + W&B
|
||||
wandb.save("/lambda/nfs/storage/checkpoints/best_model.pt")
|
||||
```
|
||||
|
||||
## Cost Optimization Strategies
|
||||
|
||||
### Checkpointing for interruption recovery
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
def save_checkpoint(model, optimizer, epoch, loss, path):
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': loss,
|
||||
}, path)
|
||||
|
||||
def load_checkpoint(path, model, optimizer):
|
||||
if os.path.exists(path):
|
||||
checkpoint = torch.load(path)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
return checkpoint['epoch'], checkpoint['loss']
|
||||
return 0, float('inf')
|
||||
|
||||
# Save every N steps to filesystem
|
||||
checkpoint_path = "/lambda/nfs/storage/checkpoints/latest.pt"
|
||||
if step % 1000 == 0:
|
||||
save_checkpoint(model, optimizer, epoch, loss, checkpoint_path)
|
||||
```
|
||||
|
||||
### Instance selection by workload
|
||||
|
||||
```python
|
||||
def recommend_instance(model_params: int, batch_size: int, task: str) -> str:
|
||||
"""Recommend Lambda instance based on workload."""
|
||||
|
||||
if task == "inference":
|
||||
if model_params < 7e9:
|
||||
return "gpu_1x_a10" # $0.75/hr
|
||||
elif model_params < 13e9:
|
||||
return "gpu_1x_a6000" # $0.80/hr
|
||||
else:
|
||||
return "gpu_1x_h100_pcie" # $2.49/hr
|
||||
|
||||
elif task == "fine-tuning":
|
||||
if model_params < 7e9:
|
||||
return "gpu_1x_a100" # $1.29/hr
|
||||
elif model_params < 13e9:
|
||||
return "gpu_4x_a100" # $5.16/hr
|
||||
else:
|
||||
return "gpu_8x_h100_sxm5" # $23.92/hr
|
||||
|
||||
elif task == "pretraining":
|
||||
return "gpu_8x_h100_sxm5" # Maximum performance
|
||||
|
||||
return "gpu_1x_a100" # Default
|
||||
```
|
||||
|
||||
### Auto-terminate idle instances
|
||||
|
||||
```python
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
def auto_terminate_idle(api_key: str, idle_threshold_hours: float = 2):
|
||||
"""Terminate instances idle for too long."""
|
||||
manager = LambdaJobManager(api_key)
|
||||
|
||||
with lambda_cloud_client.ApiClient(manager.config) as client:
|
||||
api = lambda_cloud_client.DefaultApi(client)
|
||||
instances = api.list_instances()
|
||||
|
||||
for instance in instances.data:
|
||||
# Check if instance has been running without activity
|
||||
# (You'd need to track this separately)
|
||||
launch_time = instance.launched_at
|
||||
if datetime.now() - launch_time > timedelta(hours=idle_threshold_hours):
|
||||
print(f"Terminating idle instance: {instance.id}")
|
||||
manager.terminate([instance.id])
|
||||
```
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
### SSH key rotation
|
||||
|
||||
```bash
|
||||
# Generate new key pair
|
||||
ssh-keygen -t ed25519 -f ~/.ssh/lambda_key_new -C "lambda-$(date +%Y%m)"
|
||||
|
||||
# Add new key via Lambda console or API
|
||||
# Update authorized_keys on running instances
|
||||
ssh ubuntu@<IP> "echo '$(cat ~/.ssh/lambda_key_new.pub)' >> ~/.ssh/authorized_keys"
|
||||
|
||||
# Test new key
|
||||
ssh -i ~/.ssh/lambda_key_new ubuntu@<IP>
|
||||
|
||||
# Remove old key from Lambda console
|
||||
```
|
||||
|
||||
### Firewall configuration
|
||||
|
||||
```bash
|
||||
# Lambda console: Only open necessary ports
|
||||
# Recommended:
|
||||
# - 22 (SSH) - Always needed
|
||||
# - 6006 (TensorBoard) - If using
|
||||
# - 8888 (Jupyter) - If using
|
||||
# - 29500 (PyTorch distributed) - For multi-node only
|
||||
```
|
||||
|
||||
### Secrets management
|
||||
|
||||
```bash
|
||||
# Don't hardcode API keys in code
|
||||
# Use environment variables
|
||||
export HF_TOKEN="hf_..."
|
||||
export WANDB_API_KEY="..."
|
||||
|
||||
# Or use .env file (add to .gitignore)
|
||||
source .env
|
||||
|
||||
# On instance, store in ~/.bashrc
|
||||
echo 'export HF_TOKEN="..."' >> ~/.bashrc
|
||||
```
|
||||
530
skills/mlops-knowledge/lambda-labs/references/troubleshooting.md
Normal file
530
skills/mlops-knowledge/lambda-labs/references/troubleshooting.md
Normal file
@@ -0,0 +1,530 @@
|
||||
# Lambda Labs Troubleshooting Guide
|
||||
|
||||
## Instance Launch Issues
|
||||
|
||||
### No instances available
|
||||
|
||||
**Error**: "No capacity available" or instance type not listed
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check availability via API
|
||||
curl -u $LAMBDA_API_KEY: \
|
||||
https://cloud.lambdalabs.com/api/v1/instance-types | jq '.data | to_entries[] | select(.value.regions_with_capacity_available | length > 0) | .key'
|
||||
|
||||
# Try different regions
|
||||
# US regions: us-west-1, us-east-1, us-south-1
|
||||
# International: eu-west-1, asia-northeast-1, etc.
|
||||
|
||||
# Try alternative GPU types
|
||||
# H100 not available? Try A100
|
||||
# A100 not available? Try A10 or A6000
|
||||
```
|
||||
|
||||
### Instance stuck launching
|
||||
|
||||
**Problem**: Instance shows "booting" for over 20 minutes
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Single-GPU: Should be ready in 3-5 minutes
|
||||
# Multi-GPU (8x): May take 10-15 minutes
|
||||
|
||||
# If stuck longer:
|
||||
# 1. Terminate the instance
|
||||
# 2. Try a different region
|
||||
# 3. Try a different instance type
|
||||
# 4. Contact Lambda support if persistent
|
||||
```
|
||||
|
||||
### API authentication fails
|
||||
|
||||
**Error**: `401 Unauthorized` or `403 Forbidden`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Verify API key format (should start with specific prefix)
|
||||
echo $LAMBDA_API_KEY
|
||||
|
||||
# Test API key
|
||||
curl -u $LAMBDA_API_KEY: \
|
||||
https://cloud.lambdalabs.com/api/v1/instance-types
|
||||
|
||||
# Generate new API key from Lambda console if needed
|
||||
# Settings > API keys > Generate
|
||||
```
|
||||
|
||||
### Quota limits reached
|
||||
|
||||
**Error**: "Instance limit reached" or "Quota exceeded"
|
||||
|
||||
**Solutions**:
|
||||
- Check current running instances in console
|
||||
- Terminate unused instances
|
||||
- Contact Lambda support to request quota increase
|
||||
- Use 1-Click Clusters for large-scale needs
|
||||
|
||||
## SSH Connection Issues
|
||||
|
||||
### Connection refused
|
||||
|
||||
**Error**: `ssh: connect to host <IP> port 22: Connection refused`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Wait for instance to fully initialize
|
||||
# Single-GPU: 3-5 minutes
|
||||
# Multi-GPU: 10-15 minutes
|
||||
|
||||
# Check instance status in console (should be "active")
|
||||
|
||||
# Verify correct IP address
|
||||
curl -u $LAMBDA_API_KEY: \
|
||||
https://cloud.lambdalabs.com/api/v1/instances | jq '.data[].ip'
|
||||
```
|
||||
|
||||
### Permission denied
|
||||
|
||||
**Error**: `Permission denied (publickey)`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Verify SSH key matches
|
||||
ssh -v -i ~/.ssh/lambda_key ubuntu@<IP>
|
||||
|
||||
# Check key permissions
|
||||
chmod 600 ~/.ssh/lambda_key
|
||||
chmod 644 ~/.ssh/lambda_key.pub
|
||||
|
||||
# Verify key was added to Lambda console before launch
|
||||
# Keys must be added BEFORE launching instance
|
||||
|
||||
# Check authorized_keys on instance (if you have another way in)
|
||||
cat ~/.ssh/authorized_keys
|
||||
```
|
||||
|
||||
### Host key verification failed
|
||||
|
||||
**Error**: `WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED!`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# This happens when IP is reused by different instance
|
||||
# Remove old key
|
||||
ssh-keygen -R <IP>
|
||||
|
||||
# Then connect again
|
||||
ssh ubuntu@<IP>
|
||||
```
|
||||
|
||||
### Timeout during SSH
|
||||
|
||||
**Error**: `ssh: connect to host <IP> port 22: Operation timed out`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check if instance is in "active" state
|
||||
|
||||
# Verify firewall allows SSH (port 22)
|
||||
# Lambda console > Firewall
|
||||
|
||||
# Check your local network allows outbound SSH
|
||||
|
||||
# Try from different network/VPN
|
||||
```
|
||||
|
||||
## GPU Issues
|
||||
|
||||
### GPU not detected
|
||||
|
||||
**Error**: `nvidia-smi: command not found` or no GPUs shown
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Reboot instance
|
||||
sudo reboot
|
||||
|
||||
# Reinstall NVIDIA drivers (if needed)
|
||||
wget -nv -O- https://lambdalabs.com/install-lambda-stack.sh | sh -
|
||||
sudo reboot
|
||||
|
||||
# Check driver status
|
||||
nvidia-smi
|
||||
lsmod | grep nvidia
|
||||
```
|
||||
|
||||
### CUDA out of memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check GPU memory
|
||||
import torch
|
||||
print(torch.cuda.get_device_properties(0).total_memory / 1e9, "GB")
|
||||
|
||||
# Clear cache
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Reduce batch size
|
||||
batch_size = batch_size // 2
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Use mixed precision
|
||||
from torch.cuda.amp import autocast
|
||||
with autocast():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Use larger GPU instance
|
||||
# A100-40GB → A100-80GB → H100
|
||||
```
|
||||
|
||||
### CUDA version mismatch
|
||||
|
||||
**Error**: `CUDA driver version is insufficient for CUDA runtime version`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check versions
|
||||
nvidia-smi # Shows driver CUDA version
|
||||
nvcc --version # Shows toolkit version
|
||||
|
||||
# Lambda Stack should have compatible versions
|
||||
# If mismatch, reinstall Lambda Stack
|
||||
wget -nv -O- https://lambdalabs.com/install-lambda-stack.sh | sh -
|
||||
sudo reboot
|
||||
|
||||
# Or install specific PyTorch version
|
||||
pip install torch==2.1.0+cu121 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
```
|
||||
|
||||
### Multi-GPU not working
|
||||
|
||||
**Error**: Only one GPU being used
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check all GPUs visible
|
||||
import torch
|
||||
print(f"GPUs available: {torch.cuda.device_count()}")
|
||||
|
||||
# Verify CUDA_VISIBLE_DEVICES not set restrictively
|
||||
import os
|
||||
print(os.environ.get("CUDA_VISIBLE_DEVICES", "not set"))
|
||||
|
||||
# Use DataParallel or DistributedDataParallel
|
||||
model = torch.nn.DataParallel(model)
|
||||
# or
|
||||
model = torch.nn.parallel.DistributedDataParallel(model)
|
||||
```
|
||||
|
||||
## Filesystem Issues
|
||||
|
||||
### Filesystem not mounted
|
||||
|
||||
**Error**: `/lambda/nfs/<name>` doesn't exist
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Filesystem must be attached at launch time
|
||||
# Cannot attach to running instance
|
||||
|
||||
# Verify filesystem was selected during launch
|
||||
|
||||
# Check mount points
|
||||
df -h | grep lambda
|
||||
|
||||
# If missing, terminate and relaunch with filesystem
|
||||
```
|
||||
|
||||
### Slow filesystem performance
|
||||
|
||||
**Problem**: Reading/writing to filesystem is slow
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Use local SSD for temporary/intermediate files
|
||||
# /home/ubuntu has fast NVMe storage
|
||||
|
||||
# Copy frequently accessed data to local storage
|
||||
cp -r /lambda/nfs/storage/dataset /home/ubuntu/dataset
|
||||
|
||||
# Use filesystem for checkpoints and final outputs only
|
||||
|
||||
# Check network bandwidth
|
||||
iperf3 -c <filesystem_server>
|
||||
```
|
||||
|
||||
### Data lost after termination
|
||||
|
||||
**Problem**: Files disappeared after instance terminated
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Root volume (/home/ubuntu) is EPHEMERAL
|
||||
# Data there is lost on termination
|
||||
|
||||
# ALWAYS use filesystem for persistent data
|
||||
/lambda/nfs/<filesystem_name>/
|
||||
|
||||
# Sync important local files before terminating
|
||||
rsync -av /home/ubuntu/outputs/ /lambda/nfs/storage/outputs/
|
||||
```
|
||||
|
||||
### Filesystem full
|
||||
|
||||
**Error**: `No space left on device`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check filesystem usage
|
||||
df -h /lambda/nfs/storage
|
||||
|
||||
# Find large files
|
||||
du -sh /lambda/nfs/storage/* | sort -h
|
||||
|
||||
# Clean up old checkpoints
|
||||
find /lambda/nfs/storage/checkpoints -mtime +7 -delete
|
||||
|
||||
# Increase filesystem size in Lambda console
|
||||
# (may require support request)
|
||||
```
|
||||
|
||||
## Network Issues
|
||||
|
||||
### Port not accessible
|
||||
|
||||
**Error**: Cannot connect to service (TensorBoard, Jupyter, etc.)
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Lambda default: Only port 22 is open
|
||||
# Configure firewall in Lambda console
|
||||
|
||||
# Or use SSH tunneling (recommended)
|
||||
ssh -L 6006:localhost:6006 ubuntu@<IP>
|
||||
# Access at http://localhost:6006
|
||||
|
||||
# For Jupyter
|
||||
ssh -L 8888:localhost:8888 ubuntu@<IP>
|
||||
```
|
||||
|
||||
### Slow data download
|
||||
|
||||
**Problem**: Downloading datasets is slow
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check available bandwidth
|
||||
speedtest-cli
|
||||
|
||||
# Use multi-threaded download
|
||||
aria2c -x 16 <URL>
|
||||
|
||||
# For HuggingFace models
|
||||
export HF_HUB_ENABLE_HF_TRANSFER=1
|
||||
pip install hf_transfer
|
||||
|
||||
# For S3, use parallel transfer
|
||||
aws s3 sync s3://bucket/data /local/data --quiet
|
||||
```
|
||||
|
||||
### Inter-node communication fails
|
||||
|
||||
**Error**: Distributed training can't connect between nodes
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Verify nodes in same region (required)
|
||||
|
||||
# Check private IPs can communicate
|
||||
ping <other_node_private_ip>
|
||||
|
||||
# Verify NCCL settings
|
||||
export NCCL_DEBUG=INFO
|
||||
export NCCL_IB_DISABLE=0 # Enable InfiniBand if available
|
||||
|
||||
# Check firewall allows distributed ports
|
||||
# Need: 29500 (PyTorch), or configured MASTER_PORT
|
||||
```
|
||||
|
||||
## Software Issues
|
||||
|
||||
### Package installation fails
|
||||
|
||||
**Error**: `pip install` errors
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Use virtual environment (don't modify system Python)
|
||||
python -m venv ~/myenv
|
||||
source ~/myenv/bin/activate
|
||||
pip install <package>
|
||||
|
||||
# For CUDA packages, match CUDA version
|
||||
pip install torch --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# Clear pip cache if corrupted
|
||||
pip cache purge
|
||||
```
|
||||
|
||||
### Python version issues
|
||||
|
||||
**Error**: Package requires different Python version
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install alternate Python (don't replace system Python)
|
||||
sudo apt install python3.11 python3.11-venv python3.11-dev
|
||||
|
||||
# Create venv with specific Python
|
||||
python3.11 -m venv ~/py311env
|
||||
source ~/py311env/bin/activate
|
||||
```
|
||||
|
||||
### ImportError or ModuleNotFoundError
|
||||
|
||||
**Error**: Module not found despite installation
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Verify correct Python environment
|
||||
which python
|
||||
pip list | grep <module>
|
||||
|
||||
# Ensure virtual environment is activated
|
||||
source ~/myenv/bin/activate
|
||||
|
||||
# Reinstall in correct environment
|
||||
pip uninstall <package>
|
||||
pip install <package>
|
||||
```
|
||||
|
||||
## Training Issues
|
||||
|
||||
### Training hangs
|
||||
|
||||
**Problem**: Training stops progressing, no output
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check GPU utilization
|
||||
watch -n 1 nvidia-smi
|
||||
|
||||
# If GPUs at 0%, likely data loading bottleneck
|
||||
# Increase num_workers in DataLoader
|
||||
|
||||
# Check for deadlocks in distributed training
|
||||
export NCCL_DEBUG=INFO
|
||||
|
||||
# Add timeouts
|
||||
dist.init_process_group(..., timeout=timedelta(minutes=30))
|
||||
```
|
||||
|
||||
### Checkpoint corruption
|
||||
|
||||
**Error**: `RuntimeError: storage has wrong size` or similar
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use safe saving pattern
|
||||
checkpoint_path = "/lambda/nfs/storage/checkpoint.pt"
|
||||
temp_path = checkpoint_path + ".tmp"
|
||||
|
||||
# Save to temp first
|
||||
torch.save(state_dict, temp_path)
|
||||
# Then atomic rename
|
||||
os.rename(temp_path, checkpoint_path)
|
||||
|
||||
# For loading corrupted checkpoint
|
||||
try:
|
||||
state = torch.load(checkpoint_path)
|
||||
except:
|
||||
# Fall back to previous checkpoint
|
||||
state = torch.load(checkpoint_path + ".backup")
|
||||
```
|
||||
|
||||
### Memory leak
|
||||
|
||||
**Problem**: Memory usage grows over time
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Clear CUDA cache periodically
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Detach tensors when logging
|
||||
loss_value = loss.detach().cpu().item()
|
||||
|
||||
# Don't accumulate gradients unintentionally
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Use gradient accumulation properly
|
||||
if (step + 1) % accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
## Billing Issues
|
||||
|
||||
### Unexpected charges
|
||||
|
||||
**Problem**: Bill higher than expected
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check for forgotten running instances
|
||||
curl -u $LAMBDA_API_KEY: \
|
||||
https://cloud.lambdalabs.com/api/v1/instances | jq '.data[].id'
|
||||
|
||||
# Terminate all instances
|
||||
# Lambda console > Instances > Terminate all
|
||||
|
||||
# Lambda charges by the minute
|
||||
# No charge for stopped instances (but no "stop" feature - only terminate)
|
||||
```
|
||||
|
||||
### Instance terminated unexpectedly
|
||||
|
||||
**Problem**: Instance disappeared without manual termination
|
||||
|
||||
**Possible causes**:
|
||||
- Payment issue (card declined)
|
||||
- Account suspension
|
||||
- Instance health check failure
|
||||
|
||||
**Solutions**:
|
||||
- Check email for Lambda notifications
|
||||
- Verify payment method in console
|
||||
- Contact Lambda support
|
||||
- Always checkpoint to filesystem
|
||||
|
||||
## Common Error Messages
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| `No capacity available` | Region/GPU sold out | Try different region or GPU type |
|
||||
| `Permission denied (publickey)` | SSH key mismatch | Re-add key, check permissions |
|
||||
| `CUDA out of memory` | Model too large | Reduce batch size, use larger GPU |
|
||||
| `No space left on device` | Disk full | Clean up or use filesystem |
|
||||
| `Connection refused` | Instance not ready | Wait 3-15 minutes for boot |
|
||||
| `Module not found` | Wrong Python env | Activate correct virtualenv |
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **Documentation**: https://docs.lambda.ai
|
||||
2. **Support**: https://support.lambdalabs.com
|
||||
3. **Email**: support@lambdalabs.com
|
||||
4. **Status**: Check Lambda status page for outages
|
||||
|
||||
### Information to Include
|
||||
|
||||
When contacting support, include:
|
||||
- Instance ID
|
||||
- Region
|
||||
- Instance type
|
||||
- Error message (full traceback)
|
||||
- Steps to reproduce
|
||||
- Time of occurrence
|
||||
258
skills/mlops-knowledge/llama-cpp/SKILL.md
Normal file
258
skills/mlops-knowledge/llama-cpp/SKILL.md
Normal file
@@ -0,0 +1,258 @@
|
||||
---
|
||||
name: llama-cpp
|
||||
description: Runs LLM inference on CPU, Apple Silicon, and consumer GPUs without NVIDIA hardware. Use for edge deployment, M1/M2/M3 Macs, AMD/Intel GPUs, or when CUDA is unavailable. Supports GGUF quantization (1.5-8 bit) for reduced memory and 4-10× speedup vs PyTorch on CPU.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Inference Serving, Llama.cpp, CPU Inference, Apple Silicon, Edge Deployment, GGUF, Quantization, Non-NVIDIA, AMD GPUs, Intel GPUs, Embedded]
|
||||
dependencies: [llama-cpp-python]
|
||||
---
|
||||
|
||||
# llama.cpp
|
||||
|
||||
Pure C/C++ LLM inference with minimal dependencies, optimized for CPUs and non-NVIDIA hardware.
|
||||
|
||||
## When to use llama.cpp
|
||||
|
||||
**Use llama.cpp when:**
|
||||
- Running on CPU-only machines
|
||||
- Deploying on Apple Silicon (M1/M2/M3/M4)
|
||||
- Using AMD or Intel GPUs (no CUDA)
|
||||
- Edge deployment (Raspberry Pi, embedded systems)
|
||||
- Need simple deployment without Docker/Python
|
||||
|
||||
**Use TensorRT-LLM instead when:**
|
||||
- Have NVIDIA GPUs (A100/H100)
|
||||
- Need maximum throughput (100K+ tok/s)
|
||||
- Running in datacenter with CUDA
|
||||
|
||||
**Use vLLM instead when:**
|
||||
- Have NVIDIA GPUs
|
||||
- Need Python-first API
|
||||
- Want PagedAttention
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# macOS/Linux
|
||||
brew install llama.cpp
|
||||
|
||||
# Or build from source
|
||||
git clone https://github.com/ggerganov/llama.cpp
|
||||
cd llama.cpp
|
||||
make
|
||||
|
||||
# With Metal (Apple Silicon)
|
||||
make LLAMA_METAL=1
|
||||
|
||||
# With CUDA (NVIDIA)
|
||||
make LLAMA_CUDA=1
|
||||
|
||||
# With ROCm (AMD)
|
||||
make LLAMA_HIP=1
|
||||
```
|
||||
|
||||
### Download model
|
||||
|
||||
```bash
|
||||
# Download from HuggingFace (GGUF format)
|
||||
huggingface-cli download \
|
||||
TheBloke/Llama-2-7B-Chat-GGUF \
|
||||
llama-2-7b-chat.Q4_K_M.gguf \
|
||||
--local-dir models/
|
||||
|
||||
# Or convert from HuggingFace
|
||||
python convert_hf_to_gguf.py models/llama-2-7b-chat/
|
||||
```
|
||||
|
||||
### Run inference
|
||||
|
||||
```bash
|
||||
# Simple chat
|
||||
./llama-cli \
|
||||
-m models/llama-2-7b-chat.Q4_K_M.gguf \
|
||||
-p "Explain quantum computing" \
|
||||
-n 256 # Max tokens
|
||||
|
||||
# Interactive chat
|
||||
./llama-cli \
|
||||
-m models/llama-2-7b-chat.Q4_K_M.gguf \
|
||||
--interactive
|
||||
```
|
||||
|
||||
### Server mode
|
||||
|
||||
```bash
|
||||
# Start OpenAI-compatible server
|
||||
./llama-server \
|
||||
-m models/llama-2-7b-chat.Q4_K_M.gguf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
-ngl 32 # Offload 32 layers to GPU
|
||||
|
||||
# Client request
|
||||
curl http://localhost:8080/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llama-2-7b-chat",
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100
|
||||
}'
|
||||
```
|
||||
|
||||
## Quantization formats
|
||||
|
||||
### GGUF format overview
|
||||
|
||||
| Format | Bits | Size (7B) | Speed | Quality | Use Case |
|
||||
|--------|------|-----------|-------|---------|----------|
|
||||
| **Q4_K_M** | 4.5 | 4.1 GB | Fast | Good | **Recommended default** |
|
||||
| Q4_K_S | 4.3 | 3.9 GB | Faster | Lower | Speed critical |
|
||||
| Q5_K_M | 5.5 | 4.8 GB | Medium | Better | Quality critical |
|
||||
| Q6_K | 6.5 | 5.5 GB | Slower | Best | Maximum quality |
|
||||
| Q8_0 | 8.0 | 7.0 GB | Slow | Excellent | Minimal degradation |
|
||||
| Q2_K | 2.5 | 2.7 GB | Fastest | Poor | Testing only |
|
||||
|
||||
### Choosing quantization
|
||||
|
||||
```bash
|
||||
# General use (balanced)
|
||||
Q4_K_M # 4-bit, medium quality
|
||||
|
||||
# Maximum speed (more degradation)
|
||||
Q2_K or Q3_K_M
|
||||
|
||||
# Maximum quality (slower)
|
||||
Q6_K or Q8_0
|
||||
|
||||
# Very large models (70B, 405B)
|
||||
Q3_K_M or Q4_K_S # Lower bits to fit in memory
|
||||
```
|
||||
|
||||
## Hardware acceleration
|
||||
|
||||
### Apple Silicon (Metal)
|
||||
|
||||
```bash
|
||||
# Build with Metal
|
||||
make LLAMA_METAL=1
|
||||
|
||||
# Run with GPU acceleration (automatic)
|
||||
./llama-cli -m model.gguf -ngl 999 # Offload all layers
|
||||
|
||||
# Performance: M3 Max 40-60 tokens/sec (Llama 2-7B Q4_K_M)
|
||||
```
|
||||
|
||||
### NVIDIA GPUs (CUDA)
|
||||
|
||||
```bash
|
||||
# Build with CUDA
|
||||
make LLAMA_CUDA=1
|
||||
|
||||
# Offload layers to GPU
|
||||
./llama-cli -m model.gguf -ngl 35 # Offload 35/40 layers
|
||||
|
||||
# Hybrid CPU+GPU for large models
|
||||
./llama-cli -m llama-70b.Q4_K_M.gguf -ngl 20 # GPU: 20 layers, CPU: rest
|
||||
```
|
||||
|
||||
### AMD GPUs (ROCm)
|
||||
|
||||
```bash
|
||||
# Build with ROCm
|
||||
make LLAMA_HIP=1
|
||||
|
||||
# Run with AMD GPU
|
||||
./llama-cli -m model.gguf -ngl 999
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Batch processing
|
||||
|
||||
```bash
|
||||
# Process multiple prompts from file
|
||||
cat prompts.txt | ./llama-cli \
|
||||
-m model.gguf \
|
||||
--batch-size 512 \
|
||||
-n 100
|
||||
```
|
||||
|
||||
### Constrained generation
|
||||
|
||||
```bash
|
||||
# JSON output with grammar
|
||||
./llama-cli \
|
||||
-m model.gguf \
|
||||
-p "Generate a person: " \
|
||||
--grammar-file grammars/json.gbnf
|
||||
|
||||
# Outputs valid JSON only
|
||||
```
|
||||
|
||||
### Context size
|
||||
|
||||
```bash
|
||||
# Increase context (default 512)
|
||||
./llama-cli \
|
||||
-m model.gguf \
|
||||
-c 4096 # 4K context window
|
||||
|
||||
# Very long context (if model supports)
|
||||
./llama-cli -m model.gguf -c 32768 # 32K context
|
||||
```
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
### CPU performance (Llama 2-7B Q4_K_M)
|
||||
|
||||
| CPU | Threads | Speed | Cost |
|
||||
|-----|---------|-------|------|
|
||||
| Apple M3 Max | 16 | 50 tok/s | $0 (local) |
|
||||
| AMD Ryzen 9 7950X | 32 | 35 tok/s | $0.50/hour |
|
||||
| Intel i9-13900K | 32 | 30 tok/s | $0.40/hour |
|
||||
| AWS c7i.16xlarge | 64 | 40 tok/s | $2.88/hour |
|
||||
|
||||
### GPU acceleration (Llama 2-7B Q4_K_M)
|
||||
|
||||
| GPU | Speed | vs CPU | Cost |
|
||||
|-----|-------|--------|------|
|
||||
| NVIDIA RTX 4090 | 120 tok/s | 3-4× | $0 (local) |
|
||||
| NVIDIA A10 | 80 tok/s | 2-3× | $1.00/hour |
|
||||
| AMD MI250 | 70 tok/s | 2× | $2.00/hour |
|
||||
| Apple M3 Max (Metal) | 50 tok/s | ~Same | $0 (local) |
|
||||
|
||||
## Supported models
|
||||
|
||||
**LLaMA family**:
|
||||
- Llama 2 (7B, 13B, 70B)
|
||||
- Llama 3 (8B, 70B, 405B)
|
||||
- Code Llama
|
||||
|
||||
**Mistral family**:
|
||||
- Mistral 7B
|
||||
- Mixtral 8x7B, 8x22B
|
||||
|
||||
**Other**:
|
||||
- Falcon, BLOOM, GPT-J
|
||||
- Phi-3, Gemma, Qwen
|
||||
- LLaVA (vision), Whisper (audio)
|
||||
|
||||
**Find models**: https://huggingface.co/models?library=gguf
|
||||
|
||||
## References
|
||||
|
||||
- **[Quantization Guide](references/quantization.md)** - GGUF formats, conversion, quality comparison
|
||||
- **[Server Deployment](references/server.md)** - API endpoints, Docker, monitoring
|
||||
- **[Optimization](references/optimization.md)** - Performance tuning, hybrid CPU+GPU
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/ggerganov/llama.cpp
|
||||
- **Models**: https://huggingface.co/models?library=gguf
|
||||
- **Discord**: https://discord.gg/llama-cpp
|
||||
|
||||
|
||||
89
skills/mlops-knowledge/llama-cpp/references/optimization.md
Normal file
89
skills/mlops-knowledge/llama-cpp/references/optimization.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# Performance Optimization Guide
|
||||
|
||||
Maximize llama.cpp inference speed and efficiency.
|
||||
|
||||
## CPU Optimization
|
||||
|
||||
### Thread tuning
|
||||
```bash
|
||||
# Set threads (default: physical cores)
|
||||
./llama-cli -m model.gguf -t 8
|
||||
|
||||
# For AMD Ryzen 9 7950X (16 cores, 32 threads)
|
||||
-t 16 # Best: physical cores
|
||||
|
||||
# Avoid hyperthreading (slower for matrix ops)
|
||||
```
|
||||
|
||||
### BLAS acceleration
|
||||
```bash
|
||||
# OpenBLAS (faster matrix ops)
|
||||
make LLAMA_OPENBLAS=1
|
||||
|
||||
# BLAS gives 2-3× speedup
|
||||
```
|
||||
|
||||
## GPU Offloading
|
||||
|
||||
### Layer offloading
|
||||
```bash
|
||||
# Offload 35 layers to GPU (hybrid mode)
|
||||
./llama-cli -m model.gguf -ngl 35
|
||||
|
||||
# Offload all layers
|
||||
./llama-cli -m model.gguf -ngl 999
|
||||
|
||||
# Find optimal value:
|
||||
# Start with -ngl 999
|
||||
# If OOM, reduce by 5 until fits
|
||||
```
|
||||
|
||||
### Memory usage
|
||||
```bash
|
||||
# Check VRAM usage
|
||||
nvidia-smi dmon
|
||||
|
||||
# Reduce context if needed
|
||||
./llama-cli -m model.gguf -c 2048 # 2K context instead of 4K
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
```bash
|
||||
# Increase batch size for throughput
|
||||
./llama-cli -m model.gguf -b 512 # Default: 512
|
||||
|
||||
# Physical batch (GPU)
|
||||
--ubatch 128 # Process 128 tokens at once
|
||||
```
|
||||
|
||||
## Context Management
|
||||
|
||||
```bash
|
||||
# Default context (512 tokens)
|
||||
-c 512
|
||||
|
||||
# Longer context (slower, more memory)
|
||||
-c 4096
|
||||
|
||||
# Very long context (if model supports)
|
||||
-c 32768
|
||||
```
|
||||
|
||||
## Benchmarks
|
||||
|
||||
### CPU Performance (Llama 2-7B Q4_K_M)
|
||||
|
||||
| Setup | Speed | Notes |
|
||||
|-------|-------|-------|
|
||||
| Apple M3 Max | 50 tok/s | Metal acceleration |
|
||||
| AMD 7950X (16c) | 35 tok/s | OpenBLAS |
|
||||
| Intel i9-13900K | 30 tok/s | AVX2 |
|
||||
|
||||
### GPU Offloading (RTX 4090)
|
||||
|
||||
| Layers GPU | Speed | VRAM |
|
||||
|------------|-------|------|
|
||||
| 0 (CPU only) | 30 tok/s | 0 GB |
|
||||
| 20 (hybrid) | 80 tok/s | 8 GB |
|
||||
| 35 (all) | 120 tok/s | 12 GB |
|
||||
213
skills/mlops-knowledge/llama-cpp/references/quantization.md
Normal file
213
skills/mlops-knowledge/llama-cpp/references/quantization.md
Normal file
@@ -0,0 +1,213 @@
|
||||
# GGUF Quantization Guide
|
||||
|
||||
Complete guide to GGUF quantization formats and model conversion.
|
||||
|
||||
## Quantization Overview
|
||||
|
||||
**GGUF** (GPT-Generated Unified Format) - Standard format for llama.cpp models.
|
||||
|
||||
### Format Comparison
|
||||
|
||||
| Format | Perplexity | Size (7B) | Tokens/sec | Notes |
|
||||
|--------|------------|-----------|------------|-------|
|
||||
| FP16 | 5.9565 (baseline) | 13.0 GB | 15 tok/s | Original quality |
|
||||
| Q8_0 | 5.9584 (+0.03%) | 7.0 GB | 25 tok/s | Nearly lossless |
|
||||
| **Q6_K** | 5.9642 (+0.13%) | 5.5 GB | 30 tok/s | Best quality/size |
|
||||
| **Q5_K_M** | 5.9796 (+0.39%) | 4.8 GB | 35 tok/s | Balanced |
|
||||
| **Q4_K_M** | 6.0565 (+1.68%) | 4.1 GB | 40 tok/s | **Recommended** |
|
||||
| Q4_K_S | 6.1125 (+2.62%) | 3.9 GB | 42 tok/s | Faster, lower quality |
|
||||
| Q3_K_M | 6.3184 (+6.07%) | 3.3 GB | 45 tok/s | Small models only |
|
||||
| Q2_K | 6.8673 (+15.3%) | 2.7 GB | 50 tok/s | Not recommended |
|
||||
|
||||
**Recommendation**: Use **Q4_K_M** for best balance of quality and speed.
|
||||
|
||||
## Converting Models
|
||||
|
||||
### HuggingFace to GGUF
|
||||
|
||||
```bash
|
||||
# 1. Download HuggingFace model
|
||||
huggingface-cli download meta-llama/Llama-2-7b-chat-hf \
|
||||
--local-dir models/llama-2-7b-chat/
|
||||
|
||||
# 2. Convert to FP16 GGUF
|
||||
python convert_hf_to_gguf.py \
|
||||
models/llama-2-7b-chat/ \
|
||||
--outtype f16 \
|
||||
--outfile models/llama-2-7b-chat-f16.gguf
|
||||
|
||||
# 3. Quantize to Q4_K_M
|
||||
./llama-quantize \
|
||||
models/llama-2-7b-chat-f16.gguf \
|
||||
models/llama-2-7b-chat-Q4_K_M.gguf \
|
||||
Q4_K_M
|
||||
```
|
||||
|
||||
### Batch quantization
|
||||
|
||||
```bash
|
||||
# Quantize to multiple formats
|
||||
for quant in Q4_K_M Q5_K_M Q6_K Q8_0; do
|
||||
./llama-quantize \
|
||||
model-f16.gguf \
|
||||
model-${quant}.gguf \
|
||||
$quant
|
||||
done
|
||||
```
|
||||
|
||||
## K-Quantization Methods
|
||||
|
||||
**K-quants** use mixed precision for better quality:
|
||||
- Attention weights: Higher precision
|
||||
- Feed-forward weights: Lower precision
|
||||
|
||||
**Variants**:
|
||||
- `_S` (Small): Faster, lower quality
|
||||
- `_M` (Medium): Balanced (recommended)
|
||||
- `_L` (Large): Better quality, larger size
|
||||
|
||||
**Example**: `Q4_K_M`
|
||||
- `Q4`: 4-bit quantization
|
||||
- `K`: Mixed precision method
|
||||
- `M`: Medium quality
|
||||
|
||||
## Quality Testing
|
||||
|
||||
```bash
|
||||
# Calculate perplexity (quality metric)
|
||||
./llama-perplexity \
|
||||
-m model.gguf \
|
||||
-f wikitext-2-raw/wiki.test.raw \
|
||||
-c 512
|
||||
|
||||
# Lower perplexity = better quality
|
||||
# Baseline (FP16): ~5.96
|
||||
# Q4_K_M: ~6.06 (+1.7%)
|
||||
# Q2_K: ~6.87 (+15.3% - too much degradation)
|
||||
```
|
||||
|
||||
## Use Case Guide
|
||||
|
||||
### General purpose (chatbots, assistants)
|
||||
```
|
||||
Q4_K_M - Best balance
|
||||
Q5_K_M - If you have extra RAM
|
||||
```
|
||||
|
||||
### Code generation
|
||||
```
|
||||
Q5_K_M or Q6_K - Higher precision helps with code
|
||||
```
|
||||
|
||||
### Creative writing
|
||||
```
|
||||
Q4_K_M - Sufficient quality
|
||||
Q3_K_M - Acceptable for draft generation
|
||||
```
|
||||
|
||||
### Technical/medical
|
||||
```
|
||||
Q6_K or Q8_0 - Maximum accuracy
|
||||
```
|
||||
|
||||
### Edge devices (Raspberry Pi)
|
||||
```
|
||||
Q2_K or Q3_K_S - Fit in limited RAM
|
||||
```
|
||||
|
||||
## Model Size Scaling
|
||||
|
||||
### 7B parameter models
|
||||
|
||||
| Format | Size | RAM needed |
|
||||
|--------|------|------------|
|
||||
| Q2_K | 2.7 GB | 5 GB |
|
||||
| Q3_K_M | 3.3 GB | 6 GB |
|
||||
| Q4_K_M | 4.1 GB | 7 GB |
|
||||
| Q5_K_M | 4.8 GB | 8 GB |
|
||||
| Q6_K | 5.5 GB | 9 GB |
|
||||
| Q8_0 | 7.0 GB | 11 GB |
|
||||
|
||||
### 13B parameter models
|
||||
|
||||
| Format | Size | RAM needed |
|
||||
|--------|------|------------|
|
||||
| Q2_K | 5.1 GB | 8 GB |
|
||||
| Q3_K_M | 6.2 GB | 10 GB |
|
||||
| Q4_K_M | 7.9 GB | 12 GB |
|
||||
| Q5_K_M | 9.2 GB | 14 GB |
|
||||
| Q6_K | 10.7 GB | 16 GB |
|
||||
|
||||
### 70B parameter models
|
||||
|
||||
| Format | Size | RAM needed |
|
||||
|--------|------|------------|
|
||||
| Q2_K | 26 GB | 32 GB |
|
||||
| Q3_K_M | 32 GB | 40 GB |
|
||||
| Q4_K_M | 41 GB | 48 GB |
|
||||
| Q4_K_S | 39 GB | 46 GB |
|
||||
| Q5_K_M | 48 GB | 56 GB |
|
||||
|
||||
**Recommendation for 70B**: Use Q3_K_M or Q4_K_S to fit in consumer hardware.
|
||||
|
||||
## Finding Pre-Quantized Models
|
||||
|
||||
**TheBloke** on HuggingFace:
|
||||
- https://huggingface.co/TheBloke
|
||||
- Most models available in all GGUF formats
|
||||
- No conversion needed
|
||||
|
||||
**Example**:
|
||||
```bash
|
||||
# Download pre-quantized Llama 2-7B
|
||||
huggingface-cli download \
|
||||
TheBloke/Llama-2-7B-Chat-GGUF \
|
||||
llama-2-7b-chat.Q4_K_M.gguf \
|
||||
--local-dir models/
|
||||
```
|
||||
|
||||
## Importance Matrices (imatrix)
|
||||
|
||||
**What**: Calibration data to improve quantization quality.
|
||||
|
||||
**Benefits**:
|
||||
- 10-20% perplexity improvement with Q4
|
||||
- Essential for Q3 and below
|
||||
|
||||
**Usage**:
|
||||
```bash
|
||||
# 1. Generate importance matrix
|
||||
./llama-imatrix \
|
||||
-m model-f16.gguf \
|
||||
-f calibration-data.txt \
|
||||
-o model.imatrix
|
||||
|
||||
# 2. Quantize with imatrix
|
||||
./llama-quantize \
|
||||
--imatrix model.imatrix \
|
||||
model-f16.gguf \
|
||||
model-Q4_K_M.gguf \
|
||||
Q4_K_M
|
||||
```
|
||||
|
||||
**Calibration data**:
|
||||
- Use domain-specific text (e.g., code for code models)
|
||||
- ~100MB of representative text
|
||||
- Higher quality data = better quantization
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Model outputs gibberish**:
|
||||
- Quantization too aggressive (Q2_K)
|
||||
- Try Q4_K_M or Q5_K_M
|
||||
- Verify model converted correctly
|
||||
|
||||
**Out of memory**:
|
||||
- Use lower quantization (Q4_K_S instead of Q5_K_M)
|
||||
- Offload fewer layers to GPU (`-ngl`)
|
||||
- Use smaller context (`-c 2048`)
|
||||
|
||||
**Slow inference**:
|
||||
- Higher quantization uses more compute
|
||||
- Q8_0 much slower than Q4_K_M
|
||||
- Consider speed vs quality trade-off
|
||||
125
skills/mlops-knowledge/llama-cpp/references/server.md
Normal file
125
skills/mlops-knowledge/llama-cpp/references/server.md
Normal file
@@ -0,0 +1,125 @@
|
||||
# Server Deployment Guide
|
||||
|
||||
Production deployment of llama.cpp server with OpenAI-compatible API.
|
||||
|
||||
## Server Modes
|
||||
|
||||
### llama-server
|
||||
|
||||
```bash
|
||||
# Basic server
|
||||
./llama-server \
|
||||
-m models/llama-2-7b-chat.Q4_K_M.gguf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
-c 4096 # Context size
|
||||
|
||||
# With GPU acceleration
|
||||
./llama-server \
|
||||
-m models/llama-2-70b.Q4_K_M.gguf \
|
||||
-ngl 40 # Offload 40 layers to GPU
|
||||
```
|
||||
|
||||
## OpenAI-Compatible API
|
||||
|
||||
### Chat completions
|
||||
```bash
|
||||
curl http://localhost:8080/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llama-2",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful"},
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100
|
||||
}'
|
||||
```
|
||||
|
||||
### Streaming
|
||||
```bash
|
||||
curl http://localhost:8080/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llama-2",
|
||||
"messages": [{"role": "user", "content": "Count to 10"}],
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
## Docker Deployment
|
||||
|
||||
**Dockerfile**:
|
||||
```dockerfile
|
||||
FROM ubuntu:22.04
|
||||
RUN apt-get update && apt-get install -y git build-essential
|
||||
RUN git clone https://github.com/ggerganov/llama.cpp
|
||||
WORKDIR /llama.cpp
|
||||
RUN make LLAMA_CUDA=1
|
||||
COPY models/ /models/
|
||||
EXPOSE 8080
|
||||
CMD ["./llama-server", "-m", "/models/model.gguf", "--host", "0.0.0.0", "--port", "8080"]
|
||||
```
|
||||
|
||||
**Run**:
|
||||
```bash
|
||||
docker run --gpus all -p 8080:8080 llama-cpp:latest
|
||||
```
|
||||
|
||||
## Monitoring
|
||||
|
||||
```bash
|
||||
# Server metrics endpoint
|
||||
curl http://localhost:8080/metrics
|
||||
|
||||
# Health check
|
||||
curl http://localhost:8080/health
|
||||
```
|
||||
|
||||
**Metrics**:
|
||||
- requests_total
|
||||
- tokens_generated
|
||||
- prompt_tokens
|
||||
- completion_tokens
|
||||
- kv_cache_tokens
|
||||
|
||||
## Load Balancing
|
||||
|
||||
**NGINX**:
|
||||
```nginx
|
||||
upstream llama_cpp {
|
||||
server llama1:8080;
|
||||
server llama2:8080;
|
||||
}
|
||||
|
||||
server {
|
||||
location / {
|
||||
proxy_pass http://llama_cpp;
|
||||
proxy_read_timeout 300s;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
**Parallel requests**:
|
||||
```bash
|
||||
./llama-server \
|
||||
-m model.gguf \
|
||||
-np 4 # 4 parallel slots
|
||||
```
|
||||
|
||||
**Continuous batching**:
|
||||
```bash
|
||||
./llama-server \
|
||||
-m model.gguf \
|
||||
--cont-batching # Enable continuous batching
|
||||
```
|
||||
|
||||
**Context caching**:
|
||||
```bash
|
||||
./llama-server \
|
||||
-m model.gguf \
|
||||
--cache-prompt # Cache processed prompts
|
||||
```
|
||||
304
skills/mlops-knowledge/llava/SKILL.md
Normal file
304
skills/mlops-knowledge/llava/SKILL.md
Normal file
@@ -0,0 +1,304 @@
|
||||
---
|
||||
name: llava
|
||||
description: Large Language and Vision Assistant. Enables visual instruction tuning and image-based conversations. Combines CLIP vision encoder with Vicuna/LLaMA language models. Supports multi-turn image chat, visual question answering, and instruction following. Use for vision-language chatbots or image understanding tasks. Best for conversational image analysis.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [LLaVA, Vision-Language, Multimodal, Visual Question Answering, Image Chat, CLIP, Vicuna, Conversational AI, Instruction Tuning, VQA]
|
||||
dependencies: [transformers, torch, pillow]
|
||||
---
|
||||
|
||||
# LLaVA - Large Language and Vision Assistant
|
||||
|
||||
Open-source vision-language model for conversational image understanding.
|
||||
|
||||
## When to use LLaVA
|
||||
|
||||
**Use when:**
|
||||
- Building vision-language chatbots
|
||||
- Visual question answering (VQA)
|
||||
- Image description and captioning
|
||||
- Multi-turn image conversations
|
||||
- Visual instruction following
|
||||
- Document understanding with images
|
||||
|
||||
**Metrics**:
|
||||
- **23,000+ GitHub stars**
|
||||
- GPT-4V level capabilities (targeted)
|
||||
- Apache 2.0 License
|
||||
- Multiple model sizes (7B-34B params)
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **GPT-4V**: Highest quality, API-based
|
||||
- **CLIP**: Simple zero-shot classification
|
||||
- **BLIP-2**: Better for captioning only
|
||||
- **Flamingo**: Research, not open-source
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Clone repository
|
||||
git clone https://github.com/haotian-liu/LLaVA
|
||||
cd LLaVA
|
||||
|
||||
# Install
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Basic usage
|
||||
|
||||
```python
|
||||
from llava.model.builder import load_pretrained_model
|
||||
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
|
||||
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
||||
from llava.conversation import conv_templates
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
# Load model
|
||||
model_path = "liuhaotian/llava-v1.5-7b"
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
||||
model_path=model_path,
|
||||
model_base=None,
|
||||
model_name=get_model_name_from_path(model_path)
|
||||
)
|
||||
|
||||
# Load image
|
||||
image = Image.open("image.jpg")
|
||||
image_tensor = process_images([image], image_processor, model.config)
|
||||
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
|
||||
|
||||
# Create conversation
|
||||
conv = conv_templates["llava_v1"].copy()
|
||||
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
# Generate response
|
||||
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
output_ids = model.generate(
|
||||
input_ids,
|
||||
images=image_tensor,
|
||||
do_sample=True,
|
||||
temperature=0.2,
|
||||
max_new_tokens=512
|
||||
)
|
||||
|
||||
response = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
|
||||
print(response)
|
||||
```
|
||||
|
||||
## Available models
|
||||
|
||||
| Model | Parameters | VRAM | Quality |
|
||||
|-------|------------|------|---------|
|
||||
| LLaVA-v1.5-7B | 7B | ~14 GB | Good |
|
||||
| LLaVA-v1.5-13B | 13B | ~28 GB | Better |
|
||||
| LLaVA-v1.6-34B | 34B | ~70 GB | Best |
|
||||
|
||||
```python
|
||||
# Load different models
|
||||
model_7b = "liuhaotian/llava-v1.5-7b"
|
||||
model_13b = "liuhaotian/llava-v1.5-13b"
|
||||
model_34b = "liuhaotian/llava-v1.6-34b"
|
||||
|
||||
# 4-bit quantization for lower VRAM
|
||||
load_4bit = True # Reduces VRAM by ~4×
|
||||
```
|
||||
|
||||
## CLI usage
|
||||
|
||||
```bash
|
||||
# Single image query
|
||||
python -m llava.serve.cli \
|
||||
--model-path liuhaotian/llava-v1.5-7b \
|
||||
--image-file image.jpg \
|
||||
--query "What is in this image?"
|
||||
|
||||
# Multi-turn conversation
|
||||
python -m llava.serve.cli \
|
||||
--model-path liuhaotian/llava-v1.5-7b \
|
||||
--image-file image.jpg
|
||||
# Then type questions interactively
|
||||
```
|
||||
|
||||
## Web UI (Gradio)
|
||||
|
||||
```bash
|
||||
# Launch Gradio interface
|
||||
python -m llava.serve.gradio_web_server \
|
||||
--model-path liuhaotian/llava-v1.5-7b \
|
||||
--load-4bit # Optional: reduce VRAM
|
||||
|
||||
# Access at http://localhost:7860
|
||||
```
|
||||
|
||||
## Multi-turn conversations
|
||||
|
||||
```python
|
||||
# Initialize conversation
|
||||
conv = conv_templates["llava_v1"].copy()
|
||||
|
||||
# Turn 1
|
||||
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
response1 = generate(conv, model, image) # "A dog playing in a park"
|
||||
|
||||
# Turn 2
|
||||
conv.messages[-1][1] = response1 # Add previous response
|
||||
conv.append_message(conv.roles[0], "What breed is the dog?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
response2 = generate(conv, model, image) # "Golden Retriever"
|
||||
|
||||
# Turn 3
|
||||
conv.messages[-1][1] = response2
|
||||
conv.append_message(conv.roles[0], "What time of day is it?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
response3 = generate(conv, model, image)
|
||||
```
|
||||
|
||||
## Common tasks
|
||||
|
||||
### Image captioning
|
||||
|
||||
```python
|
||||
question = "Describe this image in detail."
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Visual question answering
|
||||
|
||||
```python
|
||||
question = "How many people are in the image?"
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Object detection (textual)
|
||||
|
||||
```python
|
||||
question = "List all the objects you can see in this image."
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Scene understanding
|
||||
|
||||
```python
|
||||
question = "What is happening in this scene?"
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Document understanding
|
||||
|
||||
```python
|
||||
question = "What is the main topic of this document?"
|
||||
response = ask(model, document_image, question)
|
||||
```
|
||||
|
||||
## Training custom model
|
||||
|
||||
```bash
|
||||
# Stage 1: Feature alignment (558K image-caption pairs)
|
||||
bash scripts/v1_5/pretrain.sh
|
||||
|
||||
# Stage 2: Visual instruction tuning (150K instruction data)
|
||||
bash scripts/v1_5/finetune.sh
|
||||
```
|
||||
|
||||
## Quantization (reduce VRAM)
|
||||
|
||||
```python
|
||||
# 4-bit quantization
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
||||
model_path="liuhaotian/llava-v1.5-13b",
|
||||
model_base=None,
|
||||
model_name=get_model_name_from_path("liuhaotian/llava-v1.5-13b"),
|
||||
load_4bit=True # Reduces VRAM ~4×
|
||||
)
|
||||
|
||||
# 8-bit quantization
|
||||
load_8bit=True # Reduces VRAM ~2×
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with 7B model** - Good quality, manageable VRAM
|
||||
2. **Use 4-bit quantization** - Reduces VRAM significantly
|
||||
3. **GPU required** - CPU inference extremely slow
|
||||
4. **Clear prompts** - Specific questions get better answers
|
||||
5. **Multi-turn conversations** - Maintain conversation context
|
||||
6. **Temperature 0.2-0.7** - Balance creativity/consistency
|
||||
7. **max_new_tokens 512-1024** - For detailed responses
|
||||
8. **Batch processing** - Process multiple images sequentially
|
||||
|
||||
## Performance
|
||||
|
||||
| Model | VRAM (FP16) | VRAM (4-bit) | Speed (tokens/s) |
|
||||
|-------|-------------|--------------|------------------|
|
||||
| 7B | ~14 GB | ~4 GB | ~20 |
|
||||
| 13B | ~28 GB | ~8 GB | ~12 |
|
||||
| 34B | ~70 GB | ~18 GB | ~5 |
|
||||
|
||||
*On A100 GPU*
|
||||
|
||||
## Benchmarks
|
||||
|
||||
LLaVA achieves competitive scores on:
|
||||
- **VQAv2**: 78.5%
|
||||
- **GQA**: 62.0%
|
||||
- **MM-Vet**: 35.4%
|
||||
- **MMBench**: 64.3%
|
||||
|
||||
## Limitations
|
||||
|
||||
1. **Hallucinations** - May describe things not in image
|
||||
2. **Spatial reasoning** - Struggles with precise locations
|
||||
3. **Small text** - Difficulty reading fine print
|
||||
4. **Object counting** - Imprecise for many objects
|
||||
5. **VRAM requirements** - Need powerful GPU
|
||||
6. **Inference speed** - Slower than CLIP
|
||||
|
||||
## Integration with frameworks
|
||||
|
||||
### LangChain
|
||||
|
||||
```python
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
class LLaVALLM(LLM):
|
||||
def _call(self, prompt, stop=None):
|
||||
# Custom LLaVA inference
|
||||
return response
|
||||
|
||||
llm = LLaVALLM()
|
||||
```
|
||||
|
||||
### Gradio App
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
def chat(image, text, history):
|
||||
response = ask_llava(model, image, text)
|
||||
return response
|
||||
|
||||
demo = gr.ChatInterface(
|
||||
chat,
|
||||
additional_inputs=[gr.Image(type="pil")],
|
||||
title="LLaVA Chat"
|
||||
)
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/haotian-liu/LLaVA ⭐ 23,000+
|
||||
- **Paper**: https://arxiv.org/abs/2304.08485
|
||||
- **Demo**: https://llava.hliu.cc
|
||||
- **Models**: https://huggingface.co/liuhaotian
|
||||
- **License**: Apache 2.0
|
||||
|
||||
|
||||
197
skills/mlops-knowledge/llava/references/training.md
Normal file
197
skills/mlops-knowledge/llava/references/training.md
Normal file
@@ -0,0 +1,197 @@
|
||||
# LLaVA Training Guide
|
||||
|
||||
Guide to training and fine-tuning LLaVA models.
|
||||
|
||||
## Training stages
|
||||
|
||||
### Stage 1: Feature alignment (Pretraining)
|
||||
|
||||
**Purpose**: Align vision encoder with language model
|
||||
|
||||
**Data**: 558K image-caption pairs (CC3M subset)
|
||||
|
||||
```bash
|
||||
# Download pretrained projector or train from scratch
|
||||
bash scripts/v1_5/pretrain.sh
|
||||
```
|
||||
|
||||
**Configuration:**
|
||||
- Base model: Vicuna-7B or LLaMA-2-7B
|
||||
- Vision encoder: CLIP ViT-L/14
|
||||
- Training time: ~20 hours on 8× A100
|
||||
|
||||
### Stage 2: Visual instruction tuning
|
||||
|
||||
**Purpose**: Teach model to follow visual instructions
|
||||
|
||||
**Data**: 150K GPT-generated multimodal instruction data
|
||||
|
||||
```bash
|
||||
# Fine-tune with instruction data
|
||||
bash scripts/v1_5/finetune.sh
|
||||
```
|
||||
|
||||
**Configuration:**
|
||||
- Epochs: 1
|
||||
- Batch size: 128 (across 8 GPUs)
|
||||
- Learning rate: 2e-5
|
||||
- Training time: ~24 hours on 8× A100
|
||||
|
||||
## Data format
|
||||
|
||||
### Instruction data format
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"id": "001",
|
||||
"image": "path/to/image.jpg",
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "<image>\nWhat is in this image?"
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "The image shows a dog playing in a park."
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "What breed is the dog?"
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "It appears to be a Golden Retriever."
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
## Fine-tuning on custom data
|
||||
|
||||
### Prepare your data
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
# Create instruction data
|
||||
data = []
|
||||
for image_path, qa_pairs in your_dataset:
|
||||
conversations = []
|
||||
for q, a in qa_pairs:
|
||||
conversations.append({"from": "human", "value": f"<image>\n{q}"})
|
||||
conversations.append({"from": "gpt", "value": a})
|
||||
|
||||
data.append({
|
||||
"id": str(len(data)),
|
||||
"image": image_path,
|
||||
"conversations": conversations
|
||||
})
|
||||
|
||||
# Save
|
||||
with open("custom_data.json", "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
```
|
||||
|
||||
### Fine-tune script
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
|
||||
# Set paths
|
||||
DATA_PATH="custom_data.json"
|
||||
IMAGE_FOLDER="path/to/images"
|
||||
MODEL_PATH="liuhaotian/llava-v1.5-7b"
|
||||
OUTPUT_DIR="./checkpoints/llava-custom"
|
||||
|
||||
# Fine-tune
|
||||
deepspeed llava/train/train_mem.py \
|
||||
--deepspeed ./scripts/zero2.json \
|
||||
--model_name_or_path $MODEL_PATH \
|
||||
--version v1 \
|
||||
--data_path $DATA_PATH \
|
||||
--image_folder $IMAGE_FOLDER \
|
||||
--vision_tower openai/clip-vit-large-patch14-336 \
|
||||
--mm_projector_type mlp2x_gelu \
|
||||
--mm_vision_select_layer -2 \
|
||||
--mm_use_im_start_end False \
|
||||
--mm_use_im_patch_token False \
|
||||
--image_aspect_ratio pad \
|
||||
--group_by_modality_length True \
|
||||
--bf16 True \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--num_train_epochs 1 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--per_device_eval_batch_size 4 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--evaluation_strategy "no" \
|
||||
--save_strategy "steps" \
|
||||
--save_steps 50000 \
|
||||
--save_total_limit 1 \
|
||||
--learning_rate 2e-5 \
|
||||
--weight_decay 0. \
|
||||
--warmup_ratio 0.03 \
|
||||
--lr_scheduler_type "cosine" \
|
||||
--logging_steps 1 \
|
||||
--tf32 True \
|
||||
--model_max_length 2048 \
|
||||
--gradient_checkpointing True \
|
||||
--dataloader_num_workers 4 \
|
||||
--lazy_preprocess True \
|
||||
--report_to wandb
|
||||
```
|
||||
|
||||
## LoRA fine-tuning (memory efficient)
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
# LoRA config
|
||||
lora_config = LoraConfig(
|
||||
r=8, # LoRA rank
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
model = get_peft_model(base_model, lora_config)
|
||||
|
||||
# Train with much lower memory
|
||||
```
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
### Full fine-tuning
|
||||
|
||||
- **7B model**: 8× A100 (40GB)
|
||||
- **13B model**: 8× A100 (80GB)
|
||||
- **Training time**: 20-48 hours
|
||||
|
||||
### LoRA fine-tuning
|
||||
|
||||
- **7B model**: 1× A100 (40GB)
|
||||
- **13B model**: 2× A100 (40GB)
|
||||
- **Training time**: 10-24 hours
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with pretrained** - Don't train from scratch
|
||||
2. **Use LoRA for efficiency** - 10× less memory
|
||||
3. **Quality over quantity** - 1K high-quality > 10K low-quality
|
||||
4. **Multi-turn conversations** - More engaging than single Q&A
|
||||
5. **Diverse images** - Cover different scenarios
|
||||
6. **Clear instructions** - Specific questions get better answers
|
||||
7. **Monitor loss** - Should decrease smoothly
|
||||
8. **Save checkpoints** - Training can fail
|
||||
9. **Test regularly** - Validate on held-out set
|
||||
10. **Use DeepSpeed** - For multi-GPU training
|
||||
|
||||
## Resources
|
||||
|
||||
- **Training script**: https://github.com/haotian-liu/LLaVA/tree/main/scripts
|
||||
- **Data format**: https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md
|
||||
- **Paper**: https://arxiv.org/abs/2304.08485
|
||||
490
skills/mlops-knowledge/lm-evaluation-harness/SKILL.md
Normal file
490
skills/mlops-knowledge/lm-evaluation-harness/SKILL.md
Normal file
@@ -0,0 +1,490 @@
|
||||
---
|
||||
name: evaluating-llms-harness
|
||||
description: Evaluates LLMs across 60+ academic benchmarks (MMLU, HumanEval, GSM8K, TruthfulQA, HellaSwag). Use when benchmarking model quality, comparing models, reporting academic results, or tracking training progress. Industry standard used by EleutherAI, HuggingFace, and major labs. Supports HuggingFace, vLLM, APIs.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Evaluation, LM Evaluation Harness, Benchmarking, MMLU, HumanEval, GSM8K, EleutherAI, Model Quality, Academic Benchmarks, Industry Standard]
|
||||
dependencies: [lm-eval, transformers, vllm]
|
||||
---
|
||||
|
||||
# lm-evaluation-harness - LLM Benchmarking
|
||||
|
||||
## Quick start
|
||||
|
||||
lm-evaluation-harness evaluates LLMs across 60+ academic benchmarks using standardized prompts and metrics.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install lm-eval
|
||||
```
|
||||
|
||||
**Evaluate any HuggingFace model**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu,gsm8k,hellaswag \
|
||||
--device cuda:0 \
|
||||
--batch_size 8
|
||||
```
|
||||
|
||||
**View available tasks**:
|
||||
```bash
|
||||
lm_eval --tasks list
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Standard benchmark evaluation
|
||||
|
||||
Evaluate model on core benchmarks (MMLU, GSM8K, HumanEval).
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
Benchmark Evaluation:
|
||||
- [ ] Step 1: Choose benchmark suite
|
||||
- [ ] Step 2: Configure model
|
||||
- [ ] Step 3: Run evaluation
|
||||
- [ ] Step 4: Analyze results
|
||||
```
|
||||
|
||||
**Step 1: Choose benchmark suite**
|
||||
|
||||
**Core reasoning benchmarks**:
|
||||
- **MMLU** (Massive Multitask Language Understanding) - 57 subjects, multiple choice
|
||||
- **GSM8K** - Grade school math word problems
|
||||
- **HellaSwag** - Common sense reasoning
|
||||
- **TruthfulQA** - Truthfulness and factuality
|
||||
- **ARC** (AI2 Reasoning Challenge) - Science questions
|
||||
|
||||
**Code benchmarks**:
|
||||
- **HumanEval** - Python code generation (164 problems)
|
||||
- **MBPP** (Mostly Basic Python Problems) - Python coding
|
||||
|
||||
**Standard suite** (recommended for model releases):
|
||||
```bash
|
||||
--tasks mmlu,gsm8k,hellaswag,truthfulqa,arc_challenge
|
||||
```
|
||||
|
||||
**Step 2: Configure model**
|
||||
|
||||
**HuggingFace model**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf,dtype=bfloat16 \
|
||||
--tasks mmlu \
|
||||
--device cuda:0 \
|
||||
--batch_size auto # Auto-detect optimal batch size
|
||||
```
|
||||
|
||||
**Quantized model (4-bit/8-bit)**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf,load_in_4bit=True \
|
||||
--tasks mmlu \
|
||||
--device cuda:0
|
||||
```
|
||||
|
||||
**Custom checkpoint**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=/path/to/my-model,tokenizer=/path/to/tokenizer \
|
||||
--tasks mmlu \
|
||||
--device cuda:0
|
||||
```
|
||||
|
||||
**Step 3: Run evaluation**
|
||||
|
||||
```bash
|
||||
# Full MMLU evaluation (57 subjects)
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu \
|
||||
--num_fewshot 5 \ # 5-shot evaluation (standard)
|
||||
--batch_size 8 \
|
||||
--output_path results/ \
|
||||
--log_samples # Save individual predictions
|
||||
|
||||
# Multiple benchmarks at once
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu,gsm8k,hellaswag,truthfulqa,arc_challenge \
|
||||
--num_fewshot 5 \
|
||||
--batch_size 8 \
|
||||
--output_path results/llama2-7b-eval.json
|
||||
```
|
||||
|
||||
**Step 4: Analyze results**
|
||||
|
||||
Results saved to `results/llama2-7b-eval.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"results": {
|
||||
"mmlu": {
|
||||
"acc": 0.459,
|
||||
"acc_stderr": 0.004
|
||||
},
|
||||
"gsm8k": {
|
||||
"exact_match": 0.142,
|
||||
"exact_match_stderr": 0.006
|
||||
},
|
||||
"hellaswag": {
|
||||
"acc_norm": 0.765,
|
||||
"acc_norm_stderr": 0.004
|
||||
}
|
||||
},
|
||||
"config": {
|
||||
"model": "hf",
|
||||
"model_args": "pretrained=meta-llama/Llama-2-7b-hf",
|
||||
"num_fewshot": 5
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Workflow 2: Track training progress
|
||||
|
||||
Evaluate checkpoints during training.
|
||||
|
||||
```
|
||||
Training Progress Tracking:
|
||||
- [ ] Step 1: Set up periodic evaluation
|
||||
- [ ] Step 2: Choose quick benchmarks
|
||||
- [ ] Step 3: Automate evaluation
|
||||
- [ ] Step 4: Plot learning curves
|
||||
```
|
||||
|
||||
**Step 1: Set up periodic evaluation**
|
||||
|
||||
Evaluate every N training steps:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# eval_checkpoint.sh
|
||||
|
||||
CHECKPOINT_DIR=$1
|
||||
STEP=$2
|
||||
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=$CHECKPOINT_DIR/checkpoint-$STEP \
|
||||
--tasks gsm8k,hellaswag \
|
||||
--num_fewshot 0 \ # 0-shot for speed
|
||||
--batch_size 16 \
|
||||
--output_path results/step-$STEP.json
|
||||
```
|
||||
|
||||
**Step 2: Choose quick benchmarks**
|
||||
|
||||
Fast benchmarks for frequent evaluation:
|
||||
- **HellaSwag**: ~10 minutes on 1 GPU
|
||||
- **GSM8K**: ~5 minutes
|
||||
- **PIQA**: ~2 minutes
|
||||
|
||||
Avoid for frequent eval (too slow):
|
||||
- **MMLU**: ~2 hours (57 subjects)
|
||||
- **HumanEval**: Requires code execution
|
||||
|
||||
**Step 3: Automate evaluation**
|
||||
|
||||
Integrate with training script:
|
||||
|
||||
```python
|
||||
# In training loop
|
||||
if step % eval_interval == 0:
|
||||
model.save_pretrained(f"checkpoints/step-{step}")
|
||||
|
||||
# Run evaluation
|
||||
os.system(f"./eval_checkpoint.sh checkpoints step-{step}")
|
||||
```
|
||||
|
||||
Or use PyTorch Lightning callbacks:
|
||||
|
||||
```python
|
||||
from pytorch_lightning import Callback
|
||||
|
||||
class EvalHarnessCallback(Callback):
|
||||
def on_validation_epoch_end(self, trainer, pl_module):
|
||||
step = trainer.global_step
|
||||
checkpoint_path = f"checkpoints/step-{step}"
|
||||
|
||||
# Save checkpoint
|
||||
trainer.save_checkpoint(checkpoint_path)
|
||||
|
||||
# Run lm-eval
|
||||
os.system(f"lm_eval --model hf --model_args pretrained={checkpoint_path} ...")
|
||||
```
|
||||
|
||||
**Step 4: Plot learning curves**
|
||||
|
||||
```python
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Load all results
|
||||
steps = []
|
||||
mmlu_scores = []
|
||||
|
||||
for file in sorted(glob.glob("results/step-*.json")):
|
||||
with open(file) as f:
|
||||
data = json.load(f)
|
||||
step = int(file.split("-")[1].split(".")[0])
|
||||
steps.append(step)
|
||||
mmlu_scores.append(data["results"]["mmlu"]["acc"])
|
||||
|
||||
# Plot
|
||||
plt.plot(steps, mmlu_scores)
|
||||
plt.xlabel("Training Step")
|
||||
plt.ylabel("MMLU Accuracy")
|
||||
plt.title("Training Progress")
|
||||
plt.savefig("training_curve.png")
|
||||
```
|
||||
|
||||
### Workflow 3: Compare multiple models
|
||||
|
||||
Benchmark suite for model comparison.
|
||||
|
||||
```
|
||||
Model Comparison:
|
||||
- [ ] Step 1: Define model list
|
||||
- [ ] Step 2: Run evaluations
|
||||
- [ ] Step 3: Generate comparison table
|
||||
```
|
||||
|
||||
**Step 1: Define model list**
|
||||
|
||||
```bash
|
||||
# models.txt
|
||||
meta-llama/Llama-2-7b-hf
|
||||
meta-llama/Llama-2-13b-hf
|
||||
mistralai/Mistral-7B-v0.1
|
||||
microsoft/phi-2
|
||||
```
|
||||
|
||||
**Step 2: Run evaluations**
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# eval_all_models.sh
|
||||
|
||||
TASKS="mmlu,gsm8k,hellaswag,truthfulqa"
|
||||
|
||||
while read model; do
|
||||
echo "Evaluating $model"
|
||||
|
||||
# Extract model name for output file
|
||||
model_name=$(echo $model | sed 's/\//-/g')
|
||||
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=$model,dtype=bfloat16 \
|
||||
--tasks $TASKS \
|
||||
--num_fewshot 5 \
|
||||
--batch_size auto \
|
||||
--output_path results/$model_name.json
|
||||
|
||||
done < models.txt
|
||||
```
|
||||
|
||||
**Step 3: Generate comparison table**
|
||||
|
||||
```python
|
||||
import json
|
||||
import pandas as pd
|
||||
|
||||
models = [
|
||||
"meta-llama-Llama-2-7b-hf",
|
||||
"meta-llama-Llama-2-13b-hf",
|
||||
"mistralai-Mistral-7B-v0.1",
|
||||
"microsoft-phi-2"
|
||||
]
|
||||
|
||||
tasks = ["mmlu", "gsm8k", "hellaswag", "truthfulqa"]
|
||||
|
||||
results = []
|
||||
for model in models:
|
||||
with open(f"results/{model}.json") as f:
|
||||
data = json.load(f)
|
||||
row = {"Model": model.replace("-", "/")}
|
||||
for task in tasks:
|
||||
# Get primary metric for each task
|
||||
metrics = data["results"][task]
|
||||
if "acc" in metrics:
|
||||
row[task.upper()] = f"{metrics['acc']:.3f}"
|
||||
elif "exact_match" in metrics:
|
||||
row[task.upper()] = f"{metrics['exact_match']:.3f}"
|
||||
results.append(row)
|
||||
|
||||
df = pd.DataFrame(results)
|
||||
print(df.to_markdown(index=False))
|
||||
```
|
||||
|
||||
Output:
|
||||
```
|
||||
| Model | MMLU | GSM8K | HELLASWAG | TRUTHFULQA |
|
||||
|------------------------|-------|-------|-----------|------------|
|
||||
| meta-llama/Llama-2-7b | 0.459 | 0.142 | 0.765 | 0.391 |
|
||||
| meta-llama/Llama-2-13b | 0.549 | 0.287 | 0.801 | 0.430 |
|
||||
| mistralai/Mistral-7B | 0.626 | 0.395 | 0.812 | 0.428 |
|
||||
| microsoft/phi-2 | 0.560 | 0.613 | 0.682 | 0.447 |
|
||||
```
|
||||
|
||||
### Workflow 4: Evaluate with vLLM (faster inference)
|
||||
|
||||
Use vLLM backend for 5-10x faster evaluation.
|
||||
|
||||
```
|
||||
vLLM Evaluation:
|
||||
- [ ] Step 1: Install vLLM
|
||||
- [ ] Step 2: Configure vLLM backend
|
||||
- [ ] Step 3: Run evaluation
|
||||
```
|
||||
|
||||
**Step 1: Install vLLM**
|
||||
|
||||
```bash
|
||||
pip install vllm
|
||||
```
|
||||
|
||||
**Step 2: Configure vLLM backend**
|
||||
|
||||
```bash
|
||||
lm_eval --model vllm \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8 \
|
||||
--tasks mmlu \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Step 3: Run evaluation**
|
||||
|
||||
vLLM is 5-10× faster than standard HuggingFace:
|
||||
|
||||
```bash
|
||||
# Standard HF: ~2 hours for MMLU on 7B model
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu \
|
||||
--batch_size 8
|
||||
|
||||
# vLLM: ~15-20 minutes for MMLU on 7B model
|
||||
lm_eval --model vllm \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf,tensor_parallel_size=2 \
|
||||
--tasks mmlu \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use lm-evaluation-harness when:**
|
||||
- Benchmarking models for academic papers
|
||||
- Comparing model quality across standard tasks
|
||||
- Tracking training progress
|
||||
- Reporting standardized metrics (everyone uses same prompts)
|
||||
- Need reproducible evaluation
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **HELM** (Stanford): Broader evaluation (fairness, efficiency, calibration)
|
||||
- **AlpacaEval**: Instruction-following evaluation with LLM judges
|
||||
- **MT-Bench**: Conversational multi-turn evaluation
|
||||
- **Custom scripts**: Domain-specific evaluation
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Evaluation too slow**
|
||||
|
||||
Use vLLM backend:
|
||||
```bash
|
||||
lm_eval --model vllm \
|
||||
--model_args pretrained=model-name,tensor_parallel_size=2
|
||||
```
|
||||
|
||||
Or reduce fewshot examples:
|
||||
```bash
|
||||
--num_fewshot 0 # Instead of 5
|
||||
```
|
||||
|
||||
Or evaluate subset of MMLU:
|
||||
```bash
|
||||
--tasks mmlu_stem # Only STEM subjects
|
||||
```
|
||||
|
||||
**Issue: Out of memory**
|
||||
|
||||
Reduce batch size:
|
||||
```bash
|
||||
--batch_size 1 # Or --batch_size auto
|
||||
```
|
||||
|
||||
Use quantization:
|
||||
```bash
|
||||
--model_args pretrained=model-name,load_in_8bit=True
|
||||
```
|
||||
|
||||
Enable CPU offloading:
|
||||
```bash
|
||||
--model_args pretrained=model-name,device_map=auto,offload_folder=offload
|
||||
```
|
||||
|
||||
**Issue: Different results than reported**
|
||||
|
||||
Check fewshot count:
|
||||
```bash
|
||||
--num_fewshot 5 # Most papers use 5-shot
|
||||
```
|
||||
|
||||
Check exact task name:
|
||||
```bash
|
||||
--tasks mmlu # Not mmlu_direct or mmlu_fewshot
|
||||
```
|
||||
|
||||
Verify model and tokenizer match:
|
||||
```bash
|
||||
--model_args pretrained=model-name,tokenizer=same-model-name
|
||||
```
|
||||
|
||||
**Issue: HumanEval not executing code**
|
||||
|
||||
Install execution dependencies:
|
||||
```bash
|
||||
pip install human-eval
|
||||
```
|
||||
|
||||
Enable code execution:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=model-name \
|
||||
--tasks humaneval \
|
||||
--allow_code_execution # Required for HumanEval
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Benchmark descriptions**: See [references/benchmark-guide.md](references/benchmark-guide.md) for detailed description of all 60+ tasks, what they measure, and interpretation.
|
||||
|
||||
**Custom tasks**: See [references/custom-tasks.md](references/custom-tasks.md) for creating domain-specific evaluation tasks.
|
||||
|
||||
**API evaluation**: See [references/api-evaluation.md](references/api-evaluation.md) for evaluating OpenAI, Anthropic, and other API models.
|
||||
|
||||
**Multi-GPU strategies**: See [references/distributed-eval.md](references/distributed-eval.md) for data parallel and tensor parallel evaluation.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA (CUDA 11.8+), works on CPU (very slow)
|
||||
- **VRAM**:
|
||||
- 7B model: 16GB (bf16) or 8GB (8-bit)
|
||||
- 13B model: 28GB (bf16) or 14GB (8-bit)
|
||||
- 70B model: Requires multi-GPU or quantization
|
||||
- **Time** (7B model, single A100):
|
||||
- HellaSwag: 10 minutes
|
||||
- GSM8K: 5 minutes
|
||||
- MMLU (full): 2 hours
|
||||
- HumanEval: 20 minutes
|
||||
|
||||
## Resources
|
||||
|
||||
- GitHub: https://github.com/EleutherAI/lm-evaluation-harness
|
||||
- Docs: https://github.com/EleutherAI/lm-evaluation-harness/tree/main/docs
|
||||
- Task library: 60+ tasks including MMLU, GSM8K, HumanEval, TruthfulQA, HellaSwag, ARC, WinoGrande, etc.
|
||||
- Leaderboard: https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard (uses this harness)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,490 @@
|
||||
# API Evaluation
|
||||
|
||||
Guide to evaluating OpenAI, Anthropic, and other API-based language models.
|
||||
|
||||
## Overview
|
||||
|
||||
The lm-evaluation-harness supports evaluating API-based models through a unified `TemplateAPI` interface. This allows benchmarking of:
|
||||
- OpenAI models (GPT-4, GPT-3.5, etc.)
|
||||
- Anthropic models (Claude 3, Claude 2, etc.)
|
||||
- Local OpenAI-compatible APIs
|
||||
- Custom API endpoints
|
||||
|
||||
**Why evaluate API models**:
|
||||
- Benchmark closed-source models
|
||||
- Compare API models to open models
|
||||
- Validate API performance
|
||||
- Track model updates over time
|
||||
|
||||
## Supported API Models
|
||||
|
||||
| Provider | Model Type | Request Types | Logprobs |
|
||||
|----------|------------|---------------|----------|
|
||||
| OpenAI (completions) | `openai-completions` | All | ✅ Yes |
|
||||
| OpenAI (chat) | `openai-chat-completions` | `generate_until` only | ❌ No |
|
||||
| Anthropic (completions) | `anthropic-completions` | All | ❌ No |
|
||||
| Anthropic (chat) | `anthropic-chat` | `generate_until` only | ❌ No |
|
||||
| Local (OpenAI-compatible) | `local-completions` | Depends on server | Varies |
|
||||
|
||||
**Note**: Models without logprobs can only be evaluated on generation tasks, not perplexity or loglikelihood tasks.
|
||||
|
||||
## OpenAI Models
|
||||
|
||||
### Setup
|
||||
|
||||
```bash
|
||||
export OPENAI_API_KEY=sk-...
|
||||
```
|
||||
|
||||
### Completion Models (Legacy)
|
||||
|
||||
**Available models**: `davinci-002`, `babbage-002`
|
||||
|
||||
```bash
|
||||
lm_eval --model openai-completions \
|
||||
--model_args model=davinci-002 \
|
||||
--tasks lambada_openai,hellaswag \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Supports**:
|
||||
- `generate_until`: ✅
|
||||
- `loglikelihood`: ✅
|
||||
- `loglikelihood_rolling`: ✅
|
||||
|
||||
### Chat Models
|
||||
|
||||
**Available models**: `gpt-4`, `gpt-4-turbo`, `gpt-3.5-turbo`
|
||||
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args model=gpt-4-turbo \
|
||||
--tasks mmlu,gsm8k,humaneval \
|
||||
--num_fewshot 5 \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Supports**:
|
||||
- `generate_until`: ✅
|
||||
- `loglikelihood`: ❌ (no logprobs)
|
||||
- `loglikelihood_rolling`: ❌
|
||||
|
||||
**Important**: Chat models don't provide logprobs, so they can only be used with generation tasks (MMLU, GSM8K, HumanEval), not perplexity tasks.
|
||||
|
||||
### Configuration Options
|
||||
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args \
|
||||
model=gpt-4-turbo,\
|
||||
base_url=https://api.openai.com/v1,\
|
||||
num_concurrent=5,\
|
||||
max_retries=3,\
|
||||
timeout=60,\
|
||||
batch_size=auto
|
||||
```
|
||||
|
||||
**Parameters**:
|
||||
- `model`: Model identifier (required)
|
||||
- `base_url`: API endpoint (default: OpenAI)
|
||||
- `num_concurrent`: Concurrent requests (default: 5)
|
||||
- `max_retries`: Retry failed requests (default: 3)
|
||||
- `timeout`: Request timeout in seconds (default: 60)
|
||||
- `tokenizer`: Tokenizer to use (default: matches model)
|
||||
- `tokenizer_backend`: `"tiktoken"` or `"huggingface"`
|
||||
|
||||
### Cost Management
|
||||
|
||||
OpenAI charges per token. Estimate costs before running:
|
||||
|
||||
```python
|
||||
# Rough estimate
|
||||
num_samples = 1000
|
||||
avg_tokens_per_sample = 500 # input + output
|
||||
cost_per_1k_tokens = 0.01 # GPT-3.5 Turbo
|
||||
|
||||
total_cost = (num_samples * avg_tokens_per_sample / 1000) * cost_per_1k_tokens
|
||||
print(f"Estimated cost: ${total_cost:.2f}")
|
||||
```
|
||||
|
||||
**Cost-saving tips**:
|
||||
- Use `--limit N` for testing
|
||||
- Start with `gpt-3.5-turbo` before `gpt-4`
|
||||
- Set `max_gen_toks` to minimum needed
|
||||
- Use `num_fewshot=0` for zero-shot when possible
|
||||
|
||||
## Anthropic Models
|
||||
|
||||
### Setup
|
||||
|
||||
```bash
|
||||
export ANTHROPIC_API_KEY=sk-ant-...
|
||||
```
|
||||
|
||||
### Completion Models (Legacy)
|
||||
|
||||
```bash
|
||||
lm_eval --model anthropic-completions \
|
||||
--model_args model=claude-2.1 \
|
||||
--tasks lambada_openai,hellaswag \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
### Chat Models (Recommended)
|
||||
|
||||
**Available models**: `claude-3-5-sonnet-20241022`, `claude-3-opus-20240229`, `claude-3-sonnet-20240229`, `claude-3-haiku-20240307`
|
||||
|
||||
```bash
|
||||
lm_eval --model anthropic-chat \
|
||||
--model_args model=claude-3-5-sonnet-20241022 \
|
||||
--tasks mmlu,gsm8k,humaneval \
|
||||
--num_fewshot 5 \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Aliases**: `anthropic-chat-completions` (same as `anthropic-chat`)
|
||||
|
||||
### Configuration Options
|
||||
|
||||
```bash
|
||||
lm_eval --model anthropic-chat \
|
||||
--model_args \
|
||||
model=claude-3-5-sonnet-20241022,\
|
||||
base_url=https://api.anthropic.com,\
|
||||
num_concurrent=5,\
|
||||
max_retries=3,\
|
||||
timeout=60
|
||||
```
|
||||
|
||||
### Cost Management
|
||||
|
||||
Anthropic pricing (as of 2024):
|
||||
- Claude 3.5 Sonnet: $3.00 / 1M input, $15.00 / 1M output
|
||||
- Claude 3 Opus: $15.00 / 1M input, $75.00 / 1M output
|
||||
- Claude 3 Haiku: $0.25 / 1M input, $1.25 / 1M output
|
||||
|
||||
**Budget-friendly strategy**:
|
||||
```bash
|
||||
# Test on small sample first
|
||||
lm_eval --model anthropic-chat \
|
||||
--model_args model=claude-3-haiku-20240307 \
|
||||
--tasks mmlu \
|
||||
--limit 100
|
||||
|
||||
# Then run full eval on best model
|
||||
lm_eval --model anthropic-chat \
|
||||
--model_args model=claude-3-5-sonnet-20241022 \
|
||||
--tasks mmlu \
|
||||
--num_fewshot 5
|
||||
```
|
||||
|
||||
## Local OpenAI-Compatible APIs
|
||||
|
||||
Many local inference servers expose OpenAI-compatible APIs (vLLM, Text Generation Inference, llama.cpp, Ollama).
|
||||
|
||||
### vLLM Local Server
|
||||
|
||||
**Start server**:
|
||||
```bash
|
||||
vllm serve meta-llama/Llama-2-7b-hf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8000
|
||||
```
|
||||
|
||||
**Evaluate**:
|
||||
```bash
|
||||
lm_eval --model local-completions \
|
||||
--model_args \
|
||||
model=meta-llama/Llama-2-7b-hf,\
|
||||
base_url=http://localhost:8000/v1,\
|
||||
num_concurrent=1 \
|
||||
--tasks mmlu,gsm8k \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
### Text Generation Inference (TGI)
|
||||
|
||||
**Start server**:
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 \
|
||||
ghcr.io/huggingface/text-generation-inference:latest \
|
||||
--model-id meta-llama/Llama-2-7b-hf
|
||||
```
|
||||
|
||||
**Evaluate**:
|
||||
```bash
|
||||
lm_eval --model local-completions \
|
||||
--model_args \
|
||||
model=meta-llama/Llama-2-7b-hf,\
|
||||
base_url=http://localhost:8080/v1 \
|
||||
--tasks hellaswag,arc_challenge
|
||||
```
|
||||
|
||||
### Ollama
|
||||
|
||||
**Start server**:
|
||||
```bash
|
||||
ollama serve
|
||||
ollama pull llama2:7b
|
||||
```
|
||||
|
||||
**Evaluate**:
|
||||
```bash
|
||||
lm_eval --model local-completions \
|
||||
--model_args \
|
||||
model=llama2:7b,\
|
||||
base_url=http://localhost:11434/v1 \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
### llama.cpp Server
|
||||
|
||||
**Start server**:
|
||||
```bash
|
||||
./server -m models/llama-2-7b.gguf --host 0.0.0.0 --port 8080
|
||||
```
|
||||
|
||||
**Evaluate**:
|
||||
```bash
|
||||
lm_eval --model local-completions \
|
||||
--model_args \
|
||||
model=llama2,\
|
||||
base_url=http://localhost:8080/v1 \
|
||||
--tasks gsm8k
|
||||
```
|
||||
|
||||
## Custom API Implementation
|
||||
|
||||
For custom API endpoints, subclass `TemplateAPI`:
|
||||
|
||||
### Create `my_api.py`
|
||||
|
||||
```python
|
||||
from lm_eval.models.api_models import TemplateAPI
|
||||
import requests
|
||||
|
||||
class MyCustomAPI(TemplateAPI):
|
||||
"""Custom API model."""
|
||||
|
||||
def __init__(self, base_url, api_key, **kwargs):
|
||||
super().__init__(base_url=base_url, **kwargs)
|
||||
self.api_key = api_key
|
||||
|
||||
def _create_payload(self, messages, gen_kwargs):
|
||||
"""Create API request payload."""
|
||||
return {
|
||||
"messages": messages,
|
||||
"api_key": self.api_key,
|
||||
**gen_kwargs
|
||||
}
|
||||
|
||||
def parse_generations(self, response):
|
||||
"""Parse generation response."""
|
||||
return response.json()["choices"][0]["text"]
|
||||
|
||||
def parse_logprobs(self, response):
|
||||
"""Parse logprobs (if available)."""
|
||||
# Return None if API doesn't provide logprobs
|
||||
logprobs = response.json().get("logprobs")
|
||||
if logprobs:
|
||||
return logprobs["token_logprobs"]
|
||||
return None
|
||||
```
|
||||
|
||||
### Register and Use
|
||||
|
||||
```python
|
||||
from lm_eval import evaluator
|
||||
from my_api import MyCustomAPI
|
||||
|
||||
model = MyCustomAPI(
|
||||
base_url="https://api.example.com/v1",
|
||||
api_key="your-key"
|
||||
)
|
||||
|
||||
results = evaluator.simple_evaluate(
|
||||
model=model,
|
||||
tasks=["mmlu", "gsm8k"],
|
||||
num_fewshot=5,
|
||||
batch_size="auto"
|
||||
)
|
||||
```
|
||||
|
||||
## Comparing API and Open Models
|
||||
|
||||
### Side-by-Side Evaluation
|
||||
|
||||
```bash
|
||||
# Evaluate OpenAI GPT-4
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args model=gpt-4-turbo \
|
||||
--tasks mmlu,gsm8k,hellaswag \
|
||||
--num_fewshot 5 \
|
||||
--output_path results/gpt4.json
|
||||
|
||||
# Evaluate open Llama 2 70B
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-70b-hf,dtype=bfloat16 \
|
||||
--tasks mmlu,gsm8k,hellaswag \
|
||||
--num_fewshot 5 \
|
||||
--output_path results/llama2-70b.json
|
||||
|
||||
# Compare results
|
||||
python scripts/compare_results.py \
|
||||
results/gpt4.json \
|
||||
results/llama2-70b.json
|
||||
```
|
||||
|
||||
### Typical Comparisons
|
||||
|
||||
| Model | MMLU | GSM8K | HumanEval | Cost |
|
||||
|-------|------|-------|-----------|------|
|
||||
| GPT-4 Turbo | 86.4% | 92.0% | 67.0% | $$$$ |
|
||||
| Claude 3 Opus | 86.8% | 95.0% | 84.9% | $$$$ |
|
||||
| GPT-3.5 Turbo | 70.0% | 57.1% | 48.1% | $$ |
|
||||
| Llama 2 70B | 68.9% | 56.8% | 29.9% | Free (self-host) |
|
||||
| Mixtral 8x7B | 70.6% | 58.4% | 40.2% | Free (self-host) |
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Rate Limiting
|
||||
|
||||
Respect API rate limits:
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args \
|
||||
model=gpt-4-turbo,\
|
||||
num_concurrent=3,\ # Lower concurrency
|
||||
timeout=120 \ # Longer timeout
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
### Reproducibility
|
||||
|
||||
Set temperature to 0 for deterministic results:
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args model=gpt-4-turbo \
|
||||
--tasks mmlu \
|
||||
--gen_kwargs temperature=0.0
|
||||
```
|
||||
|
||||
Or use `seed` for sampling:
|
||||
```bash
|
||||
lm_eval --model anthropic-chat \
|
||||
--model_args model=claude-3-5-sonnet-20241022 \
|
||||
--tasks gsm8k \
|
||||
--gen_kwargs temperature=0.7,seed=42
|
||||
```
|
||||
|
||||
### Caching
|
||||
|
||||
API models automatically cache responses to avoid redundant calls:
|
||||
```bash
|
||||
# First run: makes API calls
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args model=gpt-4-turbo \
|
||||
--tasks mmlu \
|
||||
--limit 100
|
||||
|
||||
# Second run: uses cache (instant, free)
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args model=gpt-4-turbo \
|
||||
--tasks mmlu \
|
||||
--limit 100
|
||||
```
|
||||
|
||||
Cache location: `~/.cache/lm_eval/`
|
||||
|
||||
### Error Handling
|
||||
|
||||
APIs can fail. Use retries:
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args \
|
||||
model=gpt-4-turbo,\
|
||||
max_retries=5,\
|
||||
timeout=120 \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Authentication failed"
|
||||
|
||||
Check API key:
|
||||
```bash
|
||||
echo $OPENAI_API_KEY # Should print sk-...
|
||||
echo $ANTHROPIC_API_KEY # Should print sk-ant-...
|
||||
```
|
||||
|
||||
### "Rate limit exceeded"
|
||||
|
||||
Reduce concurrency:
|
||||
```bash
|
||||
--model_args num_concurrent=1
|
||||
```
|
||||
|
||||
Or add delays between requests.
|
||||
|
||||
### "Timeout error"
|
||||
|
||||
Increase timeout:
|
||||
```bash
|
||||
--model_args timeout=180
|
||||
```
|
||||
|
||||
### "Model not found"
|
||||
|
||||
For local APIs, verify server is running:
|
||||
```bash
|
||||
curl http://localhost:8000/v1/models
|
||||
```
|
||||
|
||||
### Cost Runaway
|
||||
|
||||
Use `--limit` for testing:
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args model=gpt-4-turbo \
|
||||
--tasks mmlu \
|
||||
--limit 50 # Only 50 samples
|
||||
```
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Custom Headers
|
||||
|
||||
```bash
|
||||
lm_eval --model local-completions \
|
||||
--model_args \
|
||||
base_url=http://api.example.com/v1,\
|
||||
header="Authorization: Bearer token,X-Custom: value"
|
||||
```
|
||||
|
||||
### Disable SSL Verification (Development Only)
|
||||
|
||||
```bash
|
||||
lm_eval --model local-completions \
|
||||
--model_args \
|
||||
base_url=https://localhost:8000/v1,\
|
||||
verify_certificate=false
|
||||
```
|
||||
|
||||
### Custom Tokenizer
|
||||
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args \
|
||||
model=gpt-4-turbo,\
|
||||
tokenizer=gpt2,\
|
||||
tokenizer_backend=huggingface
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- OpenAI API: https://platform.openai.com/docs/api-reference
|
||||
- Anthropic API: https://docs.anthropic.com/claude/reference
|
||||
- TemplateAPI: `lm_eval/models/api_models.py`
|
||||
- OpenAI models: `lm_eval/models/openai_completions.py`
|
||||
- Anthropic models: `lm_eval/models/anthropic_llms.py`
|
||||
@@ -0,0 +1,488 @@
|
||||
# Benchmark Guide
|
||||
|
||||
Complete guide to all 60+ evaluation tasks in lm-evaluation-harness, what they measure, and how to interpret results.
|
||||
|
||||
## Overview
|
||||
|
||||
The lm-evaluation-harness includes 60+ benchmarks spanning:
|
||||
- Language understanding (MMLU, GLUE)
|
||||
- Mathematical reasoning (GSM8K, MATH)
|
||||
- Code generation (HumanEval, MBPP)
|
||||
- Instruction following (IFEval, AlpacaEval)
|
||||
- Long-context understanding (LongBench)
|
||||
- Multilingual capabilities (AfroBench, NorEval)
|
||||
- Reasoning (BBH, ARC)
|
||||
- Truthfulness (TruthfulQA)
|
||||
|
||||
**List all tasks**:
|
||||
```bash
|
||||
lm_eval --tasks list
|
||||
```
|
||||
|
||||
## Major Benchmarks
|
||||
|
||||
### MMLU (Massive Multitask Language Understanding)
|
||||
|
||||
**What it measures**: Broad knowledge across 57 subjects (STEM, humanities, social sciences, law).
|
||||
|
||||
**Task variants**:
|
||||
- `mmlu`: Original 57-subject benchmark
|
||||
- `mmlu_pro`: More challenging version with reasoning-focused questions
|
||||
- `mmlu_prox`: Multilingual extension
|
||||
|
||||
**Format**: Multiple choice (4 options)
|
||||
|
||||
**Example**:
|
||||
```
|
||||
Question: What is the capital of France?
|
||||
A. Berlin
|
||||
B. Paris
|
||||
C. London
|
||||
D. Madrid
|
||||
Answer: B
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu \
|
||||
--num_fewshot 5
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Random: 25% (chance)
|
||||
- GPT-3 (175B): 43.9%
|
||||
- GPT-4: 86.4%
|
||||
- Human expert: ~90%
|
||||
|
||||
**Good for**: Assessing general knowledge and domain expertise.
|
||||
|
||||
### GSM8K (Grade School Math 8K)
|
||||
|
||||
**What it measures**: Mathematical reasoning on grade-school level word problems.
|
||||
|
||||
**Task variants**:
|
||||
- `gsm8k`: Base task
|
||||
- `gsm8k_cot`: With chain-of-thought prompting
|
||||
- `gsm_plus`: Adversarial variant with perturbations
|
||||
|
||||
**Format**: Free-form generation, extract numerical answer
|
||||
|
||||
**Example**:
|
||||
```
|
||||
Question: A baker made 200 cookies. He sold 3/5 of them in the morning and 1/4 of the remaining in the afternoon. How many cookies does he have left?
|
||||
Answer: 60
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks gsm8k \
|
||||
--num_fewshot 5
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Random: ~0%
|
||||
- GPT-3 (175B): 17.0%
|
||||
- GPT-4: 92.0%
|
||||
- Llama 2 70B: 56.8%
|
||||
|
||||
**Good for**: Testing multi-step reasoning and arithmetic.
|
||||
|
||||
### HumanEval
|
||||
|
||||
**What it measures**: Python code generation from docstrings (functional correctness).
|
||||
|
||||
**Task variants**:
|
||||
- `humaneval`: Standard benchmark
|
||||
- `humaneval_instruct`: For instruction-tuned models
|
||||
|
||||
**Format**: Code generation, execution-based evaluation
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
def has_close_elements(numbers: List[float], threshold: float) -> bool:
|
||||
""" Check if in given list of numbers, are any two numbers closer to each other than
|
||||
given threshold.
|
||||
>>> has_close_elements([1.0, 2.0, 3.0], 0.5)
|
||||
False
|
||||
>>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
|
||||
True
|
||||
"""
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=codellama/CodeLlama-7b-hf \
|
||||
--tasks humaneval \
|
||||
--batch_size 1
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Random: 0%
|
||||
- GPT-3 (175B): 0%
|
||||
- Codex: 28.8%
|
||||
- GPT-4: 67.0%
|
||||
- Code Llama 34B: 53.7%
|
||||
|
||||
**Good for**: Evaluating code generation capabilities.
|
||||
|
||||
### BBH (BIG-Bench Hard)
|
||||
|
||||
**What it measures**: 23 challenging reasoning tasks where models previously failed to beat humans.
|
||||
|
||||
**Categories**:
|
||||
- Logical reasoning
|
||||
- Math word problems
|
||||
- Social understanding
|
||||
- Algorithmic reasoning
|
||||
|
||||
**Format**: Multiple choice and free-form
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks bbh \
|
||||
--num_fewshot 3
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Random: ~25%
|
||||
- GPT-3 (175B): 33.9%
|
||||
- PaLM 540B: 58.3%
|
||||
- GPT-4: 86.7%
|
||||
|
||||
**Good for**: Testing advanced reasoning capabilities.
|
||||
|
||||
### IFEval (Instruction-Following Evaluation)
|
||||
|
||||
**What it measures**: Ability to follow specific, verifiable instructions.
|
||||
|
||||
**Instruction types**:
|
||||
- Format constraints (e.g., "answer in 3 sentences")
|
||||
- Length constraints (e.g., "use at least 100 words")
|
||||
- Content constraints (e.g., "include the word 'banana'")
|
||||
- Structural constraints (e.g., "use bullet points")
|
||||
|
||||
**Format**: Free-form generation with rule-based verification
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-chat-hf \
|
||||
--tasks ifeval \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Measures: Instruction adherence (not quality)
|
||||
- GPT-4: 86% instruction following
|
||||
- Claude 2: 84%
|
||||
|
||||
**Good for**: Evaluating chat/instruct models.
|
||||
|
||||
### GLUE (General Language Understanding Evaluation)
|
||||
|
||||
**What it measures**: Natural language understanding across 9 tasks.
|
||||
|
||||
**Tasks**:
|
||||
- `cola`: Grammatical acceptability
|
||||
- `sst2`: Sentiment analysis
|
||||
- `mrpc`: Paraphrase detection
|
||||
- `qqp`: Question pairs
|
||||
- `stsb`: Semantic similarity
|
||||
- `mnli`: Natural language inference
|
||||
- `qnli`: Question answering NLI
|
||||
- `rte`: Recognizing textual entailment
|
||||
- `wnli`: Winograd schemas
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=bert-base-uncased \
|
||||
--tasks glue \
|
||||
--num_fewshot 0
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- BERT Base: 78.3 (GLUE score)
|
||||
- RoBERTa Large: 88.5
|
||||
- Human baseline: 87.1
|
||||
|
||||
**Good for**: Encoder-only models, fine-tuning baselines.
|
||||
|
||||
### LongBench
|
||||
|
||||
**What it measures**: Long-context understanding (4K-32K tokens).
|
||||
|
||||
**21 tasks covering**:
|
||||
- Single-document QA
|
||||
- Multi-document QA
|
||||
- Summarization
|
||||
- Few-shot learning
|
||||
- Code completion
|
||||
- Synthetic tasks
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks longbench \
|
||||
--batch_size 1
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Tests context utilization
|
||||
- Many models struggle beyond 4K tokens
|
||||
- GPT-4 Turbo: 54.3%
|
||||
|
||||
**Good for**: Evaluating long-context models.
|
||||
|
||||
## Additional Benchmarks
|
||||
|
||||
### TruthfulQA
|
||||
|
||||
**What it measures**: Model's propensity to be truthful vs. generate plausible-sounding falsehoods.
|
||||
|
||||
**Format**: Multiple choice with 4-5 options
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks truthfulqa_mc2 \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Larger models often score worse (more convincing lies)
|
||||
- GPT-3: 58.8%
|
||||
- GPT-4: 59.0%
|
||||
- Human: ~94%
|
||||
|
||||
### ARC (AI2 Reasoning Challenge)
|
||||
|
||||
**What it measures**: Grade-school science questions.
|
||||
|
||||
**Variants**:
|
||||
- `arc_easy`: Easier questions
|
||||
- `arc_challenge`: Harder questions requiring reasoning
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks arc_challenge \
|
||||
--num_fewshot 25
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- ARC-Easy: Most models >80%
|
||||
- ARC-Challenge random: 25%
|
||||
- GPT-4: 96.3%
|
||||
|
||||
### HellaSwag
|
||||
|
||||
**What it measures**: Commonsense reasoning about everyday situations.
|
||||
|
||||
**Format**: Choose most plausible continuation
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks hellaswag \
|
||||
--num_fewshot 10
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Random: 25%
|
||||
- GPT-3: 78.9%
|
||||
- Llama 2 70B: 85.3%
|
||||
|
||||
### WinoGrande
|
||||
|
||||
**What it measures**: Commonsense reasoning via pronoun resolution.
|
||||
|
||||
**Example**:
|
||||
```
|
||||
The trophy doesn't fit in the brown suitcase because _ is too large.
|
||||
A. the trophy
|
||||
B. the suitcase
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks winogrande \
|
||||
--num_fewshot 5
|
||||
```
|
||||
|
||||
### PIQA
|
||||
|
||||
**What it measures**: Physical commonsense reasoning.
|
||||
|
||||
**Example**: "To clean a keyboard, use compressed air or..."
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks piqa
|
||||
```
|
||||
|
||||
## Multilingual Benchmarks
|
||||
|
||||
### AfroBench
|
||||
|
||||
**What it measures**: Performance across 64 African languages.
|
||||
|
||||
**15 tasks**: NLU, text generation, knowledge, QA, math reasoning
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks afrobench
|
||||
```
|
||||
|
||||
### NorEval
|
||||
|
||||
**What it measures**: Norwegian language understanding (9 task categories).
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=NbAiLab/nb-gpt-j-6B \
|
||||
--tasks noreval
|
||||
```
|
||||
|
||||
## Domain-Specific Benchmarks
|
||||
|
||||
### MATH
|
||||
|
||||
**What it measures**: High-school competition math problems.
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks math \
|
||||
--num_fewshot 4
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Very challenging
|
||||
- GPT-4: 42.5%
|
||||
- Minerva 540B: 33.6%
|
||||
|
||||
### MBPP (Mostly Basic Python Problems)
|
||||
|
||||
**What it measures**: Python programming from natural language descriptions.
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=codellama/CodeLlama-7b-hf \
|
||||
--tasks mbpp \
|
||||
--batch_size 1
|
||||
```
|
||||
|
||||
### DROP
|
||||
|
||||
**What it measures**: Reading comprehension requiring discrete reasoning.
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks drop
|
||||
```
|
||||
|
||||
## Benchmark Selection Guide
|
||||
|
||||
### For General Purpose Models
|
||||
|
||||
Run this suite:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu,gsm8k,hellaswag,arc_challenge,truthfulqa_mc2 \
|
||||
--num_fewshot 5
|
||||
```
|
||||
|
||||
### For Code Models
|
||||
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=codellama/CodeLlama-7b-hf \
|
||||
--tasks humaneval,mbpp \
|
||||
--batch_size 1
|
||||
```
|
||||
|
||||
### For Chat/Instruct Models
|
||||
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-chat-hf \
|
||||
--tasks ifeval,mmlu,gsm8k_cot \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
### For Long Context Models
|
||||
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-3.1-8B \
|
||||
--tasks longbench \
|
||||
--batch_size 1
|
||||
```
|
||||
|
||||
## Interpreting Results
|
||||
|
||||
### Understanding Metrics
|
||||
|
||||
**Accuracy**: Percentage of correct answers (most common)
|
||||
|
||||
**Exact Match (EM)**: Requires exact string match (strict)
|
||||
|
||||
**F1 Score**: Balances precision and recall
|
||||
|
||||
**BLEU/ROUGE**: Text generation similarity
|
||||
|
||||
**Pass@k**: Percentage passing when generating k samples
|
||||
|
||||
### Typical Score Ranges
|
||||
|
||||
| Model Size | MMLU | GSM8K | HumanEval | HellaSwag |
|
||||
|------------|------|-------|-----------|-----------|
|
||||
| 7B | 40-50% | 10-20% | 5-15% | 70-80% |
|
||||
| 13B | 45-55% | 20-35% | 15-25% | 75-82% |
|
||||
| 70B | 60-70% | 50-65% | 35-50% | 82-87% |
|
||||
| GPT-4 | 86% | 92% | 67% | 95% |
|
||||
|
||||
### Red Flags
|
||||
|
||||
- **All tasks at random chance**: Model not trained properly
|
||||
- **Exact 0% on generation tasks**: Likely format/parsing issue
|
||||
- **Huge variance across runs**: Check seed/sampling settings
|
||||
- **Better than GPT-4 on everything**: Likely contamination
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always report few-shot setting**: 0-shot, 5-shot, etc.
|
||||
2. **Run multiple seeds**: Report mean ± std
|
||||
3. **Check for data contamination**: Search training data for benchmark examples
|
||||
4. **Compare to published baselines**: Validate your setup
|
||||
5. **Report all hyperparameters**: Model, batch size, max tokens, temperature
|
||||
|
||||
## References
|
||||
|
||||
- Task list: `lm_eval --tasks list`
|
||||
- Task README: `lm_eval/tasks/README.md`
|
||||
- Papers: See individual benchmark papers
|
||||
@@ -0,0 +1,602 @@
|
||||
# Custom Tasks
|
||||
|
||||
Complete guide to creating domain-specific evaluation tasks in lm-evaluation-harness.
|
||||
|
||||
## Overview
|
||||
|
||||
Custom tasks allow you to evaluate models on your own datasets and metrics. Tasks are defined using YAML configuration files with optional Python utilities for complex logic.
|
||||
|
||||
**Why create custom tasks**:
|
||||
- Evaluate on proprietary/domain-specific data
|
||||
- Test specific capabilities not covered by existing benchmarks
|
||||
- Create evaluation pipelines for internal models
|
||||
- Reproduce research experiments
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Minimal Custom Task
|
||||
|
||||
Create `my_tasks/simple_qa.yaml`:
|
||||
|
||||
```yaml
|
||||
task: simple_qa
|
||||
dataset_path: data/simple_qa.jsonl
|
||||
output_type: generate_until
|
||||
doc_to_text: "Question: {{question}}\nAnswer:"
|
||||
doc_to_target: "{{answer}}"
|
||||
metric_list:
|
||||
- metric: exact_match
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
```
|
||||
|
||||
**Run it**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks simple_qa \
|
||||
--include_path my_tasks/
|
||||
```
|
||||
|
||||
## Task Configuration Reference
|
||||
|
||||
### Essential Fields
|
||||
|
||||
```yaml
|
||||
# Task identification
|
||||
task: my_custom_task # Unique task name (required)
|
||||
task_alias: "My Task" # Display name
|
||||
tag: # Tags for grouping
|
||||
- custom
|
||||
- domain_specific
|
||||
|
||||
# Dataset configuration
|
||||
dataset_path: data/my_data.jsonl # HuggingFace dataset or local path
|
||||
dataset_name: default # Subset name (if applicable)
|
||||
training_split: train
|
||||
validation_split: validation
|
||||
test_split: test
|
||||
|
||||
# Evaluation configuration
|
||||
output_type: generate_until # or loglikelihood, multiple_choice
|
||||
num_fewshot: 5 # Number of few-shot examples
|
||||
batch_size: auto # Batch size
|
||||
|
||||
# Prompt templates (Jinja2)
|
||||
doc_to_text: "Question: {{question}}"
|
||||
doc_to_target: "{{answer}}"
|
||||
|
||||
# Metrics
|
||||
metric_list:
|
||||
- metric: exact_match
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
|
||||
# Metadata
|
||||
metadata:
|
||||
version: 1.0
|
||||
```
|
||||
|
||||
### Output Types
|
||||
|
||||
**`generate_until`**: Free-form generation
|
||||
```yaml
|
||||
output_type: generate_until
|
||||
generation_kwargs:
|
||||
max_gen_toks: 256
|
||||
until:
|
||||
- "\n"
|
||||
- "."
|
||||
temperature: 0.0
|
||||
```
|
||||
|
||||
**`loglikelihood`**: Compute log probability of targets
|
||||
```yaml
|
||||
output_type: loglikelihood
|
||||
# Used for perplexity, classification
|
||||
```
|
||||
|
||||
**`multiple_choice`**: Choose from options
|
||||
```yaml
|
||||
output_type: multiple_choice
|
||||
doc_to_choice: "{{choices}}" # List of choices
|
||||
```
|
||||
|
||||
## Data Formats
|
||||
|
||||
### Local JSONL File
|
||||
|
||||
`data/my_data.jsonl`:
|
||||
```json
|
||||
{"question": "What is 2+2?", "answer": "4"}
|
||||
{"question": "Capital of France?", "answer": "Paris"}
|
||||
```
|
||||
|
||||
**Task config**:
|
||||
```yaml
|
||||
dataset_path: data/my_data.jsonl
|
||||
dataset_kwargs:
|
||||
data_files:
|
||||
test: data/my_data.jsonl
|
||||
```
|
||||
|
||||
### HuggingFace Dataset
|
||||
|
||||
```yaml
|
||||
dataset_path: squad
|
||||
dataset_name: plain_text
|
||||
test_split: validation
|
||||
```
|
||||
|
||||
### CSV File
|
||||
|
||||
`data/my_data.csv`:
|
||||
```csv
|
||||
question,answer,category
|
||||
What is 2+2?,4,math
|
||||
Capital of France?,Paris,geography
|
||||
```
|
||||
|
||||
**Task config**:
|
||||
```yaml
|
||||
dataset_path: data/my_data.csv
|
||||
dataset_kwargs:
|
||||
data_files:
|
||||
test: data/my_data.csv
|
||||
```
|
||||
|
||||
## Prompt Engineering
|
||||
|
||||
### Simple Template
|
||||
|
||||
```yaml
|
||||
doc_to_text: "Question: {{question}}\nAnswer:"
|
||||
doc_to_target: "{{answer}}"
|
||||
```
|
||||
|
||||
### Conditional Logic
|
||||
|
||||
```yaml
|
||||
doc_to_text: |
|
||||
{% if context %}
|
||||
Context: {{context}}
|
||||
{% endif %}
|
||||
Question: {{question}}
|
||||
Answer:
|
||||
```
|
||||
|
||||
### Multiple Choice
|
||||
|
||||
```yaml
|
||||
doc_to_text: |
|
||||
Question: {{question}}
|
||||
A. {{choices[0]}}
|
||||
B. {{choices[1]}}
|
||||
C. {{choices[2]}}
|
||||
D. {{choices[3]}}
|
||||
Answer:
|
||||
|
||||
doc_to_target: "{{ 'ABCD'[answer_idx] }}"
|
||||
doc_to_choice: ["A", "B", "C", "D"]
|
||||
```
|
||||
|
||||
### Few-Shot Formatting
|
||||
|
||||
```yaml
|
||||
fewshot_delimiter: "\n\n" # Between examples
|
||||
target_delimiter: " " # Between question and answer
|
||||
doc_to_text: "Q: {{question}}"
|
||||
doc_to_target: "A: {{answer}}"
|
||||
```
|
||||
|
||||
## Custom Python Functions
|
||||
|
||||
For complex logic, use Python functions in `utils.py`.
|
||||
|
||||
### Create `my_tasks/utils.py`
|
||||
|
||||
```python
|
||||
def process_docs(dataset):
|
||||
"""Preprocess documents."""
|
||||
def _process(doc):
|
||||
# Custom preprocessing
|
||||
doc["question"] = doc["question"].strip().lower()
|
||||
return doc
|
||||
|
||||
return dataset.map(_process)
|
||||
|
||||
def doc_to_text(doc):
|
||||
"""Custom prompt formatting."""
|
||||
context = doc.get("context", "")
|
||||
question = doc["question"]
|
||||
|
||||
if context:
|
||||
return f"Context: {context}\nQuestion: {question}\nAnswer:"
|
||||
return f"Question: {question}\nAnswer:"
|
||||
|
||||
def doc_to_target(doc):
|
||||
"""Custom target extraction."""
|
||||
return doc["answer"].strip().lower()
|
||||
|
||||
def aggregate_scores(items):
|
||||
"""Custom metric aggregation."""
|
||||
correct = sum(1 for item in items if item == 1.0)
|
||||
total = len(items)
|
||||
return correct / total if total > 0 else 0.0
|
||||
```
|
||||
|
||||
### Use in Task Config
|
||||
|
||||
```yaml
|
||||
task: my_custom_task
|
||||
dataset_path: data/my_data.jsonl
|
||||
|
||||
# Use Python functions
|
||||
process_docs: !function utils.process_docs
|
||||
doc_to_text: !function utils.doc_to_text
|
||||
doc_to_target: !function utils.doc_to_target
|
||||
|
||||
metric_list:
|
||||
- metric: exact_match
|
||||
aggregation: !function utils.aggregate_scores
|
||||
higher_is_better: true
|
||||
```
|
||||
|
||||
## Real-World Examples
|
||||
|
||||
### Example 1: Domain QA Task
|
||||
|
||||
**Goal**: Evaluate medical question answering.
|
||||
|
||||
`medical_qa/medical_qa.yaml`:
|
||||
```yaml
|
||||
task: medical_qa
|
||||
dataset_path: data/medical_qa.jsonl
|
||||
output_type: generate_until
|
||||
num_fewshot: 3
|
||||
|
||||
doc_to_text: |
|
||||
Medical Question: {{question}}
|
||||
Context: {{context}}
|
||||
Answer (be concise):
|
||||
|
||||
doc_to_target: "{{answer}}"
|
||||
|
||||
generation_kwargs:
|
||||
max_gen_toks: 100
|
||||
until:
|
||||
- "\n\n"
|
||||
temperature: 0.0
|
||||
|
||||
metric_list:
|
||||
- metric: exact_match
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
- metric: !function utils.medical_f1
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
|
||||
filter_list:
|
||||
- name: lowercase
|
||||
filter:
|
||||
- function: lowercase
|
||||
- function: remove_whitespace
|
||||
|
||||
metadata:
|
||||
version: 1.0
|
||||
domain: medical
|
||||
```
|
||||
|
||||
`medical_qa/utils.py`:
|
||||
```python
|
||||
from sklearn.metrics import f1_score
|
||||
import re
|
||||
|
||||
def medical_f1(predictions, references):
|
||||
"""Custom F1 for medical terms."""
|
||||
pred_terms = set(extract_medical_terms(predictions[0]))
|
||||
ref_terms = set(extract_medical_terms(references[0]))
|
||||
|
||||
if not pred_terms and not ref_terms:
|
||||
return 1.0
|
||||
if not pred_terms or not ref_terms:
|
||||
return 0.0
|
||||
|
||||
tp = len(pred_terms & ref_terms)
|
||||
fp = len(pred_terms - ref_terms)
|
||||
fn = len(ref_terms - pred_terms)
|
||||
|
||||
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
||||
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
||||
|
||||
return 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||
|
||||
def extract_medical_terms(text):
|
||||
"""Extract medical terminology."""
|
||||
# Custom logic
|
||||
return re.findall(r'\b[A-Z][a-z]+(?:[A-Z][a-z]+)*\b', text)
|
||||
```
|
||||
|
||||
### Example 2: Code Evaluation
|
||||
|
||||
`code_eval/python_challenges.yaml`:
|
||||
```yaml
|
||||
task: python_challenges
|
||||
dataset_path: data/python_problems.jsonl
|
||||
output_type: generate_until
|
||||
num_fewshot: 0
|
||||
|
||||
doc_to_text: |
|
||||
Write a Python function to solve:
|
||||
{{problem_statement}}
|
||||
|
||||
Function signature:
|
||||
{{function_signature}}
|
||||
|
||||
doc_to_target: "{{canonical_solution}}"
|
||||
|
||||
generation_kwargs:
|
||||
max_gen_toks: 512
|
||||
until:
|
||||
- "\n\nclass"
|
||||
- "\n\ndef"
|
||||
temperature: 0.2
|
||||
|
||||
metric_list:
|
||||
- metric: !function utils.execute_code
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
|
||||
process_results: !function utils.process_code_results
|
||||
|
||||
metadata:
|
||||
version: 1.0
|
||||
```
|
||||
|
||||
`code_eval/utils.py`:
|
||||
```python
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
def execute_code(predictions, references):
|
||||
"""Execute generated code against test cases."""
|
||||
generated_code = predictions[0]
|
||||
test_cases = json.loads(references[0])
|
||||
|
||||
try:
|
||||
# Execute code with test cases
|
||||
for test_input, expected_output in test_cases:
|
||||
result = execute_with_timeout(generated_code, test_input, timeout=5)
|
||||
if result != expected_output:
|
||||
return 0.0
|
||||
return 1.0
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def execute_with_timeout(code, input_data, timeout=5):
|
||||
"""Safely execute code with timeout."""
|
||||
# Implementation with subprocess and timeout
|
||||
pass
|
||||
|
||||
def process_code_results(doc, results):
|
||||
"""Process code execution results."""
|
||||
return {
|
||||
"passed": results[0] == 1.0,
|
||||
"generated_code": results[1]
|
||||
}
|
||||
```
|
||||
|
||||
### Example 3: Instruction Following
|
||||
|
||||
`instruction_eval/instruction_eval.yaml`:
|
||||
```yaml
|
||||
task: instruction_following
|
||||
dataset_path: data/instructions.jsonl
|
||||
output_type: generate_until
|
||||
num_fewshot: 0
|
||||
|
||||
doc_to_text: |
|
||||
Instruction: {{instruction}}
|
||||
{% if constraints %}
|
||||
Constraints: {{constraints}}
|
||||
{% endif %}
|
||||
Response:
|
||||
|
||||
doc_to_target: "{{expected_response}}"
|
||||
|
||||
generation_kwargs:
|
||||
max_gen_toks: 256
|
||||
temperature: 0.7
|
||||
|
||||
metric_list:
|
||||
- metric: !function utils.check_constraints
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
- metric: !function utils.semantic_similarity
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
|
||||
process_docs: !function utils.add_constraint_checkers
|
||||
```
|
||||
|
||||
`instruction_eval/utils.py`:
|
||||
```python
|
||||
from sentence_transformers import SentenceTransformer, util
|
||||
|
||||
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
|
||||
def check_constraints(predictions, references):
|
||||
"""Check if response satisfies constraints."""
|
||||
response = predictions[0]
|
||||
constraints = json.loads(references[0])
|
||||
|
||||
satisfied = 0
|
||||
total = len(constraints)
|
||||
|
||||
for constraint in constraints:
|
||||
if verify_constraint(response, constraint):
|
||||
satisfied += 1
|
||||
|
||||
return satisfied / total if total > 0 else 1.0
|
||||
|
||||
def verify_constraint(response, constraint):
|
||||
"""Verify single constraint."""
|
||||
if constraint["type"] == "length":
|
||||
return len(response.split()) >= constraint["min_words"]
|
||||
elif constraint["type"] == "contains":
|
||||
return constraint["keyword"] in response.lower()
|
||||
# Add more constraint types
|
||||
return True
|
||||
|
||||
def semantic_similarity(predictions, references):
|
||||
"""Compute semantic similarity."""
|
||||
pred_embedding = model.encode(predictions[0])
|
||||
ref_embedding = model.encode(references[0])
|
||||
return float(util.cos_sim(pred_embedding, ref_embedding))
|
||||
|
||||
def add_constraint_checkers(dataset):
|
||||
"""Parse constraints into verifiable format."""
|
||||
def _parse(doc):
|
||||
# Parse constraint string into structured format
|
||||
doc["parsed_constraints"] = parse_constraints(doc.get("constraints", ""))
|
||||
return doc
|
||||
return dataset.map(_parse)
|
||||
```
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Output Filtering
|
||||
|
||||
```yaml
|
||||
filter_list:
|
||||
- name: extract_answer
|
||||
filter:
|
||||
- function: regex
|
||||
regex_pattern: "Answer: (.*)"
|
||||
group: 1
|
||||
- function: lowercase
|
||||
- function: strip_whitespace
|
||||
```
|
||||
|
||||
### Multiple Metrics
|
||||
|
||||
```yaml
|
||||
metric_list:
|
||||
- metric: exact_match
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
- metric: f1
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
- metric: bleu
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
```
|
||||
|
||||
### Task Groups
|
||||
|
||||
Create `my_tasks/_default.yaml`:
|
||||
```yaml
|
||||
group: my_eval_suite
|
||||
task:
|
||||
- simple_qa
|
||||
- medical_qa
|
||||
- python_challenges
|
||||
```
|
||||
|
||||
**Run entire suite**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks my_eval_suite \
|
||||
--include_path my_tasks/
|
||||
```
|
||||
|
||||
## Testing Your Task
|
||||
|
||||
### Validate Configuration
|
||||
|
||||
```bash
|
||||
# Test task loading
|
||||
lm_eval --tasks my_custom_task --include_path my_tasks/ --limit 0
|
||||
|
||||
# Run on 5 samples
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=gpt2 \
|
||||
--tasks my_custom_task \
|
||||
--include_path my_tasks/ \
|
||||
--limit 5
|
||||
```
|
||||
|
||||
### Debug Mode
|
||||
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=gpt2 \
|
||||
--tasks my_custom_task \
|
||||
--include_path my_tasks/ \
|
||||
--limit 1 \
|
||||
--log_samples # Save input/output samples
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Start simple**: Test with minimal config first
|
||||
2. **Version your tasks**: Use `metadata.version`
|
||||
3. **Document your metrics**: Explain custom metrics in comments
|
||||
4. **Test with multiple models**: Ensure robustness
|
||||
5. **Validate on known examples**: Include sanity checks
|
||||
6. **Use filters carefully**: Can hide errors
|
||||
7. **Handle edge cases**: Empty strings, missing fields
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Classification Task
|
||||
|
||||
```yaml
|
||||
output_type: loglikelihood
|
||||
doc_to_text: "Text: {{text}}\nLabel:"
|
||||
doc_to_target: " {{label}}" # Space prefix important!
|
||||
metric_list:
|
||||
- metric: acc
|
||||
aggregation: mean
|
||||
```
|
||||
|
||||
### Perplexity Evaluation
|
||||
|
||||
```yaml
|
||||
output_type: loglikelihood_rolling
|
||||
doc_to_text: "{{text}}"
|
||||
metric_list:
|
||||
- metric: perplexity
|
||||
aggregation: perplexity
|
||||
```
|
||||
|
||||
### Ranking Task
|
||||
|
||||
```yaml
|
||||
output_type: loglikelihood
|
||||
doc_to_text: "Query: {{query}}\nPassage: {{passage}}\nRelevant:"
|
||||
doc_to_target: [" Yes", " No"]
|
||||
metric_list:
|
||||
- metric: acc
|
||||
aggregation: mean
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**"Task not found"**: Check `--include_path` and task name
|
||||
|
||||
**Empty results**: Verify `doc_to_text` and `doc_to_target` templates
|
||||
|
||||
**Metric errors**: Ensure metric names are correct (exact_match, not exact-match)
|
||||
|
||||
**Filter issues**: Test filters with `--log_samples`
|
||||
|
||||
**Python function not found**: Check `!function module.function_name` syntax
|
||||
|
||||
## References
|
||||
|
||||
- Task system: EleutherAI/lm-evaluation-harness docs
|
||||
- Example tasks: `lm_eval/tasks/` directory
|
||||
- TaskConfig: `lm_eval/api/task.py`
|
||||
@@ -0,0 +1,519 @@
|
||||
# Distributed Evaluation
|
||||
|
||||
Guide to running evaluation across multiple GPUs using data parallelism and tensor/pipeline parallelism.
|
||||
|
||||
## Overview
|
||||
|
||||
Distributed evaluation speeds up benchmarking by:
|
||||
- **Data Parallelism**: Split evaluation samples across GPUs (each GPU has full model copy)
|
||||
- **Tensor Parallelism**: Split model weights across GPUs (for large models)
|
||||
- **Pipeline Parallelism**: Split model layers across GPUs (for very large models)
|
||||
|
||||
**When to use**:
|
||||
- Data Parallel: Model fits on single GPU, want faster evaluation
|
||||
- Tensor/Pipeline Parallel: Model too large for single GPU
|
||||
|
||||
## HuggingFace Models (`hf`)
|
||||
|
||||
### Data Parallelism (Recommended)
|
||||
|
||||
Each GPU loads a full copy of the model and processes a subset of evaluation data.
|
||||
|
||||
**Single Node (8 GPUs)**:
|
||||
```bash
|
||||
accelerate launch --multi_gpu --num_processes 8 \
|
||||
-m lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf,dtype=bfloat16 \
|
||||
--tasks mmlu,gsm8k,hellaswag \
|
||||
--batch_size 16
|
||||
```
|
||||
|
||||
**Speedup**: Near-linear (8 GPUs = ~8× faster)
|
||||
|
||||
**Memory**: Each GPU needs full model (7B model ≈ 14GB × 8 = 112GB total)
|
||||
|
||||
### Tensor Parallelism (Model Sharding)
|
||||
|
||||
Split model weights across GPUs for models too large for single GPU.
|
||||
|
||||
**Without accelerate launcher**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
parallelize=True,\
|
||||
dtype=bfloat16 \
|
||||
--tasks mmlu,gsm8k \
|
||||
--batch_size 8
|
||||
```
|
||||
|
||||
**With 8 GPUs**: 70B model (140GB) / 8 = 17.5GB per GPU ✅
|
||||
|
||||
**Advanced sharding**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
parallelize=True,\
|
||||
device_map_option=auto,\
|
||||
max_memory_per_gpu=40GB,\
|
||||
max_cpu_memory=100GB,\
|
||||
dtype=bfloat16 \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
**Options**:
|
||||
- `device_map_option`: `"auto"` (default), `"balanced"`, `"balanced_low_0"`
|
||||
- `max_memory_per_gpu`: Max memory per GPU (e.g., `"40GB"`)
|
||||
- `max_cpu_memory`: Max CPU memory for offloading
|
||||
- `offload_folder`: Disk offloading directory
|
||||
|
||||
### Combined Data + Tensor Parallelism
|
||||
|
||||
Use both for very large models.
|
||||
|
||||
**Example: 70B model on 16 GPUs (2 copies, 8 GPUs each)**:
|
||||
```bash
|
||||
accelerate launch --multi_gpu --num_processes 2 \
|
||||
-m lm_eval --model hf \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
parallelize=True,\
|
||||
dtype=bfloat16 \
|
||||
--tasks mmlu \
|
||||
--batch_size 8
|
||||
```
|
||||
|
||||
**Result**: 2× speedup from data parallelism, 70B model fits via tensor parallelism
|
||||
|
||||
### Configuration with `accelerate config`
|
||||
|
||||
Create `~/.cache/huggingface/accelerate/default_config.yaml`:
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: MULTI_GPU
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
gpu_ids: all
|
||||
mixed_precision: bf16
|
||||
```
|
||||
|
||||
**Then run**:
|
||||
```bash
|
||||
accelerate launch -m lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
## vLLM Models (`vllm`)
|
||||
|
||||
vLLM provides highly optimized distributed inference.
|
||||
|
||||
### Tensor Parallelism
|
||||
|
||||
**Single Node (4 GPUs)**:
|
||||
```bash
|
||||
lm_eval --model vllm \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
tensor_parallel_size=4,\
|
||||
dtype=auto,\
|
||||
gpu_memory_utilization=0.9 \
|
||||
--tasks mmlu,gsm8k \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Memory**: 70B model split across 4 GPUs = ~35GB per GPU
|
||||
|
||||
### Data Parallelism
|
||||
|
||||
**Multiple model replicas**:
|
||||
```bash
|
||||
lm_eval --model vllm \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-7b-hf,\
|
||||
data_parallel_size=4,\
|
||||
dtype=auto,\
|
||||
gpu_memory_utilization=0.8 \
|
||||
--tasks hellaswag,arc_challenge \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Result**: 4 model replicas = 4× throughput
|
||||
|
||||
### Combined Tensor + Data Parallelism
|
||||
|
||||
**Example: 8 GPUs = 4 TP × 2 DP**:
|
||||
```bash
|
||||
lm_eval --model vllm \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
tensor_parallel_size=4,\
|
||||
data_parallel_size=2,\
|
||||
dtype=auto,\
|
||||
gpu_memory_utilization=0.85 \
|
||||
--tasks mmlu \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Result**: 70B model fits (TP=4), 2× speedup (DP=2)
|
||||
|
||||
### Multi-Node vLLM
|
||||
|
||||
vLLM doesn't natively support multi-node. Use Ray:
|
||||
|
||||
```bash
|
||||
# Start Ray cluster
|
||||
ray start --head --port=6379
|
||||
|
||||
# Run evaluation
|
||||
lm_eval --model vllm \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
tensor_parallel_size=8,\
|
||||
dtype=auto \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
## NVIDIA NeMo Models (`nemo_lm`)
|
||||
|
||||
### Data Replication
|
||||
|
||||
**8 replicas on 8 GPUs**:
|
||||
```bash
|
||||
torchrun --nproc-per-node=8 --no-python \
|
||||
lm_eval --model nemo_lm \
|
||||
--model_args \
|
||||
path=/path/to/model.nemo,\
|
||||
devices=8 \
|
||||
--tasks hellaswag,arc_challenge \
|
||||
--batch_size 32
|
||||
```
|
||||
|
||||
**Speedup**: Near-linear (8× faster)
|
||||
|
||||
### Tensor Parallelism
|
||||
|
||||
**4-way tensor parallelism**:
|
||||
```bash
|
||||
torchrun --nproc-per-node=4 --no-python \
|
||||
lm_eval --model nemo_lm \
|
||||
--model_args \
|
||||
path=/path/to/70b_model.nemo,\
|
||||
devices=4,\
|
||||
tensor_model_parallel_size=4 \
|
||||
--tasks mmlu,gsm8k \
|
||||
--batch_size 16
|
||||
```
|
||||
|
||||
### Pipeline Parallelism
|
||||
|
||||
**2 TP × 2 PP on 4 GPUs**:
|
||||
```bash
|
||||
torchrun --nproc-per-node=4 --no-python \
|
||||
lm_eval --model nemo_lm \
|
||||
--model_args \
|
||||
path=/path/to/model.nemo,\
|
||||
devices=4,\
|
||||
tensor_model_parallel_size=2,\
|
||||
pipeline_model_parallel_size=2 \
|
||||
--tasks mmlu \
|
||||
--batch_size 8
|
||||
```
|
||||
|
||||
**Constraint**: `devices = TP × PP`
|
||||
|
||||
### Multi-Node NeMo
|
||||
|
||||
Currently not supported by lm-evaluation-harness.
|
||||
|
||||
## SGLang Models (`sglang`)
|
||||
|
||||
### Tensor Parallelism
|
||||
|
||||
```bash
|
||||
lm_eval --model sglang \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
tp_size=4,\
|
||||
dtype=auto \
|
||||
--tasks gsm8k \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
### Data Parallelism (Deprecated)
|
||||
|
||||
**Note**: SGLang is deprecating data parallelism. Use tensor parallelism instead.
|
||||
|
||||
```bash
|
||||
lm_eval --model sglang \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-7b-hf,\
|
||||
dp_size=4,\
|
||||
dtype=auto \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
## Performance Comparison
|
||||
|
||||
### 70B Model Evaluation (MMLU, 5-shot)
|
||||
|
||||
| Method | GPUs | Time | Memory/GPU | Notes |
|
||||
|--------|------|------|------------|-------|
|
||||
| HF (no parallel) | 1 | 8 hours | 140GB (OOM) | Won't fit |
|
||||
| HF (TP=8) | 8 | 2 hours | 17.5GB | Slower, fits |
|
||||
| HF (DP=8) | 8 | 1 hour | 140GB (OOM) | Won't fit |
|
||||
| vLLM (TP=4) | 4 | 30 min | 35GB | Fast! |
|
||||
| vLLM (TP=4, DP=2) | 8 | 15 min | 35GB | Fastest |
|
||||
|
||||
### 7B Model Evaluation (Multiple Tasks)
|
||||
|
||||
| Method | GPUs | Time | Speedup |
|
||||
|--------|------|------|---------|
|
||||
| HF (single) | 1 | 4 hours | 1× |
|
||||
| HF (DP=4) | 4 | 1 hour | 4× |
|
||||
| HF (DP=8) | 8 | 30 min | 8× |
|
||||
| vLLM (DP=8) | 8 | 15 min | 16× |
|
||||
|
||||
**Takeaway**: vLLM is significantly faster than HuggingFace for inference.
|
||||
|
||||
## Choosing Parallelism Strategy
|
||||
|
||||
### Decision Tree
|
||||
|
||||
```
|
||||
Model fits on single GPU?
|
||||
├─ YES: Use data parallelism
|
||||
│ ├─ HF: accelerate launch --multi_gpu --num_processes N
|
||||
│ └─ vLLM: data_parallel_size=N (fastest)
|
||||
│
|
||||
└─ NO: Use tensor/pipeline parallelism
|
||||
├─ Model < 70B:
|
||||
│ └─ vLLM: tensor_parallel_size=4
|
||||
├─ Model 70-175B:
|
||||
│ ├─ vLLM: tensor_parallel_size=8
|
||||
│ └─ Or HF: parallelize=True
|
||||
└─ Model > 175B:
|
||||
└─ Contact framework authors
|
||||
```
|
||||
|
||||
### Memory Estimation
|
||||
|
||||
**Rule of thumb**:
|
||||
```
|
||||
Memory (GB) = Parameters (B) × Precision (bytes) × 1.2 (overhead)
|
||||
```
|
||||
|
||||
**Examples**:
|
||||
- 7B FP16: 7 × 2 × 1.2 = 16.8GB ✅ Fits A100 40GB
|
||||
- 13B FP16: 13 × 2 × 1.2 = 31.2GB ✅ Fits A100 40GB
|
||||
- 70B FP16: 70 × 2 × 1.2 = 168GB ❌ Need TP=4 or TP=8
|
||||
- 70B BF16: 70 × 2 × 1.2 = 168GB (same as FP16)
|
||||
|
||||
**With tensor parallelism**:
|
||||
```
|
||||
Memory per GPU = Total Memory / TP
|
||||
```
|
||||
|
||||
- 70B on 4 GPUs: 168GB / 4 = 42GB per GPU ✅
|
||||
- 70B on 8 GPUs: 168GB / 8 = 21GB per GPU ✅
|
||||
|
||||
## Multi-Node Evaluation
|
||||
|
||||
### HuggingFace with SLURM
|
||||
|
||||
**Submit job**:
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#SBATCH --nodes=4
|
||||
#SBATCH --gpus-per-node=8
|
||||
#SBATCH --ntasks-per-node=1
|
||||
|
||||
srun accelerate launch --multi_gpu \
|
||||
--num_processes $((SLURM_NNODES * 8)) \
|
||||
-m lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu,gsm8k,hellaswag \
|
||||
--batch_size 16
|
||||
```
|
||||
|
||||
**Submit**:
|
||||
```bash
|
||||
sbatch eval_job.sh
|
||||
```
|
||||
|
||||
### Manual Multi-Node Setup
|
||||
|
||||
**On each node, run**:
|
||||
```bash
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_machines 4 \
|
||||
--num_processes 32 \
|
||||
--main_process_ip $MASTER_IP \
|
||||
--main_process_port 29500 \
|
||||
--machine_rank $NODE_RANK \
|
||||
-m lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
**Environment variables**:
|
||||
- `MASTER_IP`: IP of rank 0 node
|
||||
- `NODE_RANK`: 0, 1, 2, 3 for each node
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Start Small
|
||||
|
||||
Test on small sample first:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-70b-hf,parallelize=True \
|
||||
--tasks mmlu \
|
||||
--limit 100 # Just 100 samples
|
||||
```
|
||||
|
||||
### 2. Monitor GPU Usage
|
||||
|
||||
```bash
|
||||
# Terminal 1: Run evaluation
|
||||
lm_eval --model hf ...
|
||||
|
||||
# Terminal 2: Monitor
|
||||
watch -n 1 nvidia-smi
|
||||
```
|
||||
|
||||
Look for:
|
||||
- GPU utilization > 90%
|
||||
- Memory usage stable
|
||||
- All GPUs active
|
||||
|
||||
### 3. Optimize Batch Size
|
||||
|
||||
```bash
|
||||
# Auto batch size (recommended)
|
||||
--batch_size auto
|
||||
|
||||
# Or tune manually
|
||||
--batch_size 16 # Start here
|
||||
--batch_size 32 # Increase if memory allows
|
||||
```
|
||||
|
||||
### 4. Use Mixed Precision
|
||||
|
||||
```bash
|
||||
--model_args dtype=bfloat16 # Faster, less memory
|
||||
```
|
||||
|
||||
### 5. Check Communication
|
||||
|
||||
For data parallelism, check network bandwidth:
|
||||
```bash
|
||||
# Should see InfiniBand or high-speed network
|
||||
nvidia-smi topo -m
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "CUDA out of memory"
|
||||
|
||||
**Solutions**:
|
||||
1. Increase tensor parallelism:
|
||||
```bash
|
||||
--model_args tensor_parallel_size=8 # Was 4
|
||||
```
|
||||
|
||||
2. Reduce batch size:
|
||||
```bash
|
||||
--batch_size 4 # Was 16
|
||||
```
|
||||
|
||||
3. Lower precision:
|
||||
```bash
|
||||
--model_args dtype=int8 # Quantization
|
||||
```
|
||||
|
||||
### "NCCL error" or Hanging
|
||||
|
||||
**Check**:
|
||||
1. All GPUs visible: `nvidia-smi`
|
||||
2. NCCL installed: `python -c "import torch; print(torch.cuda.nccl.version())"`
|
||||
3. Network connectivity between nodes
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
export NCCL_DEBUG=INFO # Enable debug logging
|
||||
export NCCL_IB_DISABLE=0 # Use InfiniBand if available
|
||||
```
|
||||
|
||||
### Slow Evaluation
|
||||
|
||||
**Possible causes**:
|
||||
1. **Data loading bottleneck**: Preprocess dataset
|
||||
2. **Low GPU utilization**: Increase batch size
|
||||
3. **Communication overhead**: Reduce parallelism degree
|
||||
|
||||
**Profile**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu \
|
||||
--limit 100 \
|
||||
--log_samples # Check timing
|
||||
```
|
||||
|
||||
### GPUs Imbalanced
|
||||
|
||||
**Symptom**: GPU 0 at 100%, others at 50%
|
||||
|
||||
**Solution**: Use `device_map_option=balanced`:
|
||||
```bash
|
||||
--model_args parallelize=True,device_map_option=balanced
|
||||
```
|
||||
|
||||
## Example Configurations
|
||||
|
||||
### Small Model (7B) - Fast Evaluation
|
||||
|
||||
```bash
|
||||
# 8 A100s, data parallel
|
||||
accelerate launch --multi_gpu --num_processes 8 \
|
||||
-m lm_eval --model hf \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-7b-hf,\
|
||||
dtype=bfloat16 \
|
||||
--tasks mmlu,gsm8k,hellaswag,arc_challenge \
|
||||
--num_fewshot 5 \
|
||||
--batch_size 32
|
||||
|
||||
# Time: ~30 minutes
|
||||
```
|
||||
|
||||
### Large Model (70B) - vLLM
|
||||
|
||||
```bash
|
||||
# 8 H100s, tensor parallel
|
||||
lm_eval --model vllm \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
tensor_parallel_size=8,\
|
||||
dtype=auto,\
|
||||
gpu_memory_utilization=0.9 \
|
||||
--tasks mmlu,gsm8k,humaneval \
|
||||
--num_fewshot 5 \
|
||||
--batch_size auto
|
||||
|
||||
# Time: ~1 hour
|
||||
```
|
||||
|
||||
### Very Large Model (175B+)
|
||||
|
||||
**Requires specialized setup - contact framework maintainers**
|
||||
|
||||
## References
|
||||
|
||||
- HuggingFace Accelerate: https://huggingface.co/docs/accelerate/
|
||||
- vLLM docs: https://docs.vllm.ai/
|
||||
- NeMo docs: https://docs.nvidia.com/nemo-framework/
|
||||
- lm-eval distributed guide: `docs/model_guide.md`
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user