test_reduce_by_key.cpp 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2015 Jakub Szuppe <j.szuppe@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #define BOOST_TEST_MODULE TestReduceByKey
  11. #include <boost/test/unit_test.hpp>
  12. #include <boost/compute/lambda.hpp>
  13. #include <boost/compute/system.hpp>
  14. #include <boost/compute/functional.hpp>
  15. #include <boost/compute/algorithm/inclusive_scan.hpp>
  16. #include <boost/compute/algorithm/reduce_by_key.hpp>
  17. #include <boost/compute/container/vector.hpp>
  18. #include "check_macros.hpp"
  19. #include "context_setup.hpp"
  20. namespace bc = boost::compute;
  21. BOOST_AUTO_TEST_CASE(reduce_by_key_int)
  22. {
  23. //! [reduce_by_key_int]
  24. // setup keys and values
  25. int keys[] = { 0, 2, -3, -3, -3, -3, -3, 4 };
  26. int data[] = { 1, 1, 1, 1, 1, 2, 5, 1 };
  27. boost::compute::vector<int> keys_input(keys, keys + 8, queue);
  28. boost::compute::vector<int> values_input(data, data + 8, queue);
  29. boost::compute::vector<int> keys_output(8, context);
  30. boost::compute::vector<int> values_output(8, context);
  31. // reduce by key
  32. boost::compute::reduce_by_key(keys_input.begin(), keys_input.end(), values_input.begin(),
  33. keys_output.begin(), values_output.begin(), queue);
  34. // keys_output = { 0, 2, -3, 4 }
  35. // values_output = { 1, 1, 10, 1 }
  36. //! [reduce_by_key_int]
  37. CHECK_RANGE_EQUAL(int, 4, keys_output, (0, 2, -3, 4));
  38. CHECK_RANGE_EQUAL(int, 4, values_output, (1, 1, 10, 1));
  39. }
  40. BOOST_AUTO_TEST_CASE(reduce_by_key_int_long_vector)
  41. {
  42. size_t size = 1024;
  43. bc::vector<int> keys_input(size, int(0), queue);
  44. bc::vector<int> values_input(size, int(1), queue);
  45. bc::vector<int> keys_output(size, context);
  46. bc::vector<int> values_output(size, context);
  47. bc::reduce_by_key(keys_input.begin(), keys_input.end(), values_input.begin(),
  48. keys_output.begin(), values_output.begin(), queue);
  49. CHECK_RANGE_EQUAL(int, 1, keys_output, (0));
  50. CHECK_RANGE_EQUAL(int, 1, values_output, (static_cast<int>(size)));
  51. keys_input[137] = 1;
  52. keys_input[677] = 1;
  53. keys_input[1001] = 1;
  54. bc::inclusive_scan(keys_input.begin(), keys_input.end(), keys_input.begin(), queue);
  55. bc::reduce_by_key(keys_input.begin(), keys_input.end(), values_input.begin(),
  56. keys_output.begin(), values_output.begin(), queue);
  57. CHECK_RANGE_EQUAL(int, 4, keys_output, (0, 1, 2, 3));
  58. CHECK_RANGE_EQUAL(int, 4, values_output, (137, 540, 324, 23));
  59. }
  60. BOOST_AUTO_TEST_CASE(reduce_by_key_empty_vector)
  61. {
  62. bc::vector<int> keys_input(context);
  63. bc::vector<int> values_input(context);
  64. bc::vector<int> keys_output(context);
  65. bc::vector<int> values_output(context);
  66. bc::reduce_by_key(keys_input.begin(), keys_input.end(), values_input.begin(),
  67. keys_output.begin(), values_output.begin(), queue);
  68. BOOST_CHECK(keys_output.empty());
  69. BOOST_CHECK(values_output.empty());
  70. }
  71. BOOST_AUTO_TEST_CASE(reduce_by_key_int_one_key_value)
  72. {
  73. int keys[] = { 22 };
  74. int data[] = { -9 };
  75. bc::vector<int> keys_input(keys, keys + 1, queue);
  76. bc::vector<int> values_input(data, data + 1, queue);
  77. bc::vector<int> keys_output(1, context);
  78. bc::vector<int> values_output(1, context);
  79. bc::reduce_by_key(keys_input.begin(), keys_input.end(), values_input.begin(),
  80. keys_output.begin(), values_output.begin(), queue);
  81. CHECK_RANGE_EQUAL(int, 1, keys_output, (22));
  82. CHECK_RANGE_EQUAL(int, 1, values_output, (-9));
  83. }
  84. BOOST_AUTO_TEST_CASE(reduce_by_key_int_min_max)
  85. {
  86. int keys[] = { 0, 2, 2, 3, 3, 3, 3, 3, 4 };
  87. int data[] = { 1, 2, 1, -3, 1, 4, 2, 5, 77 };
  88. bc::vector<int> keys_input(keys, keys + 9, queue);
  89. bc::vector<int> values_input(data, data + 9, queue);
  90. bc::vector<int> keys_output(9, context);
  91. bc::vector<int> values_output(9, context);
  92. bc::reduce_by_key(keys_input.begin(), keys_input.end(), values_input.begin(),
  93. keys_output.begin(), values_output.begin(), bc::min<int>(),
  94. bc::equal_to<int>(), queue);
  95. CHECK_RANGE_EQUAL(int, 4, keys_output, (0, 2, 3, 4));
  96. CHECK_RANGE_EQUAL(int, 4, values_output, (1, 1, -3, 77));
  97. bc::reduce_by_key(keys_input.begin(), keys_input.end(), values_input.begin(),
  98. keys_output.begin(), values_output.begin(), bc::max<int>(),
  99. bc::equal_to<int>(), queue);
  100. CHECK_RANGE_EQUAL(int, 4, keys_output, (0, 2, 3, 4));
  101. CHECK_RANGE_EQUAL(int, 4, values_output, (1, 2, 5, 77));
  102. }
  103. BOOST_AUTO_TEST_CASE(reduce_by_key_float_max)
  104. {
  105. int keys[] = { 0, 2, 2, 3, 3, 3, 3, 3, 4 };
  106. float data[] = { 1.0, 2.0, -1.5, -3.0, 1.0, -0.24, 2, 5, 77.1 };
  107. bc::vector<int> keys_input(keys, keys + 9, queue);
  108. bc::vector<float> values_input(data, data + 9, queue);
  109. bc::vector<int> keys_output(9, context);
  110. bc::vector<float> values_output(9, context);
  111. bc::reduce_by_key(keys_input.begin(), keys_input.end(), values_input.begin(),
  112. keys_output.begin(), values_output.begin(), bc::max<float>(),
  113. queue);
  114. CHECK_RANGE_EQUAL(int, 4, keys_output, (0, 2, 3, 4));
  115. BOOST_CHECK_CLOSE(float(values_output[0]), 1.0f, 1e-4f);
  116. BOOST_CHECK_CLOSE(float(values_output[1]), 2.0f, 1e-4f);
  117. BOOST_CHECK_CLOSE(float(values_output[2]), 5.0f, 1e-4f);
  118. BOOST_CHECK_CLOSE(float(values_output[3]), 77.1f, 1e-4f);
  119. }
  120. BOOST_AUTO_TEST_CASE(reduce_by_key_int2)
  121. {
  122. using bc::int2_;
  123. int keys[] = { 0, 2, 3, 3, 3, 3, 4, 4 };
  124. int2_ data[] = {
  125. int2_(0, 1), int2_(-3, 2), int2_(0, 1), int2_(0, 1),
  126. int2_(-3, 0), int2_(0, 0), int2_(-3, 2), int2_(-7, -2)
  127. };
  128. bc::vector<int> keys_input(keys, keys + 8, queue);
  129. bc::vector<int2_> values_input(data, data + 8, queue);
  130. bc::vector<int> keys_output(8, context);
  131. bc::vector<int2_> values_output(8, context);
  132. bc::reduce_by_key(keys_input.begin(), keys_input.end(), values_input.begin(),
  133. keys_output.begin(), values_output.begin(), queue);
  134. CHECK_RANGE_EQUAL(int, 4, keys_output, (0, 2, 3, 4));
  135. CHECK_RANGE_EQUAL(int2_, 4, values_output,
  136. (int2_(0, 1), int2_(-3, 2), int2_(-3, 2), int2_(-10, 0)));
  137. }
  138. BOOST_AUTO_TEST_CASE(reduce_by_key_int2_long_vector)
  139. {
  140. using bc::int2_;
  141. size_t size = 1024;
  142. bc::vector<int> keys_input(size, int(0), queue);
  143. bc::vector<int2_> values_input(size, int2_(1, -1), queue);
  144. bc::vector<int> keys_output(size, context);
  145. bc::vector<int2_> values_output(size, context);
  146. bc::reduce_by_key(keys_input.begin(), keys_input.end(), values_input.begin(),
  147. keys_output.begin(), values_output.begin(), queue);
  148. CHECK_RANGE_EQUAL(int, 1, keys_output, (0));
  149. CHECK_RANGE_EQUAL(int2_, 1, values_output, (int2_(int(size), -int(size))));
  150. keys_input[137] = 1;
  151. keys_input[677] = 1;
  152. keys_input[1001] = 1;
  153. bc::inclusive_scan(keys_input.begin(), keys_input.end(), keys_input.begin(), queue);
  154. bc::reduce_by_key(keys_input.begin(), keys_input.end(), values_input.begin(),
  155. keys_output.begin(), values_output.begin(), queue);
  156. CHECK_RANGE_EQUAL(int, 4, keys_output, (0, 1, 2, 3));
  157. CHECK_RANGE_EQUAL(int2_, 4, values_output,
  158. (int2_(137, -137), int2_(540, -540), int2_(324, -324), int2_(23, -23)));
  159. }
  160. BOOST_AUTO_TEST_SUITE_END()