"""
Executor for sktime MCP.
Responsible for instantiating estimators, loading datasets,
and running fit/predict operations.
"""
import asyncio
import inspect
import logging
import os
import uuid
from typing import Any
import pandas as pd
from sktime_mcp.registry.interface import get_registry
from sktime_mcp.runtime.handles import get_handle_manager
from sktime_mcp.runtime.jobs import JobStatus, get_job_manager
logger = logging.getLogger(__name__)
# Dynamically discover all available sktime demo datasets at import time.
# This replaces the old hardcoded dictionary and automatically exposes every
# load_* function in sktime.datasets to the MCP server.
def _discover_demo_datasets() -> dict:
"""Return a mapping of dataset name -> dotted module path for every
``load_*`` function exported by ``sktime.datasets``."""
try:
import sktime.datasets as _ds_module
return {
name.removeprefix("load_"): f"sktime.datasets.{name}"
for name, obj in inspect.getmembers(_ds_module, inspect.isfunction)
if name.startswith("load_")
}
except Exception: # pragma: no cover
return {} # fallback: empty dict if sktime not installed
DEMO_DATASETS = _discover_demo_datasets()
[docs]
class Executor:
"""
Execution runtime for sktime estimators.
Handles instantiation, fitting, and prediction.
"""
[docs]
def __init__(self):
self._registry = get_registry()
self._handle_manager = get_handle_manager()
self._job_manager = get_job_manager()
self._data_handles = {} # Store data handles
self._auto_format_enabled = (
os.environ.get("SKTIME_MCP_AUTO_FORMAT", "true").lower() == "true"
)
[docs]
def instantiate(
self,
estimator_name: str,
params: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Instantiate an estimator and return a handle."""
node = self._registry.get_estimator_by_name(estimator_name)
if node is None:
return {"success": False, "error": f"Unknown estimator: {estimator_name}"}
try:
instance = node.class_ref(**(params or {}))
handle_id = self._handle_manager.create_handle(
estimator_name=estimator_name,
instance=instance,
params=params or {},
)
return {
"success": True,
"handle": handle_id,
"estimator": estimator_name,
"params": params or {},
}
except Exception as e:
return {"success": False, "error": str(e)}
# L-7: We can also add custom load_dataset functions here
[docs]
def load_dataset(self, name: str) -> dict[str, Any]:
"""Load a demo dataset."""
if name not in DEMO_DATASETS:
return {
"success": False,
"error": f"Unknown dataset: {name}",
"available": list(DEMO_DATASETS.keys()),
}
try:
module_path = DEMO_DATASETS[name]
parts = module_path.rsplit(".", 1)
module = __import__(parts[0], fromlist=[parts[1]])
loader = getattr(module, parts[1])
data = loader()
if isinstance(data, tuple):
y, X = data[0], data[1] if len(data) > 1 else None
else:
y, X = data, None
return {
"success": True,
"name": name,
"shape": y.shape if hasattr(y, "shape") else len(y),
"type": str(type(y).__name__),
"data": y,
"exog": X,
}
except Exception as e:
return {"success": False, "error": str(e)}
[docs]
def fit(
self,
handle_id: str,
y: Any,
X: Any | None = None,
fh: Any | None = None,
) -> dict[str, Any]:
"""Fit an estimator."""
try:
instance = self._handle_manager.get_instance(handle_id)
except KeyError:
return {"success": False, "error": f"Handle not found: {handle_id}"}
try:
if fh is not None:
instance.fit(y, X=X, fh=fh)
elif X is not None:
instance.fit(y, X=X)
else:
instance.fit(y)
self._handle_manager.mark_fitted(handle_id)
return {"success": True, "handle": handle_id, "fitted": True}
except Exception as e:
return {"success": False, "error": str(e)}
[docs]
def predict(
self,
handle_id: str,
fh: int | list[int] | None = None,
X: Any | None = None,
) -> dict[str, Any]:
"""Generate predictions."""
try:
instance = self._handle_manager.get_instance(handle_id)
except KeyError:
return {"success": False, "error": f"Handle not found: {handle_id}"}
if not self._handle_manager.is_fitted(handle_id):
return {"success": False, "error": "Estimator not fitted"}
try:
if fh is None:
fh = list(range(1, 13))
predictions = instance.predict(fh=fh, X=X) if X is not None else instance.predict(fh=fh)
if isinstance(predictions, pd.Series):
# Convert index to string to avoid JSON serialization issues with Period/DatetimeIndex
predictions_copy = predictions.copy()
predictions_copy.index = predictions_copy.index.astype(str)
result = predictions_copy.to_dict()
elif isinstance(predictions, pd.DataFrame):
predictions_copy = predictions.copy()
predictions_copy.index = predictions_copy.index.astype(str)
result = predictions_copy.to_dict(orient="list")
else:
result = predictions.tolist() if hasattr(predictions, "tolist") else predictions
return {
"success": True,
"predictions": result,
"horizon": len(fh) if hasattr(fh, "__len__") else fh,
}
except Exception as e:
return {"success": False, "error": str(e)}
[docs]
def fit_predict(
self,
handle_id: str,
dataset: str,
horizon: int = 12,
data_handle: str | None = None,
) -> dict[str, Any]:
"""Convenience method: load data, fit, and predict."""
if dataset and data_handle:
return {
"success": False,
"error": "Provide either 'dataset' or 'data_handle', not both.",
}
if data_handle is None and (not dataset or not str(dataset).strip()):
return {
"success": False,
"error": (
"Either 'dataset' (e.g. 'airline') or "
"'data_handle' (from load_data_source) is required."
),
}
if data_handle is not None:
# Use custom loaded data
if data_handle not in self._data_handles:
return {
"success": False,
"error": f"Unknown data handle: {data_handle}",
"available_handles": list(self._data_handles.keys()),
}
data_info = self._data_handles[data_handle]
y = data_info["y"]
X = data_info.get("X")
else:
# Use demo dataset
data_result = self.load_dataset(dataset)
if not data_result["success"]:
return data_result
y = data_result["data"]
X = data_result.get("exog")
fh = list(range(1, horizon + 1))
fit_result = self.fit(handle_id, y, X=X, fh=fh)
if not fit_result["success"]:
return fit_result
return self.predict(handle_id, fh=fh, X=X)
[docs]
async def fit_predict_async(
self,
handle_id: str,
dataset: str | None = None,
data_handle: str | None = None,
horizon: int = 12,
job_id: str | None = None,
) -> dict[str, Any]:
"""
Async version of fit_predict with job tracking.
Runs the training in the background without blocking the MCP server.
Accepts either a demo dataset name or a data handle from
load_data_source.
Args:
handle_id: Estimator handle
dataset: Demo dataset name
data_handle: Data handle from load_data_source
horizon: Forecast horizon
job_id: Optional job ID for tracking (created if not provided)
Returns:
Dictionary with success status and job_id
"""
# Get estimator info for job tracking
try:
handle_info = self._handle_manager.get_info(handle_id)
estimator_name = handle_info.estimator_name
except Exception as e:
logger.warning(f"Could not get estimator name: {e}")
estimator_name = "Unknown"
source_name = dataset if dataset else data_handle
# Create job if not provided
if job_id is None:
job_id = self._job_manager.create_job(
job_type="fit_predict",
estimator_handle=handle_id,
estimator_name=estimator_name,
dataset_name=source_name,
horizon=horizon,
total_steps=3,
)
try:
# Update status to RUNNING
self._job_manager.update_job(job_id, status=JobStatus.RUNNING)
# Step 1: Load data
if data_handle:
# Use custom data from a loaded handle
self._job_manager.update_job(
job_id,
completed_steps=0,
current_step=f"Loading data from handle '{data_handle}'...",
)
await asyncio.sleep(0.01)
if data_handle not in self._data_handles:
self._job_manager.update_job(
job_id,
status=JobStatus.FAILED,
errors=[f"Unknown data handle: {data_handle}"],
)
return {
"success": False,
"error": f"Unknown data handle: {data_handle}",
"available_handles": list(self._data_handles.keys()),
}
data_info = self._data_handles[data_handle]
y = data_info["y"]
X = data_info.get("X")
else:
# Use built-in demo dataset
self._job_manager.update_job(
job_id,
completed_steps=0,
current_step=f"Loading dataset '{dataset}'...",
)
await asyncio.sleep(0.01)
data_result = self.load_dataset(dataset)
if not data_result["success"]:
self._job_manager.update_job(
job_id,
status=JobStatus.FAILED,
errors=[f"Failed to load dataset: {data_result.get('error')}"],
)
return data_result
y = data_result["data"]
X = data_result.get("exog")
fh = list(range(1, horizon + 1))
# Step 2: Fit model
self._job_manager.update_job(
job_id,
completed_steps=1,
current_step=f"Fitting {estimator_name} on {source_name}...",
)
await asyncio.sleep(0.01)
loop = asyncio.get_event_loop()
fit_result = await loop.run_in_executor(
None, lambda: self.fit(handle_id, y, X=X, fh=fh)
)
if not fit_result["success"]:
self._job_manager.update_job(
job_id,
status=JobStatus.FAILED,
errors=[f"Fit failed: {fit_result.get('error')}"],
)
return fit_result
# Step 3: Generate predictions
self._job_manager.update_job(
job_id,
completed_steps=2,
current_step=f"Generating predictions (horizon={horizon})...",
)
await asyncio.sleep(0.01)
predict_result = await loop.run_in_executor(
None, lambda: self.predict(handle_id, fh=fh, X=X)
)
if not predict_result["success"]:
self._job_manager.update_job(
job_id,
status=JobStatus.FAILED,
errors=[f"Prediction failed: {predict_result.get('error')}"],
)
return predict_result
# Mark as completed
self._job_manager.update_job(
job_id,
status=JobStatus.COMPLETED,
completed_steps=3,
current_step="Completed",
result=predict_result,
)
return predict_result
except Exception as e:
logger.exception(f"Error in async fit_predict for job {job_id}")
self._job_manager.update_job(job_id, status=JobStatus.FAILED, errors=[str(e)])
return {"success": False, "error": str(e), "job_id": job_id}
# L-9: We can add more methods here to handle diverse use cases and their pipelines
[docs]
def instantiate_pipeline(
self,
components: list[str],
params_list: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
"""
Instantiate a pipeline from a list of components.
Args:
components: List of estimator names in pipeline order
params_list: Optional list of parameter dicts for each component
Returns:
Dictionary with success status and handle
"""
if not components:
return {"success": False, "error": "Pipeline cannot be empty"}
# Validate the pipeline first
from sktime_mcp.composition.validator import get_composition_validator
validator = get_composition_validator()
validation = validator.validate_pipeline(components)
if not validation.valid:
return {
"success": False,
"error": "Invalid pipeline composition",
"validation_errors": validation.errors,
"suggestions": validation.suggestions,
}
try:
# If only one component, just instantiate it directly
if len(components) == 1:
params = params_list[0] if params_list else {}
return self.instantiate(components[0], params)
# Build the pipeline
# Get all component nodes
component_instances = []
params_list = params_list or [{}] * len(components)
for i, comp_name in enumerate(components):
node = self._registry.get_estimator_by_name(comp_name)
if node is None:
return {"success": False, "error": f"Unknown estimator: {comp_name}"}
params = params_list[i] if i < len(params_list) else {}
instance = node.class_ref(**params)
component_instances.append(instance)
# Determine the type of pipeline to create
# Check if all but last are transformers
all_transformers_except_last = all(
self._registry.get_estimator_by_name(comp).task == "transformation"
for comp in components[:-1]
)
final_task = self._registry.get_estimator_by_name(components[-1]).task
if all_transformers_except_last and final_task == "forecasting":
# Use TransformedTargetForecaster
from sktime.forecasting.compose import TransformedTargetForecaster
# Chain transformers if multiple
if len(component_instances) == 2:
pipeline = TransformedTargetForecaster(
[
("transformer", component_instances[0]),
("forecaster", component_instances[1]),
]
)
else:
# Multiple transformers - chain them
from sktime.transformations.compose import TransformerPipeline
transformer_pipeline = TransformerPipeline(
[(f"step_{i}", comp) for i, comp in enumerate(component_instances[:-1])]
)
pipeline = TransformedTargetForecaster(
[
("transformers", transformer_pipeline),
("forecaster", component_instances[-1]),
]
)
elif all_transformers_except_last and final_task in ("classification", "regression"):
# Use sklearn-style Pipeline
from sktime.pipeline import Pipeline
pipeline = Pipeline(
[(f"step_{i}", comp) for i, comp in enumerate(component_instances)]
)
elif all(
self._registry.get_estimator_by_name(comp).task == "transformation"
for comp in components
):
# All transformers - use TransformerPipeline
from sktime.transformations.compose import TransformerPipeline
pipeline = TransformerPipeline(
[(f"step_{i}", comp) for i, comp in enumerate(component_instances)]
)
else:
return {
"success": False,
"error": "Unsupported pipeline composition type",
"hint": "Currently supports: transformers → forecaster, transformers → classifier/regressor, or transformer chains",
}
# Create a handle for the pipeline
pipeline_name = " → ".join(components)
handle_id = self._handle_manager.create_handle(
estimator_name=pipeline_name,
instance=pipeline,
params={"components": components, "params_list": params_list},
)
return {
"success": True,
"handle": handle_id,
"pipeline": pipeline_name,
"components": components,
"params_list": params_list,
}
except Exception as e:
import traceback
return {
"success": False,
"error": str(e),
"traceback": traceback.format_exc(),
}
[docs]
def list_datasets(self) -> list[str]:
"""List available demo datasets."""
return list(DEMO_DATASETS.keys())
[docs]
def load_data_source(self, config: dict[str, Any]) -> dict[str, Any]:
"""
Load data from any source (pandas, SQL, file, etc.).
Args:
config: Data source configuration with 'type' key
Examples:
- {"type": "pandas", "data": df, "time_column": "date", "target_column": "value"}
- {"type": "sql", "connection_string": "...", "query": "...", "time_column": "date"}
- {"type": "file", "path": "/path/to/data.csv", "time_column": "date"}
Returns:
Dictionary with:
- success: bool
- data_handle: str (handle ID for the loaded data)
- metadata: dict (information about the data)
- validation: dict (validation results)
"""
try:
from sktime_mcp.data import DataSourceRegistry
# Create adapter
adapter = DataSourceRegistry.create_adapter(config)
# Load data
data = adapter.load()
# Validate
is_valid, validation_report = adapter.validate(data)
if not is_valid:
return {
"success": False,
"error": "Data validation failed",
"validation": validation_report,
}
# Convert to sktime format
y, X = adapter.to_sktime_format(data)
# Update metadata to reflect the target and used columns
metadata = adapter.get_metadata().copy()
metadata["columns"] = [y.name if hasattr(y, "name") and y.name else "target"]
if X is not None:
metadata["exog_columns"] = list(X.columns)
# Inject column dtypes so LLMs can distinguish time index vs target
metadata["dtypes"] = {col: str(dtype) for col, dtype in data.dtypes.items()}
# Generate handle
data_handle = f"data_{uuid.uuid4().hex[:8]}"
# Store
self._data_handles[data_handle] = {
"y": y,
"X": X,
"metadata": metadata,
"validation": validation_report,
"config": config, # Store config for reference
}
# Apply auto-formatting if enabled
if getattr(self, "_auto_format_enabled", True):
try:
format_result = self.format_data_handle(
data_handle, auto_infer_freq=True, fill_missing=True, remove_duplicates=True
)
if format_result["success"]:
# Return the NEW handle (formatted)
return {
"success": True,
"data_handle": format_result["data_handle"],
"original_handle": data_handle,
"metadata": format_result["metadata"],
"validation": validation_report,
"formatted": True,
"changes_made": format_result["changes_made"],
}
except Exception as e:
logger.warning(f"Auto-formatting failed: {e}")
# Continue with unformatted data if formatting fails
_final_meta = adapter.get_metadata().copy()
_final_meta["dtypes"] = {col: str(dtype) for col, dtype in data.dtypes.items()}
return {
"success": True,
"data_handle": data_handle,
"metadata": _final_meta,
"validation": validation_report,
}
except Exception as e:
logger.exception("Error loading data source")
return {
"success": False,
"error": str(e),
"error_type": type(e).__name__,
}
[docs]
async def load_data_source_async(
self,
config: dict[str, Any],
job_id: str | None = None,
) -> dict[str, Any]:
"""
Async version of load_data_source with job tracking.
Runs data loading in the background without blocking the
MCP server. Progress is tracked via the JobManager.
Args:
config: Data source configuration
job_id: Optional job ID (created if not provided)
Returns:
Dictionary with data_handle and metadata
"""
source_type = config.get("type", "unknown")
if job_id is None:
job_id = self._job_manager.create_job(
job_type="data_loading",
estimator_handle="",
dataset_name=source_type,
total_steps=3,
)
try:
self._job_manager.update_job(job_id, status=JobStatus.RUNNING)
# Step 1: Load raw data
self._job_manager.update_job(
job_id, completed_steps=0, current_step=f"Loading data from '{source_type}'..."
)
await asyncio.sleep(0.01)
from sktime_mcp.data import DataSourceRegistry
loop = asyncio.get_event_loop()
adapter = DataSourceRegistry.create_adapter(config)
data = await loop.run_in_executor(None, adapter.load)
# Step 2: Validate
self._job_manager.update_job(
job_id, completed_steps=1, current_step="Validating data..."
)
await asyncio.sleep(0.01)
is_valid, validation_report = adapter.validate(data)
if not is_valid:
self._job_manager.update_job(
job_id, status=JobStatus.FAILED, errors=["Data validation failed"]
)
return {
"success": False,
"error": "Data validation failed",
"validation": validation_report,
}
# Step 3: Convert, store, and format
self._job_manager.update_job(
job_id, completed_steps=2, current_step="Converting to sktime format..."
)
await asyncio.sleep(0.01)
y, X = adapter.to_sktime_format(data)
metadata = adapter.get_metadata().copy()
metadata["columns"] = [y.name if hasattr(y, "name") and y.name else "target"]
if X is not None:
metadata["exog_columns"] = list(X.columns)
# Inject column dtypes so LLMs can distinguish time index vs target
metadata["dtypes"] = {col: str(dtype) for col, dtype in data.dtypes.items()}
data_handle = f"data_{uuid.uuid4().hex[:8]}"
self._data_handles[data_handle] = {
"y": y,
"X": X,
"metadata": metadata,
"validation": validation_report,
"config": config,
}
# auto-format if enabled
if getattr(self, "_auto_format_enabled", True):
try:
format_result = self.format_data_handle(
data_handle, auto_infer_freq=True, fill_missing=True, remove_duplicates=True
)
if format_result["success"]:
data_handle = format_result["data_handle"]
metadata = format_result["metadata"]
except Exception as e:
logger.warning(f"Auto-formatting failed: {e}")
result = {
"success": True,
"data_handle": data_handle,
"metadata": metadata,
"validation": validation_report,
}
# mark completed with the data_handle in the result
self._job_manager.update_job(
job_id,
status=JobStatus.COMPLETED,
completed_steps=3,
current_step="Completed",
result=result,
)
return result
except Exception as e:
logger.exception(f"Error in async data loading for job {job_id}")
self._job_manager.update_job(job_id, status=JobStatus.FAILED, errors=[str(e)])
return {
"success": False,
"error": str(e),
"job_id": job_id,
}
[docs]
def list_data_handles(self) -> dict[str, Any]:
"""
List all loaded data handles.
Returns:
Dictionary with list of data handles and their metadata
"""
handles = []
for handle_id, data_info in self._data_handles.items():
handles.append(
{
"handle": handle_id,
"metadata": data_info["metadata"],
"validation": data_info["validation"],
}
)
return {
"success": True,
"count": len(handles),
"handles": handles,
}
[docs]
def release_data_handle(self, data_handle: str) -> dict[str, Any]:
"""
Release a data handle and free memory.
Args:
data_handle: Data handle to release
Returns:
Dictionary with success status
"""
if data_handle in self._data_handles:
del self._data_handles[data_handle]
return {
"success": True,
"message": f"Data handle '{data_handle}' released",
}
else:
return {
"success": False,
"error": f"Data handle '{data_handle}' not found",
}
_executor_instance: Executor | None = None
def get_executor() -> Executor:
global _executor_instance
if _executor_instance is None:
_executor_instance = Executor()
return _executor_instance