Source code for sktime_mcp.tools.fit_predict
"""
fit_predict tool for sktime MCP.
Executes complete forecasting workflows.
"""
import logging
from typing import Any
from sktime_mcp.runtime.executor import get_executor
logger = logging.getLogger(__name__)
def _validate_horizon(horizon: int) -> dict[str, Any]:
"""
Validate the horizon parameter.
Checks if the horizon parameter is strictly integer or not
Checks if the horizon parameter is greater than 0 or not
"""
warnings = []
if not isinstance(horizon, int):
return {
"valid": False,
"error": (
f"'horizon' must be an integer, got {type(horizon).__name__}. "
f'Example: {{"horizon": 12}}'
),
"warnings": warnings,
}
if horizon <= 0:
return {
"valid": False,
"error": (
f"'horizon' must be greater than 0, got {horizon}. Example: {{\"horizon\": 12}}"
),
"warnings": warnings,
}
return {"valid": True, "warnings": warnings}
[docs]
def fit_predict_tool(
estimator_handle: str,
dataset: str,
horizon: int = 12,
data_handle: str | None = None,
) -> dict[str, Any]:
"""
Execute a complete fit-predict workflow.
Args:
estimator_handle: Handle from instantiate_estimator
dataset: Name of demo dataset (e.g., "airline", "sunspots")
horizon: Forecast horizon (default: 12)
data_handle: Optional handle from load_data_source for custom data
Returns:
Dictionary with:
- success: bool
- predictions: Forecast values
- horizon: Number of steps predicted
Example:
>>> fit_predict_tool("est_abc123", "airline", horizon=12)
{
"success": True,
"predictions": {1: 450.2, 2: 460.5, ...},
"horizon": 12
}
"""
validation = _validate_horizon(horizon)
if not validation["valid"]:
return {
"success": False,
"error": validation["error"],
}
if dataset and data_handle:
return {
"success": False,
"error": "Provide either 'dataset' or 'data_handle', not both.",
}
if data_handle is None and (not dataset or not str(dataset).strip()):
return {
"success": False,
"error": (
"Either 'dataset' (e.g. 'airline') or "
"'data_handle' (from load_data_source) is required."
),
}
executor = get_executor()
return executor.fit_predict(estimator_handle, dataset, horizon, data_handle=data_handle)
def predict_tool(
estimator_handle: str,
horizon: int = 12,
) -> dict[str, Any]:
"""
Generate predictions from a fitted estimator.
Args:
estimator_handle: Handle of a fitted estimator
horizon: Forecast horizon
Returns:
Dictionary with predictions
"""
validation = _validate_horizon(horizon)
if not validation["valid"]:
return {
"success": False,
"error": validation["error"],
}
executor = get_executor()
fh = list(range(1, horizon + 1))
return executor.predict(estimator_handle, fh=fh)
def list_datasets_tool() -> dict[str, Any]:
"""
List available demo datasets.
Returns:
Dictionary with list of dataset names
"""
executor = get_executor()
return {
"success": True,
"datasets": executor.list_datasets(),
}
[docs]
def fit_predict_async_tool(
estimator_handle: str,
dataset: str | None = None,
data_handle: str | None = None,
horizon: int = 12,
) -> dict[str, Any]:
"""
Execute a fit-predict workflow in the background (non-blocking).
Schedules the training as a background job and returns immediately
with a job_id. Use check_job_status to monitor progress.
Accepts either a demo dataset name or a data handle from
load_data_source -- exactly one must be provided.
Args:
estimator_handle: Handle from instantiate_estimator
dataset: Name of demo dataset (e.g., "airline", "sunspots")
data_handle: Handle from load_data_source (e.g., "data_abc123")
horizon: Forecast horizon (default: 12)
Returns:
Dictionary with:
- success: bool
- job_id: Job ID for tracking progress
- message: Information about the job
Example:
>>> fit_predict_async_tool("est_abc123", dataset="airline", horizon=12)
>>> fit_predict_async_tool("est_abc123", data_handle="data_xyz", horizon=5)
"""
validation = _validate_horizon(horizon)
if not validation["valid"]:
return {
"success": False,
"error": validation["error"],
}
if dataset and data_handle:
return {
"success": False,
"error": "Provide either 'dataset' or 'data_handle', not both.",
}
if not dataset and not data_handle:
return {
"success": False,
"error": (
"Either 'dataset' (e.g. 'airline') or "
"'data_handle' (from load_data_source) is required."
),
}
import asyncio
from sktime_mcp.runtime.jobs import get_job_manager
executor = get_executor()
job_manager = get_job_manager()
# Get estimator info
try:
handle_info = executor._handle_manager.get_info(estimator_handle)
estimator_name = handle_info.estimator_name
except Exception as e:
logger.warning(f"Could not get estimator name: {e}")
estimator_name = "Unknown"
source_name = dataset if dataset else data_handle
# Create job
job_id = job_manager.create_job(
job_type="fit_predict",
estimator_handle=estimator_handle,
estimator_name=estimator_name,
dataset_name=source_name,
horizon=horizon,
total_steps=3,
)
# Schedule the async coroutine on the event loop
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
coro = executor.fit_predict_async(
estimator_handle,
dataset=dataset,
data_handle=data_handle,
horizon=horizon,
job_id=job_id,
)
asyncio.run_coroutine_threadsafe(coro, loop)
return {
"success": True,
"job_id": job_id,
"message": (
f"Training job started for {estimator_name} on {source_name}. "
f"Use check_job_status('{job_id}') to monitor progress."
),
"estimator": estimator_name,
"data_source": source_name,
"horizon": horizon,
}