run_policy.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from __future__ import annotations
  2. import argparse
  3. import sys
  4. from pathlib import Path
  5. SCRIPT_ROOT = Path(__file__).resolve().parents[1]
  6. if str(SCRIPT_ROOT) not in sys.path:
  7. sys.path.insert(0, str(SCRIPT_ROOT))
  8. from guguji_rl.config import load_config
  9. def parse_args() -> argparse.Namespace:
  10. parser = argparse.ArgumentParser(description='Run a trained PPO policy online in Gazebo / ROS 2.')
  11. parser.add_argument('--config', default='configs/walk_ppo.yaml', help='配置文件路径')
  12. parser.add_argument('--model', required=True, help='训练好的模型路径,例如 outputs/.../final_model.zip')
  13. parser.add_argument(
  14. '--deterministic',
  15. action='store_true',
  16. help='是否强制使用确定性动作,适合策略回放和部署调试',
  17. )
  18. parser.add_argument(
  19. '--max-episodes',
  20. type=int,
  21. default=0,
  22. help='最多运行多少个 episode,0 表示一直循环运行',
  23. )
  24. return parser.parse_args()
  25. def resolve_input_path(path_str: str) -> Path:
  26. path = Path(path_str)
  27. if path.is_absolute() or path.exists():
  28. return path
  29. return SCRIPT_ROOT / path
  30. def main() -> int:
  31. args = parse_args()
  32. try:
  33. from stable_baselines3 import PPO
  34. except ImportError:
  35. print('缺少 stable-baselines3,请先安装 requirements.txt', file=sys.stderr)
  36. return 1
  37. from guguji_rl.envs import GazeboBipedEnv
  38. config = load_config(resolve_input_path(args.config))
  39. env = GazeboBipedEnv(config)
  40. model = PPO.load(resolve_input_path(args.model))
  41. episode_index = 0
  42. try:
  43. while args.max_episodes <= 0 or episode_index < args.max_episodes:
  44. # 这里直接复用训练环境,方便策略回放和后续在线部署共用同一套观测/动作逻辑。
  45. observation, _ = env.reset()
  46. terminated = False
  47. truncated = False
  48. total_reward = 0.0
  49. while not terminated and not truncated:
  50. action, _ = model.predict(
  51. observation,
  52. deterministic=args.deterministic or bool(config['evaluation']['deterministic']),
  53. )
  54. observation, reward, terminated, truncated, _ = env.step(action)
  55. total_reward += reward
  56. print(
  57. f'episode={episode_index} '
  58. f'total_reward={total_reward:.3f} '
  59. f'terminated={terminated} truncated={truncated}'
  60. )
  61. episode_index += 1
  62. except KeyboardInterrupt:
  63. print('收到 Ctrl+C,停止在线策略运行。')
  64. finally:
  65. env.close()
  66. return 0
  67. if __name__ == '__main__':
  68. raise SystemExit(main())