Point Cloud Library (PCL)  1.14.1-dev
decision_tree_trainer.hpp
1 /*
2  * Software License Agreement (BSD License)
3  *
4  * Point Cloud Library (PCL) - www.pointclouds.org
5  * Copyright (c) 2010-2011, Willow Garage, Inc.
6  *
7  * All rights reserved.
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * * Redistributions of source code must retain the above copyright
14  * notice, this list of conditions and the following disclaimer.
15  * * Redistributions in binary form must reproduce the above
16  * copyright notice, this list of conditions and the following
17  * disclaimer in the documentation and/or other materials provided
18  * with the distribution.
19  * * Neither the name of Willow Garage, Inc. nor the names of its
20  * contributors may be used to endorse or promote products derived
21  * from this software without specific prior written permission.
22  *
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34  * POSSIBILITY OF SUCH DAMAGE.
35  *
36  */
37 
38 #pragma once
39 
40 namespace pcl {
41 
42 template <class FeatureType,
43  class DataSet,
44  class LabelType,
45  class ExampleIndex,
46  class NodeType>
48  DecisionTreeTrainer() = default;
49 
50 template <class FeatureType,
51  class DataSet,
52  class LabelType,
53  class ExampleIndex,
54  class NodeType>
56  ~DecisionTreeTrainer() = default;
57 
58 template <class FeatureType,
59  class DataSet,
60  class LabelType,
61  class ExampleIndex,
62  class NodeType>
63 void
66 {
67  // create random features
68  std::vector<FeatureType> features;
69 
70  if (!random_features_at_split_node_)
71  feature_handler_->createRandomFeatures(num_of_features_, features);
72 
73  // recursively build decision tree
74  NodeType root_node;
75  tree.setRoot(root_node);
76 
77  if (decision_tree_trainer_data_provider_) {
78  std::cerr << "use decision_tree_trainer_data_provider_" << std::endl;
79 
80  decision_tree_trainer_data_provider_->getDatasetAndLabels(
81  data_set_, label_data_, examples_);
82  trainDecisionTreeNode(
83  features, examples_, label_data_, max_tree_depth_, tree.getRoot());
84  label_data_.clear();
85  data_set_.clear();
86  examples_.clear();
87  }
88  else {
89  trainDecisionTreeNode(
90  features, examples_, label_data_, max_tree_depth_, tree.getRoot());
91  }
92 }
93 
94 template <class FeatureType,
95  class DataSet,
96  class LabelType,
97  class ExampleIndex,
98  class NodeType>
99 void
101  trainDecisionTreeNode(std::vector<FeatureType>& features,
102  std::vector<ExampleIndex>& examples,
103  std::vector<LabelType>& label_data,
104  const std::size_t max_depth,
105  NodeType& node)
106 {
107  const std::size_t num_of_examples = examples.size();
108  if (num_of_examples == 0) {
109  PCL_ERROR(
110  "Reached invalid point in decision tree training: Number of examples is 0!\n");
111  return;
112  };
113 
114  if (max_depth == 0) {
115  stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
116  return;
117  };
118 
119  if (examples.size() < min_examples_for_split_) {
120  stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
121  return;
122  }
123 
124  if (random_features_at_split_node_) {
125  features.clear();
126  feature_handler_->createRandomFeatures(num_of_features_, features);
127  }
128 
129  std::vector<float> feature_results;
130  std::vector<unsigned char> flags;
131 
132  feature_results.reserve(num_of_examples);
133  flags.reserve(num_of_examples);
134 
135  // find best feature for split
136  int best_feature_index = -1;
137  float best_feature_threshold = 0.0f;
138  float best_feature_information_gain = 0.0f;
139 
140  const std::size_t num_of_features = features.size();
141  for (std::size_t feature_index = 0; feature_index < num_of_features;
142  ++feature_index) {
143  // evaluate features
144  feature_handler_->evaluateFeature(
145  features[feature_index], data_set_, examples, feature_results, flags);
146 
147  // get list of thresholds
148  if (!thresholds_.empty()) {
149  // compute information gain for each threshold and store threshold with highest
150  // information gain
151  for (const float& threshold : thresholds_) {
152 
153  const float information_gain = stats_estimator_->computeInformationGain(
154  data_set_, examples, label_data, feature_results, flags, threshold);
155 
156  if (information_gain > best_feature_information_gain) {
157  best_feature_information_gain = information_gain;
158  best_feature_index = static_cast<int>(feature_index);
159  best_feature_threshold = threshold;
160  }
161  }
162  }
163  else {
164  std::vector<float> thresholds;
165  thresholds.reserve(num_of_thresholds_);
166  createThresholdsUniform(num_of_thresholds_, feature_results, thresholds);
167 
168  // compute information gain for each threshold and store threshold with highest
169  // information gain
170  for (std::size_t threshold_index = 0; threshold_index < num_of_thresholds_;
171  ++threshold_index) {
172  const float threshold = thresholds[threshold_index];
173 
174  // compute information gain
175  const float information_gain = stats_estimator_->computeInformationGain(
176  data_set_, examples, label_data, feature_results, flags, threshold);
177 
178  if (information_gain > best_feature_information_gain) {
179  best_feature_information_gain = information_gain;
180  best_feature_index = static_cast<int>(feature_index);
181  best_feature_threshold = threshold;
182  }
183  }
184  }
185  }
186 
187  if (best_feature_index == -1) {
188  stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
189  return;
190  }
191 
192  // get branch indices for best feature and best threshold
193  std::vector<unsigned char> branch_indices;
194  branch_indices.reserve(num_of_examples);
195  {
196  feature_handler_->evaluateFeature(
197  features[best_feature_index], data_set_, examples, feature_results, flags);
198 
199  stats_estimator_->computeBranchIndices(
200  feature_results, flags, best_feature_threshold, branch_indices);
201  }
202 
203  stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
204 
205  // separate data
206  {
207  const std::size_t num_of_branches = stats_estimator_->getNumOfBranches();
208 
209  std::vector<std::size_t> branch_counts(num_of_branches, 0);
210  for (std::size_t example_index = 0; example_index < num_of_examples;
211  ++example_index) {
212  ++branch_counts[branch_indices[example_index]];
213  }
214 
215  node.feature = features[best_feature_index];
216  node.threshold = best_feature_threshold;
217  node.sub_nodes.resize(num_of_branches);
218 
219  for (std::size_t branch_index = 0; branch_index < num_of_branches; ++branch_index) {
220  if (branch_counts[branch_index] == 0) {
221  NodeType branch_node;
222  stats_estimator_->computeAndSetNodeStats(
223  data_set_, examples, label_data, branch_node);
224  // branch_node->num_of_sub_nodes = 0;
225 
226  node.sub_nodes[branch_index] = branch_node;
227 
228  continue;
229  }
230 
231  std::vector<LabelType> branch_labels;
232  std::vector<ExampleIndex> branch_examples;
233  branch_labels.reserve(branch_counts[branch_index]);
234  branch_examples.reserve(branch_counts[branch_index]);
235 
236  for (std::size_t example_index = 0; example_index < num_of_examples;
237  ++example_index) {
238  if (branch_indices[example_index] == branch_index) {
239  branch_examples.push_back(examples[example_index]);
240  branch_labels.push_back(label_data[example_index]);
241  }
242  }
243 
244  trainDecisionTreeNode(features,
245  branch_examples,
246  branch_labels,
247  max_depth - 1,
248  node.sub_nodes[branch_index]);
249  }
250  }
251 }
252 
253 template <class FeatureType,
254  class DataSet,
255  class LabelType,
256  class ExampleIndex,
257  class NodeType>
258 void
260  createThresholdsUniform(const std::size_t num_of_thresholds,
261  std::vector<float>& values,
262  std::vector<float>& thresholds)
263 {
264  // estimate range of values
265  float min_value = ::std::numeric_limits<float>::max();
266  float max_value = -::std::numeric_limits<float>::max();
267 
268  const std::size_t num_of_values = values.size();
269  for (std::size_t value_index = 0; value_index < num_of_values; ++value_index) {
270  const float value = values[value_index];
271 
272  if (value < min_value)
273  min_value = value;
274  if (value > max_value)
275  max_value = value;
276  }
277 
278  const float range = max_value - min_value;
279  const float step = range / static_cast<float>(num_of_thresholds + 2);
280 
281  // compute thresholds
282  thresholds.resize(num_of_thresholds);
283 
284  for (std::size_t threshold_index = 0; threshold_index < num_of_thresholds;
285  ++threshold_index) {
286  thresholds[threshold_index] =
287  min_value + step * (static_cast<float>(threshold_index + 1));
288  }
289 }
290 
291 } // namespace pcl
Class representing a decision tree.
Definition: decision_tree.h:49
NodeType & getRoot()
Returns the root node of the tree.
Definition: decision_tree.h:69
void setRoot(const NodeType &root)
Sets the root node of the tree.
Definition: decision_tree.h:62
static void createThresholdsUniform(const std::size_t num_of_thresholds, std::vector< float > &values, std::vector< float > &thresholds)
Creates uniformly distributed thresholds over the range of the supplied values.
void trainDecisionTreeNode(std::vector< FeatureType > &features, std::vector< ExampleIndex > &examples, std::vector< LabelType > &label_data, std::size_t max_depth, NodeType &node)
Trains a decision tree node from the specified features, label data, and examples.
void train(DecisionTree< NodeType > &tree)
Trains a decision tree using the set training data and settings.
virtual ~DecisionTreeTrainer()
Destructor.
DecisionTreeTrainer()
Constructor.