ndarray.cpp 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. // Copyright Jim Bosch & Ankit Daftery 2010-2012.
  2. // Copyright Stefan Seefeld 2016.
  3. // Distributed under the Boost Software License, Version 1.0.
  4. // (See accompanying file LICENSE_1_0.txt or copy at
  5. // http://www.boost.org/LICENSE_1_0.txt)
  6. #include <boost/python/numpy.hpp>
  7. namespace p = boost::python;
  8. namespace np = boost::python::numpy;
  9. np::ndarray zeros(p::tuple shape, np::dtype dt) { return np::zeros(shape, dt);}
  10. np::ndarray array2(p::object obj, np::dtype dt) { return np::array(obj,dt);}
  11. np::ndarray array1(p::object obj) { return np::array(obj);}
  12. np::ndarray empty1(p::tuple shape, np::dtype dt) { return np::empty(shape,dt);}
  13. np::ndarray c_empty(p::tuple shape, np::dtype dt)
  14. {
  15. // convert 'shape' to a C array so we can test the corresponding
  16. // version of the constructor
  17. unsigned len = p::len(shape);
  18. Py_intptr_t *c_shape = new Py_intptr_t[len];
  19. for (unsigned i = 0; i != len; ++i)
  20. c_shape[i] = p::extract<Py_intptr_t>(shape[i]);
  21. np::ndarray result = np::empty(len, c_shape, dt);
  22. delete [] c_shape;
  23. return result;
  24. }
  25. np::ndarray transpose(np::ndarray arr) { return arr.transpose();}
  26. np::ndarray squeeze(np::ndarray arr) { return arr.squeeze();}
  27. np::ndarray reshape(np::ndarray arr,p::tuple tup) { return arr.reshape(tup);}
  28. Py_intptr_t shape_index(np::ndarray arr,int k) { return arr.shape(k); }
  29. Py_intptr_t strides_index(np::ndarray arr,int k) { return arr.strides(k); }
  30. BOOST_PYTHON_MODULE(ndarray_ext)
  31. {
  32. np::initialize();
  33. p::def("zeros", zeros);
  34. p::def("zeros_matrix", zeros, np::as_matrix<>());
  35. p::def("array", array2);
  36. p::def("array", array1);
  37. p::def("empty", empty1);
  38. p::def("c_empty", c_empty);
  39. p::def("transpose", transpose);
  40. p::def("squeeze", squeeze);
  41. p::def("reshape", reshape);
  42. p::def("shape_index", shape_index);
  43. p::def("strides_index", strides_index);
  44. }