Point Cloud Library (PCL)  1.14.1-dev
fern_trainer.hpp
1 /*
2  * Software License Agreement (BSD License)
3  *
4  * Point Cloud Library (PCL) - www.pointclouds.org
5  * Copyright (c) 2010-2011, Willow Garage, Inc.
6  *
7  * All rights reserved.
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * * Redistributions of source code must retain the above copyright
14  * notice, this list of conditions and the following disclaimer.
15  * * Redistributions in binary form must reproduce the above
16  * copyright notice, this list of conditions and the following
17  * disclaimer in the documentation and/or other materials provided
18  * with the distribution.
19  * * Neither the name of Willow Garage, Inc. nor the names of its
20  * contributors may be used to endorse or promote products derived
21  * from this software without specific prior written permission.
22  *
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34  * POSSIBILITY OF SUCH DAMAGE.
35  *
36  */
37 
38 #pragma once
39 
40 namespace pcl {
41 
42 template <class FeatureType,
43  class DataSet,
44  class LabelType,
45  class ExampleIndex,
46  class NodeType>
48 : fern_depth_(10)
49 , num_of_features_(1000)
50 , num_of_thresholds_(10)
51 , feature_handler_(nullptr)
52 , stats_estimator_(nullptr)
53 , data_set_()
54 , label_data_()
55 , examples_()
56 {}
57 
58 template <class FeatureType,
59  class DataSet,
60  class LabelType,
61  class ExampleIndex,
62  class NodeType>
63 void
66 {
67  const std::size_t num_of_branches = stats_estimator_->getNumOfBranches();
68  const std::size_t num_of_examples = examples_.size();
69 
70  // create random features
71  std::vector<FeatureType> features;
72  feature_handler_->createRandomFeatures(num_of_features_, features);
73 
74  // setup fern
75  fern.initialize(fern_depth_);
76 
77  // evaluate all features
78  std::vector<std::vector<float>> feature_results(num_of_features_);
79  std::vector<std::vector<unsigned char>> flags(num_of_features_);
80 
81  for (std::size_t feature_index = 0; feature_index < num_of_features_;
82  ++feature_index) {
83  feature_results[feature_index].reserve(num_of_examples);
84  flags[feature_index].reserve(num_of_examples);
85 
86  feature_handler_->evaluateFeature(features[feature_index],
87  data_set_,
88  examples_,
89  feature_results[feature_index],
90  flags[feature_index]);
91  }
92 
93  // iteratively select features and thresholds
94  std::vector<std::vector<std::vector<float>>> branch_feature_results(
95  num_of_features_); // [feature_index][branch_index][result_index]
96  std::vector<std::vector<std::vector<unsigned char>>> branch_flags(
97  num_of_features_); // [feature_index][branch_index][flag_index]
98  std::vector<std::vector<std::vector<ExampleIndex>>> branch_examples(
99  num_of_features_); // [feature_index][branch_index][result_index]
100  std::vector<std::vector<std::vector<LabelType>>> branch_label_data(
101  num_of_features_); // [feature_index][branch_index][flag_index]
102 
103  // - initialize branch feature results and flags
104  for (std::size_t feature_index = 0; feature_index < num_of_features_;
105  ++feature_index) {
106  branch_feature_results[feature_index].resize(1);
107  branch_flags[feature_index].resize(1);
108  branch_examples[feature_index].resize(1);
109  branch_label_data[feature_index].resize(1);
110 
111  branch_feature_results[feature_index][0] = feature_results[feature_index];
112  branch_flags[feature_index][0] = flags[feature_index];
113  branch_examples[feature_index][0] = examples_;
114  branch_label_data[feature_index][0] = label_data_;
115  }
116 
117  for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
118  // get thresholds
119  std::vector<std::vector<float>> thresholds(num_of_features_);
120 
121  for (std::size_t feature_index = 0; feature_index < num_of_features_;
122  ++feature_index) {
123  thresholds.reserve(num_of_thresholds_);
124  createThresholdsUniform(num_of_thresholds_,
125  feature_results[feature_index],
126  thresholds[feature_index]);
127  }
128 
129  // compute information gain
130  int best_feature_index = -1;
131  float best_feature_threshold = 0.0f;
132  float best_feature_information_gain = 0.0f;
133 
134  for (std::size_t feature_index = 0; feature_index < num_of_features_;
135  ++feature_index) {
136  for (std::size_t threshold_index = 0; threshold_index < num_of_thresholds_;
137  ++threshold_index) {
138  float information_gain = 0.0f;
139  for (std::size_t branch_index = 0;
140  branch_index < branch_feature_results[feature_index].size();
141  ++branch_index) {
142  const float branch_information_gain =
143  stats_estimator_->computeInformationGain(
144  data_set_,
145  branch_examples[feature_index][branch_index],
146  branch_label_data[feature_index][branch_index],
147  branch_feature_results[feature_index][branch_index],
148  branch_flags[feature_index][branch_index],
149  thresholds[feature_index][threshold_index]);
150 
151  information_gain +=
152  branch_information_gain *
153  branch_feature_results[feature_index][branch_index].size();
154  }
155 
156  if (information_gain > best_feature_information_gain) {
157  best_feature_information_gain = information_gain;
158  best_feature_index = static_cast<int>(feature_index);
159  best_feature_threshold = thresholds[feature_index][threshold_index];
160  }
161  }
162  }
163 
164  // add feature to the feature list of the fern
165  fern.accessFeature(depth_index) = features[best_feature_index];
166  fern.accessThreshold(depth_index) = best_feature_threshold;
167 
168  // update branch feature results and flags
169  for (std::size_t feature_index = 0; feature_index < num_of_features_;
170  ++feature_index) {
171  std::vector<std::vector<float>>& cur_branch_feature_results =
172  branch_feature_results[feature_index];
173  std::vector<std::vector<unsigned char>>& cur_branch_flags =
174  branch_flags[feature_index];
175  std::vector<std::vector<ExampleIndex>>& cur_branch_examples =
176  branch_examples[feature_index];
177  std::vector<std::vector<LabelType>>& cur_branch_label_data =
178  branch_label_data[feature_index];
179 
180  const std::size_t total_num_of_new_branches =
181  num_of_branches * cur_branch_feature_results.size();
182 
183  std::vector<std::vector<float>> new_branch_feature_results(
184  total_num_of_new_branches); // [branch_index][example_index]
185  std::vector<std::vector<unsigned char>> new_branch_flags(
186  total_num_of_new_branches); // [branch_index][example_index]
187  std::vector<std::vector<ExampleIndex>> new_branch_examples(
188  total_num_of_new_branches); // [branch_index][example_index]
189  std::vector<std::vector<LabelType>> new_branch_label_data(
190  total_num_of_new_branches); // [branch_index][example_index]
191 
192  for (std::size_t branch_index = 0;
193  branch_index < cur_branch_feature_results.size();
194  ++branch_index) {
195  const std::size_t num_of_examples_in_this_branch =
196  cur_branch_feature_results[branch_index].size();
197 
198  std::vector<unsigned char> branch_indices;
199  branch_indices.reserve(num_of_examples_in_this_branch);
200 
201  stats_estimator_->computeBranchIndices(cur_branch_feature_results[branch_index],
202  cur_branch_flags[branch_index],
203  best_feature_threshold,
204  branch_indices);
205 
206  // split results into different branches
207  const std::size_t base_branch_index = branch_index * num_of_branches;
208  for (std::size_t example_index = 0;
209  example_index < num_of_examples_in_this_branch;
210  ++example_index) {
211  const std::size_t combined_branch_index =
212  base_branch_index + branch_indices[example_index];
213 
214  new_branch_feature_results[combined_branch_index].push_back(
215  cur_branch_feature_results[branch_index][example_index]);
216  new_branch_flags[combined_branch_index].push_back(
217  cur_branch_flags[branch_index][example_index]);
218  new_branch_examples[combined_branch_index].push_back(
219  cur_branch_examples[branch_index][example_index]);
220  new_branch_label_data[combined_branch_index].push_back(
221  cur_branch_label_data[branch_index][example_index]);
222  }
223  }
224 
225  branch_feature_results[feature_index] = new_branch_feature_results;
226  branch_flags[feature_index] = new_branch_flags;
227  branch_examples[feature_index] = new_branch_examples;
228  branch_label_data[feature_index] = new_branch_label_data;
229  }
230  }
231 
232  // set node statistics
233  // - re-evaluate selected features
234  std::vector<std::vector<float>> final_feature_results(
235  fern_depth_); // [feature_index][example_index]
236  std::vector<std::vector<unsigned char>> final_flags(
237  fern_depth_); // [feature_index][example_index]
238  std::vector<std::vector<unsigned char>> final_branch_indices(
239  fern_depth_); // [feature_index][example_index]
240  for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
241  final_feature_results[depth_index].reserve(num_of_examples);
242  final_flags[depth_index].reserve(num_of_examples);
243  final_branch_indices[depth_index].reserve(num_of_examples);
244 
245  feature_handler_->evaluateFeature(fern.accessFeature(depth_index),
246  data_set_,
247  examples_,
248  final_feature_results[depth_index],
249  final_flags[depth_index]);
250 
251  stats_estimator_->computeBranchIndices(final_feature_results[depth_index],
252  final_flags[depth_index],
253  fern.accessThreshold(depth_index),
254  final_branch_indices[depth_index]);
255  }
256 
257  // - distribute examples to nodes
258  std::vector<std::vector<LabelType>> node_labels(
259  0x1 << fern_depth_); // [node_index][example_index]
260  std::vector<std::vector<ExampleIndex>> node_examples(
261  0x1 << fern_depth_); // [node_index][example_index]
262 
263  for (std::size_t example_index = 0; example_index < num_of_examples;
264  ++example_index) {
265  std::size_t node_index = 0;
266  for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
267  node_index *= num_of_branches;
268  node_index += final_branch_indices[depth_index][example_index];
269  }
270 
271  node_labels[node_index].push_back(label_data_[example_index]);
272  node_examples[node_index].push_back(examples_[example_index]);
273  }
274 
275  // - compute and set statistics for every node
276  const std::size_t num_of_nodes = 0x1 << fern_depth_;
277  for (std::size_t node_index = 0; node_index < num_of_nodes; ++node_index) {
278  stats_estimator_->computeAndSetNodeStats(data_set_,
279  node_examples[node_index],
280  node_labels[node_index],
281  fern[node_index]);
282  }
283 }
284 
285 template <class FeatureType,
286  class DataSet,
287  class LabelType,
288  class ExampleIndex,
289  class NodeType>
290 void
292  createThresholdsUniform(const std::size_t num_of_thresholds,
293  std::vector<float>& values,
294  std::vector<float>& thresholds)
295 {
296  // estimate range of values
297  float min_value = ::std::numeric_limits<float>::max();
298  float max_value = -::std::numeric_limits<float>::max();
299 
300  const std::size_t num_of_values = values.size();
301  for (int value_index = 0; value_index < num_of_values; ++value_index) {
302  const float value = values[value_index];
303 
304  if (value < min_value)
305  min_value = value;
306  if (value > max_value)
307  max_value = value;
308  }
309 
310  const float range = max_value - min_value;
311  const float step = range / (num_of_thresholds + 2);
312 
313  // compute thresholds
314  thresholds.resize(num_of_thresholds);
315 
316  for (int threshold_index = 0; threshold_index < num_of_thresholds;
317  ++threshold_index) {
318  thresholds[threshold_index] = min_value + step * (threshold_index + 1);
319  }
320 }
321 
322 } // namespace pcl
Class representing a Fern.
Definition: fern.h:49
float & accessThreshold(const std::size_t threshold_index)
Access operator for thresholds.
Definition: fern.h:177
void initialize(const std::size_t num_of_decisions)
Initializes the fern.
Definition: fern.h:59
FeatureType & accessFeature(const std::size_t feature_index)
Access operator for features.
Definition: fern.h:157
static void createThresholdsUniform(const std::size_t num_of_thresholds, std::vector< float > &values, std::vector< float > &thresholds)
Creates uniformly distributed thresholds over the range of the supplied values.
void train(Fern< FeatureType, NodeType > &fern)
Trains a decision tree using the set training data and settings.
FernTrainer()
Constructor.