Source code for cw.lib.civitai
"""
CivitAI integration for downloading LoRA files via AIR URN.
AIR format: urn:air:{ecosystem}:{type}:civitai:{modelId}@{versionId}
Download endpoint: GET https://civitai.com/api/download/models/{versionId}
Metadata endpoint: GET https://civitai.com/api/v1/model-versions/{versionId}
"""
import logging
import re
from pathlib import Path
from typing import Optional
import requests
logger = logging.getLogger(__name__)
[docs]
def parse_air(air_urn: str) -> tuple[str, str]:
"""
Extract model ID and version ID from a CivitAI AIR URN.
Args:
air_urn: e.g. "urn:air:zimageturbo:lora:civitai:2344335@2636956"
Returns:
(model_id, version_id) as strings
Raises:
ValueError: If the AIR URN cannot be parsed
"""
match = re.search(r"civitai:(\d+)@(\d+)", air_urn)
if not match:
raise ValueError(f"Cannot parse CivitAI AIR URN: {air_urn}")
return match.group(1), match.group(2)
[docs]
def fetch_model_version_metadata(version_id: str, api_key: Optional[str] = None) -> dict:
"""
Fetch metadata for a model version from CivitAI API.
Args:
version_id: CivitAI model version ID
api_key: Optional CivitAI API key for authentication
Returns:
Dictionary containing model version metadata
Raises:
RuntimeError: If API request fails
"""
url = f"https://civitai.com/api/v1/model-versions/{version_id}"
headers = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
logger.debug(f"Fetching metadata from CivitAI (version={version_id})")
response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status()
data = response.json()
logger.debug(f"Fetched metadata for '{data.get('name', 'Unknown')}'")
return data
[docs]
def extract_lora_metadata(metadata: dict) -> dict:
"""
Extract relevant LoRA fields from CivitAI model version metadata.
Args:
metadata: Raw metadata from CivitAI API
Returns:
Dictionary with extracted fields suitable for LoraModel creation
"""
# Map CivitAI baseModel to our base_architecture choices
base_model_map = {
"SD 1.5": "sd15",
"SD 2.1": "sd15", # Close enough
"SDXL 1.0": "sdxl",
"SDXL 0.9": "sdxl",
"SDXL Turbo": "sdxl",
"Pony": "sdxl", # Pony is SDXL-based
"Flux.1 D": "flux1",
"Flux.1 S": "flux1",
}
# Extract base architecture
base_model = metadata.get("baseModel", "")
base_architecture = base_model_map.get(base_model, "sdxl") # Default to SDXL
# Extract trigger words
trained_words = metadata.get("trainedWords", [])
prompt_suffix = ", ".join(trained_words) if trained_words else ""
# Try to get guidance scale from example images
guidance_scale = None
images = metadata.get("images", [])
if images and images[0]:
# Look at the first image's meta for cfgScale
first_image_meta = images[0].get("meta") or {}
cfg = first_image_meta.get("cfgScale")
if cfg:
guidance_scale = float(cfg)
# Extract negative prompt examples
negative_prompt_suffix = ""
if images and images[0]:
first_image_meta = images[0].get("meta") or {}
neg_prompt = first_image_meta.get("negativePrompt", "")
if neg_prompt:
negative_prompt_suffix = neg_prompt
# Build extracted metadata
extracted = {
"label": metadata.get("name", ""),
"base_architecture": base_architecture,
"prompt_suffix": prompt_suffix,
"negative_prompt_suffix": negative_prompt_suffix,
"notes": metadata.get("description", ""),
}
# Add guidance_scale only if we found it
if guidance_scale:
extracted["guidance_scale"] = guidance_scale
# Add stats for reference
stats = metadata.get("stats", {})
if stats:
stats_text = f"\n\nCivitAI Stats: {stats.get('downloadCount', 0):,} downloads, {stats.get('rating', 0):.1f}★ ({stats.get('ratingCount', 0)} ratings)"
extracted["notes"] = (extracted["notes"] or "") + stats_text
return extracted
[docs]
def download_lora(air_urn: str, dest_path: str, api_key: str) -> str:
"""
Download a LoRA file from CivitAI using its AIR URN.
Args:
air_urn: CivitAI AIR URN containing model/version IDs
dest_path: Local path where the file should be saved
api_key: CivitAI API key for authentication
Returns:
The dest_path string on success
Raises:
ValueError: If AIR cannot be parsed
RuntimeError: If download fails
"""
if not api_key:
raise RuntimeError("CIVITAI_API_KEY is not configured")
model_id, version_id = parse_air(air_urn)
logger.debug(f"Parsed AIR: model_id={model_id}, version_id={version_id}")
url = f"https://civitai.com/api/download/models/{version_id}"
headers = {"Authorization": f"Bearer {api_key}"}
logger.info(f"Downloading LoRA from CivitAI (model={model_id}, version={version_id})")
response = requests.get(url, headers=headers, stream=True, timeout=300)
response.raise_for_status()
logger.debug("Download request successful, streaming to file")
dest = Path(dest_path)
dest.parent.mkdir(parents=True, exist_ok=True)
# Stream to a temp file then rename for atomicity
tmp_path = dest.with_suffix(".tmp")
size = 0
try:
with open(tmp_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
size += len(chunk)
tmp_path.rename(dest)
except Exception:
logger.error(f"Download failed, removing temp file: {tmp_path}")
tmp_path.unlink(missing_ok=True)
raise
logger.info(f"Downloaded LoRA to {dest_path} ({size / 1024 / 1024:.1f} MB)")
return dest_path