#!/usr/bin/env python3
"""
Base model class for diffusion pipelines.
This module implements the Template Method pattern for diffusion model loading
and image generation. All concrete model implementations inherit from
:class:`BaseModel` and override specific hooks to customize behavior.
The key methods are:
- :meth:`BaseModel.load_pipeline` - Template method for loading (concrete)
- :meth:`BaseModel.generate` - Template method for generation (concrete)
- :meth:`BaseModel._create_pipeline` - Hook for pipeline creation (abstract)
Example usage::
from cw.lib.models import ModelFactory
# Create model from presets config
model = ModelFactory.create_model(model_config, model_path)
model.load_pipeline()
image, metadata = model.generate("a beautiful sunset")
"""
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Optional, Tuple
import torch
from PIL import Image
logger = logging.getLogger(__name__)
[docs]
class BaseModel(ABC):
"""
Abstract base class for all diffusion models.
This class implements the Template Method pattern, providing a common
framework for pipeline loading and image generation. Subclasses only
need to override :meth:`_create_pipeline` and optionally customize
behavior through hooks.
Template Methods (do not override):
- :meth:`load_pipeline`: Handles device setup, pipeline creation, and optimizations
- :meth:`generate`: Handles parameter resolution, prompt building, and generation
Required Abstract Method:
- :meth:`_create_pipeline`: Return the specific pipeline instance
Optional Hooks (override for customization):
- :meth:`_build_prompts`: Customize prompt processing
- :meth:`_build_pipeline_kwargs`: Add model-specific pipeline parameters
- :meth:`_apply_device_optimizations`: Custom device optimizations
- :meth:`_handle_special_prompt_requirements`: Special prompt handling
Configuration Flags (set in presets.json settings):
- ``force_default_guidance``: Always use default guidance_scale (turbo models)
- ``use_sequential_cpu_offload``: Use sequential vs model CPU offload
- ``enable_vae_slicing``: Enable VAE slicing for memory efficiency
- ``enable_debug_logging``: Enable verbose debug output
- ``max_sequence_length``: Maximum sequence length for text encoder
Attributes:
config: Model configuration dictionary from presets.json
model_path: Path to model weights (local or HuggingFace ID)
pipeline: The loaded diffusion pipeline (None until load_pipeline called)
device: Torch device (mps, cuda, or cpu)
current_lora: Currently loaded LoRA configuration (or None)
settings: Model settings from config
dtype: Torch dtype for model weights
"""
[docs]
def __init__(self, model_config: Dict, model_path: str):
"""
Initialize base model
Args:
model_config: Model configuration from presets.json
model_path: Full path to model file or HuggingFace ID
"""
self.config = model_config
self.model_path = model_path
self.pipeline = None
self.device = None
self.current_lora = None
# Extract settings
self.settings = model_config.get("settings", {})
self.default_steps = self.settings.get("steps", 20)
self.default_guidance = self.settings.get("guidance_scale", 7.5)
self.default_resolution = self.settings.get("resolution", 1024)
self.supports_negative_prompt = self.settings.get("supports_negative_prompt", False)
self.max_sequence_length = self.settings.get("max_sequence_length")
# Behavior flags (for refactored template methods)
self.force_default_guidance = self.settings.get("force_default_guidance", False)
self.use_sequential_cpu_offload = self.settings.get("use_sequential_cpu_offload", False)
self.enable_vae_slicing = self.settings.get("enable_vae_slicing", False)
self.enable_debug_logging = self.settings.get("enable_debug_logging", False)
# Scheduler configuration
self.default_scheduler = self.settings.get("scheduler")
self._original_scheduler_config = None # Store original for restoration
# Get dtype — FP8 variants can't be used as torch_dtype for loading,
# so fall back to bfloat16 (weights are upcast automatically).
dtype_str = self.settings.get("dtype", "bfloat16")
if dtype_str.startswith("float8"):
self.dtype = torch.bfloat16
else:
self.dtype = getattr(torch, dtype_str, torch.bfloat16)
[docs]
def load_pipeline(self, progress_callback=None) -> str:
"""
Template method for pipeline loading (common flow)
Subclasses should NOT override this method. Instead, override
_create_pipeline() to specify how to load the pipeline.
Args:
progress_callback: Optional callback for progress updates
Returns:
Status message
"""
if self.pipeline is not None:
logger.debug(f"Pipeline already loaded for {self.model_name}, skipping load_pipeline")
return "Model already loaded"
logger.info(f"Loading pipeline for {self.model_name}")
logger.debug(f"Model path: {self.model_path}")
logger.debug(
f"Model settings: steps={self.default_steps}, guidance={self.default_guidance}, dtype={self.dtype}"
)
try:
# Step 1: Device setup
if progress_callback:
progress_callback(0, desc="Setting up device...")
logger.debug("Step 1: Setting up device")
device, device_name = self.setup_device()
# Step 2: Load pipeline (model-specific)
if progress_callback:
progress_callback(0.3, desc=f"Loading {self.model_name} pipeline...")
logger.debug(f"Step 2: Creating pipeline via _create_pipeline()")
self.pipeline = self._create_pipeline()
logger.debug(f"Pipeline created: {type(self.pipeline).__name__}")
# Step 2.5: Store original scheduler config for potential restoration
if hasattr(self.pipeline, "scheduler") and hasattr(self.pipeline.scheduler, "config"):
self._original_scheduler_config = self.pipeline.scheduler.config
logger.debug(
f"Stored original scheduler config: {self.pipeline.scheduler.__class__.__name__}"
)
# Step 2.6: Apply default scheduler if configured
if self.default_scheduler:
if progress_callback:
progress_callback(0.5, desc=f"Setting scheduler: {self.default_scheduler}...")
logger.debug(f"Step 2.5: Setting default scheduler: {self.default_scheduler}")
self.set_scheduler(self.default_scheduler)
# Step 3: Apply optimizations
if progress_callback:
progress_callback(0.7, desc="Enabling optimizations...")
logger.debug("Step 3: Applying device optimizations")
self._apply_device_optimizations()
if progress_callback:
progress_callback(1.0, desc="Model loaded successfully!")
logger.info(f"Pipeline loaded successfully: {self.model_name} on {device_name}")
return f"{self.model_name} loaded on {device_name}"
except Exception as e:
logger.error(f"Error loading pipeline for {self.model_name}: {e}", exc_info=True)
return f"Error loading {self.model_name}: {e}"
[docs]
def generate(
self,
prompt: str,
negative_prompt: Optional[str] = None,
steps: Optional[int] = None,
guidance_scale: Optional[float] = None,
width: Optional[int] = None,
height: Optional[int] = None,
seed: Optional[int] = None,
clip_skip: Optional[int] = None,
scheduler: Optional[str] = None,
progress_callback=None,
**extra_params,
) -> Tuple[Image.Image, Dict]:
"""
Template method for image generation (common flow)
Subclasses should NOT override this method. Instead, override
the hooks (_build_prompts, _build_pipeline_kwargs, etc.) to
customize behavior.
Args:
prompt: Text prompt
negative_prompt: Negative prompt (if supported)
steps: Number of inference steps
guidance_scale: Guidance scale
width: Image width
height: Image height
seed: Random seed
clip_skip: Number of CLIP layers to skip (if supported)
scheduler: Scheduler class name to use (overrides default)
progress_callback: Optional callback for progress updates
**extra_params: Additional parameters passed through to hooks
(e.g., control_image, conditioning_scale for ControlNet)
Returns:
Tuple of (generated image, metadata dict)
"""
if self.pipeline is None:
raise RuntimeError("Pipeline not loaded")
logger.debug(f"generate() called for {self.model_name}")
logger.debug(
f"Input prompt: {prompt[:100]}..." if len(prompt) > 100 else f"Input prompt: {prompt}"
)
logger.debug(
f"Generation params: steps={steps}, guidance={guidance_scale}, size={width}x{height}, seed={seed}"
)
# Step 0: Apply scheduler override if provided
scheduler_applied = None
if scheduler:
logger.debug(f"Step 0: Applying scheduler override: {scheduler}")
scheduler_applied = self.set_scheduler(scheduler)
# Step 1: Apply parameter defaults and overrides
logger.debug("Step 1: Preparing generation parameters")
params = self._prepare_generation_params(
prompt, negative_prompt, steps, guidance_scale, width, height, seed, clip_skip
)
# Track which scheduler is active for metadata
params["scheduler"] = scheduler_applied or self._get_current_scheduler_name()
# Merge extra params (e.g., control_image, conditioning_scale, guidance_end)
# so they're available to _build_pipeline_kwargs hooks
params.update(extra_params)
logger.debug(
f"Resolved params: steps={params['steps']}, guidance={params['guidance_scale']}, scheduler={params['scheduler']}"
)
# Step 2: Build prompts with LoRA suffixes (hook for customization)
logger.debug("Step 2: Building prompts with LoRA suffixes")
params = self._build_prompts(params)
logger.debug(f"Final prompt length: {len(params['prompt'])} chars")
# Step 3: Build pipeline kwargs (hook for model-specific parameters)
logger.debug("Step 3: Building pipeline kwargs")
gen_kwargs = self._build_pipeline_kwargs(params, progress_callback)
logger.debug(f"Pipeline kwargs keys: {list(gen_kwargs.keys())}")
# Step 3.5: Pre-generation VAE check (for debugging black images on MPS)
if self.device.type == "mps" and hasattr(self.pipeline, "vae"):
logger.debug(
f"[MPS VAE check] device: {self.pipeline.vae.device}, dtype: {self.pipeline.vae.dtype}"
)
if self.pipeline.vae.dtype != torch.float32:
logger.warning(f"[MPS VAE check] VAE is NOT float32! Fixing now...")
self.pipeline.vae = self.pipeline.vae.to(dtype=torch.float32)
logger.debug(
f"[MPS VAE check] VAE after fix: {self.pipeline.vae.device}, dtype: {self.pipeline.vae.dtype}"
)
# Step 4: Generate image
logger.debug("Step 4: Calling pipeline for generation")
image = self.pipeline(**gen_kwargs).images[0]
logger.debug(f"Generation complete, image size: {image.size}")
# Step 5: Cleanup
logger.debug("Step 5: Clearing cache")
self.clear_cache()
# Step 6: Build metadata
logger.debug("Step 6: Building metadata")
metadata = self._build_metadata(params)
logger.info(
f"Image generated successfully: {params['width']}x{params['height']}, {params['steps']} steps"
)
return image, metadata
@abstractmethod
def _create_pipeline(self):
"""
Create and return the pipeline instance (model-specific)
This method must be overridden by subclasses to load their specific
pipeline type (e.g., FluxPipeline, StableDiffusionXLPipeline, etc.)
Returns:
Pipeline instance
"""
pass
def _prepare_generation_params(
self,
prompt: str,
negative_prompt: Optional[str],
steps: Optional[int],
guidance_scale: Optional[float],
width: Optional[int],
height: Optional[int],
seed: Optional[int],
clip_skip: Optional[int],
) -> Dict:
"""
Prepare generation parameters with defaults and overrides
Args:
prompt: Text prompt
negative_prompt: Negative prompt
steps: Number of inference steps
guidance_scale: Guidance scale
width: Image width
height: Image height
seed: Random seed
clip_skip: CLIP skip layers
Returns:
Dictionary of prepared parameters
"""
# Apply config defaults
steps = steps or self.default_steps
width = width or self.default_resolution
height = height or self.default_resolution
# Handle guidance_scale with model-specific override behavior
guidance_scale = self._resolve_guidance_scale(guidance_scale)
# Apply LoRA clip_skip override
effective_clip_skip = self.get_lora_clip_skip() or clip_skip
return {
"prompt": prompt,
"negative_prompt": negative_prompt,
"steps": steps,
"guidance_scale": guidance_scale,
"width": width,
"height": height,
"seed": seed,
"clip_skip": effective_clip_skip,
}
def _resolve_guidance_scale(self, guidance_scale: Optional[float]) -> float:
"""
Resolve final guidance scale with LoRA and model-specific overrides
Args:
guidance_scale: Guidance scale parameter (may be None)
Returns:
Resolved guidance scale value
"""
# Check if model forces guidance override (turbo models)
if self.force_default_guidance:
# Turbo models MUST use their default guidance - don't allow overrides
return self.default_guidance
# Use parameter or default
guidance_scale = guidance_scale if guidance_scale is not None else self.default_guidance
# LoRA takes precedence (but only if model doesn't force default)
lora_guidance = self.get_lora_guidance_scale()
if lora_guidance is not None:
guidance_scale = lora_guidance
return guidance_scale
def _build_prompts(self, params: Dict) -> Dict:
"""
Build final prompts with LoRA suffixes
Override this method for token limiting or special prompt handling.
Args:
params: Parameter dictionary from _prepare_generation_params()
Returns:
Updated parameter dictionary with modified prompts
"""
# Append LoRA prompt suffix
lora_suffix = self.get_lora_prompt_suffix()
if lora_suffix:
params["prompt"] = f"{params['prompt']}, {lora_suffix}"
# Append LoRA negative prompt suffix (if model supports it)
if self.supports_negative_prompt and params["negative_prompt"]:
lora_neg_suffix = self.get_lora_negative_prompt_suffix()
if lora_neg_suffix:
params["negative_prompt"] = f"{params['negative_prompt']}, {lora_neg_suffix}"
# Model-specific prompt handling (e.g., Qwen requires space for empty negative)
params = self._handle_special_prompt_requirements(params)
return params
def _handle_special_prompt_requirements(self, params: Dict) -> Dict:
"""
Handle model-specific prompt requirements
Override this method if your model has special requirements
(e.g., Qwen requires space for empty negative prompt)
Args:
params: Parameter dictionary
Returns:
Updated parameter dictionary
"""
return params
def _build_pipeline_kwargs(self, params: Dict, progress_callback) -> Dict:
"""
Build kwargs for pipeline call
Override this method for model-specific parameters
(e.g., max_sequence_length, true_cfg_scale, callbacks)
Args:
params: Parameter dictionary from _build_prompts()
progress_callback: Optional progress callback
Returns:
Dictionary of kwargs for pipeline call
"""
# Setup generator
generator = torch.Generator(device="cpu")
if params["seed"] is not None:
generator.manual_seed(params["seed"])
# Base kwargs
gen_kwargs = {
"num_inference_steps": params["steps"],
"guidance_scale": params["guidance_scale"],
"height": params["height"],
"width": params["width"],
"generator": generator,
}
# Base kwargs
gen_kwargs["prompt"] = params["prompt"]
# Add negative prompt if supported
if self.supports_negative_prompt and params.get("negative_prompt"):
gen_kwargs["negative_prompt"] = params["negative_prompt"]
# Add clip_skip if set
if params["clip_skip"] is not None:
gen_kwargs["clip_skip"] = params["clip_skip"]
# Add max_sequence_length if configured
if self.max_sequence_length is not None:
gen_kwargs["max_sequence_length"] = self.max_sequence_length
return gen_kwargs
def _build_metadata(self, params: Dict) -> Dict:
"""
Build metadata dictionary for generated image
Args:
params: Parameter dictionary
Returns:
Metadata dictionary
"""
metadata = {
"model": self.model_name,
"prompt": params["prompt"],
"steps": params["steps"],
"guidance_scale": params["guidance_scale"],
"width": params["width"],
"height": params["height"],
"seed": params["seed"],
"scheduler": params.get("scheduler"),
"lora": self.current_lora["label"] if self.current_lora else None,
}
# Add negative prompt if supported
if self.supports_negative_prompt:
metadata["negative_prompt"] = params.get("negative_prompt")
# Add max_sequence_length if used
if self.max_sequence_length is not None:
metadata["max_sequence_length"] = self.max_sequence_length
return metadata
def _apply_device_optimizations(self) -> None:
"""
Apply device-specific optimizations to pipeline
Override this method if your model needs custom optimizations
"""
device = self.device
logger.debug(f"Applying device optimizations for: {device.type}")
if device.type == "mps":
# For MPS, keep VAE on device but force float32 to avoid NaN values
# Moving VAE to CPU causes device mismatch errors without offload
if hasattr(self.pipeline, "vae"):
logger.debug("[MPS] Converting VAE to float32 (keeping on MPS)")
self.pipeline.vae = self.pipeline.vae.to(dtype=torch.float32)
# Enable VAE slicing for better memory usage and numerical stability
if hasattr(self.pipeline.vae, "enable_slicing"):
logger.debug("[MPS] Enabling VAE slicing")
self.pipeline.vae.enable_slicing()
if hasattr(self.pipeline.vae, "enable_tiling"):
logger.debug("[MPS] Enabling VAE tiling")
self.pipeline.vae.enable_tiling()
# Move entire pipeline to MPS
logger.debug("[MPS] Moving pipeline to MPS device")
self.pipeline.to(device)
logger.debug("[MPS] Enabling attention slicing")
self.pipeline.enable_attention_slicing()
# Ensure VAE is still float32 after pipeline.to()
if hasattr(self.pipeline, "vae"):
logger.debug("[MPS] Re-confirming VAE float32 after pipeline.to()")
self.pipeline.vae = self.pipeline.vae.to(dtype=torch.float32)
elif device.type == "cuda":
# Default to model_cpu_offload (override if needed)
if self.use_sequential_cpu_offload:
logger.debug("[CUDA] Enabling sequential CPU offload")
self.pipeline.enable_sequential_cpu_offload(device=device)
else:
logger.debug("[CUDA] Enabling model CPU offload")
self.pipeline.enable_model_cpu_offload()
logger.debug("[CUDA] Enabling attention slicing")
self.pipeline.enable_attention_slicing()
else:
logger.debug(f"[CPU] Moving pipeline to {device}")
self.pipeline = self.pipeline.to(device)
logger.debug("Device optimizations applied successfully")
[docs]
def set_scheduler(self, scheduler_name: str) -> Optional[str]:
"""
Set the scheduler for the pipeline by class name.
Args:
scheduler_name: Name of the scheduler class (e.g., 'EulerDiscreteScheduler',
'DPMSolverMultistepScheduler', 'FlowMatchEulerDiscreteScheduler')
Returns:
Name of the scheduler that was set, or None if failed
"""
if self.pipeline is None:
logger.debug("Cannot set scheduler: pipeline not loaded")
return None
if not hasattr(self.pipeline, "scheduler"):
logger.debug("Cannot set scheduler: pipeline has no scheduler attribute")
return None
logger.debug(f"Setting scheduler: {scheduler_name}")
try:
# Import diffusers schedulers dynamically
import diffusers
# Get the scheduler class by name
if not hasattr(diffusers, scheduler_name):
logger.warning(
f"Scheduler '{scheduler_name}' not found in diffusers, keeping current scheduler"
)
return None
scheduler_class = getattr(diffusers, scheduler_name)
# Create new scheduler from current scheduler's config
# This preserves model-specific scheduler parameters
logger.debug(f"Creating {scheduler_name} from existing config")
new_scheduler = scheduler_class.from_config(self.pipeline.scheduler.config)
self.pipeline.scheduler = new_scheduler
logger.info(f"Scheduler set to: {scheduler_name}")
return scheduler_name
except Exception as e:
logger.warning(f"Failed to set scheduler '{scheduler_name}': {e}")
return None
def _get_current_scheduler_name(self) -> Optional[str]:
"""Get the current scheduler's class name."""
if self.pipeline is None or not hasattr(self.pipeline, "scheduler"):
return None
return self.pipeline.scheduler.__class__.__name__
def _post_lora_load_fixes(self) -> None:
"""
Re-apply device-specific fixes after LoRA loading
LoRA loading can move pipeline components or change their dtypes.
Override this method to re-apply critical fixes (e.g., VAE float32 on MPS).
Default implementation does nothing.
"""
pass
def _post_lora_unload_fixes(self) -> None:
"""
Re-apply device-specific fixes after LoRA unloading
LoRA unloading can move pipeline components or change their dtypes.
Override this method to re-apply critical fixes (e.g., VAE float32 on MPS).
Default implementation does nothing.
"""
pass
[docs]
def setup_device(self) -> Tuple[torch.device, str]:
"""
Configure device for Apple Silicon optimization
Returns:
Tuple of (device, device_name)
"""
logger.debug("Detecting available compute device")
if torch.backends.mps.is_available():
device = torch.device("mps")
device_name = "MPS (Apple Silicon)"
elif torch.cuda.is_available():
device = torch.device("cuda")
device_name = "CUDA"
else:
device = torch.device("cpu")
device_name = "CPU"
self.device = device
logger.info(f"Device configured: {device_name}")
return device, device_name
[docs]
def load_lora(self, lora_path: str, lora_config: Dict) -> str:
"""
Load a LoRA adapter
Args:
lora_path: Path to LoRA file
lora_config: LoRA configuration from presets
Returns:
Status message
"""
if self.pipeline is None:
logger.warning("Attempted to load LoRA but pipeline not loaded")
return "Error: Model not loaded yet"
logger.info(f"Loading LoRA: {lora_config.get('label', 'unknown')}")
logger.debug(f"LoRA path: {lora_path}")
try:
# Unload previous LoRA if any
if self.current_lora is not None:
logger.debug(f"Unloading previous LoRA: {self.current_lora.get('label')}")
try:
self.pipeline.unload_lora_weights()
except Exception:
pass
# Load new LoRA
if not Path(lora_path).exists():
logger.error(f"LoRA file not found: {lora_path}")
return f"Error: LoRA file not found: {lora_path}"
# Get LoRA strength from settings
strength = lora_config.get("settings", {}).get("strength", 1.0)
logger.debug(f"LoRA strength: {strength}")
# Load LoRA weights
logger.debug("Loading LoRA weights into pipeline")
self.pipeline.load_lora_weights(lora_path, adapter_name="default")
# Set LoRA scale if applicable
if hasattr(self.pipeline, "set_adapters"):
logger.debug(f"Setting adapter weights: {strength}")
self.pipeline.set_adapters(["default"], adapter_weights=[strength])
# CRITICAL: Re-apply device-specific fixes after LoRA loading
# LoRA loading can move components or change dtypes
logger.debug("Applying post-LoRA load fixes")
self._post_lora_load_fixes()
self.current_lora = lora_config
logger.info(f"LoRA loaded successfully: {lora_config['label']} (strength: {strength})")
return f"LoRA loaded: {lora_config['label']} (strength: {strength})"
except Exception as e:
logger.error(f"Error loading LoRA: {e}", exc_info=True)
return f"Error loading LoRA: {e}"
[docs]
def unload_lora(self) -> str:
"""
Unload current LoRA
Returns:
Status message
"""
if self.pipeline is None:
logger.warning("Attempted to unload LoRA but pipeline not loaded")
return "Error: Model not loaded yet"
if self.current_lora is None:
logger.debug("No LoRA to unload")
return "No LoRA loaded"
lora_label = self.current_lora.get("label", "unknown")
logger.info(f"Unloading LoRA: {lora_label}")
try:
logger.debug("Calling pipeline.unload_lora_weights()")
self.pipeline.unload_lora_weights()
# CRITICAL: Re-apply device-specific fixes after LoRA unloading
# LoRA unloading can move components or change dtypes
logger.debug("Applying post-LoRA unload fixes")
self._post_lora_unload_fixes()
self.current_lora = None
logger.info(f"LoRA unloaded successfully: {lora_label}")
return "LoRA unloaded"
except Exception as e:
logger.warning(f"Exception during LoRA unload (resetting anyway): {e}")
self.current_lora = None
return "LoRA reset"
[docs]
def get_lora_prompt_suffix(self) -> str:
"""Get prompt suffix from current LoRA"""
if self.current_lora is None:
return ""
return self.current_lora.get("prompt", "")
[docs]
def get_lora_negative_prompt_suffix(self) -> str:
"""Get negative prompt suffix from current LoRA"""
if self.current_lora is None:
return ""
return self.current_lora.get("negative_prompt", "")
[docs]
def get_lora_clip_skip(self) -> Optional[int]:
"""Get clip_skip value from current LoRA"""
if self.current_lora is None:
return None
return self.current_lora.get("settings", {}).get("clip_skip")
[docs]
def get_lora_guidance_scale(self) -> Optional[float]:
"""Get guidance_scale value from current LoRA"""
if self.current_lora is None:
return None
return self.current_lora.get("settings", {}).get("guidance_scale")
[docs]
def clear_cache(self) -> None:
"""Clear GPU memory cache"""
if self.device is None:
return
logger.debug(f"Clearing {self.device.type} cache")
if self.device.type == "mps":
torch.mps.empty_cache()
elif self.device.type == "cuda":
torch.cuda.empty_cache()
[docs]
def is_loaded(self) -> bool:
"""Check if pipeline is loaded"""
return self.pipeline is not None
@property
def model_name(self) -> str:
"""Get model label"""
return self.config.get("label", "Unknown Model")
@property
def model_slug(self) -> str:
"""Get model slug"""
return self.config.get("slug", "unknown")