Source code for cw.lib.models.qwen

#!/usr/bin/env python3
"""
Qwen-Image model implementation
Supports negative prompts and uses true_cfg_scale with 50 steps optimal
"""

import inspect
import logging
from typing import Dict

import torch
from diffusers import QwenImagePipeline

from .base import BaseModel

logger = logging.getLogger(__name__)


[docs] class QwenImageModel(BaseModel): """Qwen-Image-2512 implementation""" def _create_pipeline(self): """Load Qwen-Image pipeline with optional 8-bit quantization""" load_kwargs = { "torch_dtype": self.dtype, "low_cpu_mem_usage": True, "use_safetensors": True, } # Add 8-bit quantization if enabled if self.settings.get("load_in_8bit", False): load_kwargs["load_in_8bit"] = True load_kwargs["device_map"] = "auto" return QwenImagePipeline.from_pretrained(self.model_path, **load_kwargs) def _apply_device_optimizations(self) -> None: """Apply optimizations with VAE slicing and MPS VAE fix""" # Use base implementation (keeps VAE on MPS in float32) super()._apply_device_optimizations() # Enable VAE slicing for lower memory if hasattr(self.pipeline, "vae"): self.pipeline.vae.enable_slicing() def _handle_special_prompt_requirements(self, params: Dict) -> Dict: """Qwen requires space for empty negative prompt""" if params["negative_prompt"] is None or not params["negative_prompt"].strip(): params["negative_prompt"] = " " return params def _build_pipeline_kwargs(self, params: Dict, progress_callback) -> Dict: """Build kwargs with runtime parameter detection""" gen_kwargs = super()._build_pipeline_kwargs(params, progress_callback) # Detect true_cfg_scale vs guidance_scale at runtime sig = inspect.signature(self.pipeline.__call__) if "true_cfg_scale" in sig.parameters: gen_kwargs["true_cfg_scale"] = gen_kwargs.pop("guidance_scale") # Add progress callback if supported if progress_callback and "callback_on_step_end" in sig.parameters: gen_kwargs["callback_on_step_end"] = progress_callback if "callback_on_step_end_tensor_inputs" in sig.parameters: gen_kwargs["callback_on_step_end_tensor_inputs"] = ["latents"] logger.info(f"Starting inference with {params['steps']} steps") logger.info(f"Generation kwargs keys: {list(gen_kwargs.keys())}") return gen_kwargs def _post_lora_load_fixes(self) -> None: """Re-apply MPS VAE fix after LoRA loading""" if self.device.type == "mps" and hasattr(self.pipeline, "vae"): logger.info(f"[Qwen MPS Fix] Re-applying VAE float32 after LoRA load") logger.info(f"[Qwen MPS Fix] VAE dtype before: {self.pipeline.vae.dtype}") self.pipeline.vae = self.pipeline.vae.to(dtype=torch.float32) logger.info(f"[Qwen MPS Fix] VAE dtype after: {self.pipeline.vae.dtype}") def _post_lora_unload_fixes(self) -> None: """Re-apply MPS VAE fix after LoRA unloading""" if self.device.type == "mps" and hasattr(self.pipeline, "vae"): logger.info(f"[Qwen MPS Fix] Re-applying VAE float32 after LoRA unload") logger.info(f"[Qwen MPS Fix] VAE dtype before: {self.pipeline.vae.dtype}") self.pipeline.vae = self.pipeline.vae.to(dtype=torch.float32) logger.info(f"[Qwen MPS Fix] VAE dtype after: {self.pipeline.vae.dtype}")