from dataclasses import dataclass
from enum import Enum
from typing import List, Dict, Set, Optional
import json
import hashlib
import time
class ThreatCategory(Enum):
SPOOFING = "spoofing"
TAMPERING = "tampering"
REPUDIATION = "repudiation"
INFORMATION_DISCLOSURE = "information_disclosure"
DENIAL_OF_SERVICE = "denial_of_service"
ELEVATION_OF_PRIVILEGE = "elevation_of_privilege"
class RiskLevel(Enum):
LOW = 1
MEDIUM = 2
HIGH = 3
CRITICAL = 4
@dataclass
class Asset:
"""Represents a system asset to be protected"""
name: str
asset_type: str # data, process, interactor, data_store
description: str
value: RiskLevel
data_classification: str # public, internal, confidential, restricted
@dataclass
class DataFlow:
"""Represents data flow between system components"""
source: str
destination: str
data_type: str
protocol: str
authentication_required: bool
encryption_in_transit: bool
@dataclass
class TrustBoundary:
"""Represents trust boundary in the system"""
name: str
internal_assets: Set[str]
external_assets: Set[str]
boundary_type: str # network, process, physical
@dataclass
class Threat:
"""Represents an identified threat"""
threat_id: str
category: ThreatCategory
target_asset: str
description: str
impact: RiskLevel
likelihood: RiskLevel
risk_score: int
mitigation_controls: List[str]
residual_risk: RiskLevel
class STRIDEThreatModeling:
"""Comprehensive STRIDE threat modeling implementation"""
def __init__(self):
self.assets = {}
self.data_flows = []
self.trust_boundaries = []
self.identified_threats = []
self.threat_templates = self._load_threat_templates()
def _load_threat_templates(self) -> Dict[ThreatCategory, List[str]]:
"""Load threat templates for each STRIDE category"""
return {
ThreatCategory.SPOOFING: [
"Attacker impersonates legitimate user",
"Service impersonation attack",
"DNS spoofing attack",
"IP address spoofing",
"Certificate spoofing"
],
ThreatCategory.TAMPERING: [
"Data modification in transit",
"Database tampering",
"Code injection attacks",
"Configuration tampering",
"Log file manipulation"
],
ThreatCategory.REPUDIATION: [
"User denies performing action",
"Insufficient audit logging",
"Log tampering to hide evidence",
"Weak digital signatures",
"Missing transaction records"
],
ThreatCategory.INFORMATION_DISCLOSURE: [
"Sensitive data exposure",
"SQL injection revealing data",
"Directory traversal attacks",
"Memory dumps containing secrets",
"Side-channel attacks"
],
ThreatCategory.DENIAL_OF_SERVICE: [
"DDoS attacks",
"Resource exhaustion",
"Algorithmic complexity attacks",
"Database connection exhaustion",
"Memory exhaustion attacks"
],
ThreatCategory.ELEVATION_OF_PRIVILEGE: [
"Buffer overflow exploitation",
"SQL injection with admin access",
"Privilege escalation vulnerabilities",
"Cross-site scripting (XSS)",
"Insecure direct object references"
]
}
def add_asset(self, asset: Asset):
"""Add an asset to the threat model"""
self.assets[asset.name] = asset
def add_data_flow(self, data_flow: DataFlow):
"""Add a data flow to the threat model"""
self.data_flows.append(data_flow)
def add_trust_boundary(self, trust_boundary: TrustBoundary):
"""Add a trust boundary to the threat model"""
self.trust_boundaries.append(trust_boundary)
def analyze_threats(self) -> List[Threat]:
"""Perform comprehensive threat analysis"""
self.identified_threats.clear()
# Analyze each asset against STRIDE categories
for asset_name, asset in self.assets.items():
self._analyze_asset_threats(asset)
# Analyze data flows
for data_flow in self.data_flows:
self._analyze_data_flow_threats(data_flow)
# Analyze trust boundary crossings
for boundary in self.trust_boundaries:
self._analyze_trust_boundary_threats(boundary)
# Sort threats by risk score
self.identified_threats.sort(key=lambda t: t.risk_score, reverse=True)
return self.identified_threats
def _analyze_asset_threats(self, asset: Asset):
"""Analyze threats specific to an asset"""
# Spoofing threats for interactors
if asset.asset_type == "interactor":
self._create_threat(
ThreatCategory.SPOOFING,
asset.name,
f"Attacker could impersonate {asset.name}",
asset.value,
RiskLevel.MEDIUM
)
# Tampering threats for data stores and processes
if asset.asset_type in ["data_store", "process"]:
self._create_threat(
ThreatCategory.TAMPERING,
asset.name,
f"Unauthorized modification of {asset.name}",
asset.value,
RiskLevel.MEDIUM
)
# Information disclosure for all assets with confidential data
if asset.data_classification in ["confidential", "restricted"]:
self._create_threat(
ThreatCategory.INFORMATION_DISCLOSURE,
asset.name,
f"Unauthorized access to sensitive data in {asset.name}",
asset.value,
RiskLevel.HIGH
)
# Denial of service for critical processes
if asset.asset_type == "process" and asset.value == RiskLevel.CRITICAL:
self._create_threat(
ThreatCategory.DENIAL_OF_SERVICE,
asset.name,
f"Service disruption of critical process {asset.name}",
asset.value,
RiskLevel.HIGH
)
def _analyze_data_flow_threats(self, data_flow: DataFlow):
"""Analyze threats in data flows"""
# Tampering threats for unencrypted data in transit
if not data_flow.encryption_in_transit:
self._create_threat(
ThreatCategory.TAMPERING,
f"{data_flow.source}->{data_flow.destination}",
f"Data tampering in transit between {data_flow.source} and {data_flow.destination}",
RiskLevel.HIGH,
RiskLevel.MEDIUM
)
# Information disclosure for unencrypted sensitive data
if not data_flow.encryption_in_transit and data_flow.data_type in ["personal", "financial"]:
self._create_threat(
ThreatCategory.INFORMATION_DISCLOSURE,
f"{data_flow.source}->{data_flow.destination}",
f"Sensitive data exposure in transit",
RiskLevel.HIGH,
RiskLevel.MEDIUM
)
# Spoofing threats for unauthenticated flows
if not data_flow.authentication_required:
self._create_threat(
ThreatCategory.SPOOFING,
f"{data_flow.source}->{data_flow.destination}",
f"Source spoofing in data flow from {data_flow.source}",
RiskLevel.MEDIUM,
RiskLevel.HIGH
)
def _analyze_trust_boundary_threats(self, boundary: TrustBoundary):
"""Analyze threats at trust boundaries"""
# Elevation of privilege at trust boundaries
self._create_threat(
ThreatCategory.ELEVATION_OF_PRIVILEGE,
boundary.name,
f"Privilege escalation across {boundary.name} trust boundary",
RiskLevel.HIGH,
RiskLevel.MEDIUM
)
# Information disclosure across boundaries
self._create_threat(
ThreatCategory.INFORMATION_DISCLOSURE,
boundary.name,
f"Unauthorized information flow across {boundary.name}",
RiskLevel.MEDIUM,
RiskLevel.MEDIUM
)
def _create_threat(self, category: ThreatCategory, target: str,
description: str, impact: RiskLevel, likelihood: RiskLevel):
"""Create a threat with risk calculation"""
risk_score = impact.value * likelihood.value
# Determine residual risk based on typical controls
residual_risk = self._calculate_residual_risk(category, impact, likelihood)
# Get typical mitigation controls
controls = self._get_mitigation_controls(category)
threat = Threat(
threat_id=hashlib.md5(f"{category.value}{target}{description}".encode()).hexdigest()[:8],
category=category,
target_asset=target,
description=description,
impact=impact,
likelihood=likelihood,
risk_score=risk_score,
mitigation_controls=controls,
residual_risk=residual_risk
)
self.identified_threats.append(threat)
def _calculate_residual_risk(self, category: ThreatCategory,
impact: RiskLevel, likelihood: RiskLevel) -> RiskLevel:
"""Calculate residual risk after typical controls"""
# Simplified residual risk calculation
initial_risk = impact.value * likelihood.value
# Reduction factors based on typical controls
reduction_factors = {
ThreatCategory.SPOOFING: 0.7, # Strong authentication reduces risk
ThreatCategory.TAMPERING: 0.6, # Integrity controls
ThreatCategory.REPUDIATION: 0.5, # Audit logging
ThreatCategory.INFORMATION_DISCLOSURE: 0.6, # Encryption and access controls
ThreatCategory.DENIAL_OF_SERVICE: 0.8, # Harder to fully mitigate
ThreatCategory.ELEVATION_OF_PRIVILEGE: 0.7 # Access controls help
}
reduced_risk = initial_risk * reduction_factors.get(category, 0.8)
if reduced_risk <= 2:
return RiskLevel.LOW
elif reduced_risk <= 4:
return RiskLevel.MEDIUM
elif reduced_risk <= 8:
return RiskLevel.HIGH
else:
return RiskLevel.CRITICAL
def _get_mitigation_controls(self, category: ThreatCategory) -> List[str]:
"""Get typical mitigation controls for threat category"""
control_mapping = {
ThreatCategory.SPOOFING: [
"Multi-factor authentication",
"Strong credential policies",
"Certificate-based authentication",
"Digital signatures"
],
ThreatCategory.TAMPERING: [
"Input validation",
"Data integrity checks",
"Code signing",
"Database transaction controls",
"Checksums and hashing"
],
ThreatCategory.REPUDIATION: [
"Comprehensive audit logging",
"Digital signatures",
"Timestamping services",
"Log integrity protection"
],
ThreatCategory.INFORMATION_DISCLOSURE: [
"Encryption at rest and in transit",
"Access controls and authorization",
"Data loss prevention (DLP)",
"Network segmentation",
"Least privilege principle"
],
ThreatCategory.DENIAL_OF_SERVICE: [
"Rate limiting and throttling",
"Load balancing and scaling",
"DDoS protection services",
"Resource monitoring and alerting",
"Circuit breakers"
],
ThreatCategory.ELEVATION_OF_PRIVILEGE: [
"Input validation and sanitization",
"Least privilege access controls",
"Security code reviews",
"Runtime protection (ASLR, DEP)",
"Regular security patching"
]
}
return control_mapping.get(category, [])
def generate_threat_report(self) -> Dict:
"""Generate comprehensive threat assessment report"""
if not self.identified_threats:
self.analyze_threats()
# Categorize threats by risk level
risk_distribution = {level: 0 for level in RiskLevel}
category_distribution = {cat: 0 for cat in ThreatCategory}
for threat in self.identified_threats:
risk_distribution[threat.residual_risk] += 1
category_distribution[threat.category] += 1
# Calculate risk metrics
total_threats = len(self.identified_threats)
high_risk_threats = sum(1 for t in self.identified_threats
if t.residual_risk in [RiskLevel.HIGH, RiskLevel.CRITICAL])
report = {
"threat_model_summary": {
"total_assets": len(self.assets),
"total_data_flows": len(self.data_flows),
"total_trust_boundaries": len(self.trust_boundaries),
"total_threats_identified": total_threats,
"high_risk_threats": high_risk_threats,
"risk_score": sum(t.risk_score for t in self.identified_threats)
},
"risk_distribution": {level.name: count for level, count in risk_distribution.items()},
"category_distribution": {cat.name: count for cat, count in category_distribution.items()},
"top_threats": [
{
"threat_id": t.threat_id,
"category": t.category.name,
"description": t.description,
"target": t.target_asset,
"risk_score": t.risk_score,
"residual_risk": t.residual_risk.name,
"controls": t.mitigation_controls
}
for t in self.identified_threats[:10]
],
"recommended_actions": self._generate_recommendations()
}
return report
def _generate_recommendations(self) -> List[str]:
"""Generate security recommendations based on threat analysis"""
recommendations = []
# High-level recommendations based on threat patterns
high_risk_count = sum(1 for t in self.identified_threats
if t.residual_risk == RiskLevel.HIGH)
critical_risk_count = sum(1 for t in self.identified_threats
if t.residual_risk == RiskLevel.CRITICAL)
if critical_risk_count > 0:
recommendations.append(f"URGENT: Address {critical_risk_count} critical risk threats immediately")
if high_risk_count > 5:
recommendations.append(f"Prioritize mitigation of {high_risk_count} high-risk threats")
# Category-specific recommendations
category_counts = {}
for threat in self.identified_threats:
if threat.residual_risk in [RiskLevel.HIGH, RiskLevel.CRITICAL]:
category_counts[threat.category] = category_counts.get(threat.category, 0) + 1
for category, count in category_counts.items():
if count >= 3:
recommendations.append(f"Focus on {category.name.replace('_', ' ')} controls - {count} threats identified")
# Generic recommendations
recommendations.extend([
"Implement comprehensive security monitoring and alerting",
"Conduct regular penetration testing and vulnerability assessments",
"Establish incident response procedures",
"Implement security training for development teams"
])
return recommendations
# Amazon-specific threat modeling for AWS services
class AWSSecurityThreatModeling(STRIDEThreatModeling):
"""Specialized threat modeling for AWS environments"""
def __init__(self):
super().__init__()
self.aws_service_threats = self._load_aws_service_threats()
def _load_aws_service_threats(self) -> Dict[str, List[Dict]]:
"""Load AWS service-specific threat patterns"""
return {
"s3": [
{
"category": ThreatCategory.INFORMATION_DISCLOSURE,
"description": "Misconfigured S3 bucket permissions allowing public access",
"impact": RiskLevel.HIGH,
"likelihood": RiskLevel.HIGH,
"controls": ["Bucket policies", "ACLs", "Block public access settings"]
},
{
"category": ThreatCategory.TAMPERING,
"description": "Unauthorized S3 object modification",
"impact": RiskLevel.MEDIUM,
"likelihood": RiskLevel.MEDIUM,
"controls": ["Object versioning", "MFA delete", "IAM policies"]
}
],
"dynamodb": [
{
"category": ThreatCategory.INFORMATION_DISCLOSURE,
"description": "Overprivileged IAM roles accessing DynamoDB data",
"impact": RiskLevel.HIGH,
"likelihood": RiskLevel.MEDIUM,
"controls": ["Fine-grained IAM policies", "VPC endpoints", "Encryption"]
}
],
"lambda": [
{
"category": ThreatCategory.ELEVATION_OF_PRIVILEGE,
"description": "Lambda function with excessive IAM permissions",
"impact": RiskLevel.HIGH,
"likelihood": RiskLevel.MEDIUM,
"controls": ["Least privilege IAM roles", "Resource-based policies"]
},
{
"category": ThreatCategory.INFORMATION_DISCLOSURE,
"description": "Sensitive data in Lambda environment variables",
"impact": RiskLevel.MEDIUM,
"likelihood": RiskLevel.HIGH,
"controls": ["AWS Secrets Manager", "Parameter Store", "Encryption"]
}
],
"api_gateway": [
{
"category": ThreatCategory.DENIAL_OF_SERVICE,
"description": "API Gateway without throttling limits",
"impact": RiskLevel.MEDIUM,
"likelihood": RiskLevel.HIGH,
"controls": ["Throttling limits", "Usage plans", "WAF integration"]
}
]
}
def analyze_aws_service_threats(self, service_name: str):
"""Analyze threats specific to AWS services"""
if service_name not in self.aws_service_threats:
return
for threat_pattern in self.aws_service_threats[service_name]:
self._create_threat(
threat_pattern["category"],
f"aws_{service_name}",
threat_pattern["description"],
threat_pattern["impact"],
threat_pattern["likelihood"]
)