evaluate.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from __future__ import annotations
  2. import argparse
  3. import sys
  4. from pathlib import Path
  5. SCRIPT_ROOT = Path(__file__).resolve().parents[1]
  6. if str(SCRIPT_ROOT) not in sys.path:
  7. sys.path.insert(0, str(SCRIPT_ROOT))
  8. from guguji_rl.config import load_config
  9. def parse_args() -> argparse.Namespace:
  10. parser = argparse.ArgumentParser(description='Evaluate trained PPO policy in Gazebo.')
  11. parser.add_argument('--config', default='configs/walk_ppo.yaml', help='配置文件路径')
  12. parser.add_argument('--model', required=True, help='训练好的模型路径,例如 outputs/.../final_model.zip')
  13. parser.add_argument('--episodes', type=int, default=None, help='可选覆盖配置中的评估轮数')
  14. return parser.parse_args()
  15. def resolve_input_path(path_str: str) -> Path:
  16. path = Path(path_str)
  17. if path.is_absolute() or path.exists():
  18. return path
  19. return SCRIPT_ROOT / path
  20. def main() -> int:
  21. args = parse_args()
  22. try:
  23. from stable_baselines3 import PPO
  24. except ImportError:
  25. print('缺少 stable-baselines3,请先安装 requirements.txt', file=sys.stderr)
  26. return 1
  27. from guguji_rl.envs import GazeboBipedEnv
  28. config = load_config(resolve_input_path(args.config))
  29. if args.episodes is not None:
  30. config['evaluation']['episodes'] = args.episodes
  31. env = GazeboBipedEnv(config)
  32. model = PPO.load(resolve_input_path(args.model))
  33. for episode_index in range(int(config['evaluation']['episodes'])):
  34. observation, info = env.reset()
  35. done = False
  36. truncated = False
  37. total_reward = 0.0
  38. while not done and not truncated:
  39. action, _ = model.predict(
  40. observation,
  41. deterministic=bool(config['evaluation']['deterministic']),
  42. )
  43. observation, reward, done, truncated, info = env.step(action)
  44. total_reward += reward
  45. print(f'episode={episode_index} total_reward={total_reward:.3f}')
  46. env.close()
  47. return 0
  48. if __name__ == '__main__':
  49. raise SystemExit(main())