ndarray.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. #!/usr/bin/env python
  2. # Copyright Jim Bosch & Ankit Daftery 2010-2012.
  3. # Distributed under the Boost Software License, Version 1.0.
  4. # (See accompanying file LICENSE_1_0.txt or copy at
  5. # http://www.boost.org/LICENSE_1_0.txt)
  6. import ndarray_ext
  7. import unittest
  8. import numpy
  9. class TestNdarray(unittest.TestCase):
  10. def testNdzeros(self):
  11. for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128):
  12. v = numpy.zeros(60, dtype=dtp)
  13. dt = numpy.dtype(dtp)
  14. for shape in ((60,),(6,10),(4,3,5),(2,2,3,5)):
  15. a1 = ndarray_ext.zeros(shape,dt)
  16. a2 = v.reshape(a1.shape)
  17. self.assertEqual(shape,a1.shape)
  18. self.assert_((a1 == a2).all())
  19. def testNdzeros_matrix(self):
  20. for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128):
  21. dt = numpy.dtype(dtp)
  22. shape = (6, 10)
  23. a1 = ndarray_ext.zeros_matrix(shape, dt)
  24. a2 = numpy.matrix(numpy.zeros(shape, dtype=dtp))
  25. self.assertEqual(shape,a1.shape)
  26. self.assert_((a1 == a2).all())
  27. self.assertEqual(type(a1), type(a2))
  28. def testNdarray(self):
  29. a = range(0,60)
  30. for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128):
  31. v = numpy.array(a, dtype=dtp)
  32. dt = numpy.dtype(dtp)
  33. a1 = ndarray_ext.array(a)
  34. a2 = ndarray_ext.array(a,dt)
  35. self.assert_((a1 == v).all())
  36. self.assert_((a2 == v).all())
  37. for shape in ((60,),(6,10),(4,3,5),(2,2,3,5)):
  38. a1 = a1.reshape(shape)
  39. self.assertEqual(shape,a1.shape)
  40. a2 = a2.reshape(shape)
  41. self.assertEqual(shape,a2.shape)
  42. def testNdempty(self):
  43. for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128):
  44. dt = numpy.dtype(dtp)
  45. for shape in ((60,),(6,10),(4,3,5),(2,2,3,5)):
  46. a1 = ndarray_ext.empty(shape,dt)
  47. a2 = ndarray_ext.c_empty(shape,dt)
  48. self.assertEqual(shape,a1.shape)
  49. self.assertEqual(shape,a2.shape)
  50. def testTranspose(self):
  51. for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128):
  52. dt = numpy.dtype(dtp)
  53. for shape in ((6,10),(4,3,5),(2,2,3,5)):
  54. a1 = numpy.empty(shape,dt)
  55. a2 = a1.transpose()
  56. a1 = ndarray_ext.transpose(a1)
  57. self.assertEqual(a1.shape,a2.shape)
  58. def testSqueeze(self):
  59. a1 = numpy.array([[[3,4,5]]])
  60. a2 = a1.squeeze()
  61. a1 = ndarray_ext.squeeze(a1)
  62. self.assertEqual(a1.shape,a2.shape)
  63. def testReshape(self):
  64. a1 = numpy.empty((2,2))
  65. a2 = ndarray_ext.reshape(a1,(1,4))
  66. self.assertEqual(a2.shape,(1,4))
  67. def testShapeIndex(self):
  68. a = numpy.arange(24)
  69. a.shape = (1,2,3,4)
  70. def shape_check(i):
  71. print(i)
  72. self.assertEqual(ndarray_ext.shape_index(a,i) ,a.shape[i] )
  73. for i in range(4):
  74. shape_check(i)
  75. for i in range(-1,-5,-1):
  76. shape_check(i)
  77. try:
  78. ndarray_ext.shape_index(a,4) # out of bounds -- should raise IndexError
  79. self.assertTrue(False)
  80. except IndexError:
  81. pass
  82. def testStridesIndex(self):
  83. a = numpy.arange(24)
  84. a.shape = (1,2,3,4)
  85. def strides_check(i):
  86. print(i)
  87. self.assertEqual(ndarray_ext.strides_index(a,i) ,a.strides[i] )
  88. for i in range(4):
  89. strides_check(i)
  90. for i in range(-1,-5,-1):
  91. strides_check(i)
  92. try:
  93. ndarray_ext.strides_index(a,4) # out of bounds -- should raise IndexError
  94. self.assertTrue(False)
  95. except IndexError:
  96. pass
  97. if __name__=="__main__":
  98. unittest.main()