36 #ifndef VIGRA_RF_PREPROCESSING_HXX
37 #define VIGRA_RF_PREPROCESSING_HXX
40 #include <vigra/mathutil.hxx>
41 #include "rf_common.hxx"
62 template<
class Tag,
class LabelType,
class T1,
class C1,
class T2,
class C2>
77 switch(options.mtry_switch_)
80 ext_param.actual_mtry_ =
82 std::sqrt(
double(ext_param.column_count_))
87 ext_param.actual_mtry_ =
88 int(1+(
std::log(
double(ext_param.column_count_))
92 ext_param.actual_mtry_ =
93 options.mtry_func_(ext_param.column_count_);
96 ext_param.actual_mtry_ = ext_param.column_count_;
99 ext_param.actual_mtry_ =
103 switch(options.training_set_calc_switch_)
106 ext_param.actual_msample_ =
107 options.training_set_size_;
109 case RF_PROPORTIONAL:
110 ext_param.actual_msample_ =
111 static_cast<int>(
std::ceil(options.training_set_proportion_ *
112 ext_param.row_count_));
115 ext_param.actual_msample_ =
116 options.training_set_func_(ext_param.row_count_);
119 vigra_precondition(1!= 1,
"unexpected error");
127 template<
unsigned int N,
class T,
class C>
128 bool contains_nan(MultiArrayView<N, T, C>
const & in)
130 for(
int ii = 0; ii < in.size(); ++ii)
138 template<
unsigned int N,
class T,
class C>
139 bool contains_inf(MultiArrayView<N, T, C>
const & in)
141 if(!std::numeric_limits<T>::has_infinity)
143 for(
int ii = 0; ii < in.size(); ++ii)
144 if(
abs(in[ii]) == std::numeric_limits<T>::infinity())
157 template<
class LabelType,
class T1,
class C1,
class T2,
class C2>
158 class Processor<ClassificationTag, LabelType, T1, C1, T2, C2>
161 typedef Int32 LabelInt;
177 vigra_precondition(!detail::contains_nan(features),
"RandomForest(): Feature matrix "
179 vigra_precondition(!detail::contains_nan(response),
"RandomForest(): Response "
181 vigra_precondition(!detail::contains_inf(features),
"RandomForest(): Feature matrix "
183 vigra_precondition(!detail::contains_inf(response),
"RandomForest(): Response "
186 ext_param.column_count_ = features.
shape(1);
187 ext_param.row_count_ = features.
shape(0);
188 ext_param.problem_type_ = CLASSIFICATION;
189 ext_param.used_ =
true;
190 intLabels_.reshape(response.
shape());
193 if(ext_param.class_count_ == 0)
197 std::set<T2> labelToInt;
199 labelToInt.insert(response(k,0));
200 std::vector<T2> tmp_(labelToInt.begin(), labelToInt.end());
201 ext_param.
classes_(tmp_.begin(), tmp_.end());
205 if(std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0)) == ext_param.classes.end())
207 throw std::runtime_error(
"RandomForest(): invalid label in training data.");
210 intLabels_(k, 0) = std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0))
211 - ext_param.classes.begin();
214 if(ext_param.class_weights_.size() == 0)
217 tmp(static_cast<std::size_t>(ext_param.class_count_),
218 NumericTraits<T2>::one());
223 detail::fill_external_parameters(options, ext_param);
226 strata_ = intLabels_;
264 template<
class LabelType,
class T1,
class C1,
class T2,
class C2>
265 class Processor<RegressionTag,LabelType, T1, C1, T2, C2>
288 ext_param_(ext_param)
291 ext_param.column_count_ = features.
shape(1);
292 ext_param.row_count_ = features.
shape(0);
293 ext_param.problem_type_ = REGRESSION;
294 ext_param.used_ =
true;
295 detail::fill_external_parameters(options, ext_param);
296 vigra_precondition(!detail::contains_nan(features),
"Processor(): Feature Matrix "
298 vigra_precondition(!detail::contains_nan(response),
"Processor(): Response "
300 vigra_precondition(!detail::contains_inf(features),
"Processor(): Feature Matrix "
302 vigra_precondition(!detail::contains_inf(response),
"Processor(): Response "
305 ext_param.response_size_ = response.
shape(1);
306 ext_param.class_count_ = response_.shape(1);
307 std::vector<T2> tmp_(ext_param.class_count_, 0);
308 ext_param.
classes_(tmp_.begin(), tmp_.end());
333 #endif //VIGRA_RF_PREPROCESSING_HXX