260 lines
7.2 KiB
Python
260 lines
7.2 KiB
Python
"""
|
|
Security utilities for path sanitization and file handling.
|
|
"""
|
|
|
|
import os
|
|
import re
|
|
import uuid
|
|
from werkzeug.utils import secure_filename
|
|
|
|
|
|
def sanitize_folder_path(path):
|
|
"""
|
|
Sanitize folder path to prevent traversal attacks.
|
|
|
|
Args:
|
|
path: The folder path to sanitize
|
|
|
|
Returns:
|
|
Sanitized path string
|
|
|
|
Raises:
|
|
ValueError: If path contains forbidden patterns or characters
|
|
"""
|
|
if not path:
|
|
return ""
|
|
|
|
# Remove any leading/trailing slashes and whitespace
|
|
path = path.strip().strip('/')
|
|
|
|
# Reject paths containing dangerous patterns
|
|
dangerous_patterns = [
|
|
'..', # Parent directory traversal
|
|
'./', # Current directory reference
|
|
'\\', # Windows path separator
|
|
'\0', # Null byte
|
|
'~', # Home directory reference
|
|
'\x00', # Alternative null byte
|
|
'%2e%2e', # URL encoded ..
|
|
'%252e%252e', # Double URL encoded ..
|
|
]
|
|
|
|
# Check both original and lowercase version
|
|
path_lower = path.lower()
|
|
for pattern in dangerous_patterns:
|
|
if pattern in path or pattern in path_lower:
|
|
raise ValueError(f"Invalid path: contains forbidden pattern '{pattern}'")
|
|
|
|
# Only allow alphanumeric, spaces, hyphens, underscores, and forward slashes
|
|
if not re.match(r'^[a-zA-Z0-9\s\-_/]+$', path):
|
|
raise ValueError("Invalid path: contains forbidden characters")
|
|
|
|
# Normalize path (remove double slashes, etc.)
|
|
path_parts = [p for p in path.split('/') if p]
|
|
|
|
# Additional check: ensure no part is '..' or '.' or empty
|
|
for part in path_parts:
|
|
if part in ('.', '..', '') or part.strip() == '':
|
|
raise ValueError("Invalid path: contains directory traversal")
|
|
|
|
# Check each part doesn't exceed reasonable length
|
|
if len(part) > 100:
|
|
raise ValueError("Invalid path: folder name too long")
|
|
|
|
# Check total depth
|
|
if len(path_parts) > 10:
|
|
raise ValueError("Invalid path: folder depth exceeds maximum allowed")
|
|
|
|
normalized = '/'.join(path_parts)
|
|
|
|
# Final length check
|
|
if len(normalized) > 500:
|
|
raise ValueError("Invalid path: total path length exceeds maximum allowed")
|
|
|
|
return normalized
|
|
|
|
|
|
def generate_secure_file_path(file_type, original_filename):
|
|
"""
|
|
Generate secure file path using UUID to prevent predictable paths.
|
|
|
|
Args:
|
|
file_type: Type of file (image, markdown, text, document)
|
|
original_filename: Original uploaded filename
|
|
|
|
Returns:
|
|
Secure relative path for file storage
|
|
|
|
Raises:
|
|
ValueError: If file type is not allowed
|
|
"""
|
|
if not original_filename:
|
|
raise ValueError("Filename is required")
|
|
|
|
# Extract and validate extension
|
|
_, ext = os.path.splitext(original_filename)
|
|
ext = ext.lower()
|
|
|
|
# Whitelist allowed extensions by type
|
|
allowed_extensions = {
|
|
'markdown': {'.md', '.markdown', '.mdown', '.mkd'},
|
|
'image': {'.png', '.jpg', '.jpeg', '.gif', '.webp', '.svg'},
|
|
'text': {'.txt'},
|
|
'document': {'.pdf', '.doc', '.docx'}
|
|
}
|
|
|
|
# Verify file type matches extension
|
|
type_extensions = allowed_extensions.get(file_type, set())
|
|
if ext not in type_extensions:
|
|
all_allowed = set()
|
|
for exts in allowed_extensions.values():
|
|
all_allowed.update(exts)
|
|
if ext not in all_allowed:
|
|
raise ValueError(f"File extension '{ext}' is not allowed")
|
|
# Find correct file type based on extension
|
|
for ftype, exts in allowed_extensions.items():
|
|
if ext in exts:
|
|
file_type = ftype
|
|
break
|
|
|
|
# Generate UUID for filename
|
|
file_id = str(uuid.uuid4())
|
|
|
|
# Create secure filename
|
|
secure_name = f"{file_id}{ext}"
|
|
|
|
# Return path with type subdirectory
|
|
return f"{file_type}/{secure_name}"
|
|
|
|
|
|
def validate_folder_access(folder_path, company_id, db_session):
|
|
"""
|
|
Validate folder exists and belongs to company.
|
|
|
|
Args:
|
|
folder_path: Path to validate
|
|
company_id: Company ID to check against
|
|
db_session: Database session
|
|
|
|
Returns:
|
|
True if folder is valid and accessible, False otherwise
|
|
"""
|
|
if not folder_path:
|
|
return True # Root folder is always valid
|
|
|
|
try:
|
|
# Sanitize the path first
|
|
folder_path = sanitize_folder_path(folder_path)
|
|
except ValueError:
|
|
return False
|
|
|
|
# Import here to avoid circular imports
|
|
from models import NoteFolder
|
|
|
|
# Check if folder exists in database
|
|
folder = db_session.query(NoteFolder).filter_by(
|
|
path=folder_path,
|
|
company_id=company_id
|
|
).first()
|
|
|
|
return folder is not None
|
|
|
|
|
|
def ensure_safe_file_path(base_path, file_path):
|
|
"""
|
|
Ensure a file path is within the safe base directory.
|
|
|
|
Args:
|
|
base_path: The safe base directory
|
|
file_path: The file path to check
|
|
|
|
Returns:
|
|
Absolute safe path
|
|
|
|
Raises:
|
|
ValueError: If path would escape the base directory
|
|
"""
|
|
# Get absolute paths
|
|
base_abs = os.path.abspath(base_path)
|
|
|
|
# Join paths and resolve
|
|
full_path = os.path.join(base_abs, file_path)
|
|
full_abs = os.path.abspath(full_path)
|
|
|
|
# Ensure the resolved path is within the base
|
|
if not full_abs.startswith(base_abs + os.sep) and full_abs != base_abs:
|
|
raise ValueError("Path traversal detected")
|
|
|
|
return full_abs
|
|
|
|
|
|
def validate_filename(filename):
|
|
"""
|
|
Validate and secure a filename.
|
|
|
|
Args:
|
|
filename: The filename to validate
|
|
|
|
Returns:
|
|
Secure filename
|
|
|
|
Raises:
|
|
ValueError: If filename is invalid
|
|
"""
|
|
if not filename:
|
|
raise ValueError("Filename is required")
|
|
|
|
# Use werkzeug's secure_filename
|
|
secured = secure_filename(filename)
|
|
|
|
if not secured or secured == '':
|
|
raise ValueError("Invalid filename")
|
|
|
|
# Additional checks
|
|
if len(secured) > 255:
|
|
raise ValueError("Filename too long")
|
|
|
|
# Ensure it has an extension
|
|
if '.' not in secured:
|
|
raise ValueError("Filename must have an extension")
|
|
|
|
return secured
|
|
|
|
|
|
def get_safe_mime_type(filename):
|
|
"""
|
|
Get MIME type for a filename, defaulting to safe types.
|
|
|
|
Args:
|
|
filename: The filename to check
|
|
|
|
Returns:
|
|
Safe MIME type string
|
|
"""
|
|
ext = os.path.splitext(filename)[1].lower()
|
|
|
|
mime_types = {
|
|
# Markdown
|
|
'.md': 'text/markdown',
|
|
'.markdown': 'text/markdown',
|
|
'.mdown': 'text/markdown',
|
|
'.mkd': 'text/markdown',
|
|
|
|
# Images
|
|
'.png': 'image/png',
|
|
'.jpg': 'image/jpeg',
|
|
'.jpeg': 'image/jpeg',
|
|
'.gif': 'image/gif',
|
|
'.webp': 'image/webp',
|
|
'.svg': 'image/svg+xml',
|
|
|
|
# Text
|
|
'.txt': 'text/plain',
|
|
|
|
# Documents
|
|
'.pdf': 'application/pdf',
|
|
'.doc': 'application/msword',
|
|
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
|
|
}
|
|
|
|
return mime_types.get(ext, 'application/octet-stream') |