Nav2 Navigation Stack - rolling  main
ROS 2 Navigation Stack
path_align_critic.cpp
1 // Copyright (c) 2023 Open Navigation LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "nav2_mppi_controller/critics/path_align_critic.hpp"
16 
17 namespace mppi::critics
18 {
19 
21 {
22  auto getParentParam = parameters_handler_->getParamGetter(parent_name_);
23  getParentParam(enforce_path_inversion_, "enforce_path_inversion", false);
24 
25  auto getParam = parameters_handler_->getParamGetter(name_);
26  getParam(power_, "cost_power", 1);
27  getParam(weight_, "cost_weight", 10.0f);
28 
29  getParam(max_path_occupancy_ratio_, "max_path_occupancy_ratio", 0.07f);
30  getParam(offset_from_furthest_, "offset_from_furthest", 20);
31  getParam(trajectory_point_step_, "trajectory_point_step", 4);
32  getParam(
33  threshold_to_consider_,
34  "threshold_to_consider", 0.5f);
35  getParam(use_path_orientations_, "use_path_orientations", false);
36 
37  RCLCPP_INFO(
38  logger_,
39  "ReferenceTrajectoryCritic instantiated with %d power and %f weight",
40  power_, weight_);
41 }
42 
44 {
45  if (!enabled_) {
46  return;
47  }
48 
49  geometry_msgs::msg::Pose goal = utils::getCriticGoal(data, enforce_path_inversion_);
50 
51  // Don't apply close to goal, let the goal critics take over
52  if (utils::withinPositionGoalTolerance(
53  threshold_to_consider_, data.state.pose.pose, goal))
54  {
55  return;
56  }
57 
58  // Don't apply when first getting bearing w.r.t. the path
59  utils::setPathFurthestPointIfNotSet(data);
60  // Up to furthest only, closest path point is always 0 from path handler
61  const size_t path_segments_count = *data.furthest_reached_path_point;
62  float path_segments_flt = static_cast<float>(path_segments_count);
63  if (path_segments_count < offset_from_furthest_) {
64  return;
65  }
66 
67  // Don't apply when dynamic obstacles are blocking significant proportions of the local path
68  utils::setPathCostsIfNotSet(data, costmap_ros_);
69  std::vector<bool> & path_pts_valid = *data.path_pts_valid;
70  float invalid_ctr = 0.0f;
71  for (size_t i = 0; i < path_segments_count; i++) {
72  if (!path_pts_valid[i]) {invalid_ctr += 1.0f;}
73  if (invalid_ctr / path_segments_flt > max_path_occupancy_ratio_ && invalid_ctr > 2.0f) {
74  return;
75  }
76  }
77 
78  const size_t batch_size = data.trajectories.x.rows();
79  Eigen::ArrayXf cost(data.costs.rows());
80  cost.setZero();
81 
82  // Find integrated distance in the path
83  std::vector<float> path_integrated_distances(path_segments_count, 0.0f);
84  std::vector<utils::Pose2D> path(path_segments_count);
85  float dx = 0.0f, dy = 0.0f;
86  for (unsigned int i = 1; i != path_segments_count; i++) {
87  auto & pose = path[i - 1];
88  pose.x = data.path.x(i - 1);
89  pose.y = data.path.y(i - 1);
90  pose.theta = data.path.yaws(i - 1);
91 
92  dx = data.path.x(i) - pose.x;
93  dy = data.path.y(i) - pose.y;
94  path_integrated_distances[i] = path_integrated_distances[i - 1] + sqrtf(dx * dx + dy * dy);
95  }
96 
97  // Finish populating the path vector
98  auto & final_pose = path[path_segments_count - 1];
99  final_pose.x = data.path.x(path_segments_count - 1);
100  final_pose.y = data.path.y(path_segments_count - 1);
101  final_pose.theta = data.path.yaws(path_segments_count - 1);
102 
103  float summed_path_dist = 0.0f, dyaw = 0.0f;
104  unsigned int num_samples = 0u;
105  unsigned int path_pt = 0u;
106  float traj_integrated_distance = 0.0f;
107 
108  int strided_traj_rows = data.trajectories.x.rows();
109  int strided_traj_cols = floor((data.trajectories.x.cols() - 1) / trajectory_point_step_) + 1;
110  int outer_stride = strided_traj_rows * trajectory_point_step_;
111  // Get strided trajectory information
112  const auto T_x = Eigen::Map<const Eigen::ArrayXXf, 0,
113  Eigen::Stride<-1, -1>>(data.trajectories.x.data(),
114  strided_traj_rows, strided_traj_cols, Eigen::Stride<-1, -1>(outer_stride, 1));
115  const auto T_y = Eigen::Map<const Eigen::ArrayXXf, 0,
116  Eigen::Stride<-1, -1>>(data.trajectories.y.data(),
117  strided_traj_rows, strided_traj_cols, Eigen::Stride<-1, -1>(outer_stride, 1));
118  const auto T_yaw = Eigen::Map<const Eigen::ArrayXXf, 0,
119  Eigen::Stride<-1, -1>>(data.trajectories.yaws.data(), strided_traj_rows, strided_traj_cols,
120  Eigen::Stride<-1, -1>(outer_stride, 1));
121  const auto traj_sampled_size = T_x.cols();
122 
123  for (size_t t = 0; t < batch_size; ++t) {
124  summed_path_dist = 0.0f;
125  num_samples = 0u;
126  traj_integrated_distance = 0.0f;
127  path_pt = 0u;
128  float Tx_m1 = T_x(t, 0);
129  float Ty_m1 = T_y(t, 0);
130  for (int p = 1; p < traj_sampled_size; p++) {
131  const float Tx = T_x(t, p);
132  const float Ty = T_y(t, p);
133  dx = Tx - Tx_m1;
134  dy = Ty - Ty_m1;
135  Tx_m1 = Tx;
136  Ty_m1 = Ty;
137  traj_integrated_distance += sqrtf(dx * dx + dy * dy);
138  path_pt = utils::findClosestPathPt(
139  path_integrated_distances, traj_integrated_distance, path_pt);
140 
141  // The nearest path point to align to needs to be not in collision, else
142  // let the obstacle critic take over in this region due to dynamic obstacles
143  if (path_pts_valid[path_pt]) {
144  const auto & pose = path[path_pt];
145  dx = pose.x - Tx;
146  dy = pose.y - Ty;
147  num_samples++;
148  if (use_path_orientations_) {
149  dyaw = angles::shortest_angular_distance(pose.theta, T_yaw(t, p));
150  summed_path_dist += sqrtf(dx * dx + dy * dy + dyaw * dyaw);
151  } else {
152  summed_path_dist += sqrtf(dx * dx + dy * dy);
153  }
154  }
155  }
156  if (num_samples > 0u) {
157  cost(t) = summed_path_dist / static_cast<float>(num_samples);
158  } else {
159  cost(t) = 0.0f;
160  }
161  }
162 
163  if (power_ > 1u) {
164  data.costs += (cost * weight_).pow(power_).eval();
165  } else {
166  data.costs += (cost * weight_).eval();
167  }
168 }
169 
170 } // namespace mppi::critics
171 
172 #include <pluginlib/class_list_macros.hpp>
173 
174 PLUGINLIB_EXPORT_CLASS(
auto getParamGetter(const std::string &ns)
Get an object to retrieve parameters.
Abstract critic objective function to score trajectories.
void score(CriticData &data) override
Evaluate cost related to trajectories path alignment.
void initialize() override
Initialize critic.
Data to pass to critics for scoring, including state, trajectories, pruned path, global goal,...
Definition: critic_data.hpp:40