Source code for sktime_mcp.tools.instantiate
"""
instantiate_estimator tool for sktime MCP.
Creates executable estimator instances and pipelines.
"""
from typing import Any
from sktime_mcp.registry.interface import get_registry
from sktime_mcp.runtime.executor import get_executor
from sktime_mcp.runtime.handles import get_handle_manager
def _is_safe_value(value: Any) -> bool:
"""Recursively check if a value is a safe, serializable type."""
if value is None or isinstance(value, (str, int, float, bool)):
return True
if isinstance(value, (list, tuple)):
return all(_is_safe_value(item) for item in value)
if isinstance(value, dict):
return all(isinstance(k, str) and _is_safe_value(v) for k, v in value.items())
return False
def _validate_params(
params: Any,
estimator_name: str | None = None,
) -> dict[str, Any]:
"""
Validate params for type safety and optionally check keys.
Args:
params: The params argument to validate.
estimator_name: Optional estimator name for key validation.
Returns:
Dictionary with valid (bool), error (str), and warnings (list).
"""
warnings = []
# None means "use defaults", always valid
if params is None:
return {"valid": True, "warnings": warnings}
# params must be a dict
if not isinstance(params, dict):
return {
"valid": False,
"error": (
f"'params' must be a dictionary, got {type(params).__name__}. "
f'Example: {{"order": [1, 1, 1], "suppress_warnings": true}}'
),
"warnings": warnings,
}
# reject unsafe value types like callables, classes, modules
for key, value in params.items():
if not isinstance(key, str):
return {
"valid": False,
"error": (
f"Parameter keys must be strings, got {type(key).__name__} for key: {key!r}"
),
"warnings": warnings,
}
if not _is_safe_value(value):
return {
"valid": False,
"error": (
f"Unsupported type for parameter '{key}': "
f"{type(value).__name__}. "
f"Only primitive types (str, int, float, bool, list, "
f"tuple, dict, None) are allowed."
),
"warnings": warnings,
}
# check if keys match known hyperparameters (warn, don't error)
if estimator_name and params:
registry = get_registry()
node = registry.get_estimator_by_name(estimator_name)
if node is not None and node.hyperparameters:
known_keys = set(node.hyperparameters.keys())
provided_keys = set(params.keys())
unknown_keys = provided_keys - known_keys
if unknown_keys:
warnings.append(
f"Unknown parameter(s) for {estimator_name}: "
f"{sorted(unknown_keys)}. "
f"Known parameters: {sorted(known_keys)}"
)
return {"valid": True, "warnings": warnings}
[docs]
def instantiate_estimator_tool(
estimator: str,
params: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Create an estimator instance and return a handle.
Args:
estimator: Name of the estimator class (e.g., "ARIMA")
params: Optional hyperparameters for the estimator
Returns:
Dictionary with:
- success: bool
- handle: Unique handle ID string
- estimator: Name of the estimator
- params: Parameters used
- warnings: List of any validation warnings
Example:
>>> instantiate_estimator_tool("ARIMA", {"order": [1, 1, 1]})
{
"success": True,
"handle": "est_abc123def456",
"estimator": "ARIMA",
"params": {"order": [1, 1, 1]}
}
"""
# validate params before passing to executor
validation = _validate_params(params, estimator_name=estimator)
if not validation["valid"]:
return {
"success": False,
"error": validation["error"],
}
executor = get_executor()
result = executor.instantiate(estimator, params)
# attach any key-mismatch warnings to the response
if validation["warnings"] and result.get("success"):
result["warnings"] = validation["warnings"]
return result
[docs]
def instantiate_pipeline_tool(
components: list[str],
params_list: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
"""
Create a pipeline from a list of components and return a handle.
Args:
components: List of estimator names in pipeline order
params_list: Optional list of parameter dicts for each component
Returns:
Dictionary with:
- success: bool
- handle: Unique handle ID string
- pipeline: Name of the pipeline
- components: List of component names
- params_list: Parameters used for each component
- warnings: List of any validation warnings
Example:
>>> instantiate_pipeline_tool(
... ["ConditionalDeseasonalizer", "Detrender", "ARIMA"],
... [{}, {}, {"order": [1, 1, 1]}]
... )
{
"success": True,
"handle": "est_xyz789abc123",
"pipeline": "ConditionalDeseasonalizer → Detrender → ARIMA",
"components": ["ConditionalDeseasonalizer", "Detrender", "ARIMA"],
"params_list": [{}, {}, {"order": [1, 1, 1]}]
}
"""
all_warnings = []
# validate each params dict in params_list
if params_list is not None:
if not isinstance(params_list, list):
return {
"success": False,
"error": (
f"'params_list' must be a list of dictionaries, "
f"got {type(params_list).__name__}"
),
}
for i, params in enumerate(params_list):
comp_name = components[i] if i < len(components) else None
validation = _validate_params(params, estimator_name=comp_name)
if not validation["valid"]:
return {
"success": False,
"error": (
f"Validation failed for component {i} "
f"({comp_name or 'unknown'}): {validation['error']}"
),
}
all_warnings.extend(validation["warnings"])
executor = get_executor()
result = executor.instantiate_pipeline(components, params_list)
if all_warnings and result.get("success"):
result["warnings"] = all_warnings
return result
[docs]
def release_handle_tool(handle: str) -> dict[str, Any]:
"""
Release an estimator handle and free resources.
Args:
handle: The handle ID to release
Returns:
Dictionary with success status
"""
handle_manager = get_handle_manager()
released = handle_manager.release_handle(handle)
return {
"success": released,
"handle": handle,
"message": "Handle released" if released else "Handle not found",
}
[docs]
def list_handles_tool() -> dict[str, Any]:
"""
List all active estimator handles.
Returns:
Dictionary with list of active handles and their info
"""
handle_manager = get_handle_manager()
handles = handle_manager.list_handles()
return {
"success": True,
"handles": handles,
"count": len(handles),
}
[docs]
def load_model_tool(path: str) -> dict[str, Any]:
"""
Load a saved model from a local path or MLflow URI and register its handle.
Args:
path: Local directory path or MLflow URI to the saved model.
Examples:
- "/tmp/my_arima_model"
- "runs:/<run_id>/model"
- "mlflow-artifacts:/<run_id>/artifacts/model"
- "models:/<model_name>/<version>"
Returns:
Dictionary with success status and the new handle.
"""
try:
from sktime.utils.mlflow_sktime import load_model
except ImportError:
return {
"success": False,
"error": (
"The 'mlflow' package is required to load saved models. "
"Please install it with: pip install sktime-mcp[mlflow]"
),
}
try:
instance = load_model(path)
estimator_name = type(instance).__name__
handle_manager = get_handle_manager()
handle_id = handle_manager.create_handle(
estimator_name=estimator_name,
instance=instance,
params={},
metadata={"source": "loaded", "path": path},
)
handle_manager.mark_fitted(handle_id)
return {
"success": True,
"handle": handle_id,
"estimator": estimator_name,
"path": path,
"message": f"Successfully loaded {estimator_name}",
}
except Exception as exc:
return {
"success": False,
"error": f"Failed to load model: {str(exc)}",
"path": path,
}