"""
Code generation tool for sktime MCP.
Generates Python code to recreate estimators and pipelines.
"""
import keyword
from typing import Any
from sktime_mcp.registry.interface import get_registry
from sktime_mcp.runtime.executor import DEMO_DATASETS
from sktime_mcp.runtime.handles import get_handle_manager
def _format_value(value: Any) -> str:
"""Format a parameter value for Python code generation."""
if isinstance(value, str):
return f'"{value}"'
elif isinstance(value, (list, tuple)):
if isinstance(value, tuple):
items = ", ".join(_format_value(v) for v in value)
return f"({items})" if len(value) != 1 else f"({items},)"
else:
items = ", ".join(_format_value(v) for v in value)
return f"[{items}]"
elif isinstance(value, dict):
items = ", ".join(f'"{k}": {_format_value(v)}' for k, v in value.items())
return f"{{{items}}}"
elif isinstance(value, bool):
return str(value)
elif value is None:
return "None"
elif isinstance(value, (int, float)):
return str(value)
else:
# For complex objects, try to represent as str
return repr(value)
def _get_estimator_module(estimator_name: str) -> str | None:
"""Get the module path for an estimator."""
registry = get_registry()
node = registry.get_estimator_by_name(estimator_name)
if node and node.class_ref:
return node.class_ref.__module__
return None
def _is_valid_var_name(var_name: str) -> bool:
"""Return True when var_name is a valid non-keyword Python identifier."""
return isinstance(var_name, str) and var_name.isidentifier() and not keyword.iskeyword(var_name)
def _generate_single_estimator_code(
estimator_name: str, params: dict[str, Any], var_name: str = "model"
) -> dict[str, Any]:
"""Generate Python code for a single estimator."""
module = _get_estimator_module(estimator_name)
if not module:
return {"success": False, "error": f"Could not find module for estimator: {estimator_name}"}
# Build import statement
imports = [f"from {module} import {estimator_name}"]
# Build instantiation code
if params:
param_strs = []
for key, value in params.items():
param_strs.append(f"{key}={_format_value(value)}")
params_str = ", ".join(param_strs)
instantiation = f"{var_name} = {estimator_name}({params_str})"
else:
instantiation = f"{var_name} = {estimator_name}()"
# Combine into full code
code_lines = imports + ["", instantiation]
code = "\n".join(code_lines)
return {
"success": True,
"code": code,
"imports": imports,
"instantiation": instantiation,
}
def _generate_pipeline_code(
components: list[str], params_list: list[dict[str, Any]], var_name: str = "pipeline"
) -> dict[str, Any]:
"""Generate Python code for a pipeline."""
registry = get_registry()
# Collect all imports needed
imports = set()
# Get task types for components
component_tasks = []
for comp_name in components:
node = registry.get_estimator_by_name(comp_name)
if not node:
return {"success": False, "error": f"Unknown estimator in pipeline: {comp_name}"}
component_tasks.append(node.task)
module = node.class_ref.__module__
imports.add(f"from {module} import {comp_name}")
# Determine pipeline type
all_transformers_except_last = all(task == "transformation" for task in component_tasks[:-1])
final_task = component_tasks[-1]
# Build component instantiations
component_code_lines = []
for i, (comp_name, params) in enumerate(zip(components, params_list, strict=False)):
var = f"step_{i}"
if params:
param_strs = []
for key, value in params.items():
param_strs.append(f"{key}={_format_value(value)}")
params_str = ", ".join(param_strs)
component_code_lines.append(f"{var} = {comp_name}({params_str})")
else:
component_code_lines.append(f"{var} = {comp_name}()")
# Build pipeline instantiation based on composition type
if len(components) == 1:
# Single component, no pipeline needed
pipeline_code = f"{var_name} = step_0"
elif all_transformers_except_last and final_task == "forecasting":
# Use TransformedTargetForecaster
imports.add("from sktime.forecasting.compose import TransformedTargetForecaster")
if len(components) == 2:
pipeline_code = f"""{var_name} = TransformedTargetForecaster([
("transformer", step_0),
("forecaster", step_1),
])"""
else:
# Multiple transformers - chain them
imports.add("from sktime.transformations.compose import TransformerPipeline")
transformer_steps = ", ".join(
f'("step_{i}", step_{i})' for i in range(len(components) - 1)
)
pipeline_code = f"""transformer_chain = TransformerPipeline([
{transformer_steps}
])
{var_name} = TransformedTargetForecaster([
("transformers", transformer_chain),
("forecaster", step_{len(components) - 1}),
])"""
elif all_transformers_except_last and final_task in ("classification", "regression"):
# Use sklearn-style Pipeline
imports.add("from sktime.pipeline import Pipeline")
steps = ", ".join(f'("step_{i}", step_{i})' for i in range(len(components)))
pipeline_code = f"""{var_name} = Pipeline([
{steps}
])"""
elif all(task == "transformation" for task in component_tasks):
# All transformers - use TransformerPipeline
imports.add("from sktime.transformations.compose import TransformerPipeline")
steps = ", ".join(f'("step_{i}", step_{i})' for i in range(len(components)))
pipeline_code = f"""{var_name} = TransformerPipeline([
{steps}
])"""
else:
return {"success": False, "error": "Unsupported pipeline composition type"}
# Combine all code
code_lines = sorted(imports) + [""] + component_code_lines + [""] + [pipeline_code]
code = "\n".join(code_lines)
return {
"success": True,
"code": code,
"imports": sorted(imports),
"pipeline_type": (
"TransformedTargetForecaster"
if "TransformedTargetForecaster" in str(imports)
else "Pipeline"
),
}