| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- from __future__ import annotations
- import argparse
- import sys
- from pathlib import Path
- SCRIPT_ROOT = Path(__file__).resolve().parents[1]
- if str(SCRIPT_ROOT) not in sys.path:
- sys.path.insert(0, str(SCRIPT_ROOT))
- from guguji_rl.config import load_config
- def parse_args() -> argparse.Namespace:
- parser = argparse.ArgumentParser(description='Evaluate trained PPO policy in Gazebo.')
- parser.add_argument('--config', default='configs/walk_ppo.yaml', help='配置文件路径')
- parser.add_argument('--model', required=True, help='训练好的模型路径,例如 outputs/.../final_model.zip')
- parser.add_argument('--episodes', type=int, default=None, help='可选覆盖配置中的评估轮数')
- return parser.parse_args()
- def resolve_input_path(path_str: str) -> Path:
- path = Path(path_str)
- if path.is_absolute() or path.exists():
- return path
- return SCRIPT_ROOT / path
- def main() -> int:
- args = parse_args()
- try:
- from stable_baselines3 import PPO
- except ImportError:
- print('缺少 stable-baselines3,请先安装 requirements.txt', file=sys.stderr)
- return 1
- from guguji_rl.envs import GazeboBipedEnv
- config = load_config(resolve_input_path(args.config))
- if args.episodes is not None:
- config['evaluation']['episodes'] = args.episodes
- env = GazeboBipedEnv(config)
- model = PPO.load(resolve_input_path(args.model))
- for episode_index in range(int(config['evaluation']['episodes'])):
- observation, info = env.reset()
- done = False
- truncated = False
- total_reward = 0.0
- while not done and not truncated:
- action, _ = model.predict(
- observation,
- deterministic=bool(config['evaluation']['deterministic']),
- )
- observation, reward, done, truncated, info = env.step(action)
- total_reward += reward
- print(f'episode={episode_index} total_reward={total_reward:.3f}')
- env.close()
- return 0
- if __name__ == '__main__':
- raise SystemExit(main())
|