41 #include <pcl/ml/branch_estimator.h>
42 #include <pcl/ml/stats_estimator.h>
50 template <
class FeatureType,
class LabelType>
63 feature.serialize(stream);
65 stream.write(
reinterpret_cast<const char*
>(&threshold),
sizeof(threshold));
67 stream.write(
reinterpret_cast<const char*
>(&value),
sizeof(value));
68 stream.write(
reinterpret_cast<const char*
>(&variance),
sizeof(variance));
70 const int num_of_sub_nodes =
static_cast<int>(sub_nodes.size());
71 stream.write(
reinterpret_cast<const char*
>(&num_of_sub_nodes),
72 sizeof(num_of_sub_nodes));
73 for (
int sub_node_index = 0; sub_node_index < num_of_sub_nodes; ++sub_node_index) {
74 sub_nodes[sub_node_index].serialize(stream);
85 feature.deserialize(stream);
87 stream.read(
reinterpret_cast<char*
>(&threshold),
sizeof(threshold));
89 stream.read(
reinterpret_cast<char*
>(&value),
sizeof(value));
90 stream.read(
reinterpret_cast<char*
>(&variance),
sizeof(variance));
93 stream.read(
reinterpret_cast<char*
>(&num_of_sub_nodes),
sizeof(num_of_sub_nodes));
94 sub_nodes.resize(num_of_sub_nodes);
96 if (num_of_sub_nodes > 0) {
97 for (
int sub_node_index = 0; sub_node_index < num_of_sub_nodes;
99 sub_nodes[sub_node_index].deserialize(stream);
122 template <
class LabelDataType,
class NodeType,
class DataSet,
class ExampleIndex>
128 : branch_estimator_(branch_estimator)
136 return branch_estimator_->getNumOfBranches();
160 std::vector<ExampleIndex>& examples,
161 std::vector<LabelDataType>& label_data,
162 std::vector<float>& results,
163 std::vector<unsigned char>& flags,
164 const float threshold)
const
166 const std::size_t num_of_examples = examples.size();
167 const std::size_t num_of_branches = getNumOfBranches();
170 std::vector<LabelDataType> sums(num_of_branches + 1, 0);
171 std::vector<LabelDataType> sqr_sums(num_of_branches + 1, 0);
172 std::vector<std::size_t> branch_element_count(num_of_branches + 1, 0);
174 for (std::size_t branch_index = 0; branch_index < num_of_branches; ++branch_index) {
175 branch_element_count[branch_index] = 1;
176 ++branch_element_count[num_of_branches];
179 for (std::size_t example_index = 0; example_index < num_of_examples;
181 unsigned char branch_index;
183 results[example_index], flags[example_index], threshold, branch_index);
185 LabelDataType label = label_data[example_index];
187 sums[branch_index] += label;
188 sums[num_of_branches] += label;
190 sqr_sums[branch_index] += label * label;
191 sqr_sums[num_of_branches] += label * label;
193 ++branch_element_count[branch_index];
194 ++branch_element_count[num_of_branches];
197 std::vector<float> variances(num_of_branches + 1, 0);
198 for (std::size_t branch_index = 0; branch_index < num_of_branches + 1;
200 const float mean_sum =
201 static_cast<float>(sums[branch_index]) / branch_element_count[branch_index];
202 const float mean_sqr_sum =
static_cast<float>(sqr_sums[branch_index]) /
203 branch_element_count[branch_index];
204 variances[branch_index] = mean_sqr_sum - mean_sum * mean_sum;
207 float information_gain = variances[num_of_branches];
208 for (std::size_t branch_index = 0; branch_index < num_of_branches; ++branch_index) {
211 const float weight =
static_cast<float>(branch_element_count[branch_index]) /
212 static_cast<float>(branch_element_count[num_of_branches]);
213 information_gain -= weight * variances[branch_index];
216 return information_gain;
228 std::vector<unsigned char>& flags,
229 const float threshold,
230 std::vector<unsigned char>& branch_indices)
const
232 const std::size_t num_of_results = results.size();
233 const std::size_t num_of_branches = getNumOfBranches();
235 branch_indices.resize(num_of_results);
236 for (std::size_t result_index = 0; result_index < num_of_results; ++result_index) {
237 unsigned char branch_index;
239 results[result_index], flags[result_index], threshold, branch_index);
240 branch_indices[result_index] = branch_index;
253 const unsigned char flag,
254 const float threshold,
255 unsigned char& branch_index)
const
257 branch_estimator_->computeBranchIndex(result, flag, threshold, branch_index);
271 std::vector<ExampleIndex>& examples,
272 std::vector<LabelDataType>& label_data,
273 NodeType& node)
const
275 const std::size_t num_of_examples = examples.size();
277 LabelDataType sum = 0.0f;
278 LabelDataType sqr_sum = 0.0f;
279 for (std::size_t example_index = 0; example_index < num_of_examples;
281 const LabelDataType label = label_data[example_index];
284 sqr_sum += label * label;
287 sum /= num_of_examples;
288 sqr_sum /= num_of_examples;
290 const float variance = sqr_sum - sum * sum;
293 node.variance = variance;
304 stream <<
"ERROR: RegressionVarianceStatsEstimator does not implement "
305 "generateCodeForBranchIndex(...)";
316 stream <<
"ERROR: RegressionVarianceStatsEstimator does not implement "
317 "generateCodeForBranchIndex(...)";
Interface for branch estimators.
Node for a regression trees which optimizes variance.
RegressionVarianceNode()
Constructor.
void serialize(std::ostream &stream) const
Serializes the node to the specified stream.
LabelType variance
The variance of the labels that ended up at this node during training.
void deserialize(std::istream &stream)
Deserializes a node from the specified stream.
float threshold
The threshold applied on the feature response.
FeatureType feature
The feature associated with the node.
LabelType value
The label value of this node.
std::vector< RegressionVarianceNode > sub_nodes
The child nodes.
Statistics estimator for regression trees which optimizes variance.
void generateCodeForOutput(NodeType &node, std::ostream &stream) const
Generates code for label output.
void computeAndSetNodeStats(DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelDataType > &label_data, NodeType &node) const
Computes and sets the statistics for a node.
void computeBranchIndex(const float result, const unsigned char flag, const float threshold, unsigned char &branch_index) const
Computes the branch index for the specified result.
LabelDataType getLabelOfNode(NodeType &node) const
Returns the label of the specified node.
void computeBranchIndices(std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold, std::vector< unsigned char > &branch_indices) const
Computes the branch indices for all supplied results.
RegressionVarianceStatsEstimator(BranchEstimator *branch_estimator)
Constructor.
float computeInformationGain(DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelDataType > &label_data, std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold) const
Computes the information gain obtained by the specified threshold.
std::size_t getNumOfBranches() const
Returns the number of branches the corresponding tree has.
void generateCodeForBranchIndexComputation(NodeType &node, std::ostream &stream) const
Generates code for branch index computation.
Class interface for gathering statistics for decision tree learning.
Define standard C methods and C++ classes that are common to all methods.