train.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. from __future__ import annotations
  2. import argparse
  3. import copy
  4. import sys
  5. from datetime import datetime
  6. from pathlib import Path
  7. from typing import Any
  8. SCRIPT_ROOT = Path(__file__).resolve().parents[1]
  9. if str(SCRIPT_ROOT) not in sys.path:
  10. sys.path.insert(0, str(SCRIPT_ROOT))
  11. import torch
  12. from guguji_rl.config import load_config, resolve_project_path, save_yaml
  13. from guguji_rl.evaluation import evaluate_forward_progress, print_forward_progress_summary
  14. def resolve_device(device_name: str) -> str:
  15. if device_name == 'auto':
  16. return 'cuda' if torch.cuda.is_available() else 'cpu'
  17. if device_name == 'cuda' and not torch.cuda.is_available():
  18. raise RuntimeError('配置要求使用 CUDA,但当前 torch 检测不到可用 GPU。')
  19. return device_name
  20. def parse_args() -> argparse.Namespace:
  21. parser = argparse.ArgumentParser(description='Train PPO policy for guguji biped robot.')
  22. parser.add_argument(
  23. '--config',
  24. default='configs/balance_ppo.yaml',
  25. help='训练配置文件路径,默认使用 balance_ppo.yaml',
  26. )
  27. parser.add_argument(
  28. '--device',
  29. default=None,
  30. help='可选覆盖配置文件中的设备设置,例如 cpu / cuda / auto',
  31. )
  32. parser.add_argument(
  33. '--total-timesteps',
  34. type=int,
  35. default=None,
  36. help='可选覆盖配置文件中的 total_timesteps',
  37. )
  38. parser.add_argument(
  39. '--init-model',
  40. default=None,
  41. help='可选指定一个已有 PPO 模型,用于继续训练或做课程学习初始化',
  42. )
  43. parser.add_argument(
  44. '--skip-auto-eval',
  45. action='store_true',
  46. help='训练完成后跳过自动前进评估',
  47. )
  48. return parser.parse_args()
  49. def resolve_input_path(path_str: str) -> Path:
  50. path = Path(path_str)
  51. if path.is_absolute() or path.exists():
  52. return path
  53. return SCRIPT_ROOT / path
  54. def maybe_override_policy_log_std(model: object, initial_log_std: float | None) -> None:
  55. """可选地缩小 PPO 的初始探索方差,适合课程学习后的精修阶段。"""
  56. if initial_log_std is None:
  57. return
  58. policy = getattr(model, 'policy', None)
  59. if policy is None or not hasattr(policy, 'log_std'):
  60. raise RuntimeError('当前策略对象不支持直接设置 log_std。')
  61. # 这里直接把每个动作维度的对数标准差统一改成同一个值,
  62. # 方便在“已有步态基础上继续训练”时降低探索噪声,减少无意义的乱踢。
  63. policy.log_std.data.fill_(float(initial_log_std))
  64. print(f'已将策略初始 log_std 设为: {float(initial_log_std):.3f}')
  65. def sanitize_stage_name(stage_name: str) -> str:
  66. sanitized = ''.join(
  67. character if character.isalnum() or character in {'-', '_'} else '_'
  68. for character in stage_name.strip()
  69. )
  70. return sanitized.strip('_') or 'stage'
  71. def build_curriculum_stage_configs(config: dict[str, Any]) -> list[tuple[str | None, dict[str, Any]]]:
  72. """把课程学习阶段展开成一组可直接训练的独立配置。"""
  73. raw_stages = config['training'].get('curriculum_stages') or []
  74. if not raw_stages:
  75. single_stage_config = copy.deepcopy(config)
  76. single_stage_config['training'].pop('curriculum_stages', None)
  77. return [(None, single_stage_config)]
  78. stage_configs: list[tuple[str | None, dict[str, Any]]] = []
  79. for stage_index, raw_stage in enumerate(raw_stages, start=1):
  80. if not isinstance(raw_stage, dict):
  81. raise RuntimeError('training.curriculum_stages 里的每个阶段都必须是字典。')
  82. stage_config = copy.deepcopy(config)
  83. stage_config['training'].pop('curriculum_stages', None)
  84. raw_name = str(raw_stage.get('name') or f'stage_{stage_index}')
  85. stage_name = f'{stage_index:02d}_{sanitize_stage_name(raw_name)}'
  86. # 课程阶段目前主要控制“目标前进速度 + 本阶段训练步数 + 探索方差”。
  87. # 这样 walking 阶段就能从慢到快逐段抬升,而不用一次把目标速度顶太高。
  88. if 'target_forward_velocity' in raw_stage:
  89. stage_config['task']['target_forward_velocity'] = float(raw_stage['target_forward_velocity'])
  90. if 'total_timesteps' in raw_stage:
  91. stage_config['training']['total_timesteps'] = int(raw_stage['total_timesteps'])
  92. if 'initial_log_std' in raw_stage:
  93. stage_config['training']['initial_log_std'] = float(raw_stage['initial_log_std'])
  94. stage_config['experiment']['name'] = f"{config['experiment']['name']}_{stage_name}"
  95. stage_configs.append((stage_name, stage_config))
  96. return stage_configs
  97. def main() -> int:
  98. args = parse_args()
  99. try:
  100. from stable_baselines3 import PPO
  101. from stable_baselines3.common.callbacks import CheckpointCallback
  102. from stable_baselines3.common.monitor import Monitor
  103. except ImportError:
  104. print(
  105. '缺少 stable-baselines3,请先进入 guguji_rl 目录安装依赖: '
  106. 'pip install -r requirements.txt',
  107. file=sys.stderr,
  108. )
  109. return 1
  110. from guguji_rl.envs import GazeboBipedEnv
  111. config = load_config(resolve_input_path(args.config))
  112. if args.device is not None:
  113. config['training']['device'] = args.device
  114. if args.total_timesteps is not None:
  115. config['training']['total_timesteps'] = args.total_timesteps
  116. if args.init_model is not None:
  117. config['training']['init_model_path'] = str(resolve_input_path(args.init_model))
  118. # 这里统一解析训练设备,方便你只改 YAML 就切换 CPU / GPU。
  119. config['training']['device'] = resolve_device(config['training']['device'])
  120. output_root = resolve_project_path(config, config['training']['output_root'])
  121. timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  122. run_dir = output_root / f"{config['experiment']['name']}_{timestamp}"
  123. run_dir.mkdir(parents=True, exist_ok=True)
  124. # 保存一份展开后的配置,便于后面复现实验。
  125. save_yaml(config, run_dir / 'resolved_config.yaml')
  126. stage_configs = build_curriculum_stage_configs(config)
  127. print(f"训练设备: {config['training']['device']}")
  128. print(f"输出目录: {run_dir}")
  129. model = None
  130. final_model_path = run_dir / 'final_model'
  131. final_stage_config = config
  132. for stage_index, (stage_name, stage_config) in enumerate(stage_configs, start=1):
  133. stage_dir = run_dir if stage_name is None else run_dir / stage_name
  134. stage_dir.mkdir(parents=True, exist_ok=True)
  135. # 每个阶段都单独保存一份实际生效的配置,后面你回看实验会很方便。
  136. save_yaml(stage_config, stage_dir / 'resolved_config.yaml')
  137. if stage_name is not None:
  138. print(
  139. f'开始课程阶段 {stage_index}/{len(stage_configs)}: {stage_name} '
  140. f'(target_forward_velocity={stage_config["task"]["target_forward_velocity"]:.2f}, '
  141. f'timesteps={int(stage_config["training"]["total_timesteps"])})'
  142. )
  143. env = Monitor(GazeboBipedEnv(stage_config))
  144. checkpoint_callback = CheckpointCallback(
  145. save_freq=max(int(stage_config['training']['checkpoint_freq']), 1),
  146. save_path=str(stage_dir / 'checkpoints'),
  147. name_prefix='guguji_ppo',
  148. )
  149. try:
  150. if model is None:
  151. policy_kwargs = {
  152. 'net_arch': list(stage_config['training']['policy_net_arch']),
  153. }
  154. # 先用 MLP + PPO 跑通训练闭环,后面你可以再逐步增大网络规模。
  155. model = PPO(
  156. policy='MlpPolicy',
  157. env=env,
  158. verbose=1,
  159. seed=int(stage_config['training']['seed']),
  160. learning_rate=float(stage_config['training']['learning_rate']),
  161. n_steps=int(stage_config['training']['n_steps']),
  162. batch_size=int(stage_config['training']['batch_size']),
  163. gamma=float(stage_config['training']['gamma']),
  164. gae_lambda=float(stage_config['training']['gae_lambda']),
  165. clip_range=float(stage_config['training']['clip_range']),
  166. ent_coef=float(stage_config['training']['ent_coef']),
  167. vf_coef=float(stage_config['training']['vf_coef']),
  168. device=stage_config['training']['device'],
  169. tensorboard_log=str(run_dir / 'tensorboard'),
  170. policy_kwargs=policy_kwargs,
  171. )
  172. init_model_path = stage_config['training'].get('init_model_path')
  173. if init_model_path:
  174. resolved_init_model_path = resolve_input_path(str(init_model_path))
  175. # 这里不是直接 load 整个 PPO 对象,而是把旧模型参数灌入新模型。
  176. # 好处是:我们仍然使用当前配置文件里的超参数,只复用之前学到的策略权重。
  177. model.set_parameters(
  178. str(resolved_init_model_path),
  179. exact_match=False,
  180. device=stage_config['training']['device'],
  181. )
  182. print(f"已加载课程初始化模型: {resolved_init_model_path}")
  183. else:
  184. model.set_env(env)
  185. maybe_override_policy_log_std(model, stage_config['training'].get('initial_log_std'))
  186. model.learn(
  187. total_timesteps=int(stage_config['training']['total_timesteps']),
  188. callback=checkpoint_callback,
  189. progress_bar=True,
  190. reset_num_timesteps=(stage_index == 1),
  191. )
  192. stage_model_path = stage_dir / 'final_model'
  193. model.save(stage_model_path)
  194. final_model_path = stage_model_path
  195. final_stage_config = stage_config
  196. if stage_name is not None:
  197. print(f'课程阶段完成,模型已保存到: {stage_model_path.with_suffix(".zip")}')
  198. finally:
  199. env.close()
  200. if final_model_path != run_dir / 'final_model' and model is not None:
  201. # 在课程学习模式下,额外在 run 根目录保存一份最终模型,方便统一引用。
  202. model.save(run_dir / 'final_model')
  203. final_model_path = run_dir / 'final_model'
  204. print(f'训练完成,模型已保存到: {run_dir / "final_model.zip"}')
  205. evaluation_config = final_stage_config['evaluation']
  206. if bool(evaluation_config.get('auto_forward_progress', True)) and not args.skip_auto_eval:
  207. try:
  208. # 每轮训练结束后自动做一次前进评估,方便你快速看 delta_x / mean_vx。
  209. summary = evaluate_forward_progress(
  210. config=final_stage_config,
  211. model_path=final_model_path,
  212. episodes=int(evaluation_config['forward_progress_episodes']),
  213. max_steps=int(evaluation_config['forward_progress_max_steps']),
  214. deterministic=bool(evaluation_config['forward_progress_deterministic']),
  215. )
  216. print_forward_progress_summary(summary)
  217. except Exception as error:
  218. print(f'自动前进评估失败: {error}', file=sys.stderr)
  219. return 0
  220. if __name__ == '__main__':
  221. raise SystemExit(main())