Source code for sktime_mcp.tools.list_estimators

"""
list_estimators tool for sktime MCP.
Discovers estimators by task type, capability tags, and/or name search.
"""

import difflib
from typing import Any

from sktime_mcp.registry.interface import get_registry


[docs] def list_estimators_tool( task: str | None = None, tags: dict[str, Any] | None = None, query: str | None = None, limit: int = 50, offset: int = 0, ) -> dict[str, Any]: """ Discover sktime estimators by task type, capability tags, and/or name search. All filters are combined: query narrows by name/docstring, then task and tags are applied on top. Args: task: Filter by task type. Options: "forecasting", "classification", "regression", "transformation", "clustering", "detection" tags: Filter by capability tags. Example: {"capability:pred_int": True} query: Search by name or description (substring, case-insensitive). limit: Maximum number of results to return (default: 50) offset: Number of results to skip for pagination (default: 0). Returns: Dictionary with: - success: bool - estimators: List of estimator summaries - count: Number of results returned in this page - total: Total matching estimators (before limit/offset) - offset: Current offset (for pagination) - limit: Current limit (for pagination) - has_more: True if more results exist beyond this page """ registry = get_registry() try: # Validate task if task is not None: valid_tasks = registry.get_available_tasks() if task not in valid_tasks: suggestions = difflib.get_close_matches(task, valid_tasks, n=3, cutoff=0.6) return { "success": False, "error": f"Invalid task: '{task}'. Valid options: {valid_tasks}." + (f" Did you mean: {suggestions}?" if suggestions else ""), } # Validate tag keys if tags is not None: valid_tag_keys = {t["tag"] for t in registry.get_available_tags()} invalid_keys = [k for k in tags if k not in valid_tag_keys] if invalid_keys: suggestions = { k: difflib.get_close_matches(k, valid_tag_keys, n=1, cutoff=0.6) for k in invalid_keys } return { "success": False, "error": f"Invalid tag key(s): {invalid_keys}. Use get_available_tags to see valid keys.", "suggestions": {k: v[0] if v else None for k, v in suggestions.items()}, } if query: estimators = registry.search_estimators(query) if task: estimators = [e for e in estimators if e.task == task] if tags: estimators = registry._filter_by_tags(estimators, tags) else: estimators = registry.get_all_estimators(task=task, tags=tags) total = len(estimators) if offset < 0: return { "success": False, "error": "offset must be a non-negative integer.", } if limit < 1: return { "success": False, "error": "limit must be a positive integer.", } page = estimators[offset : offset + limit] results = [est.to_summary() for est in page] return { "success": True, "estimators": results, "count": len(results), "total": total, "offset": offset, "limit": limit, "has_more": (offset + limit) < total, "task_filter": task, "tag_filter": tags, "query": query, } except Exception as e: return { "success": False, "error": str(e), }
def get_available_tasks() -> dict[str, Any]: """Get list of available task types.""" registry = get_registry() return { "success": True, "tasks": registry.get_available_tasks(), }
[docs] def get_available_tags() -> dict[str, Any]: """Get list of all available capability tags.""" registry = get_registry() return { "success": True, "tags": registry.get_available_tags(), }