train.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from __future__ import annotations
  2. import argparse
  3. import sys
  4. from datetime import datetime
  5. from pathlib import Path
  6. SCRIPT_ROOT = Path(__file__).resolve().parents[1]
  7. if str(SCRIPT_ROOT) not in sys.path:
  8. sys.path.insert(0, str(SCRIPT_ROOT))
  9. import torch
  10. from guguji_rl.config import load_config, resolve_project_path, save_yaml
  11. def resolve_device(device_name: str) -> str:
  12. if device_name == 'auto':
  13. return 'cuda' if torch.cuda.is_available() else 'cpu'
  14. if device_name == 'cuda' and not torch.cuda.is_available():
  15. raise RuntimeError('配置要求使用 CUDA,但当前 torch 检测不到可用 GPU。')
  16. return device_name
  17. def parse_args() -> argparse.Namespace:
  18. parser = argparse.ArgumentParser(description='Train PPO policy for guguji biped robot.')
  19. parser.add_argument(
  20. '--config',
  21. default='configs/balance_ppo.yaml',
  22. help='训练配置文件路径,默认使用 balance_ppo.yaml',
  23. )
  24. parser.add_argument(
  25. '--device',
  26. default=None,
  27. help='可选覆盖配置文件中的设备设置,例如 cpu / cuda / auto',
  28. )
  29. parser.add_argument(
  30. '--total-timesteps',
  31. type=int,
  32. default=None,
  33. help='可选覆盖配置文件中的 total_timesteps',
  34. )
  35. return parser.parse_args()
  36. def resolve_input_path(path_str: str) -> Path:
  37. path = Path(path_str)
  38. if path.is_absolute() or path.exists():
  39. return path
  40. return SCRIPT_ROOT / path
  41. def main() -> int:
  42. args = parse_args()
  43. try:
  44. from stable_baselines3 import PPO
  45. from stable_baselines3.common.callbacks import CheckpointCallback
  46. from stable_baselines3.common.monitor import Monitor
  47. except ImportError:
  48. print(
  49. '缺少 stable-baselines3,请先进入 guguji_rl 目录安装依赖: '
  50. 'pip install -r requirements.txt',
  51. file=sys.stderr,
  52. )
  53. return 1
  54. from guguji_rl.envs import GazeboBipedEnv
  55. config = load_config(resolve_input_path(args.config))
  56. if args.device is not None:
  57. config['training']['device'] = args.device
  58. if args.total_timesteps is not None:
  59. config['training']['total_timesteps'] = args.total_timesteps
  60. # 这里统一解析训练设备,方便你只改 YAML 就切换 CPU / GPU。
  61. config['training']['device'] = resolve_device(config['training']['device'])
  62. output_root = resolve_project_path(config, config['training']['output_root'])
  63. timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  64. run_dir = output_root / f"{config['experiment']['name']}_{timestamp}"
  65. run_dir.mkdir(parents=True, exist_ok=True)
  66. # 保存一份展开后的配置,便于后面复现实验。
  67. save_yaml(config, run_dir / 'resolved_config.yaml')
  68. env = Monitor(GazeboBipedEnv(config))
  69. checkpoint_callback = CheckpointCallback(
  70. save_freq=max(int(config['training']['checkpoint_freq']), 1),
  71. save_path=str(run_dir / 'checkpoints'),
  72. name_prefix='guguji_ppo',
  73. )
  74. policy_kwargs = {
  75. 'net_arch': list(config['training']['policy_net_arch']),
  76. }
  77. print(f"训练设备: {config['training']['device']}")
  78. print(f"输出目录: {run_dir}")
  79. # 先用 MLP + PPO 跑通训练闭环,后面你可以再逐步增大网络规模。
  80. model = PPO(
  81. policy='MlpPolicy',
  82. env=env,
  83. verbose=1,
  84. seed=int(config['training']['seed']),
  85. learning_rate=float(config['training']['learning_rate']),
  86. n_steps=int(config['training']['n_steps']),
  87. batch_size=int(config['training']['batch_size']),
  88. gamma=float(config['training']['gamma']),
  89. gae_lambda=float(config['training']['gae_lambda']),
  90. clip_range=float(config['training']['clip_range']),
  91. ent_coef=float(config['training']['ent_coef']),
  92. vf_coef=float(config['training']['vf_coef']),
  93. device=config['training']['device'],
  94. tensorboard_log=str(run_dir / 'tensorboard'),
  95. policy_kwargs=policy_kwargs,
  96. )
  97. model.learn(
  98. total_timesteps=int(config['training']['total_timesteps']),
  99. callback=checkpoint_callback,
  100. progress_bar=True,
  101. )
  102. model.save(run_dir / 'final_model')
  103. env.close()
  104. print(f'训练完成,模型已保存到: {run_dir / "final_model.zip"}')
  105. return 0
  106. if __name__ == '__main__':
  107. raise SystemExit(main())