Source code for cw.lib.models.mixins

#!/usr/bin/env python3
"""
Mixins for model implementations.

This module provides reusable behaviors that can be composed with
:class:`~cw.lib.models.base.BaseModel` via multiple inheritance.

Available Mixins:
    :class:`CompelPromptMixin`
        Long prompt handling (>77 tokens) and prompt weighting syntax
        ``(word:1.3)`` for CLIP-based models. Recommended for SDXL/SD15.

    :class:`CLIPTokenLimitMixin`
        Simple 77-token truncation with LoRA suffix preservation.
        Legacy mixin - prefer CompelPromptMixin for new models.

    :class:`DebugLoggingMixin`
        Conditional debug print statements via ``enable_debug_logging`` flag.

Usage Example::

    from cw.lib.models.base import BaseModel
    from cw.lib.models.mixins import CompelPromptMixin

    class MySDXLModel(CompelPromptMixin, BaseModel):
        def _create_pipeline(self):
            return StableDiffusionXLPipeline.from_pretrained(...)

Note:
    Mixins must be listed before BaseModel in the inheritance order
    (MRO) to properly override BaseModel methods.
"""

import logging
import re
from typing import Dict

logger = logging.getLogger(__name__)


[docs] class CLIPTokenLimitMixin: """ Mixin for models that need CLIP 77-token limit handling Used by SDXL, SDXLTurbo, and SD15 models to ensure prompts fit within CLIP's token limit while prioritizing LoRA trigger words. """ @staticmethod def _strip_a1111_lora_tags(text: str) -> str: """ Remove A1111/ComfyUI <lora:...> tags which are meaningless in diffusers. Args: text: Text containing potential LoRA tags Returns: Text with LoRA tags removed and cleaned up """ return re.sub(r"<lora:[^>]+>", "", text).strip().rstrip(",").strip() def _fit_prompt_to_token_limit(self, prompt: str, suffix: str) -> str: """ Build a prompt that fits within CLIP's 77-token limit. Prioritizes the LoRA suffix (trigger words), then fills remaining space with as much of the prompt as possible. Args: prompt: Base prompt text suffix: LoRA suffix to append (prioritized) Returns: Prompt that fits within token limit """ suffix = self._strip_a1111_lora_tags(suffix) tokenizer = self.pipeline.tokenizer max_content_tokens = tokenizer.model_max_length - 2 # Reserve for special tokens # If no suffix, just truncate prompt if needed if not suffix: tokens = tokenizer.encode(prompt, add_special_tokens=False) if len(tokens) <= max_content_tokens: return prompt return tokenizer.decode(tokens[:max_content_tokens], skip_special_tokens=True) # Build suffix with separator suffix_with_sep = f", {suffix}" suffix_tokens = tokenizer.encode(suffix_with_sep, add_special_tokens=False) prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False) # If both fit, return concatenated if len(prompt_tokens) + len(suffix_tokens) <= max_content_tokens: return f"{prompt}{suffix_with_sep}" # Calculate available space for prompt after reserving space for suffix available = max_content_tokens - len(suffix_tokens) if available <= 0: # Suffix alone exceeds limit, truncate suffix return tokenizer.decode(suffix_tokens[:max_content_tokens], skip_special_tokens=True) # Truncate prompt to fit available space truncated = tokenizer.decode(prompt_tokens[:available], skip_special_tokens=True) return f"{truncated}{suffix_with_sep}" def _build_prompts(self, params: Dict) -> Dict: """ Override to apply token limiting to prompts Args: params: Parameter dictionary Returns: Updated parameter dictionary with token-limited prompts """ # Apply token limiting to main prompt with LoRA suffix lora_suffix = self.get_lora_prompt_suffix() params["prompt"] = self._fit_prompt_to_token_limit(params["prompt"], lora_suffix) # Handle negative prompt if self.supports_negative_prompt and params["negative_prompt"]: lora_neg_suffix = self.get_lora_negative_prompt_suffix() if lora_neg_suffix: params["negative_prompt"] = f"{params['negative_prompt']}, {lora_neg_suffix}" return params
[docs] class CompelPromptMixin: """ Mixin for models that use Compel for long prompt handling and prompt weighting Compel enables: - Prompts longer than 77 tokens (breaks into chunks, concatenates embeddings) - Prompt weighting syntax: (word:1.3) for emphasis, (word:0.8) for de-emphasis - Proper handling of LoRA trigger words without truncation Used by SDXL, SDXLTurbo, and SD15 models as an alternative to CLIPTokenLimitMixin. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._compel = None @staticmethod def _strip_a1111_lora_tags(text: str) -> str: """ Remove A1111/ComfyUI <lora:...> tags which are meaningless in diffusers. Args: text: Text containing potential LoRA tags Returns: Text with LoRA tags removed and cleaned up """ return re.sub(r"<lora:[^>]+>", "", text).strip().rstrip(",").strip() def _get_compel(self): """ Lazy-initialize Compel instance for prompt encoding Handles both single and dual text encoder models (SDXL) Returns: Compel instance (CompelForSDXL or CompelForSD) """ if self._compel is None: logger.debug("Initializing Compel for prompt encoding") # Compel expects device as string ("mps", "cuda", "cpu") device_str = str(self.device.type) if hasattr(self.device, "type") else str(self.device) logger.debug(f"Compel device: {device_str}") # Check if model has dual text encoders (SDXL) has_text_encoder_2 = hasattr(self.pipeline, "text_encoder_2") and hasattr( self.pipeline, "tokenizer_2" ) logger.debug(f"Has dual text encoders (SDXL): {has_text_encoder_2}") # Compel wrappers don't work with CPU offloading # We need to ensure the text encoders are loaded before Compel accesses them # Check if text encoder is on meta device (from CPU offload) text_encoder_on_meta = ( hasattr(self.pipeline.text_encoder, "device") and str(self.pipeline.text_encoder.device) == "meta" ) if text_encoder_on_meta: logger.warning("Text encoders on meta device due to CPU offload") logger.debug("Compel requires text encoders on real device, moving pipeline") # Remove the offload hooks if hasattr(self.pipeline, "_all_hooks"): self.pipeline._all_hooks = [] # Move pipeline to device self.pipeline = self.pipeline.to(self.device) logger.debug(f"Pipeline moved to {self.device}") if has_text_encoder_2: # SDXL: Use CompelForSDXL wrapper from compel import CompelForSDXL logger.debug("Creating CompelForSDXL wrapper") self._compel = CompelForSDXL( pipe=self.pipeline, device=device_str, ) else: # SD 1.5: Use CompelForSD wrapper from compel import CompelForSD logger.debug("Creating CompelForSD wrapper") self._compel = CompelForSD( pipe=self.pipeline, device=device_str, ) logger.debug("Compel initialized successfully") return self._compel def _build_prompts(self, params: Dict) -> Dict: """ Override to use Compel for prompt encoding Converts text prompts to embeddings using Compel, which allows: - Long prompts (>77 tokens) - Prompt weighting syntax - Proper LoRA trigger word handling - Proper SDXL dual text encoder support Args: params: Parameter dictionary with 'prompt' and optional 'negative_prompt' Returns: Updated parameter dictionary with 'prompt_embeds' and 'negative_prompt_embeds' """ logger.debug("Building prompts with Compel") compel = self._get_compel() # Preserve original prompts for metadata (before encoding) original_prompt = params["prompt"] original_negative_prompt = params.get("negative_prompt") logger.debug(f"Original prompt length: {len(original_prompt)} chars") # Build final prompt with LoRA suffix lora_suffix = self.get_lora_prompt_suffix() lora_suffix = self._strip_a1111_lora_tags(lora_suffix) if lora_suffix: full_prompt = f"{original_prompt}, {lora_suffix}" logger.debug(f"Added LoRA suffix: {lora_suffix[:50]}...") else: full_prompt = original_prompt # Encode prompt to embeddings using Compel # CompelForSDXL/CompelForSD return LabelledConditioning objects logger.debug("Encoding prompt with Compel") conditioning = compel(full_prompt) # Extract embeddings from LabelledConditioning logger.debug(f"Conditioning type: {type(conditioning).__name__}") # Check if it's a LabelledConditioning object (from CompelFor* wrappers) if hasattr(conditioning, "embeds"): # CompelForSDXL or CompelForSD wrapper logger.debug("Using LabelledConditioning wrapper") params["prompt_embeds"] = conditioning.embeds logger.debug(f"prompt_embeds shape: {conditioning.embeds.shape}") # SDXL has pooled embeddings if hasattr(conditioning, "pooled_embeds") and conditioning.pooled_embeds is not None: params["pooled_prompt_embeds"] = conditioning.pooled_embeds logger.debug(f"pooled_prompt_embeds shape: {conditioning.pooled_embeds.shape}") elif isinstance(conditioning, tuple): # Legacy Compel returns (prompt_embeds, pooled_prompt_embeds) logger.debug("Legacy tuple format from Compel") prompt_embeds, pooled_prompt_embeds = conditioning params["prompt_embeds"] = prompt_embeds params["pooled_prompt_embeds"] = pooled_prompt_embeds else: # Direct tensor (shouldn't happen with current wrappers) logger.debug("Direct tensor format from Compel") params["prompt_embeds"] = conditioning # Handle negative prompt if self.supports_negative_prompt: # Build negative prompt with LoRA suffix lora_neg_suffix = self.get_lora_negative_prompt_suffix() if lora_neg_suffix: lora_neg_suffix = self._strip_a1111_lora_tags(lora_neg_suffix) full_negative = ( f"{original_negative_prompt}, {lora_neg_suffix}" if original_negative_prompt else lora_neg_suffix ) else: full_negative = original_negative_prompt if original_negative_prompt else "" # Encode negative prompt logger.debug("Encoding negative prompt with Compel") negative_conditioning = compel(full_negative) # Extract negative embeddings from LabelledConditioning if hasattr(negative_conditioning, "embeds"): # CompelForSDXL or CompelForSD wrapper params["negative_prompt_embeds"] = negative_conditioning.embeds logger.debug(f"negative_prompt_embeds shape: {negative_conditioning.embeds.shape}") # SDXL has pooled embeddings if ( hasattr(negative_conditioning, "pooled_embeds") and negative_conditioning.pooled_embeds is not None ): params["negative_pooled_prompt_embeds"] = negative_conditioning.pooled_embeds logger.debug( f"negative_pooled_prompt_embeds shape: {negative_conditioning.pooled_embeds.shape}" ) elif isinstance(negative_conditioning, tuple): # Legacy Compel negative_prompt_embeds, negative_pooled_prompt_embeds = negative_conditioning params["negative_prompt_embeds"] = negative_prompt_embeds params["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds else: # Direct tensor params["negative_prompt_embeds"] = negative_conditioning # Keep original text prompts in params for metadata # (Pipeline will use embeddings when both are present) params["prompt"] = full_prompt # Full prompt with LoRA suffix for metadata if self.supports_negative_prompt: if lora_neg_suffix and original_negative_prompt: params["negative_prompt"] = f"{original_negative_prompt}, {lora_neg_suffix}" elif original_negative_prompt: params["negative_prompt"] = original_negative_prompt elif lora_neg_suffix: params["negative_prompt"] = lora_neg_suffix logger.debug("Prompt encoding complete") return params
[docs] class DebugLoggingMixin: """ Mixin for models that need debug print statements Used by ZImageTurbo and SDXLTurbo for diagnostic output. """ def _debug_print(self, message: str) -> None: """ Log debug message if debug logging is enabled Args: message: Debug message to log """ if getattr(self, "enable_debug_logging", False): logger.debug(f"[{self.__class__.__name__}] {message}")