visibility_service.py•12.9 kB
import asyncio
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional
from sqlmodel import Session, select
from database import engine, ensure_user_exists
from models import User, TrackedQuery, Run, RunItem, Aggregate
from orchestration import visibility_workflow, VisibilityState
from parsing import auto_detect_brands
class VisibilityService:
def __init__(self):
self.default_platforms = ["azure_openai", "perplexity"]
async def add_tracking_query(self, phone: str, query: str, competitors: Optional[List[str]] = None) -> Dict[str, Any]:
"""Add or update a tracked query for the user"""
ensure_user_exists(phone)
with Session(engine) as session:
# Check if query already exists for this user
existing_query = session.exec(
select(TrackedQuery).where(
TrackedQuery.user_phone == phone,
TrackedQuery.query_text == query
)
).first()
if existing_query:
# Update competitors if provided
if competitors is not None:
existing_query.competitors = competitors
session.add(existing_query)
session.commit()
session.refresh(existing_query)
return {
"ok": True,
"query_id": existing_query.id
}
else:
# Create new query
new_query = TrackedQuery(
user_phone=phone,
query_text=query,
competitors=competitors or []
)
session.add(new_query)
session.commit()
session.refresh(new_query)
return {
"ok": True,
"query_id": new_query.id
}
def list_tracked_queries(self, phone: str) -> Dict[str, Any]:
"""List all tracked queries for the user"""
ensure_user_exists(phone)
with Session(engine) as session:
queries = session.exec(
select(TrackedQuery).where(TrackedQuery.user_phone == phone)
).all()
return {
"queries": [
{
"id": query.id,
"query_text": query.query_text,
"competitors": query.competitors,
"created_at": query.created_at.isoformat()
}
for query in queries
]
}
async def run_visibility_check(self, phone: str, query_id: int, platforms: Optional[List[str]] = None) -> Dict[str, Any]:
"""Run visibility check for a specific query"""
ensure_user_exists(phone)
with Session(engine) as session:
# Get the query
query = session.exec(
select(TrackedQuery).where(
TrackedQuery.id == query_id,
TrackedQuery.user_phone == phone
)
).first()
if not query:
return {"error": "query_not_found"}
# Use default platforms if none specified
platforms = platforms or self.default_platforms
# Get brand set (competitors or auto-detect)
brands = query.competitors
if not brands:
# For now, we'll use a placeholder. In a real implementation,
# we'd need to run a preliminary check to auto-detect brands
brands = ["Zoho", "HubSpot", "Salesforce"] # Placeholder
# Create visibility state
state = VisibilityState(
phone=phone,
query_text=query.query_text,
brands=brands,
platforms=platforms
)
# Run the workflow
result = await visibility_workflow.ainvoke(state)
# Save the run
run = Run(
user_phone=phone,
query_id=query_id,
platforms_summary=result["summary"]["platforms"]
)
session.add(run)
session.commit()
session.refresh(run)
# Save run items
for platform, parsed in result["parsed"].items():
if "error" not in parsed:
run_item = RunItem(
run_id=run.id,
platform=platform,
raw_answer=result["raw_results"][platform]["raw_answer"],
mentions=parsed["mentions"],
citations=parsed["citations"],
sentiment=parsed["sentiment"],
first_position_brand=parsed["first_position_brand"],
stats=parsed["stats"]
)
session.add(run_item)
session.commit()
return {
"ok": True,
"run_id": run.id,
"summary": result["summary"]
}
def fetch_visibility_report(self, phone: str, query_id: int, range_days: str = "7d") -> Dict[str, Any]:
"""Fetch visibility report for a query"""
ensure_user_exists(phone)
with Session(engine) as session:
# Get the query
query = session.exec(
select(TrackedQuery).where(
TrackedQuery.id == query_id,
TrackedQuery.user_phone == phone
)
).first()
if not query:
return {"error": "query_not_found"}
# Calculate date range
end_date = datetime.now()
if range_days == "7d":
start_date = end_date - timedelta(days=7)
elif range_days == "30d":
start_date = end_date - timedelta(days=30)
else: # "all"
start_date = datetime.min
# Get runs in date range
runs = session.exec(
select(Run).where(
Run.user_phone == phone,
Run.query_id == query_id,
Run.run_at >= start_date,
Run.run_at <= end_date
)
).all()
# Get run items for these runs
run_ids = [run.id for run in runs]
run_items = session.exec(
select(RunItem).where(RunItem.run_id.in_(run_ids))
).all()
# Aggregate data
all_mentions = []
all_citations = []
platform_status = []
first_position_data = []
for run_item in run_items:
all_mentions.extend(run_item.mentions)
all_citations.extend(run_item.citations)
platform_status.append({
"platform": run_item.platform,
"status": "ok",
"last_latency_ms": run_item.stats.get("latency_ms")
})
if run_item.first_position_brand:
first_position_data.append({
"date": run_item.created_at.strftime("%Y-%m-%d"),
"brand": run_item.first_position_brand,
"platform": run_item.platform
})
# Compute SoV
from parsing import compute_sov
sov_data = compute_sov(all_mentions)
# Aggregate citations
citation_counts = {}
for citation in all_citations:
domain = citation["domain"]
citation_counts[domain] = citation_counts.get(domain, 0) + citation["count"]
top_citation_domains = [
{"domain": domain, "count": count}
for domain, count in sorted(citation_counts.items(), key=lambda x: x[1], reverse=True)
]
# Aggregate sentiment
sentiment_data = {}
for run_item in run_items:
for brand, sentiment in run_item.sentiment.items():
if brand not in sentiment_data:
sentiment_data[brand] = {"positive": 0.0, "neutral": 0.0, "negative": 0.0}
sentiment_data[brand]["positive"] += sentiment["positive"]
sentiment_data[brand]["neutral"] += sentiment["neutral"]
sentiment_data[brand]["negative"] += sentiment["negative"]
# Average sentiment
for brand in sentiment_data:
count = len([ri for ri in run_items if brand in ri.sentiment])
if count > 0:
sentiment_data[brand]["positive"] /= count
sentiment_data[brand]["neutral"] /= count
sentiment_data[brand]["negative"] /= count
return {
"query": query.query_text,
"period": range_days,
"share_of_voice": sov_data,
"sentiment": [
{"brand": brand, **sentiment}
for brand, sentiment in sentiment_data.items()
],
"top_citation_domains": top_citation_domains,
"platform_status": platform_status,
"first_position_leaderboard": first_position_data,
"trends": {
"sov_over_time": [] # Would need more complex aggregation for trends
}
}
def get_platform_snapshot(self, phone: str, query_id: int, platform: str, date: Optional[str] = None) -> Dict[str, Any]:
"""Get a snapshot of results for a specific platform and date"""
ensure_user_exists(phone)
with Session(engine) as session:
# Get the query
query = session.exec(
select(TrackedQuery).where(
TrackedQuery.id == query_id,
TrackedQuery.user_phone == phone
)
).first()
if not query:
return {"error": "query_not_found"}
# Build query for run items
query_builder = select(RunItem).join(Run).where(
Run.user_phone == phone,
Run.query_id == query_id,
RunItem.platform == platform
)
if date:
# Parse date and filter
try:
target_date = datetime.strptime(date, "%Y-%m-%d")
next_date = target_date + timedelta(days=1)
query_builder = query_builder.where(
RunItem.created_at >= target_date,
RunItem.created_at < next_date
)
except ValueError:
return {"error": "invalid_date_format"}
# Get the most recent run item
run_item = session.exec(query_builder.order_by(RunItem.created_at.desc())).first()
if not run_item:
return {"error": "no_data_found"}
return {
"query": query.query_text,
"platform": platform,
"date": run_item.created_at.strftime("%Y-%m-%d"),
"raw_answer": run_item.raw_answer,
"mentions": run_item.mentions,
"citations": run_item.citations,
"sentiment_summary": run_item.sentiment
}
async def run_all_today(self, phone: str) -> Dict[str, Any]:
"""Run visibility check for all user queries on default platforms"""
ensure_user_exists(phone)
with Session(engine) as session:
queries = session.exec(
select(TrackedQuery).where(TrackedQuery.user_phone == phone)
).all()
run_ids = []
for query in queries:
result = await self.run_visibility_check(phone, query.id, self.default_platforms)
if result.get("ok"):
run_ids.append(result["run_id"])
return {
"ok": True,
"run_ids": run_ids
}
# Global service instance
visibility_service = VisibilityService()