check_env.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from __future__ import annotations
  2. import argparse
  3. import sys
  4. from pathlib import Path
  5. import numpy as np
  6. SCRIPT_ROOT = Path(__file__).resolve().parents[1]
  7. if str(SCRIPT_ROOT) not in sys.path:
  8. sys.path.insert(0, str(SCRIPT_ROOT))
  9. from guguji_rl.config import load_config
  10. def parse_args() -> argparse.Namespace:
  11. parser = argparse.ArgumentParser(description='Sanity check for the Gazebo RL environment.')
  12. parser.add_argument('--config', default='configs/balance_ppo.yaml', help='配置文件路径')
  13. parser.add_argument('--steps', type=int, default=5, help='检查时执行多少个 step')
  14. return parser.parse_args()
  15. def resolve_input_path(path_str: str) -> Path:
  16. path = Path(path_str)
  17. if path.is_absolute() or path.exists():
  18. return path
  19. return SCRIPT_ROOT / path
  20. def main() -> int:
  21. args = parse_args()
  22. try:
  23. from guguji_rl.envs import GazeboBipedEnv
  24. except ImportError:
  25. print('缺少训练环境依赖,请先进入 guguji_rl 目录安装 requirements.txt', file=sys.stderr)
  26. return 1
  27. config = load_config(resolve_input_path(args.config))
  28. env = GazeboBipedEnv(config)
  29. observation, info = env.reset()
  30. print('reset ok')
  31. print(f'observation shape: {observation.shape}')
  32. print(f'target_base_height: {info["target_base_height"]:.4f}')
  33. for step_index in range(args.steps):
  34. action = np.zeros(env.action_space.shape[0], dtype=np.float32)
  35. observation, reward, terminated, truncated, info = env.step(action)
  36. reward_terms = info['reward_terms']
  37. print(
  38. f'step={step_index} reward={reward:.3f} '
  39. f'vx={reward_terms["forward_velocity"]:.3f} '
  40. f'base_z={reward_terms["base_height"]:.3f}'
  41. )
  42. if terminated or truncated:
  43. print(f'episode finished early: terminated={terminated} truncated={truncated}')
  44. break
  45. env.close()
  46. return 0
  47. if __name__ == '__main__':
  48. raise SystemExit(main())