"""Node functions for the LangGraph adaptation pipeline.
Each node follows the signature ``(state: PipelineState) -> dict`` and returns
a partial state update. ORM imports are at function level to avoid circular
imports and to match the pattern used in ``cw.tvspots.tasks``.
"""
from __future__ import annotations
import json
import logging
import time
from typing import TYPE_CHECKING
from cw.lib.prompts import render_prompt
if TYPE_CHECKING:
from cw.lib.pipeline.schemas import PipelineState
logger = logging.getLogger(__name__)
MAX_FORMAT_RETRIES = 3
MAX_CULTURAL_RETRIES = 3
MAX_CONCEPT_RETRIES = 3
MAX_BRAND_RETRIES = 3
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_generator(state, output_schema, node_key=None):
"""Return an Outlines generator for *output_schema*, using per-node model if configured."""
from cw.lib.pipeline.model_loader import get_model_loader
# Resolve model: per-node config → fallback to state["model_id"]
if node_key and state.get("model_config", {}).get(node_key):
config = state["model_config"][node_key]
model_id = config["model_id"]
load_in_4bit = config.get("load_in_4bit", False)
else:
model_id = state["model_id"]
load_in_4bit = state.get("load_in_4bit", False)
logger.debug(
f"Getting generator for schema: {output_schema.__name__}",
extra={
"job_id": state.get("job_id"),
"model_id": model_id,
"node_key": node_key,
"load_in_4bit": load_in_4bit,
},
)
loader = get_model_loader(
model_id=model_id,
load_in_4bit=load_in_4bit,
)
return loader.get_generator(output_schema), loader
def _apply_chat_template(loader, system_message: str, user_prompt: str) -> str:
"""Format a system + user message pair using the loaded tokenizer's chat template."""
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_prompt},
]
return loader.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
)
def _update_job_status(job_id: int, status: str, **extra_fields):
"""Persist status and optional extra fields on the VideoAdUnit."""
from cw.tvspots.models import VideoAdUnit
update_fields = ["status"]
job = VideoAdUnit.objects.get(id=job_id)
job.status = status
for field, value in extra_fields.items():
setattr(job, field, value)
update_fields.append(field)
job.save(update_fields=update_fields)
# ---------------------------------------------------------------------------
# Concept extraction node
# ---------------------------------------------------------------------------
[docs]
def concept_node(state: PipelineState) -> dict:
"""Analyse the original script and produce a ConceptBrief."""
from cw.lib.pipeline.schemas import ConceptBrief
logger.info("Pipeline node: concept_extraction starting", extra={"job_id": state["job_id"]})
_update_job_status(state["job_id"], "concept_analysis")
start = time.time()
try:
logger.debug("Loading model and creating generator", extra={"job_id": state["job_id"]})
generator, loader = _get_generator(state, ConceptBrief, node_key="concept")
logger.debug("Rendering concept extraction prompt", extra={"job_id": state["job_id"]})
user_prompt = render_prompt(
"concept-extraction",
original_json=state["original_script"],
)
system_message = (
"You are an expert advertising analyst. Produce ONLY valid JSON "
"matching the requested schema — no commentary."
)
prompt = _apply_chat_template(loader, system_message, user_prompt)
logger.info("Generating concept brief with LLM", extra={"job_id": state["job_id"], "prompt_length": len(prompt)})
raw = generator(prompt, max_new_tokens=4096)
logger.debug("LLM generation complete, validating output", extra={"job_id": state["job_id"]})
result = ConceptBrief.model_validate(json.loads(raw) if isinstance(raw, str) else raw)
brief_json = result.model_dump_json()
logger.debug("Saving concept brief to database", extra={"job_id": state["job_id"]})
_update_job_status(state["job_id"], "concept_analysis", concept_brief=json.loads(brief_json))
elapsed = round(time.time() - start, 2)
logger.info(f"Pipeline node: concept_extraction done ({elapsed}s)", extra={"job_id": state["job_id"]})
return {"concept_brief": brief_json}
except Exception as e:
logger.error(
f"Concept extraction failed: {e}",
extra={"job_id": state["job_id"], "error": str(e)},
exc_info=True,
)
raise
# ---------------------------------------------------------------------------
# Cultural research node
# ---------------------------------------------------------------------------
[docs]
def culture_node(state: PipelineState) -> dict:
"""Produce a CulturalBrief for the target market."""
from cw.lib.pipeline.schemas import CulturalBrief
logger.info("Pipeline node: cultural_research starting", extra={"job_id": state["job_id"]})
_update_job_status(state["job_id"], "cultural_analysis")
start = time.time()
try:
logger.debug("Loading model and creating generator", extra={"job_id": state["job_id"]})
generator, loader = _get_generator(state, CulturalBrief, node_key="culture")
logger.debug("Rendering cultural research prompt", extra={"job_id": state["job_id"], "target": state["target_market_name"]})
user_prompt = render_prompt(
"cultural-research",
concept_brief_json=state["concept_brief"],
target_market_name=state["target_market_name"],
target_market_rules=state["target_market_rules"],
original_json=state["original_script"],
)
system_message = (
"You are a cultural research specialist. Produce ONLY valid JSON "
"matching the requested schema — no commentary."
)
prompt = _apply_chat_template(loader, system_message, user_prompt)
logger.info("Generating cultural brief with LLM", extra={"job_id": state["job_id"], "prompt_length": len(prompt)})
raw = generator(prompt, max_new_tokens=4096)
logger.debug("LLM generation complete, validating output", extra={"job_id": state["job_id"]})
result = CulturalBrief.model_validate(json.loads(raw) if isinstance(raw, str) else raw)
brief_json = result.model_dump_json()
logger.debug("Saving cultural brief to database", extra={"job_id": state["job_id"]})
_update_job_status(state["job_id"], "cultural_analysis", cultural_brief=json.loads(brief_json))
elapsed = round(time.time() - start, 2)
logger.info(f"Pipeline node: cultural_research done ({elapsed}s)", extra={"job_id": state["job_id"]})
return {"cultural_brief": brief_json}
except Exception as e:
logger.error(
f"Cultural research failed: {e}",
extra={"job_id": state["job_id"], "error": str(e)},
exc_info=True,
)
raise
# ---------------------------------------------------------------------------
# Writer node
# ---------------------------------------------------------------------------
[docs]
def writer_node(state: PipelineState) -> dict:
"""Generate (or revise) the adapted script."""
from cw.lib.adaptation import AdaptationOutput
format_count = state.get("format_revision_count", 0)
cultural_count = state.get("cultural_revision_count", 0)
concept_count = state.get("concept_revision_count", 0)
brand_count = state.get("brand_revision_count", 0)
total_revisions = format_count + cultural_count + concept_count + brand_count
is_revision = (
state.get("format_feedback") is not None
or state.get("cultural_feedback") is not None
or state.get("concept_feedback") is not None
or state.get("brand_feedback") is not None
)
status = "revising" if is_revision else "writing"
logger.info(
f"Pipeline node: writer starting (revision={is_revision}, total_revisions={total_revisions})",
extra={"job_id": state["job_id"]},
)
_update_job_status(state["job_id"], status)
start = time.time()
# Switch model after 2 failed revision attempts
if total_revisions >= 2:
from cw.lib.pipeline.state import get_alternative_model
from cw.lib.pipeline.model_loader import get_model_loader
alt = get_alternative_model(state.get("language_code", ""))
if alt:
loader = get_model_loader(state["model_id"], state.get("load_in_4bit", False))
logger.info(f"Switching to alternative model: {alt.model_id}")
loader.switch_model(alt.model_id, alt.load_in_4bit)
# Update state so subsequent nodes use new model
state_update_model = {"model_id": alt.model_id, "load_in_4bit": alt.load_in_4bit}
else:
state_update_model = {}
else:
state_update_model = {}
generator, loader = _get_generator(state | state_update_model, AdaptationOutput, node_key="writer")
# Build revision feedback from whichever evaluator failed
revision_feedback = state.get("format_feedback") or state.get("cultural_feedback") or state.get("concept_feedback") or state.get("brand_feedback") or None
user_prompt = render_prompt(
"adaptation",
target_market_name=state["target_market_name"],
target_market_language=state["target_market_language"],
target_market_rules=state["target_market_rules"],
target_market_code=state["target_market_code"],
original_json=state["original_script"],
num_script_rows=state["num_script_rows"],
creativity=0.7,
concept_brief=state.get("concept_brief"),
cultural_brief=state.get("cultural_brief"),
revision_feedback=revision_feedback,
)
language_code = state["target_market_language"]
system_message = (
"You are an expert advertising creative and localization strategist. "
"CRITICAL LANGUAGE RULE: You MUST write ALL descriptions, stage directions, "
f"and visual_text content in ENGLISH. Only spoken dialogue (VO) and on-screen text "
f"(supers, titles) should use {language_code}, and these MUST include "
"an English translation in parentheses. Never write scene descriptions in any "
"language other than English. Produce ONLY valid JSON — no commentary."
)
prompt = _apply_chat_template(loader, system_message, user_prompt)
logger.info(
f"Generating adapted script with LLM (revision={is_revision})",
extra={"job_id": state["job_id"], "prompt_length": len(prompt), "total_revisions": total_revisions},
)
raw = generator(prompt, max_new_tokens=4096)
logger.debug("LLM generation complete, validating output", extra={"job_id": state["job_id"]})
result = AdaptationOutput.model_validate(json.loads(raw) if isinstance(raw, str) else raw)
script_json = result.model_dump_json()
elapsed = round(time.time() - start, 2)
logger.info(f"Pipeline node: writer done ({elapsed}s)", extra={"job_id": state["job_id"]})
return {
"adapted_script": script_json,
# Clear feedback so evaluators start fresh on the new draft
"format_feedback": None,
"cultural_feedback": None,
"concept_feedback": None,
"brand_feedback": None,
**state_update_model,
}
# ---------------------------------------------------------------------------
# Format / language compliance evaluation node
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Cultural evaluation node
# ---------------------------------------------------------------------------
[docs]
def cultural_eval_node(state: PipelineState) -> dict:
"""Evaluate cultural compliance of the adapted script."""
from cw.lib.pipeline.schemas import EvaluationResult
logger.info("Pipeline node: cultural_eval starting", extra={"job_id": state["job_id"]})
_update_job_status(state["job_id"], "cultural_evaluation")
start = time.time()
try:
logger.debug("Loading model and creating generator", extra={"job_id": state["job_id"]})
generator, loader = _get_generator(state, EvaluationResult, node_key="culture_gate")
logger.debug("Rendering cultural evaluation prompt", extra={"job_id": state["job_id"]})
user_prompt = render_prompt(
"eval-cultural",
adapted_script_json=state["adapted_script"],
cultural_brief_json=state["cultural_brief"],
target_market_rules=state["target_market_rules"],
)
system_message = (
"You are a cultural compliance reviewer. Produce ONLY valid JSON "
"matching the requested schema — no commentary."
)
prompt = _apply_chat_template(loader, system_message, user_prompt)
logger.info("Evaluating cultural compliance with LLM", extra={"job_id": state["job_id"], "prompt_length": len(prompt)})
raw = generator(prompt, max_new_tokens=2048)
logger.debug("LLM evaluation complete, validating output", extra={"job_id": state["job_id"]})
result = EvaluationResult.model_validate(json.loads(raw) if isinstance(raw, str) else raw)
# Append to evaluation history
from cw.tvspots.models import VideoAdUnit
logger.debug("Updating evaluation history in database", extra={"job_id": state["job_id"]})
job = VideoAdUnit.objects.get(id=state["job_id"])
history = job.evaluation_history or []
history.append({"type": "cultural", **result.model_dump()})
job.evaluation_history = history
job.save(update_fields=["evaluation_history"])
elapsed = round(time.time() - start, 2)
logger.info(
f"Pipeline node: cultural_eval done ({elapsed}s, passed={result.passed})",
extra={"job_id": state["job_id"], "passed": result.passed},
)
except Exception as e:
logger.error(
f"Cultural evaluation failed: {e}",
extra={"job_id": state["job_id"], "error": str(e)},
exc_info=True,
)
raise
if result.passed:
return {"cultural_feedback": None}
else:
return {
"cultural_feedback": result.model_dump_json(),
"cultural_revision_count": state.get("cultural_revision_count", 0) + 1,
}
# ---------------------------------------------------------------------------
# Concept fidelity evaluation node
# ---------------------------------------------------------------------------
[docs]
def concept_eval_node(state: PipelineState) -> dict:
"""Evaluate concept fidelity of the adapted script."""
from cw.lib.pipeline.schemas import EvaluationResult
logger.info("Pipeline node: concept_eval starting", extra={"job_id": state["job_id"]})
_update_job_status(state["job_id"], "concept_evaluation")
start = time.time()
generator, loader = _get_generator(state, EvaluationResult, node_key="concept_gate")
user_prompt = render_prompt(
"eval-concept",
adapted_script_json=state["adapted_script"],
concept_brief_json=state["concept_brief"],
)
system_message = (
"You are a brand strategy reviewer. Produce ONLY valid JSON "
"matching the requested schema — no commentary."
)
prompt = _apply_chat_template(loader, system_message, user_prompt)
raw = generator(prompt, max_new_tokens=2048)
result = EvaluationResult.model_validate(json.loads(raw) if isinstance(raw, str) else raw)
# Append to evaluation history
from cw.tvspots.models import VideoAdUnit
job = VideoAdUnit.objects.get(id=state["job_id"])
history = job.evaluation_history or []
history.append({"type": "concept", **result.model_dump()})
job.evaluation_history = history
job.save(update_fields=["evaluation_history"])
elapsed = round(time.time() - start, 2)
logger.info(
f"Pipeline node: concept_eval done ({elapsed}s, passed={result.passed})",
extra={"job_id": state["job_id"]},
)
if result.passed:
return {"concept_feedback": None}
else:
return {
"concept_feedback": result.model_dump_json(),
"concept_revision_count": state.get("concept_revision_count", 0) + 1,
}
# ---------------------------------------------------------------------------
# Brand consistency evaluation node
# ---------------------------------------------------------------------------
[docs]
def brand_eval_node(state: PipelineState) -> dict:
"""Evaluate brand consistency of the adapted script."""
from cw.lib.pipeline.schemas import EvaluationResult
logger.info("Pipeline node: brand_eval starting", extra={"job_id": state["job_id"]})
_update_job_status(state["job_id"], "brand_evaluation")
start = time.time()
generator, loader = _get_generator(state, EvaluationResult, node_key="brand_gate")
user_prompt = render_prompt(
"eval-brand",
adapted_script_json=state["adapted_script"],
concept_brief_json=state["concept_brief"],
cultural_brief_json=state["cultural_brief"],
brand_guidelines=state.get("brand_guidelines", ""),
)
system_message = (
"You are a brand compliance reviewer. Produce ONLY valid JSON "
"matching the requested schema — no commentary."
)
prompt = _apply_chat_template(loader, system_message, user_prompt)
raw = generator(prompt, max_new_tokens=2048)
result = EvaluationResult.model_validate(json.loads(raw) if isinstance(raw, str) else raw)
# Append to evaluation history
from cw.tvspots.models import VideoAdUnit
job = VideoAdUnit.objects.get(id=state["job_id"])
history = job.evaluation_history or []
history.append({"type": "brand", **result.model_dump()})
job.evaluation_history = history
job.save(update_fields=["evaluation_history"])
elapsed = round(time.time() - start, 2)
logger.info(
f"Pipeline node: brand_eval done ({elapsed}s, passed={result.passed})",
extra={"job_id": state["job_id"]},
)
if result.passed:
return {"brand_feedback": None}
else:
return {
"brand_feedback": result.model_dump_json(),
"brand_revision_count": state.get("brand_revision_count", 0) + 1,
}