| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- from __future__ import annotations
- import copy
- from pathlib import Path
- from typing import Any
- import yaml
- MODULE_ROOT = Path(__file__).resolve().parents[1]
- PROJECT_ROOT = MODULE_ROOT.parent
- DEFAULT_CONFIG: dict[str, Any] = {
- 'experiment': {
- 'name': 'guguji_rl_experiment',
- },
- 'robot': {
- 'model_name': 'guguji',
- 'urdf_path': 'guguji_ros2_ws/src/guguji_ros2/urdf/guguji.urdf',
- 'joint_names': [],
- 'command_topic_prefix': '/guguji/command',
- 'reference_gait': {
- 'enabled': False,
- 'period': 0.9,
- 'stance_ratio': 0.62,
- 'hip_pitch_amplitude': 0.0,
- 'hip_pitch_bias': 0.0,
- 'knee_pitch_amplitude': 0.0,
- 'knee_pitch_bias': 0.0,
- 'swing_knee_scale': 1.0,
- 'ankle_pitch_amplitude': 0.0,
- 'ankle_pitch_bias': 0.0,
- 'push_off_ankle_scale': 0.0,
- },
- },
- 'ros': {
- 'joint_state_topic': '/joint_states',
- 'tf_topic': '/tf',
- 'clock_topic': '/clock',
- 'world_control_service': '/world/default/control',
- },
- 'sim': {
- 'world_name': 'default',
- 'step_mode': 'realtime',
- 'control_dt': 0.05,
- 'service_step_iterations': 50,
- 'reset_settle_seconds': 1.0,
- 'action_publish_delay': 0.01,
- 'post_step_wait_seconds': 0.01,
- 'spawn_x': 0.0,
- 'spawn_y': 0.0,
- 'spawn_z': 0.35,
- 'spawn_roll': 0.0,
- 'spawn_pitch': 0.0,
- 'spawn_yaw': 0.0,
- },
- 'task': {
- 'target_forward_velocity': 0.0,
- 'target_base_height': None,
- 'max_roll_rad': 0.6,
- 'max_pitch_rad': 0.6,
- 'min_base_height': 0.12,
- },
- 'rewards': {
- 'alive_bonus': 1.0,
- 'velocity_tracking_scale': 1.0,
- 'velocity_tracking_sigma': 0.3,
- 'forward_progress_scale': 0.0,
- 'hip_alternation_scale': 0.0,
- 'hip_target_separation': 0.3,
- 'hip_antiphase_sigma': 0.2,
- 'knee_flexion_scale': 0.0,
- 'knee_target': 0.2,
- 'knee_flexion_sigma': 0.12,
- 'upright_scale': 1.0,
- 'height_scale': 1.0,
- 'action_rate_penalty_scale': 0.02,
- 'joint_limit_penalty_scale': 0.02,
- 'lateral_velocity_penalty_scale': 0.05,
- 'backward_velocity_penalty_scale': 0.0,
- 'stall_penalty_scale': 0.0,
- 'stall_velocity_threshold': 0.0,
- 'fall_penalty': -10.0,
- },
- 'training': {
- 'algorithm': 'ppo',
- 'total_timesteps': 200000,
- 'max_episode_steps': 400,
- 'seed': 42,
- 'device': 'auto',
- 'init_model_path': None,
- 'initial_log_std': None,
- 'curriculum_stages': [],
- 'learning_rate': 3e-4,
- 'n_steps': 1024,
- 'batch_size': 256,
- 'gamma': 0.99,
- 'gae_lambda': 0.95,
- 'clip_range': 0.2,
- 'ent_coef': 0.0,
- 'vf_coef': 0.5,
- 'policy_net_arch': [256, 256],
- 'checkpoint_freq': 20000,
- 'output_root': 'guguji_rl/outputs',
- },
- 'evaluation': {
- 'episodes': 3,
- 'deterministic': True,
- 'auto_forward_progress': True,
- 'forward_progress_episodes': 2,
- 'forward_progress_max_steps': 500,
- 'forward_progress_deterministic': True,
- },
- }
- def _deep_update(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
- for key, value in override.items():
- if isinstance(value, dict) and isinstance(base.get(key), dict):
- _deep_update(base[key], value)
- else:
- base[key] = value
- return base
- def load_config(config_path: str | Path) -> dict[str, Any]:
- config_path = Path(config_path).resolve()
- with config_path.open('r', encoding='utf-8') as file:
- user_config = yaml.safe_load(file) or {}
- config = _deep_update(copy.deepcopy(DEFAULT_CONFIG), user_config)
- config['meta'] = {
- 'config_path': str(config_path),
- # 不再假设配置文件一定放在仓库里的固定层级,
- # 这样外部临时配置、导出的实验配置也可以直接拿来训练。
- 'project_root': str(PROJECT_ROOT),
- 'rl_root': str(MODULE_ROOT),
- }
- return config
- def resolve_project_path(config: dict[str, Any], relative_path: str | Path) -> Path:
- relative_path = Path(relative_path)
- if relative_path.is_absolute():
- return relative_path
- return Path(config['meta']['project_root']) / relative_path
- def save_yaml(data: dict[str, Any], output_path: str | Path) -> None:
- output_path = Path(output_path)
- output_path.parent.mkdir(parents=True, exist_ok=True)
- with output_path.open('w', encoding='utf-8') as file:
- yaml.safe_dump(data, file, sort_keys=False, allow_unicode=True)
|