"""
Composition Validator for sktime MCP.
sktime estimators are composable:
- transformers → forecasters
- pipelines
- reduction strategies
This module enforces:
- Compatible task types
- Valid composition order
- Tag compatibility
This prevents invalid pipelines at planning time.
"""
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from sktime_mcp.registry.interface import EstimatorNode, get_registry
logger = logging.getLogger(__name__)
class CompositionType(Enum):
"""Types of composition in sktime."""
PIPELINE = "pipeline"
TRANSFORMER_PIPELINE = "transformer_pipeline"
FORECASTING_PIPELINE = "forecasting_pipeline"
MULTIPLEXER = "multiplexer"
ENSEMBLE = "ensemble"
REDUCTION = "reduction"
[docs]
@dataclass
class CompositionRule:
"""
A rule describing valid compositions for an estimator type.
Attributes:
source_task: The task type that can be composed
target_task: The task type it can compose with
composition_type: Type of composition
position: Where in the pipeline (before, after, any)
description: Human-readable description
"""
source_task: str
target_task: str
composition_type: CompositionType
position: str # "before", "after", "any"
description: str
[docs]
@dataclass
class ValidationResult:
"""
Result of a composition validation.
Attributes:
valid: Whether the composition is valid
errors: List of validation errors
warnings: List of warnings (valid but potentially problematic)
suggestions: Suggested fixes for invalid compositions
"""
valid: bool
errors: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
suggestions: list[str] = field(default_factory=list)
[docs]
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
return {
"valid": self.valid,
"errors": self.errors,
"warnings": self.warnings,
"suggestions": self.suggestions,
}
[docs]
class CompositionValidator:
"""
Validator for sktime estimator compositions.
This class encodes the rules for valid estimator compositions
in sktime, allowing validation at planning time rather than
runtime.
"""
# Valid composition rules
COMPOSITION_RULES: list[CompositionRule] = [
# Transformers can precede forecasters
CompositionRule(
source_task="transformation",
target_task="forecasting",
composition_type=CompositionType.FORECASTING_PIPELINE,
position="before",
description="Transformers can be applied before forecasters in a pipeline",
),
# Transformers can be chained
CompositionRule(
source_task="transformation",
target_task="transformation",
composition_type=CompositionType.TRANSFORMER_PIPELINE,
position="before",
description="Transformers can be chained together",
),
# Forecasters can be ensembled
CompositionRule(
source_task="forecasting",
target_task="forecasting",
composition_type=CompositionType.ENSEMBLE,
position="any",
description="Forecasters can be combined in an ensemble",
),
# Classifiers can be ensembled
CompositionRule(
source_task="classification",
target_task="classification",
composition_type=CompositionType.ENSEMBLE,
position="any",
description="Classifiers can be combined in an ensemble",
),
# Transformers can precede classifiers
CompositionRule(
source_task="transformation",
target_task="classification",
composition_type=CompositionType.PIPELINE,
position="before",
description="Transformers can be applied before classifiers",
),
# Transformers can precede regressors
CompositionRule(
source_task="transformation",
target_task="regression",
composition_type=CompositionType.PIPELINE,
position="before",
description="Transformers can be applied before regressors",
),
]
# Known transformer categories
TRANSFORMER_CATEGORIES = {
"Imputer": "missing_value_handler",
"Detrend": "trend_remover",
"Deseasonalize": "seasonality_remover",
"Differencer": "differencer",
"BoxCoxTransformer": "power_transform",
"LogTransformer": "power_transform",
"ScaledLogitTransformer": "power_transform",
"Lag": "lag_creator",
"WindowSummarizer": "feature_creator",
"DateTimeFeatures": "feature_creator",
}
[docs]
def __init__(self):
"""Initialize the validator."""
self._registry = get_registry()
[docs]
def validate_pipeline(
self,
components: list[str],
) -> ValidationResult:
"""
Validate a proposed pipeline composition.
Args:
components: List of estimator names in pipeline order
Returns:
ValidationResult with validity status and any issues
"""
if not components:
return ValidationResult(
valid=False,
errors=["Pipeline cannot be empty"],
)
if len(components) == 1:
# Single component is always valid if it exists
estimator = self._registry.get_estimator_by_name(components[0])
if estimator is None:
return ValidationResult(
valid=False,
errors=[f"Unknown estimator: {components[0]}"],
)
return ValidationResult(valid=True)
errors = []
warnings = []
suggestions = []
# Get all estimator nodes
nodes: list[tuple[str, EstimatorNode | None]] = []
for name in components:
node = self._registry.get_estimator_by_name(name)
nodes.append((name, node))
if node is None:
errors.append(f"Unknown estimator: {name}")
if errors:
return ValidationResult(
valid=False,
errors=errors,
)
# Check pairwise compatibility
for i in range(len(nodes) - 1):
current_name, current_node = nodes[i]
next_name, next_node = nodes[i + 1]
# Check if this composition is valid
valid_pair, pair_errors, pair_warnings = self._check_pair_compatibility(
current_node, next_node
)
if not valid_pair:
errors.extend(pair_errors)
warnings.extend(pair_warnings)
# Check final component is executable (forecaster, classifier, etc.)
final_name, final_node = nodes[-1]
if final_node.task == "transformation":
errors.append(
f"Pipeline ends with transformer '{final_name}'. "
"The final component should be a forecaster, classifier, or regressor."
)
suggestions.append("Add a forecaster like 'ARIMA' or 'ExponentialSmoothing' at the end")
# Check for duplicate consecutive components
for i in range(len(components) - 1):
if components[i] == components[i + 1]:
warnings.append(
f"Duplicate consecutive component: '{components[i]}' at positions {i + 1} and {i + 2}"
)
return ValidationResult(
valid=len(errors) == 0,
errors=errors,
warnings=warnings,
suggestions=suggestions,
)
def _check_pair_compatibility(
self,
first: EstimatorNode,
second: EstimatorNode,
) -> tuple[bool, list[str], list[str]]:
"""
Check if two estimators can be composed in sequence.
Returns:
(is_valid, errors, warnings)
"""
errors = []
warnings = []
# Find applicable rule
applicable_rule = None
for rule in self.COMPOSITION_RULES:
if (
rule.source_task == first.task
and rule.target_task == second.task
and rule.position in ("before", "any")
):
applicable_rule = rule
break
if applicable_rule is None:
# No rule found - check if it's an obvious error
if first.task == second.task == "forecasting":
errors.append(
f"Cannot chain forecasters '{first.name}' → '{second.name}' directly. "
"Use an ensemble or multiplexer instead."
)
elif first.task in ("classification", "regression") and second.task != first.task:
errors.append(
f"Invalid composition: {first.task} '{first.name}' → {second.task} '{second.name}'"
)
else:
warnings.append(
f"Unusual composition: {first.task} '{first.name}' → {second.task} '{second.name}'"
)
# Check tag compatibility
tag_errors, tag_warnings = self._check_tag_compatibility(first, second)
errors.extend(tag_errors)
warnings.extend(tag_warnings)
return len(errors) == 0, errors, warnings
def _check_tag_compatibility(
self,
first: EstimatorNode,
second: EstimatorNode,
) -> tuple[list[str], list[str]]:
"""Check tag-based compatibility between estimators."""
errors = []
warnings = []
# Check univariate vs multivariate
first_univariate = first.tags.get("univariate-only", False)
second_multivariate = second.tags.get("capability:multivariate", False)
if first_univariate and second_multivariate:
warnings.append(
f"'{first.name}' is univariate-only but placed before "
f"multivariate-capable '{second.name}'"
)
# Check if transformer output is compatible with next component's input
# This is a simplified check - full check would need mtype resolution
return errors, warnings
[docs]
def get_valid_compositions(
self,
estimator_name: str,
) -> dict[str, list[str]]:
"""
Get valid compositions for an estimator.
Args:
estimator_name: Name of the estimator
Returns:
Dictionary with "can_precede" and "can_follow" lists
"""
estimator = self._registry.get_estimator_by_name(estimator_name)
if estimator is None:
return {
"can_precede": [],
"can_follow": [],
"error": f"Unknown estimator: {estimator_name}",
}
can_precede = []
can_follow = []
for rule in self.COMPOSITION_RULES:
if rule.source_task == estimator.task and rule.position in ("before", "any"):
# This estimator can precede things of target_task
can_precede.append(rule.target_task)
if rule.target_task == estimator.task and rule.position in ("before", "any"):
# Things of source_task can precede this estimator
can_follow.append(rule.source_task)
return {
"can_precede": list(set(can_precede)),
"can_follow": list(set(can_follow)),
}
[docs]
def suggest_pipeline(
self,
task: str,
requirements: dict[str, Any] | None = None,
) -> list[str]:
"""
Suggest a valid pipeline for a given task.
Args:
task: Target task (e.g., "forecasting")
requirements: Optional requirements (e.g., {"handles_missing": True})
Returns:
List of suggested estimator names forming a valid pipeline
"""
suggestions = []
if task == "forecasting":
# Suggest common preprocessing → forecaster pipeline
if requirements and requirements.get("handles_missing"):
suggestions.append("Imputer")
# Get a suitable forecaster
forecasters = self._registry.get_all_estimators(
task="forecasting",
tags=requirements if requirements else None,
)
if forecasters:
# Pick first match
suggestions.append(forecasters[0].name)
else:
suggestions.append("NaiveForecaster") # Fallback
elif task == "classification":
# Suggest transformer → classifier
classifiers = self._registry.get_all_estimators(task="classification")
if classifiers:
suggestions.append(classifiers[0].name)
return suggestions
# Singleton instance
_validator_instance: CompositionValidator | None = None
def get_composition_validator() -> CompositionValidator:
"""Get the singleton composition validator instance."""
global _validator_instance
if _validator_instance is None:
_validator_instance = CompositionValidator()
return _validator_instance