train.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. from __future__ import annotations
  2. import argparse
  3. import sys
  4. from datetime import datetime
  5. from pathlib import Path
  6. SCRIPT_ROOT = Path(__file__).resolve().parents[1]
  7. if str(SCRIPT_ROOT) not in sys.path:
  8. sys.path.insert(0, str(SCRIPT_ROOT))
  9. import torch
  10. from guguji_rl.config import load_config, resolve_project_path, save_yaml
  11. from guguji_rl.evaluation import evaluate_forward_progress, print_forward_progress_summary
  12. def resolve_device(device_name: str) -> str:
  13. if device_name == 'auto':
  14. return 'cuda' if torch.cuda.is_available() else 'cpu'
  15. if device_name == 'cuda' and not torch.cuda.is_available():
  16. raise RuntimeError('配置要求使用 CUDA,但当前 torch 检测不到可用 GPU。')
  17. return device_name
  18. def parse_args() -> argparse.Namespace:
  19. parser = argparse.ArgumentParser(description='Train PPO policy for guguji biped robot.')
  20. parser.add_argument(
  21. '--config',
  22. default='configs/balance_ppo.yaml',
  23. help='训练配置文件路径,默认使用 balance_ppo.yaml',
  24. )
  25. parser.add_argument(
  26. '--device',
  27. default=None,
  28. help='可选覆盖配置文件中的设备设置,例如 cpu / cuda / auto',
  29. )
  30. parser.add_argument(
  31. '--total-timesteps',
  32. type=int,
  33. default=None,
  34. help='可选覆盖配置文件中的 total_timesteps',
  35. )
  36. parser.add_argument(
  37. '--init-model',
  38. default=None,
  39. help='可选指定一个已有 PPO 模型,用于继续训练或做课程学习初始化',
  40. )
  41. parser.add_argument(
  42. '--skip-auto-eval',
  43. action='store_true',
  44. help='训练完成后跳过自动前进评估',
  45. )
  46. return parser.parse_args()
  47. def resolve_input_path(path_str: str) -> Path:
  48. path = Path(path_str)
  49. if path.is_absolute() or path.exists():
  50. return path
  51. return SCRIPT_ROOT / path
  52. def main() -> int:
  53. args = parse_args()
  54. try:
  55. from stable_baselines3 import PPO
  56. from stable_baselines3.common.callbacks import CheckpointCallback
  57. from stable_baselines3.common.monitor import Monitor
  58. except ImportError:
  59. print(
  60. '缺少 stable-baselines3,请先进入 guguji_rl 目录安装依赖: '
  61. 'pip install -r requirements.txt',
  62. file=sys.stderr,
  63. )
  64. return 1
  65. from guguji_rl.envs import GazeboBipedEnv
  66. config = load_config(resolve_input_path(args.config))
  67. if args.device is not None:
  68. config['training']['device'] = args.device
  69. if args.total_timesteps is not None:
  70. config['training']['total_timesteps'] = args.total_timesteps
  71. if args.init_model is not None:
  72. config['training']['init_model_path'] = str(resolve_input_path(args.init_model))
  73. # 这里统一解析训练设备,方便你只改 YAML 就切换 CPU / GPU。
  74. config['training']['device'] = resolve_device(config['training']['device'])
  75. output_root = resolve_project_path(config, config['training']['output_root'])
  76. timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  77. run_dir = output_root / f"{config['experiment']['name']}_{timestamp}"
  78. run_dir.mkdir(parents=True, exist_ok=True)
  79. # 保存一份展开后的配置,便于后面复现实验。
  80. save_yaml(config, run_dir / 'resolved_config.yaml')
  81. env = Monitor(GazeboBipedEnv(config))
  82. checkpoint_callback = CheckpointCallback(
  83. save_freq=max(int(config['training']['checkpoint_freq']), 1),
  84. save_path=str(run_dir / 'checkpoints'),
  85. name_prefix='guguji_ppo',
  86. )
  87. policy_kwargs = {
  88. 'net_arch': list(config['training']['policy_net_arch']),
  89. }
  90. print(f"训练设备: {config['training']['device']}")
  91. print(f"输出目录: {run_dir}")
  92. # 先用 MLP + PPO 跑通训练闭环,后面你可以再逐步增大网络规模。
  93. model = PPO(
  94. policy='MlpPolicy',
  95. env=env,
  96. verbose=1,
  97. seed=int(config['training']['seed']),
  98. learning_rate=float(config['training']['learning_rate']),
  99. n_steps=int(config['training']['n_steps']),
  100. batch_size=int(config['training']['batch_size']),
  101. gamma=float(config['training']['gamma']),
  102. gae_lambda=float(config['training']['gae_lambda']),
  103. clip_range=float(config['training']['clip_range']),
  104. ent_coef=float(config['training']['ent_coef']),
  105. vf_coef=float(config['training']['vf_coef']),
  106. device=config['training']['device'],
  107. tensorboard_log=str(run_dir / 'tensorboard'),
  108. policy_kwargs=policy_kwargs,
  109. )
  110. init_model_path = config['training'].get('init_model_path')
  111. if init_model_path:
  112. resolved_init_model_path = resolve_input_path(str(init_model_path))
  113. # 这里不是直接 load 整个 PPO 对象,而是把旧模型参数灌入新模型。
  114. # 好处是:我们仍然使用当前配置文件里的超参数,只复用之前学到的策略权重。
  115. model.set_parameters(str(resolved_init_model_path), exact_match=False, device=config['training']['device'])
  116. print(f"已加载课程初始化模型: {resolved_init_model_path}")
  117. model.learn(
  118. total_timesteps=int(config['training']['total_timesteps']),
  119. callback=checkpoint_callback,
  120. progress_bar=True,
  121. )
  122. final_model_path = run_dir / 'final_model'
  123. model.save(final_model_path)
  124. env.close()
  125. print(f'训练完成,模型已保存到: {run_dir / "final_model.zip"}')
  126. evaluation_config = config['evaluation']
  127. if bool(evaluation_config.get('auto_forward_progress', True)) and not args.skip_auto_eval:
  128. try:
  129. # 每轮训练结束后自动做一次前进评估,方便你快速看 delta_x / mean_vx。
  130. summary = evaluate_forward_progress(
  131. config=config,
  132. model_path=final_model_path,
  133. episodes=int(evaluation_config['forward_progress_episodes']),
  134. max_steps=int(evaluation_config['forward_progress_max_steps']),
  135. deterministic=bool(evaluation_config['forward_progress_deterministic']),
  136. )
  137. print_forward_progress_summary(summary)
  138. except Exception as error:
  139. print(f'自动前进评估失败: {error}', file=sys.stderr)
  140. return 0
  141. if __name__ == '__main__':
  142. raise SystemExit(main())