#!/usr/bin/env python3
"""
MCP CLI Command Server - Executes whitelisted network and system commands
Provides safe, controlled access to common CLI tools via MCP protocol
"""
import os
import re
import subprocess
import logging
from typing import Any, Dict, List, Optional
from flask import Flask, request, jsonify
from mcp.server import Server
from mcp.types import Tool, TextContent
import json
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Flask app
app = Flask(__name__)
# MCP Server instance
mcp_server = Server("mcp-cli")
# Security: Whitelisted commands with allowed flags and validation patterns
ALLOWED_COMMANDS = {
'ping': {
'binary': 'ping',
'allowed_flags': ['-c', '-W', '-i'],
'max_timeout': 30,
'description': 'Ping a host to check connectivity'
},
'nmap_ping_scan': {
'binary': 'nmap',
'allowed_flags': ['-sn'],
'max_timeout': 300,
'description': 'NMAP ping scan for host discovery'
},
'nmap_port_scan': {
'binary': 'nmap',
'allowed_flags': ['-p', '-sV', '-sT', '-Pn'],
'max_timeout': 600,
'description': 'NMAP port scan on specified hosts'
},
'dig': {
'binary': 'dig',
'allowed_flags': ['+short', '+trace', '+nssearch'],
'max_timeout': 30,
'description': 'DNS lookup using dig'
},
'curl_get': {
'binary': 'curl',
'allowed_flags': ['-I', '-L', '-s', '-m'],
'max_timeout': 30,
'description': 'HTTP GET request using curl'
},
'traceroute': {
'binary': 'traceroute',
'allowed_flags': ['-m', '-w'],
'max_timeout': 60,
'description': 'Trace network route to host'
},
'whois': {
'binary': 'whois',
'allowed_flags': [],
'max_timeout': 30,
'description': 'WHOIS lookup for domain or IP'
},
'host': {
'binary': 'host',
'allowed_flags': ['-t'],
'max_timeout': 10,
'description': 'DNS host lookup'
},
'netcat_test': {
'binary': 'nc',
'allowed_flags': ['-z', '-v', '-w'],
'max_timeout': 10,
'description': 'Test port connectivity with netcat'
},
'mtr': {
'binary': 'mtr',
'allowed_flags': ['-r', '-c', '-n'],
'max_timeout': 60,
'description': 'Network diagnostic with MTR'
}
}
# Validation patterns
IP_PATTERN = re.compile(r'^(\d{1,3}\.){3}\d{1,3}$')
CIDR_PATTERN = re.compile(r'^(\d{1,3}\.){3}\d{1,3}/\d{1,2}$')
HOSTNAME_PATTERN = re.compile(r'^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$')
PORT_PATTERN = re.compile(r'^\d{1,5}$')
PORT_RANGE_PATTERN = re.compile(r'^\d{1,5}-\d{1,5}$')
URL_PATTERN = re.compile(r'^https?://[a-zA-Z0-9][a-zA-Z0-9-\.]*[a-zA-Z0-9](:\d{1,5})?(/.*)?$')
SHELL_METACHAR_PATTERN = re.compile(r'[;&|`$()<>\\]')
def validate_ip(ip: str) -> bool:
"""Validate IP address format"""
if not IP_PATTERN.match(ip):
return False
parts = ip.split('.')
return all(0 <= int(part) <= 255 for part in parts)
def validate_cidr(cidr: str) -> bool:
"""Validate CIDR notation"""
if not CIDR_PATTERN.match(cidr):
return False
ip, prefix = cidr.split('/')
return validate_ip(ip) and 0 <= int(prefix) <= 32
def validate_hostname(hostname: str) -> bool:
"""Validate hostname format"""
if len(hostname) > 255:
return False
return HOSTNAME_PATTERN.match(hostname) is not None
def validate_port(port: str) -> bool:
"""Validate port number"""
if not PORT_PATTERN.match(port):
return False
return 1 <= int(port) <= 65535
def validate_port_range(port_range: str) -> bool:
"""Validate port range"""
if not PORT_RANGE_PATTERN.match(port_range):
return False
start, end = port_range.split('-')
return validate_port(start) and validate_port(end) and int(start) <= int(end)
def validate_url(url: str) -> bool:
"""Validate URL format"""
return URL_PATTERN.match(url) is not None
def check_shell_metacharacters(text: str) -> bool:
"""Check for dangerous shell metacharacters"""
return SHELL_METACHAR_PATTERN.search(text) is None
def execute_cli_command(command_name: str, args: List[str], timeout: Optional[int] = None) -> Dict[str, Any]:
"""
Execute a whitelisted CLI command with security validation
Args:
command_name: Name of the command from ALLOWED_COMMANDS
args: List of arguments for the command
timeout: Optional timeout override (cannot exceed max_timeout)
Returns:
Dict with success, stdout, stderr, returncode
"""
if command_name not in ALLOWED_COMMANDS:
return {
'success': False,
'stdout': '',
'stderr': f'Command not allowed: {command_name}',
'returncode': 1
}
cmd_config = ALLOWED_COMMANDS[command_name]
binary = cmd_config['binary']
max_timeout = cmd_config['max_timeout']
# Validate timeout
if timeout is None:
timeout = max_timeout
else:
timeout = min(timeout, max_timeout)
# Validate all arguments for shell metacharacters
for arg in args:
if not check_shell_metacharacters(str(arg)):
return {
'success': False,
'stdout': '',
'stderr': f'Invalid characters in argument: {arg}',
'returncode': 1
}
# Build command
cmd = [binary] + args
logger.info(f"Executing command: {' '.join(cmd)} (timeout: {timeout}s)")
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=timeout,
check=False
)
logger.info(f"Command completed with return code: {result.returncode}")
return {
'success': result.returncode == 0,
'stdout': result.stdout.strip(),
'stderr': result.stderr.strip(),
'returncode': result.returncode
}
except subprocess.TimeoutExpired:
logger.warning(f"Command timed out after {timeout}s")
return {
'success': False,
'stdout': '',
'stderr': f'Command timed out after {timeout}s',
'returncode': 124
}
except Exception as e:
logger.error(f"Command execution failed: {str(e)}")
return {
'success': False,
'stdout': '',
'stderr': f'Execution error: {str(e)}',
'returncode': 1
}
# MCP Tool Definitions
@mcp_server.list_tools()
async def list_tools() -> list[Tool]:
"""List all available CLI tools"""
return [
Tool(
name="cli_ping",
description="Ping a host to check connectivity and measure latency",
inputSchema={
"type": "object",
"properties": {
"host": {
"type": "string",
"description": "IP address or hostname to ping"
},
"count": {
"type": "integer",
"description": "Number of ping packets (default: 4, max: 10)",
"default": 4
},
"timeout": {
"type": "integer",
"description": "Timeout in seconds (default: 5, max: 30)",
"default": 5
}
},
"required": ["host"]
}
),
Tool(
name="cli_nmap_ping_scan",
description="NMAP ping scan (host discovery) on network range",
inputSchema={
"type": "object",
"properties": {
"target": {
"type": "string",
"description": "IP address, hostname, or CIDR range (e.g., 192.168.1.0/24)"
}
},
"required": ["target"]
}
),
Tool(
name="cli_nmap_port_scan",
description="NMAP port scan on specified host and ports",
inputSchema={
"type": "object",
"properties": {
"host": {
"type": "string",
"description": "IP address or hostname to scan"
},
"ports": {
"type": "string",
"description": "Port or port range (e.g., '80', '1-1000', '22,80,443')"
},
"service_version": {
"type": "boolean",
"description": "Detect service versions (-sV flag)",
"default": False
}
},
"required": ["host", "ports"]
}
),
Tool(
name="cli_dig",
description="DNS lookup using dig command",
inputSchema={
"type": "object",
"properties": {
"domain": {
"type": "string",
"description": "Domain name to query"
},
"record_type": {
"type": "string",
"description": "DNS record type (A, AAAA, MX, TXT, NS, CNAME, etc.)",
"default": "A"
},
"short": {
"type": "boolean",
"description": "Short output format (+short flag)",
"default": False
}
},
"required": ["domain"]
}
),
Tool(
name="cli_curl_get",
description="HTTP GET request using curl",
inputSchema={
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "URL to fetch (http:// or https://)"
},
"headers_only": {
"type": "boolean",
"description": "Fetch headers only (-I flag)",
"default": False
},
"follow_redirects": {
"type": "boolean",
"description": "Follow redirects (-L flag)",
"default": True
},
"timeout": {
"type": "integer",
"description": "Timeout in seconds (max: 30)",
"default": 10
}
},
"required": ["url"]
}
),
Tool(
name="cli_traceroute",
description="Trace network route to host",
inputSchema={
"type": "object",
"properties": {
"host": {
"type": "string",
"description": "IP address or hostname"
},
"max_hops": {
"type": "integer",
"description": "Maximum number of hops (default: 30)",
"default": 30
}
},
"required": ["host"]
}
),
Tool(
name="cli_whois",
description="WHOIS lookup for domain or IP address",
inputSchema={
"type": "object",
"properties": {
"target": {
"type": "string",
"description": "Domain name or IP address"
}
},
"required": ["target"]
}
),
Tool(
name="cli_host",
description="DNS host lookup command",
inputSchema={
"type": "object",
"properties": {
"hostname": {
"type": "string",
"description": "Hostname or IP to query"
},
"record_type": {
"type": "string",
"description": "DNS record type (A, AAAA, MX, TXT, etc.)",
"default": "A"
}
},
"required": ["hostname"]
}
),
Tool(
name="cli_netcat_test",
description="Test port connectivity using netcat",
inputSchema={
"type": "object",
"properties": {
"host": {
"type": "string",
"description": "IP address or hostname"
},
"port": {
"type": "integer",
"description": "Port number to test"
},
"timeout": {
"type": "integer",
"description": "Timeout in seconds (default: 5, max: 10)",
"default": 5
}
},
"required": ["host", "port"]
}
),
Tool(
name="cli_mtr",
description="Network diagnostic combining ping and traceroute (MTR)",
inputSchema={
"type": "object",
"properties": {
"host": {
"type": "string",
"description": "IP address or hostname"
},
"cycles": {
"type": "integer",
"description": "Number of pings per hop (default: 10, max: 20)",
"default": 10
}
},
"required": ["host"]
}
)
]
@mcp_server.call_tool()
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
"""Execute a CLI tool"""
logger.info(f"Tool call: {name} with arguments: {arguments}")
try:
if name == "cli_ping":
host = arguments['host']
count = min(arguments.get('count', 4), 10)
timeout = min(arguments.get('timeout', 5), 30)
# Validate host
if not (validate_ip(host) or validate_hostname(host)):
return [TextContent(
type="text",
text=f"Error: Invalid host format: {host}"
)]
args = ['-c', str(count), '-W', str(timeout), host]
result = execute_cli_command('ping', args, timeout=30)
elif name == "cli_nmap_ping_scan":
target = arguments['target']
# Validate target
if not (validate_ip(target) or validate_cidr(target) or validate_hostname(target)):
return [TextContent(
type="text",
text=f"Error: Invalid target format: {target}"
)]
args = ['-sn', target]
result = execute_cli_command('nmap_ping_scan', args)
elif name == "cli_nmap_port_scan":
host = arguments['host']
ports = arguments['ports']
service_version = arguments.get('service_version', False)
# Validate host
if not (validate_ip(host) or validate_hostname(host)):
return [TextContent(
type="text",
text=f"Error: Invalid host format: {host}"
)]
# Validate ports (can be single port, range, or comma-separated)
if not check_shell_metacharacters(ports):
return [TextContent(
type="text",
text=f"Error: Invalid characters in ports: {ports}"
)]
args = ['-p', ports, '-Pn']
if service_version:
args.append('-sV')
args.append(host)
result = execute_cli_command('nmap_port_scan', args)
elif name == "cli_dig":
domain = arguments['domain']
record_type = arguments.get('record_type', 'A').upper()
short = arguments.get('short', False)
# Validate domain
if not validate_hostname(domain):
return [TextContent(
type="text",
text=f"Error: Invalid domain format: {domain}"
)]
args = [domain, record_type]
if short:
args.insert(0, '+short')
result = execute_cli_command('dig', args)
elif name == "cli_curl_get":
url = arguments['url']
headers_only = arguments.get('headers_only', False)
follow_redirects = arguments.get('follow_redirects', True)
timeout = min(arguments.get('timeout', 10), 30)
# Validate URL
if not validate_url(url):
return [TextContent(
type="text",
text=f"Error: Invalid URL format: {url}"
)]
args = ['-m', str(timeout), '-s']
if headers_only:
args.append('-I')
if follow_redirects:
args.append('-L')
args.append(url)
result = execute_cli_command('curl_get', args)
elif name == "cli_traceroute":
host = arguments['host']
max_hops = min(arguments.get('max_hops', 30), 64)
# Validate host
if not (validate_ip(host) or validate_hostname(host)):
return [TextContent(
type="text",
text=f"Error: Invalid host format: {host}"
)]
args = ['-m', str(max_hops), host]
result = execute_cli_command('traceroute', args)
elif name == "cli_whois":
target = arguments['target']
# Validate target
if not (validate_ip(target) or validate_hostname(target)):
return [TextContent(
type="text",
text=f"Error: Invalid target format: {target}"
)]
args = [target]
result = execute_cli_command('whois', args)
elif name == "cli_host":
hostname = arguments['hostname']
record_type = arguments.get('record_type', 'A').upper()
# Validate hostname
if not (validate_ip(hostname) or validate_hostname(hostname)):
return [TextContent(
type="text",
text=f"Error: Invalid hostname format: {hostname}"
)]
args = ['-t', record_type, hostname]
result = execute_cli_command('host', args)
elif name == "cli_netcat_test":
host = arguments['host']
port = arguments['port']
timeout = min(arguments.get('timeout', 5), 10)
# Validate host and port
if not (validate_ip(host) or validate_hostname(host)):
return [TextContent(
type="text",
text=f"Error: Invalid host format: {host}"
)]
if not (1 <= port <= 65535):
return [TextContent(
type="text",
text=f"Error: Invalid port number: {port}"
)]
args = ['-z', '-v', '-w', str(timeout), host, str(port)]
result = execute_cli_command('netcat_test', args)
elif name == "cli_mtr":
host = arguments['host']
cycles = min(arguments.get('cycles', 10), 20)
# Validate host
if not (validate_ip(host) or validate_hostname(host)):
return [TextContent(
type="text",
text=f"Error: Invalid host format: {host}"
)]
args = ['-r', '-c', str(cycles), host]
result = execute_cli_command('mtr', args)
else:
return [TextContent(
type="text",
text=f"Error: Unknown tool: {name}"
)]
# Format output
output_parts = []
if result['stdout']:
output_parts.append(f"Output:\n{result['stdout']}")
if result['stderr']:
output_parts.append(f"Errors:\n{result['stderr']}")
if not result['success']:
output_parts.append(f"Exit code: {result['returncode']}")
output_text = "\n\n".join(output_parts) if output_parts else "Command completed with no output"
return [TextContent(type="text", text=output_text)]
except Exception as e:
logger.error(f"Tool execution error: {str(e)}", exc_info=True)
return [TextContent(
type="text",
text=f"Error executing tool: {str(e)}"
)]
# Flask HTTP Endpoints
@app.route('/health', methods=['GET'])
def health():
"""Health check endpoint"""
return jsonify({"status": "healthy", "service": "mcp-cli"}), 200
@app.route('/mcp/list_tools', methods=['GET'])
async def http_list_tools():
"""List all tools via REST API"""
try:
tools = await list_tools()
return jsonify({
"tools": [
{
"name": tool.name,
"description": tool.description,
"inputSchema": tool.inputSchema
}
for tool in tools
]
}), 200
except Exception as e:
logger.error(f"Error listing tools: {str(e)}", exc_info=True)
return jsonify({"error": str(e)}), 500
@app.route('/mcp/call_tool', methods=['POST'])
async def http_call_tool():
"""Call a tool via REST API"""
try:
data = request.get_json()
name = data.get('name')
arguments = data.get('arguments', {})
if not name:
return jsonify({"error": "Missing 'name' parameter"}), 400
result = await call_tool(name, arguments)
return jsonify({
"content": [{"type": item.type, "text": item.text} for item in result],
"isError": False
}), 200
except Exception as e:
logger.error(f"Error calling tool: {str(e)}", exc_info=True)
return jsonify({
"content": [{"type": "text", "text": f"Error: {str(e)}"}],
"isError": True
}), 500
@app.route('/mcp', methods=['POST'])
async def http_mcp_json_rpc():
"""MCP JSON-RPC 2.0 endpoint"""
try:
data = request.get_json()
method = data.get('method')
params = data.get('params', {})
rpc_id = data.get('id')
if method == 'tools/list':
tools = await list_tools()
result = {
"tools": [
{
"name": tool.name,
"description": tool.description,
"inputSchema": tool.inputSchema
}
for tool in tools
]
}
elif method == 'tools/call':
name = params.get('name')
arguments = params.get('arguments', {})
content = await call_tool(name, arguments)
result = {
"content": [{"type": item.type, "text": item.text} for item in content]
}
else:
return jsonify({
"jsonrpc": "2.0",
"error": {"code": -32601, "message": f"Method not found: {method}"},
"id": rpc_id
}), 200
return jsonify({
"jsonrpc": "2.0",
"result": result,
"id": rpc_id
}), 200
except Exception as e:
logger.error(f"JSON-RPC error: {str(e)}", exc_info=True)
return jsonify({
"jsonrpc": "2.0",
"error": {"code": -32603, "message": str(e)},
"id": data.get('id')
}), 200
if __name__ == '__main__':
port = int(os.environ.get('PORT', 3017))
logger.info(f"Starting MCP CLI Server on port {port}")
app.run(host='0.0.0.0', port=port, debug=False)