"""
Django admin configuration for diffusion models.
Uses Django Unfold for tabs, display decorators, and styled actions.
"""
from pathlib import Path
from django.conf import settings
from django.contrib import admin, messages
from django.contrib.auth.admin import GroupAdmin as BaseGroupAdmin
from django.contrib.auth.admin import UserAdmin as BaseUserAdmin
from django.contrib.auth.models import Group, User
from django.http import JsonResponse
from django.shortcuts import redirect
from django.urls import path, reverse
from django.utils.html import format_html
from django.utils.safestring import mark_safe
from django.utils.translation import gettext_lazy as _
from django_celery_results.admin import GroupResultAdmin as BaseGroupResultAdmin
from django_celery_results.admin import TaskResultAdmin as BaseTaskResultAdmin
from django_celery_results.models import GroupResult, TaskResult
from huggingface_hub import scan_cache_dir
from unfold.admin import ModelAdmin, TabularInline
from unfold.decorators import action, display
from .models import (
BASE_ARCHITECTURE_CHOICES,
CONTROL_TYPE_CHOICES,
ControlNetModel,
DiffusionJob,
DiffusionModel,
LoraModel,
Prompt,
)
# ---------------------------------------------------------------------------
# Re-register third-party / auth models with Unfold
# ---------------------------------------------------------------------------
admin.site.unregister(User)
admin.site.unregister(Group)
admin.site.unregister(TaskResult)
admin.site.unregister(GroupResult)
[docs]
@admin.register(User)
class UserAdmin(BaseUserAdmin, ModelAdmin):
pass
[docs]
@admin.register(Group)
class GroupAdmin(BaseGroupAdmin, ModelAdmin):
pass
[docs]
@admin.register(TaskResult)
class TaskResultAdmin(BaseTaskResultAdmin, ModelAdmin):
pass
[docs]
@admin.register(GroupResult)
class GroupResultAdmin(BaseGroupResultAdmin, ModelAdmin):
pass
# ---------------------------------------------------------------------------
# Inlines
# ---------------------------------------------------------------------------
[docs]
class JobInline(TabularInline):
"""Inline display of jobs for Prompts."""
model = DiffusionJob
tab = True
extra = 0
fields = ["diffusion_model", "lora_model", "status", "num_images", "created_at"]
readonly_fields = ["status", "created_at"]
can_delete = False
show_change_link = True
[docs]
def has_add_permission(self, request, obj=None):
return False
# ---------------------------------------------------------------------------
# DiffusionModel
# ---------------------------------------------------------------------------
[docs]
@admin.register(DiffusionModel)
class DiffusionModelAdmin(ModelAdmin):
list_display = [
"label",
"base_architecture",
"show_scheduler",
"token_window",
"vram_in_gb",
"steps",
"show_resolution",
"show_negative_prompt",
"show_downloaded",
"show_active",
"show_loras_count",
]
list_filter = [
"is_active",
"base_architecture",
"pipeline",
"scheduler",
"supports_negative_prompt",
"dtype",
]
search_fields = ["label", "slug", "path"]
readonly_fields = ["created_at", "updated_at"]
fieldsets = (
(
_("Model"),
{
"classes": ["tab"],
"fields": (
"label",
("slug", "path"),
("pipeline", "base_architecture", "is_active"),
),
},
),
(
_("Generation"),
{
"classes": ["tab"],
"fields": (
("default_width", "default_height", "max_pixels"),
("steps", "guidance_scale", "scheduler"),
("dtype", "max_sequence_length", "supports_negative_prompt"),
("token_window", "vram_usage"),
),
},
),
(
_("Metadata"),
{
"classes": ["tab"],
"fields": ("created_at", "updated_at"),
},
),
)
[docs]
@display(description=_("Size"))
def show_resolution(self, obj):
return f"{obj.default_width}×{obj.default_height}"
show_resolution.short_description = "Resolution"
[docs]
@display(description=_("Neg Prompt"), boolean=True)
def show_negative_prompt(self, obj):
return obj.supports_negative_prompt
[docs]
@display(description=_("Downloaded"), boolean=True)
def show_downloaded(self, obj):
"""Check if model exists in the HuggingFace cache."""
try:
cache_info = scan_cache_dir()
return any(repo.repo_id == obj.path for repo in cache_info.repos)
except Exception:
return False
[docs]
@display(description=_("Active"), boolean=True)
def show_active(self, obj):
return obj.is_active
[docs]
@display(description=_("LoRAs"))
def show_loras_count(self, obj):
count = LoraModel.objects.filter(base_architecture=obj.base_architecture).count()
if count > 0:
url = reverse("admin:diffusion_loramodel_changelist")
return format_html(
'<a href="{}?base_architecture__exact={}">{}</a>',
url,
obj.base_architecture,
count,
)
return "0"
[docs]
@display(description=_("VRAM"))
def vram_in_gb(self, obj):
if obj.vram_usage:
return f"{int(obj.vram_usage/1024)} GB"
return "—"
[docs]
@display(description=_("Scheduler"))
def show_scheduler(self, obj):
if obj.scheduler:
# Shorten common scheduler names for display
name = obj.scheduler.replace("Scheduler", "").replace("Discrete", "")
return name
return "—"
# ---------------------------------------------------------------------------
# ControlNetModel
# ---------------------------------------------------------------------------
[docs]
@admin.register(ControlNetModel)
class ControlNetModelAdmin(ModelAdmin):
list_display = [
"label",
"control_type",
"base_architecture",
"show_conditioning_scale",
"show_downloaded",
"show_active",
"show_compatible_models",
]
list_filter = ["is_active", "base_architecture", "control_type"]
search_fields = ["label", "slug", "path"]
readonly_fields = ["created_at", "updated_at"]
fieldsets = (
(
_("ControlNet"),
{
"classes": ["tab"],
"fields": (
"label",
("slug", "path"),
("control_type", "base_architecture", "is_active"),
),
},
),
(
_("Conditioning"),
{
"classes": ["tab"],
"fields": (
("default_conditioning_scale", "default_guidance_end"),
),
},
),
(
_("Metadata"),
{
"classes": ["tab"],
"fields": ("created_at", "updated_at"),
},
),
)
[docs]
@display(description=_("Scale"))
def show_conditioning_scale(self, obj):
return f"{obj.default_conditioning_scale:.1f}"
[docs]
@display(description=_("Downloaded"), boolean=True)
def show_downloaded(self, obj):
"""Check if ControlNet model exists in the HuggingFace cache."""
try:
cache_info = scan_cache_dir()
return any(repo.repo_id == obj.path for repo in cache_info.repos)
except Exception:
return False
[docs]
@display(description=_("Active"), boolean=True)
def show_active(self, obj):
return obj.is_active
[docs]
@display(description=_("Compatible Models"))
def show_compatible_models(self, obj):
models = DiffusionModel.objects.filter(
base_architecture=obj.base_architecture, is_active=True
)
if models.exists():
return ", ".join(m.label for m in models[:3])
return "—"
# ---------------------------------------------------------------------------
# LoraModel Helper Functions
# ---------------------------------------------------------------------------
def _update_lora_fields_from_metadata(lora, extracted_metadata):
"""
Update LoRA model fields from extracted CivitAI metadata.
Args:
lora: LoraModel instance to update
extracted_metadata: Dict of extracted metadata from CivitAI
Returns:
bool: True if any fields were changed, False otherwise
"""
changed = False
field_updates = [
# (field_name, extracted_key, should_update_condition)
(
"label",
"label",
lambda: (not lora.label or lora.label.startswith("CivitAI Model"))
and extracted_metadata.get("label")
and extracted_metadata["label"] != lora.label,
),
(
"base_architecture",
"base_architecture",
lambda: "base_architecture" in extracted_metadata
and lora.base_architecture != extracted_metadata["base_architecture"],
),
(
"prompt_suffix",
"prompt_suffix",
lambda: "prompt_suffix" in extracted_metadata
and lora.prompt_suffix != extracted_metadata["prompt_suffix"],
),
(
"negative_prompt_suffix",
"negative_prompt_suffix",
lambda: "negative_prompt_suffix" in extracted_metadata
and lora.negative_prompt_suffix != extracted_metadata["negative_prompt_suffix"],
),
(
"guidance_scale",
"guidance_scale",
lambda: "guidance_scale" in extracted_metadata
and lora.guidance_scale != extracted_metadata.get("guidance_scale"),
),
]
# Process field updates
for field_name, extracted_key, should_update in field_updates:
if should_update():
setattr(lora, field_name, extracted_metadata[extracted_key])
changed = True
# Notes always updated if present (stats change frequently)
if "notes" in extracted_metadata:
lora.notes = extracted_metadata["notes"]
changed = True
return changed
def _build_refresh_result_message(updated_count, skipped_count, failed_count, error_messages):
"""
Build result message for metadata refresh operation.
Args:
updated_count: Number of LoRAs successfully updated
skipped_count: Number of LoRAs unchanged
failed_count: Number of LoRAs that failed
error_messages: List of error messages
Returns:
str: Formatted result message
"""
messages_list = []
if updated_count > 0:
messages_list.append(f"{updated_count} LoRA(s) updated")
if skipped_count > 0:
messages_list.append(f"{skipped_count} unchanged")
if failed_count > 0:
messages_list.append(f"{failed_count} failed")
result_message = f"Metadata refresh complete: {', '.join(messages_list)}."
# Append error details if there are failures
if failed_count > 0:
if len(error_messages) <= 5:
result_message += f" Errors: {'; '.join(error_messages[:5])}"
else:
result_message += (
f" Errors: {'; '.join(error_messages[:5])} (and {failed_count - 5} more)"
)
return result_message
# ---------------------------------------------------------------------------
# LoraModel Admin
# ---------------------------------------------------------------------------
[docs]
@admin.register(LoraModel)
class LoraModelAdmin(ModelAdmin):
list_display = ["label", "theme", "base_architecture", "show_downloaded", "show_active"]
list_filter = ["is_active", "base_architecture", "theme"]
search_fields = ["label", "path", "air", "theme"]
readonly_fields = ["created_at", "updated_at", "show_token_counts"]
actions = ["refresh_metadata_bulk_action"]
actions_list = ["import_from_civitai_action"]
actions_row = ["refresh_metadata_action"]
fieldsets = (
(
_("LoRA"),
{
"classes": ["tab"],
"fields": (
("label", "is_active"),
("path", "air"),
("base_architecture", "theme"),
),
},
),
(
_("Prompt & Settings"),
{
"classes": ["tab"],
"fields": (
("default_strength", "guidance_scale", "clip_skip"),
("prompt_suffix", "negative_prompt_suffix"),
"show_token_counts",
"notes",
),
},
),
(
_("Metadata"),
{
"classes": ["tab"],
"fields": ("created_at", "updated_at"),
},
),
)
[docs]
@display(description=_("CLIP Token Counts"))
def show_token_counts(self, obj):
try:
from transformers import CLIPTokenizer
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
max_tokens = 75 # 77 minus BOS/EOS
parts = []
if obj.prompt_suffix:
# Strip A1111 tags for accurate count
import re
clean = re.sub(r"<lora:[^>]+>", "", obj.prompt_suffix).strip().rstrip(",").strip()
count = len(tokenizer.encode(clean, add_special_tokens=False))
remaining = max_tokens - count
parts.append(f"Prompt suffix: {count}/75 tokens ({remaining} remaining)")
if obj.negative_prompt_suffix:
count = len(tokenizer.encode(obj.negative_prompt_suffix, add_special_tokens=False))
parts.append(f"Negative suffix: {count}/75 tokens")
return mark_safe("<br>".join(parts)) if parts else "No suffixes set"
except Exception as e:
return f"Error: {e}"
[docs]
@display(description=_("Compatible With"))
def show_compatible(self, obj):
models = DiffusionModel.objects.filter(base_architecture=obj.base_architecture)
if models.exists():
return ", ".join(m.label for m in models)
return "—"
[docs]
@display(description=_("Downloaded"), boolean=True)
def show_downloaded(self, obj):
"""Check if LoRA file exists on disk."""
import re
try:
if obj.path:
lora_path = Path(obj.path)
if not lora_path.is_absolute():
lora_path = settings.MODEL_BASE_PATH / obj.path
return lora_path.exists()
if obj.air:
match = re.search(r"@(\d+)$", obj.air)
if match:
version_id = match.group(1)
return (
settings.MODEL_BASE_PATH / "loras" / f"civitai_{version_id}.safetensors"
).exists()
return False
except Exception:
return False
[docs]
@display(description=_("Active"), boolean=True)
def show_active(self, obj):
return obj.is_active
# --- Row actions ---
@action(description=_("Refresh Metadata"))
def refresh_metadata_action(self, request, object_id):
"""Refresh metadata from CivitAI for a LoRA."""
return redirect("admin:refresh_lora_metadata", object_id)
# --- Bulk actions ---
@action(description=_("Refresh metadata from CivitAI"))
def refresh_metadata_bulk_action(self, request, queryset):
"""
Refresh metadata from CivitAI for selected LoRAs.
Refactored to reduce complexity by extracting:
- Field update logic to _update_lora_fields_from_metadata()
- Message building to _build_refresh_result_message()
"""
from cw.lib.civitai import (
extract_lora_metadata,
fetch_model_version_metadata,
parse_air,
)
# Filter for LoRAs with AIRs
loras_with_air = queryset.exclude(air="")
processable = loras_with_air.count()
if processable == 0:
self.message_user(
request,
"None of the selected LoRAs have CivitAI AIRs configured.",
level=messages.WARNING,
)
return
# Initialize counters
updated_count = 0
failed_count = 0
skipped_count = 0
error_messages = []
# Process each LoRA
for lora in loras_with_air:
try:
model_id, version_id = parse_air(lora.air)
raw_metadata = fetch_model_version_metadata(version_id, settings.CIVITAI_API_KEY)
extracted = extract_lora_metadata(raw_metadata)
# Update fields and track changes
changed = _update_lora_fields_from_metadata(lora, extracted)
if changed:
lora.save()
updated_count += 1
else:
skipped_count += 1
except Exception as e:
failed_count += 1
error_messages.append(f"{lora.label}: {str(e)}")
# Build and display result message
result_message = _build_refresh_result_message(
updated_count, skipped_count, failed_count, error_messages
)
# Determine message level
if updated_count > 0:
level = messages.SUCCESS
elif skipped_count > 0:
level = messages.INFO
else:
level = messages.WARNING
self.message_user(request, result_message, level=level)
[docs]
def get_urls(self):
"""Add custom URL for download action."""
urls = super().get_urls()
custom_urls = [
path(
"<int:lora_id>/download/",
self.admin_site.admin_view(self.download_lora_view),
name="download_lora",
),
path(
"<int:lora_id>/refresh-metadata/",
self.admin_site.admin_view(self.refresh_metadata_view),
name="refresh_lora_metadata",
),
path(
"import-from-civitai/",
self.admin_site.admin_view(self.import_from_civitai_view),
name="import_lora_from_civitai",
),
]
return custom_urls + urls
[docs]
def download_lora_view(self, request, lora_id):
"""Handle the download action by queuing a Celery task."""
from .tasks import download_lora_task
lora = LoraModel.objects.get(pk=lora_id)
if not lora.air:
messages.error(request, f'LoRA "{lora.label}" has no CivitAI AIR configured.')
return redirect("admin:diffusion_loramodel_change", lora_id)
# Queue the download task
download_lora_task.apply_async(args=[lora_id], queue="default")
messages.info(
request,
f'Download started for "{lora.label}" in background. '
f"Check logs or refresh page to verify completion.",
)
return redirect("admin:diffusion_loramodel_change", lora_id)
[docs]
def import_from_civitai_view(self, request):
"""Handle importing a LoRA from CivitAI by AIR."""
from django.template.response import TemplateResponse
from .tasks import download_lora_task
# Handle form submission
if request.method == "POST":
air = request.POST.get("air", "").strip()
label = request.POST.get("label", "").strip()
base_architecture = request.POST.get("base_architecture", "sdxl")
prompt_suffix = request.POST.get("prompt_suffix", "").strip()
negative_prompt_suffix = request.POST.get("negative_prompt_suffix", "").strip()
notes = request.POST.get("notes", "").strip()
guidance_scale_str = request.POST.get("guidance_scale", "").strip()
# Validate AIR format
if not air:
messages.error(request, "AIR is required.")
return redirect("admin:import_lora_from_civitai")
# Parse AIR to validate format
try:
from cw.lib.civitai import parse_air
model_id, version_id = parse_air(air)
except ValueError as e:
messages.error(request, f"Invalid AIR format: {e}")
return redirect("admin:import_lora_from_civitai")
# Auto-generate label if not provided
if not label:
label = f"CivitAI Model {model_id} v{version_id}"
# Parse guidance_scale if provided
guidance_scale = None
if guidance_scale_str:
try:
guidance_scale = float(guidance_scale_str)
except ValueError:
pass
# Create the LoRA model with all fields
lora = LoraModel.objects.create(
label=label,
air=air,
base_architecture=base_architecture,
prompt_suffix=prompt_suffix,
negative_prompt_suffix=negative_prompt_suffix,
notes=notes,
guidance_scale=guidance_scale,
is_active=True,
)
# Queue the download task
download_lora_task.apply_async(args=[lora.id], queue="default")
messages.success(
request,
f'LoRA "{label}" created and download queued. '
f"Check the LoRA details page to verify completion.",
)
return redirect("admin:diffusion_loramodel_change", lora.id)
# Handle GET request - optionally fetch metadata if AIR provided
initial_data = {
"air": request.GET.get("air", ""),
"label": "",
"base_architecture": "sdxl",
"prompt_suffix": "",
"negative_prompt_suffix": "",
"notes": "",
"guidance_scale": "",
}
metadata_fetched = False
fetch_error = None
# If AIR is provided in query params, fetch metadata
if initial_data["air"]:
try:
from cw.lib.civitai import (
extract_lora_metadata,
fetch_model_version_metadata,
parse_air,
)
model_id, version_id = parse_air(initial_data["air"])
# Fetch metadata from CivitAI API
raw_metadata = fetch_model_version_metadata(version_id, settings.CIVITAI_API_KEY)
extracted = extract_lora_metadata(raw_metadata)
# Update initial data with extracted metadata
initial_data.update(extracted)
metadata_fetched = True
messages.info(
request,
f"Metadata fetched from CivitAI. Review and edit fields below before importing.",
)
except Exception as e:
fetch_error = str(e)
messages.warning(
request,
f"Could not fetch metadata from CivitAI: {e}. You can still import manually.",
)
# Render the form
return TemplateResponse(
request,
"admin/diffusion/loramodel/import_from_civitai.html",
{
**self.admin_site.each_context(request),
"title": _("Import LoRA from CivitAI"),
"opts": self.model._meta,
"base_architecture_choices": BASE_ARCHITECTURE_CHOICES,
"initial_data": initial_data,
"metadata_fetched": metadata_fetched,
"fetch_error": fetch_error,
},
)
@action(
description=_("Import from CivitAI"),
url_path="import-from-civitai-action",
)
def import_from_civitai_action(self, request):
"""Redirect to the import from CivitAI view."""
return redirect("admin:import_lora_from_civitai")
# ---------------------------------------------------------------------------
# Prompt
# ---------------------------------------------------------------------------
[docs]
class PromptStatusFilter(admin.SimpleListFilter):
title = _("Status")
parameter_name = "is_enhanced"
[docs]
def lookups(self, request, model_admin):
return [
("enhanced", _("Enhanced")),
("pending", _("Pending")),
]
[docs]
def queryset(self, request, queryset):
if self.value() == "enhanced":
return queryset.exclude(enhanced_prompt="")
if self.value() == "pending":
return queryset.filter(enhanced_prompt="")
return queryset
[docs]
@admin.register(Prompt)
class PromptAdmin(ModelAdmin):
list_display = [
"show_preview",
"enhancement_style",
"show_status",
"show_jobs",
"created_at",
]
list_filter = ["enhancement_method", "enhancement_style", PromptStatusFilter, "created_at"]
search_fields = ["source_prompt", "enhanced_prompt"]
readonly_fields = ["created_at", "updated_at", "enhancement_method"]
actions = ["enhance_prompts_action", "create_job_for_prompts"]
actions_row = ["enhance_single_action", "create_job_action"]
inlines = [JobInline]
fieldsets = (
(
_("Prompt"),
{
"classes": ["tab"],
"fields": ("source_prompt",),
},
),
(
_("Enhancement"),
{
"classes": ["tab"],
"fields": (
("enhanced_prompt", "negative_prompt"),
("enhancement_method", "enhancement_style", "creativity"),
),
},
),
(
_("Metadata"),
{
"classes": ["tab"],
"fields": ("created_at", "updated_at"),
},
),
)
[docs]
@display(description=_("Prompt"))
def show_preview(self, obj):
text = obj.source_prompt[:80]
if len(obj.source_prompt) > 80:
text += "…"
return text
[docs]
@display(
description=_("Status"),
label={
"Enhanced": "success",
"Pending": "warning",
},
)
def show_status(self, obj):
return "Enhanced" if obj.enhanced_prompt else "Pending"
[docs]
@display(description=_("Jobs"))
def show_jobs(self, obj):
count = obj.jobs.count()
if count > 0:
url = reverse("admin:diffusion_diffusionjob_changelist")
return format_html(
'<a href="{}?prompt__id__exact={}">{}</a>',
url,
obj.id,
count,
)
return "0"
# --- Row actions ---
@action(description=_("Enhance"))
def enhance_single_action(self, request, object_id):
from . import tasks
prompt = Prompt.objects.get(id=object_id)
if prompt.enhanced_prompt:
messages.warning(request, "Prompt is already enhanced.")
else:
tasks.enhance_prompt_task.apply_async(args=[object_id], queue="default")
messages.success(request, "Prompt queued for enhancement.")
return redirect("admin:diffusion_prompt_changelist")
@action(description=_("Create Job"))
def create_job_action(self, request, object_id):
return redirect(f"{reverse('admin:diffusion_diffusionjob_add')}?prompt={object_id}")
# --- Bulk actions ---
@action(description=_("Enhance selected prompts"))
def enhance_prompts_action(self, request, queryset):
from . import tasks
count = 0
for prompt in queryset:
if not prompt.enhanced_prompt:
tasks.enhance_prompt_task.apply_async(args=[prompt.id], queue="default")
count += 1
self.message_user(request, f"{count} prompts queued for enhancement.")
@action(description=_("Create jobs for selected prompts"))
def create_job_for_prompts(self, request, queryset):
from django.template.response import TemplateResponse
# Second pass: user submitted the confirmation form
if request.POST.get("post") == "yes":
from . import tasks
model_id = request.POST.get("diffusion_model")
lora_id = request.POST.get("lora_model") or None
num_images = int(request.POST.get("num_images", 1))
model = DiffusionModel.objects.get(id=model_id)
lora = LoraModel.objects.get(id=lora_id) if lora_id else None
count = 0
for prompt in queryset:
job = DiffusionJob.objects.create(
prompt=prompt,
diffusion_model=model,
lora_model=lora,
num_images=num_images,
)
celery_task = tasks.generate_images_task.apply_async(args=[job.id], queue="default")
job.rq_job_id = celery_task.id
job.status = "queued"
job.save()
count += 1
lora_label = f" with {lora.label}" if lora else ""
self.message_user(
request,
f"{count} jobs created and queued using {model.label}{lora_label}.",
)
return
# First pass: render confirmation page with model/LoRA selection
models = DiffusionModel.objects.filter(is_active=True)
loras = LoraModel.objects.filter(is_active=True)
return TemplateResponse(
request,
"admin/diffusion/prompt/create_jobs_confirmation.html",
{
**self.admin_site.each_context(request),
"title": _("Create jobs for selected prompts"),
"prompts": queryset,
"models": models,
"loras": loras,
"opts": self.model._meta,
},
)
# ---------------------------------------------------------------------------
# DiffusionJob
# ---------------------------------------------------------------------------
[docs]
@admin.register(DiffusionJob)
class DiffusionJobAdmin(ModelAdmin):
change_form_template = "admin/diffusion/diffusionjob/change_form.html"
[docs]
def get_urls(self):
urls = super().get_urls()
custom = [
path(
"compatible-loras/",
self.admin_site.admin_view(self.compatible_loras_view),
name="diffusion_diffusionjob_compatible_loras",
),
]
return custom + urls
[docs]
def compatible_loras_view(self, request):
model_id = request.GET.get("model_id")
if not model_id:
return JsonResponse({"lora_ids": []})
try:
model = DiffusionModel.objects.get(id=model_id)
except DiffusionModel.DoesNotExist:
return JsonResponse({"lora_ids": []})
lora_ids = list(
LoraModel.objects.filter(
is_active=True, base_architecture=model.base_architecture
).values_list("id", flat=True)
)
return JsonResponse({"lora_ids": lora_ids})
list_display = [
"show_id",
"identifier",
"show_status",
"diffusion_model",
"lora_model",
"show_scheduler",
"show_prompt",
"num_images",
"created_at",
"show_duration",
]
list_filter = ["status", "diffusion_model", "lora_model", "scheduler", "created_at"]
search_fields = ["rq_job_id", "prompt__source_prompt"]
readonly_fields = [
"rq_job_id",
"status",
"created_at",
"started_at",
"completed_at",
"show_result_images",
"show_generation_metadata",
"show_duration",
]
actions = ["queue_jobs_action", "cancel_jobs_action", "retry_failed_jobs"]
actions_row = ["queue_single_action", "retry_single_action", "cancel_single_action"]
fieldsets = (
(
_("Configuration"),
{
"classes": ["tab"],
"fields": ("diffusion_model", "lora_model", "prompt", "identifier"),
},
),
(
_("Parameters"),
{
"classes": ["tab"],
"fields": (
("width", "height"),
("steps", "guidance_scale"),
("scheduler", "lora_strength"),
("seed", "num_images"),
),
"description": _("Leave blank to use model/LoRA defaults."),
},
),
(
_("ControlNet"),
{
"classes": ["tab"],
"fields": (
"controlnet_model",
"reference_image",
"preprocessing_type",
("conditioning_scale", "control_guidance_end"),
),
"description": _(
"Optional: attach a ControlNet and reference image for "
"structure-guided generation (wireframe storyboards, etc.)."
),
},
),
(
_("Status"),
{
"classes": ["tab"],
"fields": (
"status",
"rq_job_id",
"error_message",
),
},
),
(
_("Results"),
{
"classes": ["tab"],
"fields": ("show_result_images", "show_generation_metadata"),
},
),
(
_("Timing"),
{
"classes": ["tab"],
"fields": ("created_at", "started_at", "completed_at", "show_duration"),
},
),
)
[docs]
@display(description=_("ID"))
def show_id(self, obj):
return f"#{obj.pk}"
[docs]
@display(
description=_("Status"),
label={
"Pending": "info",
"Queued": "info",
"Processing": "warning",
"Completed": "success",
"Failed": "danger",
"Cancelled": "info",
},
)
def show_status(self, obj):
return obj.get_status_display()
[docs]
@display(description=_("Prompt"))
def show_prompt(self, obj):
text = obj.prompt.source_prompt[:50]
if len(obj.prompt.source_prompt) > 50:
text += "…"
return text
[docs]
@display(description=_("Duration"))
def show_duration(self, obj):
if obj.started_at and obj.completed_at:
delta = obj.completed_at - obj.started_at
seconds = int(delta.total_seconds())
if seconds < 60:
return f"{seconds}s"
elif seconds < 3600:
return f"{seconds // 60}m {seconds % 60}s"
return f"{seconds // 3600}h {(seconds % 3600) // 60}m"
return "—"
[docs]
@display(description=_("Scheduler"))
def show_scheduler(self, obj):
# Show job override or model default
scheduler = obj.scheduler or (
obj.diffusion_model.scheduler if obj.diffusion_model else None
)
if scheduler:
# Shorten common scheduler names for display
name = scheduler.replace("Scheduler", "").replace("Discrete", "")
# Indicate if it's an override vs model default
if obj.scheduler:
return format_html('<span title="Job override">{}</span>', name)
return format_html('<span title="Model default" style="opacity: 0.7">{}</span>', name)
return "—"
[docs]
@display(description=_("Generated Images"))
def show_result_images(self, obj):
if not obj.result_images:
return "No images"
import os
from django.conf import settings
html_parts = []
for img_path in obj.result_images:
media_url = settings.MEDIA_URL
if img_path.startswith(str(settings.MEDIA_ROOT)):
rel_path = os.path.relpath(img_path, settings.MEDIA_ROOT)
img_url = os.path.join(media_url, rel_path)
else:
img_url = os.path.join(media_url, img_path)
html_parts.append(
f'<a href="{img_url}" target="_blank">'
f'<img src="{img_url}" style="max-width: 200px; max-height: 200px; margin: 5px;" />'
f"</a>"
)
return mark_safe("".join(html_parts))
# --- Row actions ---
@action(description=_("Queue"))
def queue_single_action(self, request, object_id):
from . import tasks
job = DiffusionJob.objects.get(id=object_id)
if job.status != "pending":
messages.warning(request, f"Job #{object_id} is not pending (status: {job.status}).")
else:
celery_task = tasks.generate_images_task.apply_async(args=[object_id], queue="default")
job.rq_job_id = celery_task.id
job.status = "queued"
job.save()
messages.success(request, f"Job #{object_id} queued for processing.")
return redirect("admin:diffusion_diffusionjob_changelist")
@action(description=_("Retry"))
def retry_single_action(self, request, object_id):
from . import tasks
job = DiffusionJob.objects.get(id=object_id)
if job.status != "failed":
messages.warning(request, f"Job #{object_id} is not failed (status: {job.status}).")
else:
job.status = "pending"
job.error_message = ""
job.rq_job_id = ""
job.save()
celery_task = tasks.generate_images_task.apply_async(args=[object_id], queue="default")
job.rq_job_id = celery_task.id
job.status = "queued"
job.save()
messages.success(request, f"Job #{object_id} queued for retry.")
return redirect("admin:diffusion_diffusionjob_changelist")
@action(description=_("Cancel"))
def cancel_single_action(self, request, object_id):
job = DiffusionJob.objects.get(id=object_id)
if job.status not in ["pending", "queued"]:
messages.warning(
request, f"Job #{object_id} cannot be cancelled (status: {job.status})."
)
else:
job.status = "cancelled"
job.save()
messages.success(request, f"Job #{object_id} cancelled.")
return redirect("admin:diffusion_diffusionjob_changelist")
# --- Bulk actions ---
@action(description=_("Queue selected jobs"))
def queue_jobs_action(self, request, queryset):
from . import tasks
count = 0
for job in queryset.filter(status="pending"):
celery_task = tasks.generate_images_task.apply_async(args=[job.id], queue="default")
job.rq_job_id = celery_task.id
job.status = "queued"
job.save()
count += 1
self.message_user(request, f"{count} jobs queued for processing.")
@action(description=_("Cancel selected jobs"))
def cancel_jobs_action(self, request, queryset):
count = queryset.filter(status__in=["pending", "queued"]).update(status="cancelled")
self.message_user(request, f"{count} jobs cancelled.")
@action(description=_("Retry failed jobs"))
def retry_failed_jobs(self, request, queryset):
from . import tasks
count = 0
for job in queryset.filter(status="failed"):
job.status = "pending"
job.error_message = ""
job.rq_job_id = ""
job.save()
celery_task = tasks.generate_images_task.apply_async(args=[job.id], queue="default")
job.rq_job_id = celery_task.id
job.status = "queued"
job.save()
count += 1
self.message_user(request, f"{count} failed jobs queued for retry.")
[docs]
def save_model(self, request, obj, form, change):
"""Auto-queue new jobs on save."""
is_new = obj.pk is None
super().save_model(request, obj, form, change)
if is_new and obj.status == "pending":
from . import tasks
celery_task = tasks.generate_images_task.apply_async(args=[obj.id], queue="default")
obj.rq_job_id = celery_task.id
obj.status = "queued"
obj.save()