test_scatter_if.cpp 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2015 Jakub Pola <jakub.pola@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #define BOOST_TEST_MODULE TestScatterIf
  11. #include <boost/test/unit_test.hpp>
  12. #include <boost/compute/system.hpp>
  13. #include <boost/compute/algorithm/scatter_if.hpp>
  14. #include <boost/compute/container/vector.hpp>
  15. #include <boost/compute/iterator/constant_buffer_iterator.hpp>
  16. #include <boost/compute/iterator/counting_iterator.hpp>
  17. #include <boost/compute/functional.hpp>
  18. #include "check_macros.hpp"
  19. #include "context_setup.hpp"
  20. namespace bc = boost::compute;
  21. BOOST_AUTO_TEST_CASE(scatter_if_int)
  22. {
  23. int input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
  24. bc::vector<int> input(input_data, input_data + 10, queue);
  25. int map_data[] = {9, 8, 7, 6, 5, 4, 3, 2, 1, 0};
  26. bc::vector<int> map(map_data, map_data + 10, queue);
  27. int stencil_data[] = {0, 1, 0, 1, 0, 1, 0, 1, 0, 1};
  28. bc::vector<bc::uint_> stencil(stencil_data, stencil_data + 10, queue);
  29. bc::vector<int> output(input.size(), -1, queue);
  30. bc::scatter_if(input.begin(), input.end(),
  31. map.begin(), stencil.begin(),
  32. output.begin(),
  33. queue);
  34. CHECK_RANGE_EQUAL(int, 10, output, (9, -1, 7, -1, 5, -1, 3, -1, 1, -1) );
  35. }
  36. BOOST_AUTO_TEST_CASE(scatter_if_constant_indices)
  37. {
  38. int input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
  39. bc::vector<int> input(input_data, input_data + 10, queue);
  40. int map_data[] = {9, 8, 7, 6, 5, 4, 3, 2, 1, 0};
  41. bc::buffer map_buffer(context,
  42. 10 * sizeof(int),
  43. bc::buffer::read_only | bc::buffer::use_host_ptr,
  44. map_data);
  45. int stencil_data[] = {0, 1, 0, 1, 0, 1, 0, 1, 0, 1};
  46. bc::buffer stencil_buffer(context,
  47. 10 * sizeof(bc::uint_),
  48. bc::buffer::read_only | bc::buffer::use_host_ptr,
  49. stencil_data);
  50. bc::vector<int> output(input.size(), -1, queue);
  51. bc::scatter_if(input.begin(),
  52. input.end(),
  53. bc::make_constant_buffer_iterator<int>(map_buffer, 0),
  54. bc::make_constant_buffer_iterator<int>(stencil_buffer, 0),
  55. output.begin(),
  56. queue);
  57. CHECK_RANGE_EQUAL(int, 10, output, (9, -1, 7, -1, 5, -1, 3, -1, 1, -1) );
  58. }
  59. BOOST_AUTO_TEST_CASE(scatter_if_function)
  60. {
  61. int input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
  62. bc::vector<int> input(input_data, input_data + 10, queue);
  63. int map_data[] = {9, 8, 7, 6, 5, 4, 3, 2, 1, 0};
  64. bc::vector<int> map(map_data, map_data + 10, queue);
  65. int stencil_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
  66. bc::vector<bc::uint_> stencil(stencil_data, stencil_data + 10, queue);
  67. bc::vector<int> output(input.size(), -1, queue);
  68. BOOST_COMPUTE_FUNCTION(int, gt_than_5, (int x),
  69. {
  70. if (x > 5)
  71. return true;
  72. else
  73. return false;
  74. });
  75. bc::scatter_if(input.begin(),
  76. input.end(),
  77. map.begin(),
  78. stencil.begin(),
  79. output.begin(),
  80. gt_than_5,
  81. queue);
  82. CHECK_RANGE_EQUAL(int, 10, output, (9, 8, 7, 6, -1, -1, -1, -1, -1, -1) );
  83. }
  84. BOOST_AUTO_TEST_CASE(scatter_if_counting_iterator)
  85. {
  86. int input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
  87. bc::vector<int> input(input_data, input_data + 10, queue);
  88. int map_data[] = {9, 8, 7, 6, 5, 4, 3, 2, 1, 0};
  89. bc::vector<int> map(map_data, map_data + 10, queue);
  90. bc::vector<int> output(input.size(), -1, queue);
  91. BOOST_COMPUTE_FUNCTION(int, gt_than_5, (int x),
  92. {
  93. if (x > 5)
  94. return true;
  95. else
  96. return false;
  97. });
  98. bc::scatter_if(input.begin(),
  99. input.end(),
  100. map.begin(),
  101. bc::make_counting_iterator<int>(0),
  102. output.begin(),
  103. gt_than_5,
  104. queue);
  105. CHECK_RANGE_EQUAL(int, 10, output, (9, 8, 7, 6, -1, -1, -1, -1, -1, -1) );
  106. }
  107. BOOST_AUTO_TEST_SUITE_END()