config.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from __future__ import annotations
  2. import copy
  3. from pathlib import Path
  4. from typing import Any
  5. import yaml
  6. MODULE_ROOT = Path(__file__).resolve().parents[1]
  7. PROJECT_ROOT = MODULE_ROOT.parent
  8. DEFAULT_CONFIG: dict[str, Any] = {
  9. 'experiment': {
  10. 'name': 'guguji_rl_experiment',
  11. },
  12. 'robot': {
  13. 'model_name': 'guguji',
  14. 'urdf_path': 'guguji_ros2_ws/src/guguji_ros2/urdf/guguji.urdf',
  15. 'joint_names': [],
  16. 'command_topic_prefix': '/guguji/command',
  17. 'reference_gait': {
  18. 'enabled': False,
  19. 'period': 0.9,
  20. 'stance_ratio': 0.62,
  21. 'hip_pitch_amplitude': 0.0,
  22. 'hip_pitch_bias': 0.0,
  23. 'knee_pitch_amplitude': 0.0,
  24. 'knee_pitch_bias': 0.0,
  25. 'swing_knee_scale': 1.0,
  26. 'ankle_pitch_amplitude': 0.0,
  27. 'ankle_pitch_bias': 0.0,
  28. 'push_off_ankle_scale': 0.0,
  29. },
  30. },
  31. 'ros': {
  32. 'joint_state_topic': '/joint_states',
  33. 'tf_topic': '/tf',
  34. 'clock_topic': '/clock',
  35. 'world_control_service': '/world/default/control',
  36. },
  37. 'sim': {
  38. 'world_name': 'default',
  39. 'step_mode': 'realtime',
  40. 'control_dt': 0.05,
  41. 'service_step_iterations': 50,
  42. 'reset_settle_seconds': 1.0,
  43. 'action_publish_delay': 0.01,
  44. 'post_step_wait_seconds': 0.01,
  45. 'spawn_x': 0.0,
  46. 'spawn_y': 0.0,
  47. 'spawn_z': 0.35,
  48. 'spawn_roll': 0.0,
  49. 'spawn_pitch': 0.0,
  50. 'spawn_yaw': 0.0,
  51. },
  52. 'task': {
  53. 'target_forward_velocity': 0.0,
  54. 'target_base_height': None,
  55. 'max_roll_rad': 0.6,
  56. 'max_pitch_rad': 0.6,
  57. 'min_base_height': 0.12,
  58. },
  59. 'rewards': {
  60. 'alive_bonus': 1.0,
  61. 'velocity_tracking_scale': 1.0,
  62. 'velocity_tracking_sigma': 0.3,
  63. 'forward_progress_scale': 0.0,
  64. 'hip_alternation_scale': 0.0,
  65. 'hip_target_separation': 0.3,
  66. 'hip_antiphase_sigma': 0.2,
  67. 'knee_flexion_scale': 0.0,
  68. 'knee_target': 0.2,
  69. 'knee_flexion_sigma': 0.12,
  70. 'upright_scale': 1.0,
  71. 'height_scale': 1.0,
  72. 'action_rate_penalty_scale': 0.02,
  73. 'joint_limit_penalty_scale': 0.02,
  74. 'lateral_velocity_penalty_scale': 0.05,
  75. 'backward_velocity_penalty_scale': 0.0,
  76. 'stall_penalty_scale': 0.0,
  77. 'stall_velocity_threshold': 0.0,
  78. 'fall_penalty': -10.0,
  79. },
  80. 'training': {
  81. 'algorithm': 'ppo',
  82. 'total_timesteps': 200000,
  83. 'max_episode_steps': 400,
  84. 'seed': 42,
  85. 'device': 'auto',
  86. 'init_model_path': None,
  87. 'initial_log_std': None,
  88. 'curriculum_stages': [],
  89. 'learning_rate': 3e-4,
  90. 'n_steps': 1024,
  91. 'batch_size': 256,
  92. 'gamma': 0.99,
  93. 'gae_lambda': 0.95,
  94. 'clip_range': 0.2,
  95. 'ent_coef': 0.0,
  96. 'vf_coef': 0.5,
  97. 'policy_net_arch': [256, 256],
  98. 'checkpoint_freq': 20000,
  99. 'output_root': 'guguji_rl/outputs',
  100. },
  101. 'evaluation': {
  102. 'episodes': 3,
  103. 'deterministic': True,
  104. 'auto_forward_progress': True,
  105. 'forward_progress_episodes': 2,
  106. 'forward_progress_max_steps': 500,
  107. 'forward_progress_deterministic': True,
  108. },
  109. }
  110. def _deep_update(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
  111. for key, value in override.items():
  112. if isinstance(value, dict) and isinstance(base.get(key), dict):
  113. _deep_update(base[key], value)
  114. else:
  115. base[key] = value
  116. return base
  117. def load_config(config_path: str | Path) -> dict[str, Any]:
  118. config_path = Path(config_path).resolve()
  119. with config_path.open('r', encoding='utf-8') as file:
  120. user_config = yaml.safe_load(file) or {}
  121. config = _deep_update(copy.deepcopy(DEFAULT_CONFIG), user_config)
  122. config['meta'] = {
  123. 'config_path': str(config_path),
  124. # 不再假设配置文件一定放在仓库里的固定层级,
  125. # 这样外部临时配置、导出的实验配置也可以直接拿来训练。
  126. 'project_root': str(PROJECT_ROOT),
  127. 'rl_root': str(MODULE_ROOT),
  128. }
  129. return config
  130. def resolve_project_path(config: dict[str, Any], relative_path: str | Path) -> Path:
  131. relative_path = Path(relative_path)
  132. if relative_path.is_absolute():
  133. return relative_path
  134. return Path(config['meta']['project_root']) / relative_path
  135. def save_yaml(data: dict[str, Any], output_path: str | Path) -> None:
  136. output_path = Path(output_path)
  137. output_path.parent.mkdir(parents=True, exist_ok=True)
  138. with output_path.open('w', encoding='utf-8') as file:
  139. yaml.safe_dump(data, file, sort_keys=False, allow_unicode=True)