Source code for cw.lib.pipeline.model_loader

"""Shared model loader for the adaptation pipeline.

Extracts model loading logic from AdaptationGenerator into a reusable
class that can serve Outlines generators for any Pydantic schema.
"""

from __future__ import annotations

import logging
from typing import Optional

logger = logging.getLogger(__name__)


[docs] class PipelineModelLoader: """Loads and caches an LLM, producing Outlines generators for arbitrary schemas. Designed as a shared resource: the model loads once, and ``get_generator()`` can be called with different Pydantic schemas cheaply. """ def __init__( self, model_id: str = "Qwen/Qwen2.5-3B-Instruct", device: Optional[str] = None, load_in_4bit: bool = False, ): self.model_id = model_id self.device = device self.load_in_4bit = load_in_4bit self._model = None self._tokenizer = None def _is_model_cached(self) -> bool: """Check if the model is already cached locally.""" try: from huggingface_hub import try_to_load_from_cache result = try_to_load_from_cache(self.model_id, "config.json") return result is not None except Exception: return False def _load_model(self): """Load the HuggingFace model and tokenizer (lazy, cached).""" if self._model is not None: logger.debug("Model already loaded, skipping _load_model") return import outlines import torch from transformers import AutoModelForCausalLM, AutoTokenizer is_cached = self._is_model_cached() if is_cached: logger.info(f"Model '{self.model_id}' found in local cache, loading...") else: logger.info(f"Model '{self.model_id}' not in cache, downloading from HuggingFace Hub...") # Detect device if self.device is None: if torch.backends.mps.is_available(): self.device = "mps" elif torch.cuda.is_available(): self.device = "cuda" else: self.device = "cpu" logger.info(f"Using device: {self.device}, load_in_4bit: {self.load_in_4bit}") dtype = torch.bfloat16 if self.device != "cpu" else torch.float32 model_kwargs = { "torch_dtype": dtype, "low_cpu_mem_usage": True, "device_map": self.device if not self.load_in_4bit else "auto", } if self.load_in_4bit: if self.device != "cuda": logger.warning( f"4-bit quantization requested but device is {self.device}. " "Falling back to standard loading." ) else: from transformers import BitsAndBytesConfig logger.info("Using 4-bit quantization with bitsandbytes") model_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=dtype, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) hf_model = AutoModelForCausalLM.from_pretrained(self.model_id, **model_kwargs) self._tokenizer = AutoTokenizer.from_pretrained(self.model_id) self._model = outlines.from_transformers(hf_model, self._tokenizer, device_dtype=dtype) logger.info(f"Model '{self.model_id}' loaded successfully.") @property def tokenizer(self): """Access the tokenizer (loads model if needed).""" self._load_model() return self._tokenizer
[docs] def get_generator(self, output_schema): """Return an Outlines generator bound to a specific Pydantic schema. Args: output_schema: A Pydantic BaseModel class defining the output format. Returns: An ``outlines.Generator`` that produces JSON matching the schema. """ import outlines self._load_model() logger.debug(f"Creating Outlines Generator for schema {output_schema.__name__}") # Create generator - Outlines will use default sampling parameters # which helps avoid numerical instability base_generator = outlines.Generator(self._model, output_type=output_schema) # Wrap to add retry logic for probability tensor errors # Capture model_id and device from outer scope model_id = self.model_id device = self.device class GeneratorWithRetry: """Wrapper that retries generation with safer parameters on numerical errors.""" def __call__(self, prompt, max_new_tokens=4096, **kwargs): logger.debug( f"Generating with schema {output_schema.__name__}, max_tokens={max_new_tokens}" ) try: # First attempt with default parameters result = base_generator(prompt, max_new_tokens=max_new_tokens, **kwargs) return result except Exception as e: error_str = str(e).lower() # Check if it's a numerical stability error or FSM state error is_numerical_error = any( keyword in error_str for keyword in ["probability tensor", "inf", "nan", "element < 0"] ) is_fsm_error = "no next state found" in error_str if is_numerical_error or is_fsm_error: if is_fsm_error: logger.warning( f"FSM state error for schema {output_schema.__name__}: {e}. " f"This usually means the model's output doesn't match schema constraints. " f"Clearing cache and retrying with reduced token limit..." ) else: logger.warning( f"Numerical stability error detected: {e}. " f"This can happen with structured generation. " f"Clearing cache and retrying..." ) # Clear GPU cache which can help with numerical stability import torch if torch.backends.mps.is_available(): torch.mps.empty_cache() elif torch.cuda.is_available(): torch.cuda.empty_cache() # Retry with reduced token limit for FSM errors retry_tokens = max_new_tokens // 2 if is_fsm_error else max_new_tokens logger.warning( f"Retrying generation after cache clear (max_tokens: {retry_tokens})..." ) try: result = base_generator(prompt, max_new_tokens=retry_tokens, **kwargs) logger.info("Retry after cache clear succeeded") return result except Exception as retry_e: logger.error(f"Retry also failed: {retry_e}") # For FSM errors, provide more guidance if is_fsm_error: logger.error( f"Schema {output_schema.__name__} may be too restrictive " f"for model {model_id}. Consider simplifying the schema " f"or using a more capable model." ) # Re-raise the original error logger.error( f"Generation failed for {output_schema.__name__}: {e}", exc_info=True, ) raise return GeneratorWithRetry()
[docs] def switch_model(self, model_id: str, load_in_4bit: bool = False): """Clear cache and reconfigure for a different model.""" logger.info(f"Switching model: {self.model_id} -> {model_id}") self.clear_cache() self.model_id = model_id self.load_in_4bit = load_in_4bit
[docs] def clear_cache(self): """Clear the model from memory.""" import torch if self._model is not None: del self._model self._model = None if self._tokenizer is not None: del self._tokenizer self._tokenizer = None if self.device == "mps": torch.mps.empty_cache() elif self.device == "cuda": torch.cuda.empty_cache() logger.info("Model cache cleared")
# Module-level singleton _model_loader: Optional[PipelineModelLoader] = None
[docs] def get_model_loader( model_id: str = "Qwen/Qwen2.5-3B-Instruct", load_in_4bit: bool = False, ) -> PipelineModelLoader: """Get or create the shared PipelineModelLoader singleton.""" global _model_loader if ( _model_loader is None or _model_loader.model_id != model_id or _model_loader.load_in_4bit != load_in_4bit ): _model_loader = PipelineModelLoader(model_id=model_id, load_in_4bit=load_in_4bit) return _model_loader