Point Cloud Library (PCL)  1.11.1-dev
fern_trainer.hpp
1 /*
2  * Software License Agreement (BSD License)
3  *
4  * Point Cloud Library (PCL) - www.pointclouds.org
5  * Copyright (c) 2010-2011, Willow Garage, Inc.
6  *
7  * All rights reserved.
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * * Redistributions of source code must retain the above copyright
14  * notice, this list of conditions and the following disclaimer.
15  * * Redistributions in binary form must reproduce the above
16  * copyright notice, this list of conditions and the following
17  * disclaimer in the documentation and/or other materials provided
18  * with the distribution.
19  * * Neither the name of Willow Garage, Inc. nor the names of its
20  * contributors may be used to endorse or promote products derived
21  * from this software without specific prior written permission.
22  *
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34  * POSSIBILITY OF SUCH DAMAGE.
35  *
36  */
37 
38 #pragma once
39 
40 namespace pcl {
41 
42 template <class FeatureType,
43  class DataSet,
44  class LabelType,
45  class ExampleIndex,
46  class NodeType>
48 : fern_depth_(10)
49 , num_of_features_(1000)
50 , num_of_thresholds_(10)
51 , feature_handler_(nullptr)
52 , stats_estimator_(nullptr)
53 , data_set_()
54 , label_data_()
55 , examples_()
56 {}
57 
58 template <class FeatureType,
59  class DataSet,
60  class LabelType,
61  class ExampleIndex,
62  class NodeType>
64 {}
65 
66 template <class FeatureType,
67  class DataSet,
68  class LabelType,
69  class ExampleIndex,
70  class NodeType>
71 void
74 {
75  const std::size_t num_of_branches = stats_estimator_->getNumOfBranches();
76  const std::size_t num_of_examples = examples_.size();
77 
78  // create random features
79  std::vector<FeatureType> features;
80  feature_handler_->createRandomFeatures(num_of_features_, features);
81 
82  // setup fern
83  fern.initialize(fern_depth_);
84 
85  // evaluate all features
86  std::vector<std::vector<float>> feature_results(num_of_features_);
87  std::vector<std::vector<unsigned char>> flags(num_of_features_);
88 
89  for (std::size_t feature_index = 0; feature_index < num_of_features_;
90  ++feature_index) {
91  feature_results[feature_index].reserve(num_of_examples);
92  flags[feature_index].reserve(num_of_examples);
93 
94  feature_handler_->evaluateFeature(features[feature_index],
95  data_set_,
96  examples_,
97  feature_results[feature_index],
98  flags[feature_index]);
99  }
100 
101  // iteratively select features and thresholds
102  std::vector<std::vector<std::vector<float>>> branch_feature_results(
103  num_of_features_); // [feature_index][branch_index][result_index]
104  std::vector<std::vector<std::vector<unsigned char>>> branch_flags(
105  num_of_features_); // [feature_index][branch_index][flag_index]
106  std::vector<std::vector<std::vector<ExampleIndex>>> branch_examples(
107  num_of_features_); // [feature_index][branch_index][result_index]
108  std::vector<std::vector<std::vector<LabelType>>> branch_label_data(
109  num_of_features_); // [feature_index][branch_index][flag_index]
110 
111  // - initialize branch feature results and flags
112  for (std::size_t feature_index = 0; feature_index < num_of_features_;
113  ++feature_index) {
114  branch_feature_results[feature_index].resize(1);
115  branch_flags[feature_index].resize(1);
116  branch_examples[feature_index].resize(1);
117  branch_label_data[feature_index].resize(1);
118 
119  branch_feature_results[feature_index][0] = feature_results[feature_index];
120  branch_flags[feature_index][0] = flags[feature_index];
121  branch_examples[feature_index][0] = examples_;
122  branch_label_data[feature_index][0] = label_data_;
123  }
124 
125  for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
126  // get thresholds
127  std::vector<std::vector<float>> thresholds(num_of_features_);
128 
129  for (std::size_t feature_index = 0; feature_index < num_of_features_;
130  ++feature_index) {
131  thresholds.reserve(num_of_thresholds_);
132  createThresholdsUniform(num_of_thresholds_,
133  feature_results[feature_index],
134  thresholds[feature_index]);
135  }
136 
137  // compute information gain
138  int best_feature_index = -1;
139  float best_feature_threshold = 0.0f;
140  float best_feature_information_gain = 0.0f;
141 
142  for (std::size_t feature_index = 0; feature_index < num_of_features_;
143  ++feature_index) {
144  for (std::size_t threshold_index = 0; threshold_index < num_of_thresholds_;
145  ++threshold_index) {
146  float information_gain = 0.0f;
147  for (std::size_t branch_index = 0;
148  branch_index < branch_feature_results[feature_index].size();
149  ++branch_index) {
150  const float branch_information_gain =
151  stats_estimator_->computeInformationGain(
152  data_set_,
153  branch_examples[feature_index][branch_index],
154  branch_label_data[feature_index][branch_index],
155  branch_feature_results[feature_index][branch_index],
156  branch_flags[feature_index][branch_index],
157  thresholds[feature_index][threshold_index]);
158 
159  information_gain +=
160  branch_information_gain *
161  branch_feature_results[feature_index][branch_index].size();
162  }
163 
164  if (information_gain > best_feature_information_gain) {
165  best_feature_information_gain = information_gain;
166  best_feature_index = static_cast<int>(feature_index);
167  best_feature_threshold = thresholds[feature_index][threshold_index];
168  }
169  }
170  }
171 
172  // add feature to the feature list of the fern
173  fern.accessFeature(depth_index) = features[best_feature_index];
174  fern.accessThreshold(depth_index) = best_feature_threshold;
175 
176  // update branch feature results and flags
177  for (std::size_t feature_index = 0; feature_index < num_of_features_;
178  ++feature_index) {
179  std::vector<std::vector<float>>& cur_branch_feature_results =
180  branch_feature_results[feature_index];
181  std::vector<std::vector<unsigned char>>& cur_branch_flags =
182  branch_flags[feature_index];
183  std::vector<std::vector<ExampleIndex>>& cur_branch_examples =
184  branch_examples[feature_index];
185  std::vector<std::vector<LabelType>>& cur_branch_label_data =
186  branch_label_data[feature_index];
187 
188  const std::size_t total_num_of_new_branches =
189  num_of_branches * cur_branch_feature_results.size();
190 
191  std::vector<std::vector<float>> new_branch_feature_results(
192  total_num_of_new_branches); // [branch_index][example_index]
193  std::vector<std::vector<unsigned char>> new_branch_flags(
194  total_num_of_new_branches); // [branch_index][example_index]
195  std::vector<std::vector<ExampleIndex>> new_branch_examples(
196  total_num_of_new_branches); // [branch_index][example_index]
197  std::vector<std::vector<LabelType>> new_branch_label_data(
198  total_num_of_new_branches); // [branch_index][example_index]
199 
200  for (std::size_t branch_index = 0;
201  branch_index < cur_branch_feature_results.size();
202  ++branch_index) {
203  const std::size_t num_of_examples_in_this_branch =
204  cur_branch_feature_results[branch_index].size();
205 
206  std::vector<unsigned char> branch_indices;
207  branch_indices.reserve(num_of_examples_in_this_branch);
208 
209  stats_estimator_->computeBranchIndices(cur_branch_feature_results[branch_index],
210  cur_branch_flags[branch_index],
211  best_feature_threshold,
212  branch_indices);
213 
214  // split results into different branches
215  const std::size_t base_branch_index = branch_index * num_of_branches;
216  for (std::size_t example_index = 0;
217  example_index < num_of_examples_in_this_branch;
218  ++example_index) {
219  const std::size_t combined_branch_index =
220  base_branch_index + branch_indices[example_index];
221 
222  new_branch_feature_results[combined_branch_index].push_back(
223  cur_branch_feature_results[branch_index][example_index]);
224  new_branch_flags[combined_branch_index].push_back(
225  cur_branch_flags[branch_index][example_index]);
226  new_branch_examples[combined_branch_index].push_back(
227  cur_branch_examples[branch_index][example_index]);
228  new_branch_label_data[combined_branch_index].push_back(
229  cur_branch_label_data[branch_index][example_index]);
230  }
231  }
232 
233  branch_feature_results[feature_index] = new_branch_feature_results;
234  branch_flags[feature_index] = new_branch_flags;
235  branch_examples[feature_index] = new_branch_examples;
236  branch_label_data[feature_index] = new_branch_label_data;
237  }
238  }
239 
240  // set node statistics
241  // - re-evaluate selected features
242  std::vector<std::vector<float>> final_feature_results(
243  fern_depth_); // [feature_index][example_index]
244  std::vector<std::vector<unsigned char>> final_flags(
245  fern_depth_); // [feature_index][example_index]
246  std::vector<std::vector<unsigned char>> final_branch_indices(
247  fern_depth_); // [feature_index][example_index]
248  for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
249  final_feature_results[depth_index].reserve(num_of_examples);
250  final_flags[depth_index].reserve(num_of_examples);
251  final_branch_indices[depth_index].reserve(num_of_examples);
252 
253  feature_handler_->evaluateFeature(fern.accessFeature(depth_index),
254  data_set_,
255  examples_,
256  final_feature_results[depth_index],
257  final_flags[depth_index]);
258 
259  stats_estimator_->computeBranchIndices(final_feature_results[depth_index],
260  final_flags[depth_index],
261  fern.accessThreshold(depth_index),
262  final_branch_indices[depth_index]);
263  }
264 
265  // - distribute examples to nodes
266  std::vector<std::vector<LabelType>> node_labels(
267  0x1 << fern_depth_); // [node_index][example_index]
268  std::vector<std::vector<ExampleIndex>> node_examples(
269  0x1 << fern_depth_); // [node_index][example_index]
270 
271  for (std::size_t example_index = 0; example_index < num_of_examples;
272  ++example_index) {
273  std::size_t node_index = 0;
274  for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
275  node_index *= num_of_branches;
276  node_index += final_branch_indices[depth_index][example_index];
277  }
278 
279  node_labels[node_index].push_back(label_data_[example_index]);
280  node_examples[node_index].push_back(examples_[example_index]);
281  }
282 
283  // - compute and set statistics for every node
284  const std::size_t num_of_nodes = 0x1 << fern_depth_;
285  for (std::size_t node_index = 0; node_index < num_of_nodes; ++node_index) {
286  stats_estimator_->computeAndSetNodeStats(data_set_,
287  node_examples[node_index],
288  node_labels[node_index],
289  fern[node_index]);
290  }
291 }
292 
293 template <class FeatureType,
294  class DataSet,
295  class LabelType,
296  class ExampleIndex,
297  class NodeType>
298 void
300  createThresholdsUniform(const std::size_t num_of_thresholds,
301  std::vector<float>& values,
302  std::vector<float>& thresholds)
303 {
304  // estimate range of values
305  float min_value = ::std::numeric_limits<float>::max();
306  float max_value = -::std::numeric_limits<float>::max();
307 
308  const std::size_t num_of_values = values.size();
309  for (int value_index = 0; value_index < num_of_values; ++value_index) {
310  const float value = values[value_index];
311 
312  if (value < min_value)
313  min_value = value;
314  if (value > max_value)
315  max_value = value;
316  }
317 
318  const float range = max_value - min_value;
319  const float step = range / (num_of_thresholds + 2);
320 
321  // compute thresholds
322  thresholds.resize(num_of_thresholds);
323 
324  for (int threshold_index = 0; threshold_index < num_of_thresholds;
325  ++threshold_index) {
326  thresholds[threshold_index] = min_value + step * (threshold_index + 1);
327  }
328 }
329 
330 } // namespace pcl
pcl
Definition: convolution.h:46
pcl::Fern::accessThreshold
float & accessThreshold(const std::size_t threshold_index)
Access operator for thresholds.
Definition: fern.h:186
pcl::FernTrainer::createThresholdsUniform
static void createThresholdsUniform(const std::size_t num_of_thresholds, std::vector< float > &values, std::vector< float > &thresholds)
Creates uniformely distrebuted thresholds over the range of the supplied values.
Definition: fern_trainer.hpp:300
pcl::Fern::initialize
void initialize(const std::size_t num_of_decisions)
Initializes the fern.
Definition: fern.h:62
pcl::Fern
Class representing a Fern.
Definition: fern.h:49
pcl::FernTrainer::train
void train(Fern< FeatureType, NodeType > &fern)
Trains a decision tree using the set training data and settings.
Definition: fern_trainer.hpp:72
pcl::FernTrainer::~FernTrainer
virtual ~FernTrainer()
Destructor.
Definition: fern_trainer.hpp:63
pcl::Fern::accessFeature
FeatureType & accessFeature(const std::size_t feature_index)
Access operator for features.
Definition: fern.h:166
pcl::FernTrainer::FernTrainer
FernTrainer()
Constructor.
Definition: fern_trainer.hpp:47