Source code for econsimulacra.simulator

from __future__ import annotations

import asyncio
import json
from pathlib import Path
from typing import Any, Generic, Optional, Type, TypeVar

from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.tree import Tree
from tqdm import tqdm

from .agents import Agent
from .envs import Environment, TimeTranslator
from .events import Event
from .llm_services import LLMClient, PersonaBuilder, PromptBuilder
from .logs import Logger

ObsT = TypeVar("ObsT")


[docs] class Simulator(Generic[ObsT]): """Simulator class. The Simulator class is responsible for running the simulation. Basic usage: >>> config_path = pathlib.Path("path/to/config.json") >>> logger = DictLogger() >>> simulator = Simulator( >>> config=config_path, >>> env_class=Environment, >>> logger=logger, >>> summarizer_class=SimulationSummarizer, >>> ) >>> asyncio.run(simulator.simulate(seed=42)) """ def __init__( self, config: dict[str, Any] | Path, env_class: Type[Environment], logger: Optional[Logger] = None, summarizer_class: Optional[Type[SimulationSummarizer]] = None, ) -> None: """Initialization. Args: config (dict or Path): The configuration for the simulation See also econsimulacra.envs.base.Environment for the required and optional configuration fields. env_class (Type[Environment]): The environment class to use for the simulation. This must be a subclass of Environment. logger (Logger, optional): An optional logger instance for logging simulation data. If not provided, no logging will be performed. summarizer_class (Type[SimulationSummarizer], optional): An optional summarizer class for summarizing the simulation. This must be a subclass of SimulationSummarizer. If not provided, no summarization will be performed. """ self.config: dict[str, Any] if isinstance(config, Path): config_path: Path = config self.config = json.load(open(config_path, "r")) else: self.config = config self.config = self._convert_list_to_tuple(self.config) self.parallel_batch_size: Optional[int] = self.config["simulation"].get( "parallelBatchSize" ) self.env: Environment = env_class(config=self.config, logger=logger) self.summarizer: Optional[SimulationSummarizer] = ( summarizer_class(self.env) if summarizer_class is not None else None ) def _convert_list_to_tuple(self, obj: Any) -> Any: if isinstance(obj, dict): return { self._convert_list_to_tuple(k): self._convert_list_to_tuple(v) for k, v in obj.items() } elif isinstance(obj, list): return tuple(self._convert_list_to_tuple(item) for item in obj) elif isinstance(obj, tuple): return tuple(self._convert_list_to_tuple(item) for item in obj) else: return obj
[docs] async def simulate( self, seed: Optional[int] = None, ) -> None: """Execute the full simulation loop asynchronously. Args: - seed (int, optional): Random seed. Note: This method is the main entry point for running a simulation episode. It resets the environment, iteratively collects actions from all agents, applies the joint action to the environment at each step, and finalizes optional logging and summarization at the end of the run. The simulation proceeds as follows: 1. The environment is reset with the given random seed. 2. If a summarizer is configured, the simulation start is recorded. 3. For each simulation step: observations are generated for all agents; each agent asynchronously computes its action via ``act()``; agents are evaluated in chunks determined by ``parallel_batch_size`` so that multiple agents can act concurrently without launching all coroutines at once; list-valued actions are recursively converted into tuples for downstream consistency and hashability; the resulting joint action dictionary is passed to ``env.step()``. 4. After all steps are completed, the logger is saved if present. 5. If a summarizer is configured, the simulation end is recorded. This method is intentionally asynchronous because agent decision-making may involve I/O-bound or high-latency components such as LLM inference, API calls, or remote services. By using :func:`asyncio.gather`, the simulator can evaluate multiple agents concurrently within each batch, improving throughput while still preserving step-level synchronization: all actions for a step are collected before the environment advances. Concurrency model: agent actions are computed concurrently within each batch, but the environment transition itself is performed once per step after all actions have been collected. Therefore, this method implements synchronous environment stepping with asynchronous per-agent action generation. """ def _chunked(seq: list[int], size: int) -> list[list[int]]: return [seq[i : i + size] for i in range(0, len(seq), size)] parallel_batch_size = ( 1 if self.parallel_batch_size is None else self.parallel_batch_size ) self.env.reset(seed=seed) if self.summarizer is not None: self.summarizer.summarize_start() num_steps: int = self.config["simulation"]["numSteps"] for _ in tqdm( range(num_steps), desc="Simulating", unit="step", ncols=80, leave=True ): all_actions_dic: dict[int, dict[str, Any]] = {} async def _act_one(agent_id: int) -> tuple[int, dict[str, Any]]: agent: Agent = self.env.agent_id2agent[agent_id] obs: ObsT = self.env.get_observations(agent_id=agent_id) action_dic: dict[str, Any] = await agent.act(obs=obs) action_dic = self._convert_list_to_tuple(action_dic) return agent_id, action_dic for batch in _chunked(self.env.agent_ids, parallel_batch_size): results: list[tuple[int, dict[str, Any]]] = await asyncio.gather( *[_act_one(agent_id) for agent_id in batch] ) all_actions_dic.update(dict(results)) self.env.step(all_actions_dic=all_actions_dic) if self.env.logger is not None: self.env.logger.save() if self.summarizer is not None: self.summarizer.summarize_end()
[docs] def register_classes(self, class_list: list[Type]) -> None: self.env.register_classes(class_list)
[docs] class SimulationSummarizer: def __init__(self, env: Environment) -> None: self.env: Environment = env
[docs] def summarize_start(self) -> None: console: Console = Console() tree: Tree = Tree("Simulation Configuration") tree.add(f"[green]Seed[/green]: {self.env.seed}") if "parallelBatchSize" in self.env.config["simulation"]: tree.add( f"[green]Parallel Batch Size[/green]: {self.env.config['simulation']['parallelBatchSize']}" ) tree.add( f"[green]Number of Steps[/green]: {self.env.config['simulation']['numSteps']}" ) tree.add(f"[green]Grid Space[/green]: {self.env.grid_space.get_space_size()}") social_network_branch: Tree = tree.add("[green]Social Network[/green]") social_network_branch.add( f"[green]Follow Cap[/green]: {self.env.social_network.follow_cap}" ) recsys = self.env.social_network.rec_sys recsys_branch: Tree = social_network_branch.add( f"[green]Recommender System: {recsys.__class__.__name__}[/green]" ) recsys_branch.add(f"[green]Max Recommendations[/green]: {recsys.max_recs}") recsys_branch.add( f"[green]Randomized Recommendations[/green]: {recsys.is_randomized}" ) if recsys.is_randomized: recsys_branch.add(f"[green]Temperature[/green]: {recsys.temperature}") items_branch: Tree = tree.add("[green]Items[/green]") for item_name, item in self.env.item_name2item.items(): item_branch: Tree = items_branch.add(f"[green]{item_name}[/green]") item_branch.add( f"[green]Total Amount[/green]: {self.env.get_total_amount(item_name=item_name):.1f}" ) if item_name == self.env.cash_name: item_branch.add("[green]Cash[/green]") else: item_branch.add( f"[green]Initial Price[/green]: {item.price:.1f} {self.env.cash_name}" ) agents_branch: Tree = tree.add("[green]Agents[/green]") agents_branch.add( f"[green]Number of Households[/green]: {len(self.env.household_ids)}" ) for agent_id, agent in self.env.agent_id2agent.items(): if agent_id not in self.env.household_ids: agent_branch: Tree = agents_branch.add( f"[green]Agent {agent_id}[/green]" ) agent_branch.add(f"[green]Name[/green]: {agent.agent_name}") for item_name in self.env.item_name2item.keys(): item_amount: int = int(agent.get_item_amount(item_name=item_name)) if item_amount > 0: agent_branch.add(f"[green]{item_name}[/green]: {item_amount}") agent_branch.add( f"[green]Receive Rich Info[/green]: {agent.is_rich_info_allowed}" ) agent_branch.add( f"[green]Provide Info for All Agents[/green]: {agent.provide_info4all_agents()}" ) agent_branch.add( f"[green]Provide Info for Co-Located Agents[/green]: {agent.provide_info4co_located_agents()}" ) agent_branch.add( f"[green]Provide Info for Allowed Agents[/green]: {agent.provide_info4allowed_agents()}" ) events: list[Event] = self.env.event_manager.events if len(events) > 0: event_branch: Tree = tree.add("[green]Events[/green]") for event in self.env.event_manager.events: event_branch_: Tree = event_branch.add( f"[green]{event.__class__.__name__}[/green]" ) if event.trigger.at is not None: event_branch_.add(f"[green]Trigger at[/green]: {event.trigger.at}") if event.trigger.every is not None: event_branch_.add( f"[green]Trigger every[/green]: {event.trigger.every}" ) if len(event.trigger.logs) > 0: event_branch_.add( f"[green]Trigger with logs[/green]: {[log.__name__ for log in event.trigger.logs]}" ) if event.trigger.between is not None: event_branch_.add( f"[green]Trigger between[/green]: {event.trigger.between}" ) if event.trigger.probability is not None: event_branch_.add( f"[green]Trigger probability[/green]: {event.trigger.probability}" ) if len(self.env.service_dic) > 0: service_branch: Tree = tree.add("[green]Environment Services[/green]") for _, service in self.env.service_dic.items(): if isinstance(service, LLMClient): llm_client_branch: Tree = service_branch.add( f"[green]LLM Client: {service.__class__.__name__}[/green]" ) llm_client_branch.add( f"[green]Model Name[/green]: {service.model_name}" ) max_concurrent_generations: Optional[int] = getattr( service, "max_concurrent_generations", None ) if max_concurrent_generations is not None: llm_client_branch.add( f"[green]Max Concurrent Generations[/green]: {max_concurrent_generations}" ) elif isinstance(service, PersonaBuilder): persona_builder_branch: Tree = service_branch.add( f"[green]Persona Builder: {service.__class__.__name__}[/green]" ) max_magnitude: Optional[int] = getattr( service, "max_magnitude", None ) if max_magnitude is not None: persona_builder_branch.add( f"[green]Max Magnitude[/green]: {max_magnitude}" ) attributes: Optional[list[str]] = getattr( service, "attributes", None ) if attributes is not None: persona_builder_branch.add( f"[green]Attributes[/green]: {attributes}" ) elif isinstance(service, PromptBuilder): service_branch.add( f"[green]Prompt Builder: {service.__class__.__name__}[/green]" ) elif isinstance(service, TimeTranslator): time_translator_branch: Tree = service_branch.add( f"[green]Time Translator: {service.__class__.__name__}[/green]" ) start_datetime: Optional[Any] = getattr( service, "start_datetime", None ) end_datetime: Optional[Any] = getattr(service, "end_datetime", None) time_delta: Optional[Any] = getattr(service, "time_delta", None) if start_datetime is not None: time_translator_branch.add( f"[green]Start Datetime[/green]: {str(start_datetime)}" ) if end_datetime is not None: time_translator_branch.add( f"[green]End Datetime[/green]: {str(end_datetime)}" ) if time_delta is not None: time_translator_branch.add( f"[green]Time Delta[/green]: {str(time_delta)}" ) else: service_branch.add( f"[green]Service: ({service.__class__.__name__})[/green]" ) print() console.print(Panel(tree, title="[bold green]Summary[/bold green]")) print()
[docs] def summarize_end(self) -> None: console: Console = Console() table: Table = Table(title="Invalid Actions", show_lines=True) table.add_column("Action Type", justify="center", style="cyan", no_wrap=True) table.add_column("Number", justify="center", style="magenta") for action_type, count in self.env.invalid_action_dic.items(): table.add_row(action_type, str(count)) print() console.print(table) print() example_agent_id: int = self.env.household_ids[0] agent: Agent = self.env.agent_id2agent[example_agent_id] last_prompt: Optional[str] = getattr(agent, "last_prompt", None) if last_prompt is not None: if last_prompt != "": console.print( Panel( last_prompt, title=f"[bold green]Agent {example_agent_id}'s Last Prompt[/bold green]", ) )