test_struct.cpp 4.9 KB

  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2013-2014 Kyle Lutz <kyle.r.lutz@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #include <boost/compute/config.hpp>
  11. #define BOOST_TEST_MODULE TestStruct
  12. #include <boost/test/unit_test.hpp>
  13. #include <boost/compute/function.hpp>
  14. #include <boost/compute/algorithm/find_if.hpp>
  15. #include <boost/compute/algorithm/transform.hpp>
  16. #include <boost/compute/container/vector.hpp>
  17. #include <boost/compute/functional/field.hpp>
  18. #include <boost/compute/types/struct.hpp>
  19. #include <boost/compute/type_traits/type_name.hpp>
  20. #include <boost/compute/type_traits/type_definition.hpp>
  21. #include <boost/compute/utility/source.hpp>
  22. namespace compute = boost::compute;
  23. // example code defining an atom class
  24. namespace chemistry {
  25. struct Atom
  26. {
  27. Atom(float _x, float _y, float _z, int _number)
  28. : x(_x), y(_y), z(_z), number(_number)
  29. {
  30. }
  31. float x;
  32. float y;
  33. float z;
  34. int number;
  35. };
  36. } // end chemistry namespace
  37. // adapt the chemistry::Atom class
  38. BOOST_COMPUTE_ADAPT_STRUCT(chemistry::Atom, Atom, (x, y, z, number))
  39. struct StructWithArray {
  40. int value;
  41. int array[3];
  42. };
  43. BOOST_COMPUTE_ADAPT_STRUCT(StructWithArray, StructWithArray, (value, array))
  44. #include "check_macros.hpp"
  45. #include "context_setup.hpp"
  46. BOOST_AUTO_TEST_CASE(atom_type_name)
  47. {
  48. BOOST_CHECK(std::strcmp(compute::type_name<chemistry::Atom>(), "Atom") == 0);
  49. }
  50. BOOST_AUTO_TEST_CASE(atom_struct)
  51. {
  52. std::vector<chemistry::Atom> atoms;
  53. atoms.push_back(chemistry::Atom(1.f, 0.f, 0.f, 1));
  54. atoms.push_back(chemistry::Atom(0.f, 1.f, 0.f, 1));
  55. atoms.push_back(chemistry::Atom(0.f, 0.f, 0.f, 8));
  56. compute::vector<chemistry::Atom> vec(atoms.size(), context);
  57. compute::copy(atoms.begin(), atoms.end(), vec.begin(), queue);
  58. // find the oxygen atom
  59. BOOST_COMPUTE_FUNCTION(bool, is_oxygen, (chemistry::Atom atom),
  60. {
  61. return atom.number == 8;
  62. });
  63. compute::vector<chemistry::Atom>::iterator iter =
  64. compute::find_if(vec.begin(), vec.end(), is_oxygen, queue);
  65. BOOST_CHECK(iter == vec.begin() + 2);
  66. // copy the atomic numbers to another vector
  67. compute::vector<int> atomic_numbers(vec.size(), context);
  68. compute::transform(
  69. vec.begin(), vec.end(),
  70. atomic_numbers.begin(),
  71. compute::field<int>("number"),
  72. queue
  73. );
  74. CHECK_RANGE_EQUAL(int, 3, atomic_numbers, (1, 1, 8));
  75. }
  76. BOOST_AUTO_TEST_CASE(custom_kernel)
  77. {
  78. std::vector<chemistry::Atom> data;
  79. data.push_back(chemistry::Atom(1.f, 0.f, 0.f, 1));
  80. data.push_back(chemistry::Atom(0.f, 1.f, 0.f, 1));
  81. data.push_back(chemistry::Atom(0.f, 0.f, 0.f, 8));
  82. compute::vector<chemistry::Atom> atoms(data.size(), context);
  83. compute::copy(data.begin(), data.end(), atoms.begin(), queue);
  84. std::string source = BOOST_COMPUTE_STRINGIZE_SOURCE(
  85. __kernel void custom_kernel(__global const Atom *atoms,
  86. __global float *distances)
  87. {
  88. const uint i = get_global_id(0);
  89. const __global Atom *atom = &atoms[i];
  90. const float4 center = { 0, 0, 0, 0 };
  91. const float4 position = { atom->x, atom->y, atom->z, 0 };
  92. distances[i] = distance(position, center);
  93. }
  94. );
  95. // add type definition for Atom to the start of the program source
  96. source = compute::type_definition<chemistry::Atom>() + "\n" + source;
  97. compute::program program =
  98. compute::program::build_with_source(source, context);
  99. compute::vector<float> distances(atoms.size(), context);
  100. compute::kernel custom_kernel = program.create_kernel("custom_kernel");
  101. custom_kernel.set_arg(0, atoms);
  102. custom_kernel.set_arg(1, distances);
  103. queue.enqueue_1d_range_kernel(custom_kernel, 0, atoms.size(), 1);
  104. }
  105. // Creates a StructWithArray containing 'x', 'y', 'z'.
  106. StructWithArray make_struct_with_array(int x, int y, int z)
  107. {
  108. StructWithArray s;
  109. s.value = 0;
  110. s.array[0] = x;
  111. s.array[1] = y;
  112. s.array[2] = z;
  113. return s;
  114. }
  115. BOOST_AUTO_TEST_CASE(struct_with_array)
  116. {
  117. compute::vector<StructWithArray> structs(context);
  118. structs.push_back(make_struct_with_array(1, 2, 3), queue);
  119. structs.push_back(make_struct_with_array(4, 5, 6), queue);
  120. structs.push_back(make_struct_with_array(7, 8, 9), queue);
  121. BOOST_COMPUTE_FUNCTION(int, sum_array, (StructWithArray x),
  122. {
  123. return x.array[0] + x.array[1] + x.array[2];
  124. });
  125. compute::vector<int> results(structs.size(), context);
  126. compute::transform(
  127. structs.begin(), structs.end(), results.begin(), sum_array, queue
  128. );
  129. CHECK_RANGE_EQUAL(int, 3, results, (6, 15, 24));
  130. }