JOIN
Get Time
forums   
Search | Watch Thread  |  My Post History  |  My Watches  |  User Settings
View: Flat (newest first)  | Threaded  | Tree
Previous Thread  |  Next Thread
Using STL for binary search on functions | Reply
After reading your tutorial (very good, thanks) and doing some of the practice problems I thought: Why code the same stuff over and over? Why not use STL? The algorithms are all there (i.e. upper_bound, lower_bound, ...). They are mostly used on containers, but there is no such restriction in STL. They can work on anything you have an iterator for. So, define an iterator that just stores the function argument and define operator++ increasing the function argument and operator* that computes the function (plus some boilerplate code STL needs).

As proof of concept I came up with the following for computing the first/last value of n such that n^2 is larger, larger or equal/smaller, smaller or equal than some q.
#include <algorithm>
#include <iostream>
#include <sstream>
#include <cassert>
 
using namespace std;
 
// the iterator
template<typename U, typename T >
struct function_iterator 
  :  public iterator<random_access_iterator_tag, U, T>
{
  U n;
  T (*func)(U);
  function_iterator<U,T>() : n(0), func(0) {}
  function_iterator<U,T>( U _n, T (*_func)(U) ) : n(_n), func(_func) {}
 
  function_iterator operator++() { return function_iterator<U,T>(++n, func); }
  U operator-(const function_iterator<U,T>& it) { return n-it.n; }
  const function_iterator<U,T>& operator+=(U diff) { n+=diff; return *this; }
  bool operator!=(const function_iterator<U,T>& x) { return n!=x.n; }
 
  T operator*() const { return func(n); }
  U operator&() const { return n; }
};
 
// some convenience functions using the iterator
template<typename U, typename T>
U find_first_larger_or_equal( T (*func)(U), U a, U b, T x ) {
  function_iterator<U,T> lo( a, func ), hi( b, func );
  return &lower_bound( lo, hi, x );
}
 
template<typename U, typename T>
U find_first_larger( T (*func)(U), U a, U b, T x ) {
  function_iterator<U,T> lo( a, func ), hi( b, func );
  return &upper_bound( lo, hi, x );
}
 
template<typename U, typename T>
U find_last_smaller( T (*func)(U), U a, U b, T x ) {
  return find_first_larger_or_equal( func, a, b, x )-1;
}
 
template<typename U, typename T>
U find_last_smaller_or_equal( T (*func)(U), U a, U b, T x ) {
  return find_first_larger( func, a, b, x )-1;
}
 
// extended assert macro
#define assert_equal( _eval, _expect ) do {				\
    typeof(_eval) x=_eval;						\
    if( x!=_expect ) {							\
      ostringstream oss;						\
      oss << __STRING(_eval) << "==" << __STRING(_expect) << " (got " << x << ")"; \
      __assert_fail( oss.str().c_str(), __FILE__, __LINE__, __ASSERT_FUNCTION); \
    }									\
  } while( false )
 
// the target function
long long square( int i ) { return 1LL*i*i; };
 
int main() {
  assert_equal( find_first_larger_or_equal( square, 1, 21, 1LL ), 1 );
  assert_equal( find_first_larger_or_equal( square, 1, 21, 48LL ), 7 );
  assert_equal( find_first_larger_or_equal( square, 1, 21, 49LL ), 7 );
  assert_equal( find_first_larger_or_equal( square, 1, 21, 50LL ), 8 );
  assert_equal( find_first_larger_or_equal( square, 1, 21, 400LL ), 20);
 
  assert_equal( find_first_larger( square, 1, 21, 0LL ), 1 );
  assert_equal( find_first_larger( square, 1, 21, 48LL ), 7 );
  assert_equal( find_first_larger( square, 1, 21, 49LL ), 8 );
  assert_equal( find_first_larger( square, 1, 21, 50LL ), 8 );
  assert_equal( find_first_larger( square, 1, 21, 399LL ), 20 );
 
  assert_equal( find_last_smaller( square, 1, 21, 2LL ), 1 );
  assert_equal( find_last_smaller( square, 1, 21, 48LL ), 6 );
  assert_equal( find_last_smaller( square, 1, 21, 49LL ), 6 );
  assert_equal( find_last_smaller( square, 1, 21, 50LL ), 7 );
  assert_equal( find_last_smaller( square, 1, 21, 440LL ), 20 );
 
  assert_equal( find_last_smaller_or_equal( square, 1, 21, 1LL ), 1 );
  assert_equal( find_last_smaller_or_equal( square, 1, 21, 48LL ), 6 );
  assert_equal( find_last_smaller_or_equal( square, 1, 21, 49LL ), 7 );
  assert_equal( find_last_smaller_or_equal( square, 1, 21, 50LL ), 7 );
  assert_equal( find_last_smaller_or_equal( square, 1, 21, 440LL ), 20 );
 
  return 0;
}


I think this can come in quite handy and this way you're not prone to errors. If speed is really an issue, you can even replace the function pointer with some templaticed unary function class and get rid of the function call overhead. What do you think of it?

EDIT: have to use a random_access_iterator for large argument ranges (the stl methods can work with forward_iterators and call your function only O(lg n) times but will call operator++ O(n) times if n is the size of the argument range.)
RSS