Nav2 Navigation Stack - kilted  kilted
ROS 2 Navigation Stack
critic_manager.cpp
1 // Copyright (c) 2022 Samsung Research America, @artofnothingness Alexey Budyakov
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "nav2_mppi_controller/critic_manager.hpp"
16 
17 namespace mppi
18 {
19 
21  rclcpp_lifecycle::LifecycleNode::WeakPtr parent, const std::string & name,
22  std::shared_ptr<nav2_costmap_2d::Costmap2DROS> costmap_ros, ParametersHandler * param_handler)
23 {
24  parent_ = parent;
25  costmap_ros_ = costmap_ros;
26  name_ = name;
27  auto node = parent_.lock();
28  logger_ = node->get_logger();
29  parameters_handler_ = param_handler;
30 
31  getParams();
32  loadCritics();
33 }
34 
36 {
37  auto node = parent_.lock();
38  auto getParam = parameters_handler_->getParamGetter(name_);
39  getParam(critic_names_, "critics", std::vector<std::string>{}, ParameterType::Static);
40  getParam(publish_critics_stats_, "publish_critics_stats", false, ParameterType::Static);
41 }
42 
44 {
45  if (!loader_) {
46  loader_ = std::make_unique<pluginlib::ClassLoader<critics::CriticFunction>>(
47  "nav2_mppi_controller", "mppi::critics::CriticFunction");
48  }
49 
50  auto node = parent_.lock();
51  if (publish_critics_stats_) {
52  critics_effect_pub_ = node->create_publisher<nav2_msgs::msg::CriticsStats>(
53  "~/critics_stats", 10);
54  critics_effect_pub_->on_activate();
55  }
56 
57  critics_.clear();
58  for (auto name : critic_names_) {
59  std::string fullname = getFullName(name);
60  auto instance = std::unique_ptr<critics::CriticFunction>(
61  loader_->createUnmanagedInstance(fullname));
62  critics_.push_back(std::move(instance));
63  critics_.back()->on_configure(
64  parent_, name_, name_ + "." + name, costmap_ros_,
65  parameters_handler_);
66  RCLCPP_INFO(logger_, "Critic loaded : %s", fullname.c_str());
67  }
68 }
69 
70 std::string CriticManager::getFullName(const std::string & name)
71 {
72  return "mppi::critics::" + name;
73 }
74 
76  CriticData & data) const
77 {
78  std::unique_ptr<nav2_msgs::msg::CriticsStats> stats_msg;
79  if (publish_critics_stats_) {
80  stats_msg = std::make_unique<nav2_msgs::msg::CriticsStats>();
81  stats_msg->critics.reserve(critics_.size());
82  stats_msg->changed.reserve(critics_.size());
83  stats_msg->costs_sum.reserve(critics_.size());
84  }
85 
86  for (size_t i = 0; i < critics_.size(); ++i) {
87  if (data.fail_flag) {
88  break;
89  }
90 
91  // Store costs before critic evaluation
92  Eigen::ArrayXf costs_before;
93  if (publish_critics_stats_) {
94  costs_before = data.costs;
95  }
96 
97  critics_[i]->score(data);
98 
99  // Calculate statistics if publishing is enabled
100  if (publish_critics_stats_) {
101  stats_msg->critics.push_back(critic_names_[i]);
102 
103  // Calculate sum of costs added by this individual critic
104  Eigen::ArrayXf cost_diff = data.costs - costs_before;
105  float costs_sum = cost_diff.sum();
106  stats_msg->costs_sum.push_back(costs_sum);
107  stats_msg->changed.push_back(costs_sum != 0.0f);
108  }
109  }
110 
111  // Publish statistics if enabled
112  if (critics_effect_pub_) {
113  auto node = parent_.lock();
114  stats_msg->stamp = node->get_clock()->now();
115  critics_effect_pub_->publish(std::move(stats_msg));
116  }
117 }
118 
119 } // namespace mppi
void on_configure(rclcpp_lifecycle::LifecycleNode::WeakPtr parent, const std::string &name, std::shared_ptr< nav2_costmap_2d::Costmap2DROS >, ParametersHandler *)
Configure critic manager on bringup and load plugins.
void getParams()
Get parameters (critics to load)
void evalTrajectoriesScores(CriticData &data) const
Score trajectories by the set of loaded critic functions.
virtual void loadCritics()
Load the critic plugins.
std::string getFullName(const std::string &name)
Get full-name namespaced critic IDs.
Handles getting parameters and dynamic parameter changes.
auto getParamGetter(const std::string &ns)
Get an object to retrieve parameters.
Data to pass to critics for scoring, including state, trajectories, pruned path, global goal,...
Definition: critic_data.hpp:40