train.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. from __future__ import annotations
  2. import argparse
  3. import sys
  4. from datetime import datetime
  5. from pathlib import Path
  6. SCRIPT_ROOT = Path(__file__).resolve().parents[1]
  7. if str(SCRIPT_ROOT) not in sys.path:
  8. sys.path.insert(0, str(SCRIPT_ROOT))
  9. import torch
  10. from guguji_rl.config import load_config, resolve_project_path, save_yaml
  11. def resolve_device(device_name: str) -> str:
  12. if device_name == 'auto':
  13. return 'cuda' if torch.cuda.is_available() else 'cpu'
  14. if device_name == 'cuda' and not torch.cuda.is_available():
  15. raise RuntimeError('配置要求使用 CUDA,但当前 torch 检测不到可用 GPU。')
  16. return device_name
  17. def parse_args() -> argparse.Namespace:
  18. parser = argparse.ArgumentParser(description='Train PPO policy for guguji biped robot.')
  19. parser.add_argument(
  20. '--config',
  21. default='configs/balance_ppo.yaml',
  22. help='训练配置文件路径,默认使用 balance_ppo.yaml',
  23. )
  24. parser.add_argument(
  25. '--device',
  26. default=None,
  27. help='可选覆盖配置文件中的设备设置,例如 cpu / cuda / auto',
  28. )
  29. parser.add_argument(
  30. '--total-timesteps',
  31. type=int,
  32. default=None,
  33. help='可选覆盖配置文件中的 total_timesteps',
  34. )
  35. parser.add_argument(
  36. '--init-model',
  37. default=None,
  38. help='可选指定一个已有 PPO 模型,用于继续训练或做课程学习初始化',
  39. )
  40. return parser.parse_args()
  41. def resolve_input_path(path_str: str) -> Path:
  42. path = Path(path_str)
  43. if path.is_absolute() or path.exists():
  44. return path
  45. return SCRIPT_ROOT / path
  46. def main() -> int:
  47. args = parse_args()
  48. try:
  49. from stable_baselines3 import PPO
  50. from stable_baselines3.common.callbacks import CheckpointCallback
  51. from stable_baselines3.common.monitor import Monitor
  52. except ImportError:
  53. print(
  54. '缺少 stable-baselines3,请先进入 guguji_rl 目录安装依赖: '
  55. 'pip install -r requirements.txt',
  56. file=sys.stderr,
  57. )
  58. return 1
  59. from guguji_rl.envs import GazeboBipedEnv
  60. config = load_config(resolve_input_path(args.config))
  61. if args.device is not None:
  62. config['training']['device'] = args.device
  63. if args.total_timesteps is not None:
  64. config['training']['total_timesteps'] = args.total_timesteps
  65. if args.init_model is not None:
  66. config['training']['init_model_path'] = str(resolve_input_path(args.init_model))
  67. # 这里统一解析训练设备,方便你只改 YAML 就切换 CPU / GPU。
  68. config['training']['device'] = resolve_device(config['training']['device'])
  69. output_root = resolve_project_path(config, config['training']['output_root'])
  70. timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  71. run_dir = output_root / f"{config['experiment']['name']}_{timestamp}"
  72. run_dir.mkdir(parents=True, exist_ok=True)
  73. # 保存一份展开后的配置,便于后面复现实验。
  74. save_yaml(config, run_dir / 'resolved_config.yaml')
  75. env = Monitor(GazeboBipedEnv(config))
  76. checkpoint_callback = CheckpointCallback(
  77. save_freq=max(int(config['training']['checkpoint_freq']), 1),
  78. save_path=str(run_dir / 'checkpoints'),
  79. name_prefix='guguji_ppo',
  80. )
  81. policy_kwargs = {
  82. 'net_arch': list(config['training']['policy_net_arch']),
  83. }
  84. print(f"训练设备: {config['training']['device']}")
  85. print(f"输出目录: {run_dir}")
  86. # 先用 MLP + PPO 跑通训练闭环,后面你可以再逐步增大网络规模。
  87. model = PPO(
  88. policy='MlpPolicy',
  89. env=env,
  90. verbose=1,
  91. seed=int(config['training']['seed']),
  92. learning_rate=float(config['training']['learning_rate']),
  93. n_steps=int(config['training']['n_steps']),
  94. batch_size=int(config['training']['batch_size']),
  95. gamma=float(config['training']['gamma']),
  96. gae_lambda=float(config['training']['gae_lambda']),
  97. clip_range=float(config['training']['clip_range']),
  98. ent_coef=float(config['training']['ent_coef']),
  99. vf_coef=float(config['training']['vf_coef']),
  100. device=config['training']['device'],
  101. tensorboard_log=str(run_dir / 'tensorboard'),
  102. policy_kwargs=policy_kwargs,
  103. )
  104. init_model_path = config['training'].get('init_model_path')
  105. if init_model_path:
  106. resolved_init_model_path = resolve_input_path(str(init_model_path))
  107. # 这里不是直接 load 整个 PPO 对象,而是把旧模型参数灌入新模型。
  108. # 好处是:我们仍然使用当前配置文件里的超参数,只复用之前学到的策略权重。
  109. model.set_parameters(str(resolved_init_model_path), exact_match=False, device=config['training']['device'])
  110. print(f"已加载课程初始化模型: {resolved_init_model_path}")
  111. model.learn(
  112. total_timesteps=int(config['training']['total_timesteps']),
  113. callback=checkpoint_callback,
  114. progress_bar=True,
  115. )
  116. model.save(run_dir / 'final_model')
  117. env.close()
  118. print(f'训练完成,模型已保存到: {run_dir / "final_model.zip"}')
  119. return 0
  120. if __name__ == '__main__':
  121. raise SystemExit(main())