rewards.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. import numpy as np
  4. from .math_utils import quaternion_xyzw_to_euler
  5. from .ros2_interface import RobotStateSnapshot
  6. from .urdf_utils import JointLimit
  7. @dataclass
  8. class RewardContext:
  9. current: RobotStateSnapshot
  10. previous: RobotStateSnapshot
  11. action: np.ndarray
  12. previous_action: np.ndarray
  13. joint_limits: list[JointLimit]
  14. target_forward_velocity: float
  15. target_base_height: float
  16. control_dt: float
  17. terminated: bool
  18. class BipedRewardCalculator:
  19. def __init__(self, reward_config: dict) -> None:
  20. self.reward_config = reward_config
  21. def compute(self, context: RewardContext) -> tuple[float, dict[str, float]]:
  22. dt = max(context.current.sim_time - context.previous.sim_time, context.control_dt, 1e-3)
  23. delta_position = context.current.base_position - context.previous.base_position
  24. forward_velocity = float(delta_position[0] / dt)
  25. lateral_velocity = float(delta_position[1] / dt)
  26. roll, pitch, _ = quaternion_xyzw_to_euler(context.current.base_quaternion)
  27. upright_reward = np.exp(-4.0 * (roll * roll + pitch * pitch))
  28. height_error = context.current.base_position[2] - context.target_base_height
  29. height_reward = np.exp(-8.0 * height_error * height_error)
  30. sigma = max(float(self.reward_config['velocity_tracking_sigma']), 1e-6)
  31. velocity_error = forward_velocity - context.target_forward_velocity
  32. velocity_tracking = np.exp(-(velocity_error * velocity_error) / (2.0 * sigma * sigma))
  33. # 仅奖励“真正向前”的速度,避免策略通过左右晃动或后退来钻空子。
  34. positive_forward_velocity = max(forward_velocity, 0.0)
  35. backward_velocity = max(-forward_velocity, 0.0)
  36. # 如果前向速度长期太低,就给一个停滞惩罚,逼着策略去迈开步子。
  37. stall_velocity_threshold = float(self.reward_config.get('stall_velocity_threshold', 0.0))
  38. stall_penalty = max(stall_velocity_threshold - positive_forward_velocity, 0.0)
  39. action_rate_penalty = float(np.mean(np.square(context.action - context.previous_action)))
  40. joint_limit_penalty = 0.0
  41. for index, joint_limit in enumerate(context.joint_limits):
  42. normalized = abs((context.current.joint_position[index] - joint_limit.midpoint) / joint_limit.half_range)
  43. joint_limit_penalty += max(normalized - 0.9, 0.0)
  44. joint_limit_penalty /= max(len(context.joint_limits), 1)
  45. joint_name_to_index = {
  46. joint_limit.name: index
  47. for index, joint_limit in enumerate(context.joint_limits)
  48. }
  49. hip_alternation_reward = 0.0
  50. knee_flexion_reward = 0.0
  51. left_hip_index = joint_name_to_index.get('left_hip_pitch_joint')
  52. right_hip_index = joint_name_to_index.get('right_hip_pitch_joint')
  53. if left_hip_index is not None and right_hip_index is not None:
  54. left_hip = float(context.current.joint_position[left_hip_index])
  55. right_hip = float(context.current.joint_position[right_hip_index])
  56. hip_separation = abs(left_hip - right_hip)
  57. hip_target_separation = max(float(self.reward_config.get('hip_target_separation', 0.3)), 1e-6)
  58. hip_antiphase_sigma = max(float(self.reward_config.get('hip_antiphase_sigma', 0.2)), 1e-6)
  59. # 鼓励左右髋关节朝相反方向摆动,并且摆动幅度不要太小。
  60. hip_separation_reward = min(hip_separation / hip_target_separation, 1.0)
  61. hip_antiphase_reward = np.exp(-((left_hip + right_hip) ** 2) / (2.0 * hip_antiphase_sigma ** 2))
  62. hip_alternation_reward = float(hip_separation_reward * hip_antiphase_reward)
  63. left_knee_index = joint_name_to_index.get('left_knee_pitch_joint')
  64. right_knee_index = joint_name_to_index.get('right_knee_pitch_joint')
  65. if left_knee_index is not None and right_knee_index is not None:
  66. left_knee = abs(float(context.current.joint_position[left_knee_index]))
  67. right_knee = abs(float(context.current.joint_position[right_knee_index]))
  68. knee_target = float(self.reward_config.get('knee_target', 0.2))
  69. knee_sigma = max(float(self.reward_config.get('knee_flexion_sigma', 0.12)), 1e-6)
  70. average_knee_flexion = 0.5 * (left_knee + right_knee)
  71. # 鼓励膝关节保持适度弯曲,帮助策略学会抬腿而不是整条腿僵直拖行。
  72. knee_flexion_reward = float(
  73. np.exp(-((average_knee_flexion - knee_target) ** 2) / (2.0 * knee_sigma ** 2))
  74. )
  75. reward_terms = {
  76. 'alive_bonus': float(self.reward_config['alive_bonus']),
  77. 'velocity_tracking': float(self.reward_config['velocity_tracking_scale']) * float(velocity_tracking),
  78. 'forward_progress': float(self.reward_config.get('forward_progress_scale', 0.0)) * positive_forward_velocity,
  79. 'hip_alternation': float(self.reward_config.get('hip_alternation_scale', 0.0)) * hip_alternation_reward,
  80. 'knee_flexion': float(self.reward_config.get('knee_flexion_scale', 0.0)) * knee_flexion_reward,
  81. 'upright': float(self.reward_config['upright_scale']) * float(upright_reward),
  82. 'height': float(self.reward_config['height_scale']) * float(height_reward),
  83. 'action_rate_penalty': -float(self.reward_config['action_rate_penalty_scale']) * action_rate_penalty,
  84. 'joint_limit_penalty': -float(self.reward_config['joint_limit_penalty_scale']) * joint_limit_penalty,
  85. 'lateral_velocity_penalty': -float(self.reward_config['lateral_velocity_penalty_scale']) * abs(lateral_velocity),
  86. 'backward_velocity_penalty': -float(self.reward_config.get('backward_velocity_penalty_scale', 0.0)) * backward_velocity,
  87. 'stall_penalty': -float(self.reward_config.get('stall_penalty_scale', 0.0)) * stall_penalty,
  88. }
  89. total_reward = sum(reward_terms.values())
  90. if context.terminated:
  91. reward_terms['fall_penalty'] = float(self.reward_config['fall_penalty'])
  92. total_reward += reward_terms['fall_penalty']
  93. else:
  94. reward_terms['fall_penalty'] = 0.0
  95. reward_terms['forward_velocity'] = forward_velocity
  96. reward_terms['roll'] = float(roll)
  97. reward_terms['pitch'] = float(pitch)
  98. reward_terms['base_height'] = float(context.current.base_position[2])
  99. reward_terms['total_reward'] = float(total_reward)
  100. return float(total_reward), reward_terms