Source code for cw.lib.models.sd15

#!/usr/bin/env python3
"""
Stable Diffusion 1.5 model implementation
Supports full CFG guidance and negative prompts.
Handles CLIP 77-token limit by prioritizing LoRA trigger words.
"""

import logging

import torch
from diffusers import StableDiffusionPipeline

from .base import BaseModel
from .mixins import CLIPTokenLimitMixin

logger = logging.getLogger(__name__)


[docs] class SD15Model(CLIPTokenLimitMixin, BaseModel): """Stable Diffusion 1.5 implementation (Realistic Vision, etc.)""" def _create_pipeline(self): """Load SD 1.5 pipeline from HuggingFace""" return StableDiffusionPipeline.from_pretrained( self.model_path, torch_dtype=self.dtype, ) def _apply_device_optimizations(self) -> None: """Apply device optimizations with MPS VAE fix""" # Use base implementation (keeps VAE on MPS in float32) super()._apply_device_optimizations() def _post_lora_load_fixes(self) -> None: """Re-apply MPS VAE fix after LoRA loading""" # LoRA loading can change VAE dtype - restore float32 if self.device.type == "mps" and hasattr(self.pipeline, "vae"): logger.info(f"[SD15 MPS Fix] Re-applying VAE float32 after LoRA load") logger.info(f"[SD15 MPS Fix] VAE dtype before: {self.pipeline.vae.dtype}") self.pipeline.vae = self.pipeline.vae.to(dtype=torch.float32) logger.info(f"[SD15 MPS Fix] VAE dtype after: {self.pipeline.vae.dtype}") def _post_lora_unload_fixes(self) -> None: """Re-apply MPS VAE fix after LoRA unloading""" # LoRA unloading can change VAE dtype - restore float32 if self.device.type == "mps" and hasattr(self.pipeline, "vae"): logger.info(f"[SD15 MPS Fix] Re-applying VAE float32 after LoRA unload") logger.info(f"[SD15 MPS Fix] VAE dtype before: {self.pipeline.vae.dtype}") self.pipeline.vae = self.pipeline.vae.to(dtype=torch.float32) logger.info(f"[SD15 MPS Fix] VAE dtype after: {self.pipeline.vae.dtype}")