| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- from __future__ import annotations
- import argparse
- import sys
- from pathlib import Path
- SCRIPT_ROOT = Path(__file__).resolve().parents[1]
- if str(SCRIPT_ROOT) not in sys.path:
- sys.path.insert(0, str(SCRIPT_ROOT))
- from guguji_rl.config import load_config
- def parse_args() -> argparse.Namespace:
- parser = argparse.ArgumentParser(description='Run a trained PPO policy online in Gazebo / ROS 2.')
- parser.add_argument('--config', default='configs/walk_ppo.yaml', help='配置文件路径')
- parser.add_argument('--model', required=True, help='训练好的模型路径,例如 outputs/.../final_model.zip')
- parser.add_argument(
- '--deterministic',
- action='store_true',
- help='是否强制使用确定性动作,适合策略回放和部署调试',
- )
- parser.add_argument(
- '--max-episodes',
- type=int,
- default=0,
- help='最多运行多少个 episode,0 表示一直循环运行',
- )
- return parser.parse_args()
- def resolve_input_path(path_str: str) -> Path:
- path = Path(path_str)
- if path.is_absolute() or path.exists():
- return path
- return SCRIPT_ROOT / path
- def main() -> int:
- args = parse_args()
- try:
- from stable_baselines3 import PPO
- except ImportError:
- print('缺少 stable-baselines3,请先安装 requirements.txt', file=sys.stderr)
- return 1
- from guguji_rl.envs import GazeboBipedEnv
- config = load_config(resolve_input_path(args.config))
- env = GazeboBipedEnv(config)
- model = PPO.load(resolve_input_path(args.model))
- episode_index = 0
- try:
- while args.max_episodes <= 0 or episode_index < args.max_episodes:
- # 这里直接复用训练环境,方便策略回放和后续在线部署共用同一套观测/动作逻辑。
- observation, _ = env.reset()
- terminated = False
- truncated = False
- total_reward = 0.0
- while not terminated and not truncated:
- action, _ = model.predict(
- observation,
- deterministic=args.deterministic or bool(config['evaluation']['deterministic']),
- )
- observation, reward, terminated, truncated, _ = env.step(action)
- total_reward += reward
- print(
- f'episode={episode_index} '
- f'total_reward={total_reward:.3f} '
- f'terminated={terminated} truncated={truncated}'
- )
- episode_index += 1
- except KeyboardInterrupt:
- print('收到 Ctrl+C,停止在线策略运行。')
- finally:
- env.close()
- return 0
- if __name__ == '__main__':
- raise SystemExit(main())
|