prod_test.hpp 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #ifndef TEST_PROD_OPENCL_HH
  2. #define TEST_PROD_OPENCL_HH
  3. #include "test_opencl.hpp"
  4. template <class T, class F, int number_of_tests, int max_dimension>
  5. class bench_prod
  6. {
  7. public:
  8. typedef test_opencl<T, F> test;
  9. void run()
  10. {
  11. opencl::library lib;
  12. int passedOperations = 0;
  13. // get default device and setup context
  14. compute::device device = compute::system::default_device();
  15. compute::context context(device);
  16. compute::command_queue queue(context, device);
  17. std::srand(time(0));
  18. ublas::matrix<T, F> a;
  19. ublas::matrix<T, F> b;
  20. ublas::matrix<T, F> resultUBLAS;
  21. ublas::matrix<T, F> resultOPENCL;
  22. ublas::vector<T> va;
  23. ublas::vector<T> vb;
  24. ublas::vector<T> result_vector_ublas_mv;
  25. ublas::vector<T> result_vector_ublas_vm;
  26. ublas::vector<T> result_vector_opencl_mv;
  27. ublas::vector<T> result_vector_opencl_vm;
  28. for (int i = 0; i<number_of_tests; i++)
  29. {
  30. int rowsA = std::rand() % max_dimension + 1;
  31. int colsA = std::rand() % max_dimension + 1;
  32. int colsB = std::rand() % max_dimension + 1;
  33. a.resize(rowsA, colsA);
  34. b.resize(colsA, colsB);
  35. va.resize(colsA);
  36. vb.resize(rowsA);
  37. test::init_matrix(a, 200);
  38. test::init_matrix(b, 200);
  39. test::init_vector(va, 200);
  40. test::init_vector(vb, 200);
  41. //matrix_matrix
  42. resultUBLAS = prod(a, b);
  43. resultOPENCL = opencl::prod(a, b, queue);
  44. //matrix_vector
  45. result_vector_ublas_mv = ublas::prod(a, va);
  46. result_vector_opencl_mv = opencl::prod(a, va, queue);
  47. //vector-matrix
  48. result_vector_ublas_vm = ublas::prod(vb, a);
  49. result_vector_opencl_vm = opencl::prod(vb, a, queue);
  50. if ((!test::compare(resultUBLAS, resultOPENCL)) || (!test::compare(result_vector_opencl_mv, result_vector_ublas_mv)) || (!test::compare(result_vector_opencl_vm, result_vector_ublas_vm)))
  51. {
  52. std::cout << "Error in calculations" << std::endl;
  53. std::cout << "passed: " << passedOperations << std::endl;
  54. return;
  55. }
  56. passedOperations++;
  57. }
  58. std::cout << "All is well (matrix opencl prod) of " << typeid(T).name() << std::endl;
  59. }
  60. };
  61. #endif