[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_preprocessing.hxx VIGRA

1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 
36 #ifndef VIGRA_RF_PREPROCESSING_HXX
37 #define VIGRA_RF_PREPROCESSING_HXX
38 
39 #include <limits>
40 #include <vigra/mathutil.hxx>
41 #include "rf_common.hxx"
42 
43 namespace vigra
44 {
45 
46 /** Class used while preprocessing (currently used only during learn)
47  *
48  * This class is internally used by the Random Forest learn function.
49  * Different split functors may need to process the data in different manners
50  * (i.e., regression labels that should not be touched and classification
51  * labels that must be converted into a integral format)
52  *
53  * This Class only exists in specialized versions, where the Tag class is
54  * fixed.
55  *
56  * The Tag class is determined by Splitfunctor::Preprocessor_t . Currently
57  * it can either be ClassificationTag or RegressionTag. look At the
58  * RegressionTag specialisation for the basic interface if you ever happen
59  * to care.... - or need some sort of vague new preprocessor.
60  * new preprocessor ( Soft labels or whatever)
61  */
62 template<class Tag, class LabelType, class T1, class C1, class T2, class C2>
63 class Processor;
64 
65 namespace detail
66 {
67 
68  /* Common helper function used in all Processors.
69  * This function analyses the options struct and calculates the real
70  * values needed for the current problem (data)
71  */
72  template<class T>
73  void fill_external_parameters(RandomForestOptions const & options,
74  ProblemSpec<T> & ext_param)
75  {
76  // set correct value for mtry
77  switch(options.mtry_switch_)
78  {
79  case RF_SQRT:
80  ext_param.actual_mtry_ =
81  int(std::floor(
82  std::sqrt(double(ext_param.column_count_))
83  + 0.5));
84  break;
85  case RF_LOG:
86  // this is in Breimans original paper
87  ext_param.actual_mtry_ =
88  int(1+(std::log(double(ext_param.column_count_))
89  /std::log(2.0)));
90  break;
91  case RF_FUNCTION:
92  ext_param.actual_mtry_ =
93  options.mtry_func_(ext_param.column_count_);
94  break;
95  case RF_ALL:
96  ext_param.actual_mtry_ = ext_param.column_count_;
97  break;
98  default:
99  ext_param.actual_mtry_ =
100  options.mtry_;
101  }
102  // set correct value for msample
103  switch(options.training_set_calc_switch_)
104  {
105  case RF_CONST:
106  ext_param.actual_msample_ =
107  options.training_set_size_;
108  break;
109  case RF_PROPORTIONAL:
110  ext_param.actual_msample_ =
111  static_cast<int>(std::ceil(options.training_set_proportion_ *
112  ext_param.row_count_));
113  break;
114  case RF_FUNCTION:
115  ext_param.actual_msample_ =
116  options.training_set_func_(ext_param.row_count_);
117  break;
118  default:
119  vigra_precondition(1!= 1, "unexpected error");
120 
121  }
122 
123  }
124 
125  /* Returns true if MultiArray contains NaNs
126  */
127  template<unsigned int N, class T, class C>
128  bool contains_nan(MultiArrayView<N, T, C> const & in)
129  {
130  for(int ii = 0; ii < in.size(); ++ii)
131  if(isnan(in[ii]))
132  return true;
133  return false;
134  }
135 
136  /* Returns true if MultiArray contains Infs
137  */
138  template<unsigned int N, class T, class C>
139  bool contains_inf(MultiArrayView<N, T, C> const & in)
140  {
141  if(!std::numeric_limits<T>::has_infinity)
142  return false;
143  for(int ii = 0; ii < in.size(); ++ii)
144  if(abs(in[ii]) == std::numeric_limits<T>::infinity())
145  return true;
146  return false;
147  }
148 } // namespace detail
149 
150 
151 
152 /** Preprocessor used during Classification
153  *
154  * This class converts the labels int Integral labels which are used by the
155  * standard split functor to address memory in the node objects.
156  */
157 template<class LabelType, class T1, class C1, class T2, class C2>
158 class Processor<ClassificationTag, LabelType, T1, C1, T2, C2>
159 {
160  public:
161  typedef Int32 LabelInt;
165  MultiArrayView<2, T1, C1>const & features_;
166  MultiArray<2, LabelInt> intLabels_;
168 
169  template<class T>
170  Processor(MultiArrayView<2, T1, C1>const & features,
171  MultiArrayView<2, T2, C2>const & response,
172  RandomForestOptions &options,
173  ProblemSpec<T> &ext_param)
174  :
175  features_( features) // do not touch the features.
176  {
177  vigra_precondition(!detail::contains_nan(features), "RandomForest(): Feature matrix "
178  "contains NaNs");
179  vigra_precondition(!detail::contains_nan(response), "RandomForest(): Response "
180  "contains NaNs");
181  vigra_precondition(!detail::contains_inf(features), "RandomForest(): Feature matrix "
182  "contains inf");
183  vigra_precondition(!detail::contains_inf(response), "RandomForest(): Response "
184  "contains inf");
185  // set some of the problem specific parameters
186  ext_param.column_count_ = features.shape(1);
187  ext_param.row_count_ = features.shape(0);
188  ext_param.problem_type_ = CLASSIFICATION;
189  ext_param.used_ = true;
190  intLabels_.reshape(response.shape());
191 
192  //get the class labels
193  if(ext_param.class_count_ == 0)
194  {
195  // fill up a map with the current labels and then create the
196  // integral labels.
197  std::set<T2> labelToInt;
198  for(MultiArrayIndex k = 0; k < features.shape(0); ++k)
199  labelToInt.insert(response(k,0));
200  std::vector<T2> tmp_(labelToInt.begin(), labelToInt.end());
201  ext_param.classes_(tmp_.begin(), tmp_.end());
202  }
203  for(MultiArrayIndex k = 0; k < features.shape(0); ++k)
204  {
205  if(std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0)) == ext_param.classes.end())
206  {
207  throw std::runtime_error("RandomForest(): invalid label in training data.");
208  }
209  else
210  intLabels_(k, 0) = std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0))
211  - ext_param.classes.begin();
212  }
213  // set class weights
214  if(ext_param.class_weights_.size() == 0)
215  {
217  tmp(static_cast<std::size_t>(ext_param.class_count_),
218  NumericTraits<T2>::one());
219  ext_param.class_weights(tmp.begin(), tmp.end());
220  }
221 
222  // set mtry and msample
223  detail::fill_external_parameters(options, ext_param);
224 
225  // set strata
226  strata_ = intLabels_;
227 
228  }
229 
230  /** Access the processed features
231  */
233  {
234  return features_;
235  }
236 
237  /** Access processed labels
238  */
240  {
241  return MultiArrayView<2, LabelInt>(intLabels_);
242  }
243 
244  /** Access processed strata
245  */
247  {
248  return ArrayVectorView<LabelInt>(intLabels_.size(), intLabels_.data());
249  }
250 
251  /** Access strata fraction sized - not used currently
252  */
254  {
255  return ArrayVectorView< double>();
256  }
257 };
258 
259 
260 
261 /** Regression Preprocessor - This basically does not do anything with the
262  * data.
263  */
264 template<class LabelType, class T1, class C1, class T2, class C2>
265 class Processor<RegressionTag,LabelType, T1, C1, T2, C2>
266 {
267 public:
268  // only views are created - no data copied.
269  MultiArrayView<2, T1, C1> features_;
270  MultiArrayView<2, T2, C2> response_;
271  RandomForestOptions const & options_;
272  ProblemSpec<LabelType> const &
273  ext_param_;
274  // will only be filled if needed
275  MultiArray<2, int> strata_;
276  bool strata_filled;
277 
278  // copy the views.
279  template<class T>
281  MultiArrayView<2, T2, C2> response,
282  RandomForestOptions const & options,
283  ProblemSpec<T>& ext_param)
284  :
285  features_(features),
286  response_(response),
287  options_(options),
288  ext_param_(ext_param)
289  {
290  // set some of the problem specific parameters
291  ext_param.column_count_ = features.shape(1);
292  ext_param.row_count_ = features.shape(0);
293  ext_param.problem_type_ = REGRESSION;
294  ext_param.used_ = true;
295  detail::fill_external_parameters(options, ext_param);
296  vigra_precondition(!detail::contains_nan(features), "Processor(): Feature Matrix "
297  "Contains NaNs");
298  vigra_precondition(!detail::contains_nan(response), "Processor(): Response "
299  "Contains NaNs");
300  vigra_precondition(!detail::contains_inf(features), "Processor(): Feature Matrix "
301  "Contains inf");
302  vigra_precondition(!detail::contains_inf(response), "Processor(): Response "
303  "Contains inf");
304  strata_ = MultiArray<2, int> (MultiArrayShape<2>::type(response_.shape(0), 1));
305  ext_param.response_size_ = response.shape(1);
306  ext_param.class_count_ = response_.shape(1);
307  std::vector<T2> tmp_(ext_param.class_count_, 0);
308  ext_param.classes_(tmp_.begin(), tmp_.end());
309  }
310 
311  /** access preprocessed features
312  */
314  {
315  return features_;
316  }
317 
318  /** access preprocessed response
319  */
321  {
322  return response_;
323  }
324 
325  /** access strata - this is not used currently
326  */
328  {
329  return strata_;
330  }
331 };
332 }
333 #endif //VIGRA_RF_PREPROCESSING_HXX
334 
335 
336 

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.10.0 (Thu Jan 8 2015)