templates.cpp 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. // Copyright Jim Bosch & Ankit Daftery 2010-2012.
  2. // Copyright Stefan Seefeld 2016.
  3. // Distributed under the Boost Software License, Version 1.0.
  4. // (See accompanying file LICENSE_1_0.txt or copy at
  5. // http://www.boost.org/LICENSE_1_0.txt)
  6. #include <boost/python/numpy.hpp>
  7. #include <boost/mpl/vector.hpp>
  8. #include <boost/mpl/vector_c.hpp>
  9. namespace p = boost::python;
  10. namespace np = boost::python::numpy;
  11. struct ArrayFiller
  12. {
  13. typedef boost::mpl::vector< short, int, float, std::complex<double> > TypeSequence;
  14. typedef boost::mpl::vector_c< int, 1, 2 > DimSequence;
  15. explicit ArrayFiller(np::ndarray const & arg) : argument(arg) {}
  16. template <typename T, int N>
  17. void apply() const
  18. {
  19. if (N == 1)
  20. {
  21. char * p = argument.get_data();
  22. int stride = argument.strides(0);
  23. int size = argument.shape(0);
  24. for (int n = 0; n != size; ++n, p += stride)
  25. *reinterpret_cast<T*>(p) = static_cast<T>(n);
  26. }
  27. else
  28. {
  29. char * row_p = argument.get_data();
  30. int row_stride = argument.strides(0);
  31. int col_stride = argument.strides(1);
  32. int rows = argument.shape(0);
  33. int cols = argument.shape(1);
  34. int i = 0;
  35. for (int n = 0; n != rows; ++n, row_p += row_stride)
  36. {
  37. char * col_p = row_p;
  38. for (int m = 0; m != cols; ++i, ++m, col_p += col_stride)
  39. *reinterpret_cast<T*>(col_p) = static_cast<T>(i);
  40. }
  41. }
  42. }
  43. np::ndarray argument;
  44. };
  45. void fill(np::ndarray const & arg)
  46. {
  47. ArrayFiller filler(arg);
  48. np::invoke_matching_array<ArrayFiller::TypeSequence, ArrayFiller::DimSequence >(arg, filler);
  49. }
  50. BOOST_PYTHON_MODULE(templates_ext)
  51. {
  52. np::initialize();
  53. p::def("fill", fill);
  54. }