42 template <
class FeatureType,
49 , num_of_features_(1000)
50 , num_of_thresholds_(10)
51 , feature_handler_(nullptr)
52 , stats_estimator_(nullptr)
58 template <
class FeatureType,
67 const std::size_t num_of_branches = stats_estimator_->getNumOfBranches();
68 const std::size_t num_of_examples = examples_.size();
71 std::vector<FeatureType> features;
72 feature_handler_->createRandomFeatures(num_of_features_, features);
78 std::vector<std::vector<float>> feature_results(num_of_features_);
79 std::vector<std::vector<unsigned char>> flags(num_of_features_);
81 for (std::size_t feature_index = 0; feature_index < num_of_features_;
83 feature_results[feature_index].reserve(num_of_examples);
84 flags[feature_index].reserve(num_of_examples);
86 feature_handler_->evaluateFeature(features[feature_index],
89 feature_results[feature_index],
90 flags[feature_index]);
94 std::vector<std::vector<std::vector<float>>> branch_feature_results(
96 std::vector<std::vector<std::vector<unsigned char>>> branch_flags(
98 std::vector<std::vector<std::vector<ExampleIndex>>> branch_examples(
100 std::vector<std::vector<std::vector<LabelType>>> branch_label_data(
104 for (std::size_t feature_index = 0; feature_index < num_of_features_;
106 branch_feature_results[feature_index].resize(1);
107 branch_flags[feature_index].resize(1);
108 branch_examples[feature_index].resize(1);
109 branch_label_data[feature_index].resize(1);
111 branch_feature_results[feature_index][0] = feature_results[feature_index];
112 branch_flags[feature_index][0] = flags[feature_index];
113 branch_examples[feature_index][0] = examples_;
114 branch_label_data[feature_index][0] = label_data_;
117 for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
119 std::vector<std::vector<float>> thresholds(num_of_features_);
121 for (std::size_t feature_index = 0; feature_index < num_of_features_;
123 thresholds.reserve(num_of_thresholds_);
124 createThresholdsUniform(num_of_thresholds_,
125 feature_results[feature_index],
126 thresholds[feature_index]);
130 int best_feature_index = -1;
131 float best_feature_threshold = 0.0f;
132 float best_feature_information_gain = 0.0f;
134 for (std::size_t feature_index = 0; feature_index < num_of_features_;
136 for (std::size_t threshold_index = 0; threshold_index < num_of_thresholds_;
138 float information_gain = 0.0f;
139 for (std::size_t branch_index = 0;
140 branch_index < branch_feature_results[feature_index].size();
142 const float branch_information_gain =
143 stats_estimator_->computeInformationGain(
145 branch_examples[feature_index][branch_index],
146 branch_label_data[feature_index][branch_index],
147 branch_feature_results[feature_index][branch_index],
148 branch_flags[feature_index][branch_index],
149 thresholds[feature_index][threshold_index]);
152 branch_information_gain *
153 branch_feature_results[feature_index][branch_index].size();
156 if (information_gain > best_feature_information_gain) {
157 best_feature_information_gain = information_gain;
158 best_feature_index =
static_cast<int>(feature_index);
159 best_feature_threshold = thresholds[feature_index][threshold_index];
165 fern.
accessFeature(depth_index) = features[best_feature_index];
169 for (std::size_t feature_index = 0; feature_index < num_of_features_;
171 std::vector<std::vector<float>>& cur_branch_feature_results =
172 branch_feature_results[feature_index];
173 std::vector<std::vector<unsigned char>>& cur_branch_flags =
174 branch_flags[feature_index];
175 std::vector<std::vector<ExampleIndex>>& cur_branch_examples =
176 branch_examples[feature_index];
177 std::vector<std::vector<LabelType>>& cur_branch_label_data =
178 branch_label_data[feature_index];
180 const std::size_t total_num_of_new_branches =
181 num_of_branches * cur_branch_feature_results.size();
183 std::vector<std::vector<float>> new_branch_feature_results(
184 total_num_of_new_branches);
185 std::vector<std::vector<unsigned char>> new_branch_flags(
186 total_num_of_new_branches);
187 std::vector<std::vector<ExampleIndex>> new_branch_examples(
188 total_num_of_new_branches);
189 std::vector<std::vector<LabelType>> new_branch_label_data(
190 total_num_of_new_branches);
192 for (std::size_t branch_index = 0;
193 branch_index < cur_branch_feature_results.size();
195 const std::size_t num_of_examples_in_this_branch =
196 cur_branch_feature_results[branch_index].size();
198 std::vector<unsigned char> branch_indices;
199 branch_indices.reserve(num_of_examples_in_this_branch);
201 stats_estimator_->computeBranchIndices(cur_branch_feature_results[branch_index],
202 cur_branch_flags[branch_index],
203 best_feature_threshold,
207 const std::size_t base_branch_index = branch_index * num_of_branches;
208 for (std::size_t example_index = 0;
209 example_index < num_of_examples_in_this_branch;
211 const std::size_t combined_branch_index =
212 base_branch_index + branch_indices[example_index];
214 new_branch_feature_results[combined_branch_index].push_back(
215 cur_branch_feature_results[branch_index][example_index]);
216 new_branch_flags[combined_branch_index].push_back(
217 cur_branch_flags[branch_index][example_index]);
218 new_branch_examples[combined_branch_index].push_back(
219 cur_branch_examples[branch_index][example_index]);
220 new_branch_label_data[combined_branch_index].push_back(
221 cur_branch_label_data[branch_index][example_index]);
225 branch_feature_results[feature_index] = new_branch_feature_results;
226 branch_flags[feature_index] = new_branch_flags;
227 branch_examples[feature_index] = new_branch_examples;
228 branch_label_data[feature_index] = new_branch_label_data;
234 std::vector<std::vector<float>> final_feature_results(
236 std::vector<std::vector<unsigned char>> final_flags(
238 std::vector<std::vector<unsigned char>> final_branch_indices(
240 for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
241 final_feature_results[depth_index].reserve(num_of_examples);
242 final_flags[depth_index].reserve(num_of_examples);
243 final_branch_indices[depth_index].reserve(num_of_examples);
245 feature_handler_->evaluateFeature(fern.
accessFeature(depth_index),
248 final_feature_results[depth_index],
249 final_flags[depth_index]);
251 stats_estimator_->computeBranchIndices(final_feature_results[depth_index],
252 final_flags[depth_index],
254 final_branch_indices[depth_index]);
258 std::vector<std::vector<LabelType>> node_labels(
260 std::vector<std::vector<ExampleIndex>> node_examples(
263 for (std::size_t example_index = 0; example_index < num_of_examples;
265 std::size_t node_index = 0;
266 for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
267 node_index *= num_of_branches;
268 node_index += final_branch_indices[depth_index][example_index];
271 node_labels[node_index].push_back(label_data_[example_index]);
272 node_examples[node_index].push_back(examples_[example_index]);
276 const std::size_t num_of_nodes = 0x1 << fern_depth_;
277 for (std::size_t node_index = 0; node_index < num_of_nodes; ++node_index) {
278 stats_estimator_->computeAndSetNodeStats(data_set_,
279 node_examples[node_index],
280 node_labels[node_index],
285 template <
class FeatureType,
293 std::vector<float>& values,
294 std::vector<float>& thresholds)
297 float min_value = ::std::numeric_limits<float>::max();
298 float max_value = -::std::numeric_limits<float>::max();
300 const std::size_t num_of_values = values.size();
301 for (
int value_index = 0; value_index < num_of_values; ++value_index) {
302 const float value = values[value_index];
304 if (value < min_value)
306 if (value > max_value)
310 const float range = max_value - min_value;
311 const float step = range / (num_of_thresholds + 2);
314 thresholds.resize(num_of_thresholds);
316 for (
int threshold_index = 0; threshold_index < num_of_thresholds;
318 thresholds[threshold_index] = min_value + step * (threshold_index + 1);
Class representing a Fern.
float & accessThreshold(const std::size_t threshold_index)
Access operator for thresholds.
void initialize(const std::size_t num_of_decisions)
Initializes the fern.
FeatureType & accessFeature(const std::size_t feature_index)
Access operator for features.
static void createThresholdsUniform(const std::size_t num_of_thresholds, std::vector< float > &values, std::vector< float > &thresholds)
Creates uniformly distributed thresholds over the range of the supplied values.
void train(Fern< FeatureType, NodeType > &fern)
Trains a decision tree using the set training data and settings.
FernTrainer()
Constructor.