"""Calculator activities for Temporal workflows."""
import ast
import operator
import logging
from temporalio import activity
from nexus_mcp_calculator.service import (
CalculateRequest,
CalculateResponse,
AddRequest,
SubtractRequest,
MultiplyRequest,
DivideRequest,
PowerRequest,
SumListRequest,
BasicOperationResponse,
)
logger = logging.getLogger(__name__)
class SafeExpressionEvaluator:
"""Safe mathematical expression evaluator using AST parsing."""
# Allowed operators for safe evaluation
ALLOWED_OPERATORS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.Pow: operator.pow,
ast.Mod: operator.mod,
ast.USub: operator.neg,
ast.UAdd: operator.pos,
}
# Allowed functions
ALLOWED_FUNCTIONS = {
"abs": abs,
"round": round,
"max": max,
"min": min,
"sum": sum,
}
def evaluate(self, expression: str) -> float:
"""Safely evaluate a mathematical expression."""
try:
# Parse the expression into an AST
tree = ast.parse(expression.strip(), mode="eval")
return self._eval_node(tree.body)
except Exception as e:
raise ValueError(f"Invalid expression '{expression}': {e}")
def _eval_node(self, node: ast.AST) -> float:
"""Recursively evaluate an AST node."""
if isinstance(node, ast.Constant):
# Handle numeric constants
if isinstance(node.value, (int, float)):
return float(node.value)
else:
raise ValueError(f"Unsupported constant type: {type(node.value)}")
elif isinstance(node, ast.BinOp):
# Handle binary operations (e.g., +, -, *, /)
left = self._eval_node(node.left)
right = self._eval_node(node.right)
op = type(node.op)
if op not in self.ALLOWED_OPERATORS:
raise ValueError(f"Unsupported operator: {op.__name__}")
if op == ast.Div and right == 0:
raise ValueError("Division by zero")
return self.ALLOWED_OPERATORS[op](left, right)
elif isinstance(node, ast.UnaryOp):
# Handle unary operations (e.g., -x, +x)
operand = self._eval_node(node.operand)
op = type(node.op)
if op not in self.ALLOWED_OPERATORS:
raise ValueError(f"Unsupported unary operator: {op.__name__}")
return self.ALLOWED_OPERATORS[op](operand)
elif isinstance(node, ast.Call):
# Handle function calls
if not isinstance(node.func, ast.Name):
raise ValueError("Complex function calls not supported")
func_name = node.func.id
if func_name not in self.ALLOWED_FUNCTIONS:
raise ValueError(f"Function '{func_name}' not allowed")
# Evaluate arguments
args = [self._eval_node(arg) for arg in node.args]
try:
# Special handling for round function - second argument must be int
if func_name == "round" and len(args) > 1:
args[1] = int(args[1])
return float(self.ALLOWED_FUNCTIONS[func_name](*args))
except Exception as e:
raise ValueError(f"Error calling function '{func_name}': {e}")
else:
raise ValueError(f"Unsupported AST node type: {type(node)}")
# Initialize the evaluator at module level for reuse
_evaluator = SafeExpressionEvaluator()
@activity.defn
async def calculate_activity(input: CalculateRequest) -> CalculateResponse:
"""Activity to evaluate a mathematical expression and return the result.
Safely evaluates mathematical expressions using AST parsing.
Supports basic arithmetic operators (+, -, *, /, %, **) and
common functions (abs, round, max, min, sum).
"""
activity.logger.info(f"๐งฎ Calculator.calculate activity called with expression: '{input.expression}'")
try:
result = _evaluator.evaluate(input.expression)
response = CalculateResponse(result=result, expression=input.expression)
activity.logger.info(f"๐งฎ Calculator.calculate activity result: {result}")
return response
except ValueError as e:
activity.logger.error(f"๐งฎ Calculator.calculate activity error: {e}")
raise
@activity.defn
async def add_activity(input: AddRequest) -> BasicOperationResponse:
"""Activity to add two numbers together."""
activity.logger.info(f"โ Calculator.add activity called: {input.a} + {input.b}")
result = input.a + input.b
operation = f"{input.a} + {input.b} = {result}"
response = BasicOperationResponse(result=result, operation=operation)
activity.logger.info(f"โ Calculator.add activity result: {result}")
return response
@activity.defn
async def subtract_activity(input: SubtractRequest) -> BasicOperationResponse:
"""Activity to subtract the second number from the first."""
activity.logger.info(f"โ Calculator.subtract activity called: {input.a} - {input.b}")
result = input.a - input.b
operation = f"{input.a} - {input.b} = {result}"
response = BasicOperationResponse(result=result, operation=operation)
activity.logger.info(f"โ Calculator.subtract activity result: {result}")
return response
@activity.defn
async def multiply_activity(input: MultiplyRequest) -> BasicOperationResponse:
"""Activity to multiply two numbers together."""
activity.logger.info(f"โ๏ธ Calculator.multiply activity called: {input.a} * {input.b}")
result = input.a * input.b
operation = f"{input.a} * {input.b} = {result}"
response = BasicOperationResponse(result=result, operation=operation)
activity.logger.info(f"โ๏ธ Calculator.multiply activity result: {result}")
return response
@activity.defn
async def divide_activity(input: DivideRequest) -> BasicOperationResponse:
"""Activity to divide the first number by the second."""
activity.logger.info(f"โ Calculator.divide activity called: {input.a} / {input.b}")
if input.b == 0:
activity.logger.error("Division by zero is not allowed")
raise ValueError("Division by zero is not allowed")
result = input.a / input.b
operation = f"{input.a} / {input.b} = {result}"
response = BasicOperationResponse(result=result, operation=operation)
activity.logger.info(f"โ Calculator.divide activity result: {result}")
return response
@activity.defn
async def power_activity(input: PowerRequest) -> BasicOperationResponse:
"""Activity to raise the base to the power of the exponent."""
activity.logger.info(f"๐ข Calculator.power activity called: {input.base} ^ {input.exponent}")
try:
result = input.base ** input.exponent
operation = f"{input.base} ^ {input.exponent} = {result}"
response = BasicOperationResponse(result=result, operation=operation)
activity.logger.info(f"๐ข Calculator.power activity result: {result}")
return response
except OverflowError:
activity.logger.error("Result too large to compute")
raise ValueError("Result too large to compute")
@activity.defn
async def sum_list_activity(input: SumListRequest) -> BasicOperationResponse:
"""Activity to sum all numbers in the provided list."""
activity.logger.info(f"๐ Calculator.sum_list activity called with {len(input.numbers)} numbers")
if not input.numbers:
activity.logger.error("Cannot sum an empty list")
raise ValueError("Cannot sum an empty list")
result = sum(input.numbers)
numbers_str = ", ".join(str(n) for n in input.numbers)
operation = f"sum([{numbers_str}]) = {result}"
response = BasicOperationResponse(result=result, operation=operation)
activity.logger.info(f"๐ Calculator.sum_list activity result: {result}")
return response