function.hpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_FUNCTION_HPP
  11. #define BOOST_COMPUTE_FUNCTION_HPP
  12. #include <map>
  13. #include <string>
  14. #include <sstream>
  15. #include <vector>
  16. #include <boost/assert.hpp>
  17. #include <boost/config.hpp>
  18. #include <boost/function_types/parameter_types.hpp>
  19. #include <boost/preprocessor/repetition.hpp>
  20. #include <boost/mpl/for_each.hpp>
  21. #include <boost/mpl/size.hpp>
  22. #include <boost/mpl/transform.hpp>
  23. #include <boost/static_assert.hpp>
  24. #include <boost/tuple/tuple.hpp>
  25. #include <boost/type_traits/add_pointer.hpp>
  26. #include <boost/type_traits/function_traits.hpp>
  27. #include <boost/compute/cl.hpp>
  28. #include <boost/compute/config.hpp>
  29. #include <boost/compute/type_traits/type_name.hpp>
  30. namespace boost {
  31. namespace compute {
  32. namespace detail {
  33. template<class ResultType, class ArgTuple>
  34. class invoked_function
  35. {
  36. public:
  37. typedef ResultType result_type;
  38. BOOST_STATIC_CONSTANT(
  39. size_t, arity = boost::tuples::length<ArgTuple>::value
  40. );
  41. invoked_function(const std::string &name,
  42. const std::string &source)
  43. : m_name(name),
  44. m_source(source)
  45. {
  46. }
  47. invoked_function(const std::string &name,
  48. const std::string &source,
  49. const std::map<std::string, std::string> &definitions)
  50. : m_name(name),
  51. m_source(source),
  52. m_definitions(definitions)
  53. {
  54. }
  55. invoked_function(const std::string &name,
  56. const std::string &source,
  57. const ArgTuple &args)
  58. : m_name(name),
  59. m_source(source),
  60. m_args(args)
  61. {
  62. }
  63. invoked_function(const std::string &name,
  64. const std::string &source,
  65. const std::map<std::string, std::string> &definitions,
  66. const ArgTuple &args)
  67. : m_name(name),
  68. m_source(source),
  69. m_definitions(definitions),
  70. m_args(args)
  71. {
  72. }
  73. std::string name() const
  74. {
  75. return m_name;
  76. }
  77. std::string source() const
  78. {
  79. return m_source;
  80. }
  81. const std::map<std::string, std::string>& definitions() const
  82. {
  83. return m_definitions;
  84. }
  85. const ArgTuple& args() const
  86. {
  87. return m_args;
  88. }
  89. private:
  90. std::string m_name;
  91. std::string m_source;
  92. std::map<std::string, std::string> m_definitions;
  93. ArgTuple m_args;
  94. };
  95. } // end detail namespace
  96. /// \class function
  97. /// \brief A function object.
  98. template<class Signature>
  99. class function
  100. {
  101. public:
  102. /// \internal_
  103. typedef typename
  104. boost::function_traits<Signature>::result_type result_type;
  105. /// \internal_
  106. BOOST_STATIC_CONSTANT(
  107. size_t, arity = boost::function_traits<Signature>::arity
  108. );
  109. /// \internal_
  110. typedef Signature signature;
  111. /// Creates a new function object with \p name.
  112. function(const std::string &name)
  113. : m_name(name)
  114. {
  115. }
  116. /// Destroys the function object.
  117. ~function()
  118. {
  119. }
  120. /// \internal_
  121. std::string name() const
  122. {
  123. return m_name;
  124. }
  125. /// \internal_
  126. void set_source(const std::string &source)
  127. {
  128. m_source = source;
  129. }
  130. /// \internal_
  131. std::string source() const
  132. {
  133. return m_source;
  134. }
  135. /// \internal_
  136. void define(std::string name, std::string value = std::string())
  137. {
  138. m_definitions[name] = value;
  139. }
  140. bool operator==(const function<Signature>& other) const
  141. {
  142. return
  143. (m_name == other.m_name)
  144. && (m_definitions == other.m_definitions)
  145. && (m_source == other.m_source);
  146. }
  147. bool operator!=(const function<Signature>& other) const
  148. {
  149. return !(*this == other);
  150. }
  151. /// \internal_
  152. detail::invoked_function<result_type, boost::tuple<> >
  153. operator()() const
  154. {
  155. BOOST_STATIC_ASSERT_MSG(
  156. arity == 0,
  157. "Non-nullary function invoked with zero arguments"
  158. );
  159. return detail::invoked_function<result_type, boost::tuple<> >(
  160. m_name, m_source, m_definitions
  161. );
  162. }
  163. /// \internal_
  164. template<class Arg1>
  165. detail::invoked_function<result_type, boost::tuple<Arg1> >
  166. operator()(const Arg1 &arg1) const
  167. {
  168. BOOST_STATIC_ASSERT_MSG(
  169. arity == 1,
  170. "Non-unary function invoked one argument"
  171. );
  172. return detail::invoked_function<result_type, boost::tuple<Arg1> >(
  173. m_name, m_source, m_definitions, boost::make_tuple(arg1)
  174. );
  175. }
  176. /// \internal_
  177. template<class Arg1, class Arg2>
  178. detail::invoked_function<result_type, boost::tuple<Arg1, Arg2> >
  179. operator()(const Arg1 &arg1, const Arg2 &arg2) const
  180. {
  181. BOOST_STATIC_ASSERT_MSG(
  182. arity == 2,
  183. "Non-binary function invoked with two arguments"
  184. );
  185. return detail::invoked_function<result_type, boost::tuple<Arg1, Arg2> >(
  186. m_name, m_source, m_definitions, boost::make_tuple(arg1, arg2)
  187. );
  188. }
  189. /// \internal_
  190. template<class Arg1, class Arg2, class Arg3>
  191. detail::invoked_function<result_type, boost::tuple<Arg1, Arg2, Arg3> >
  192. operator()(const Arg1 &arg1, const Arg2 &arg2, const Arg3 &arg3) const
  193. {
  194. BOOST_STATIC_ASSERT_MSG(
  195. arity == 3,
  196. "Non-ternary function invoked with three arguments"
  197. );
  198. return detail::invoked_function<result_type, boost::tuple<Arg1, Arg2, Arg3> >(
  199. m_name, m_source, m_definitions, boost::make_tuple(arg1, arg2, arg3)
  200. );
  201. }
  202. private:
  203. std::string m_name;
  204. std::string m_source;
  205. std::map<std::string, std::string> m_definitions;
  206. };
  207. /// Creates a function object given its \p name and \p source.
  208. ///
  209. /// \param name The function name.
  210. /// \param source The function source code.
  211. ///
  212. /// \see BOOST_COMPUTE_FUNCTION()
  213. template<class Signature>
  214. inline function<Signature>
  215. make_function_from_source(const std::string &name, const std::string &source)
  216. {
  217. function<Signature> f(name);
  218. f.set_source(source);
  219. return f;
  220. }
  221. namespace detail {
  222. // given a string containing the arguments declaration for a function
  223. // like: "(int a, const float b)", returns a vector containing the name
  224. // of each argument (e.g. ["a", "b"]).
  225. inline std::vector<std::string> parse_argument_names(const char *arguments)
  226. {
  227. BOOST_ASSERT_MSG(
  228. arguments[0] == '(' && arguments[std::strlen(arguments)-1] == ')',
  229. "Arguments should start and end with parentheses"
  230. );
  231. std::vector<std::string> args;
  232. size_t last_space = 0;
  233. size_t skip_comma = 0;
  234. for(size_t i = 1; i < std::strlen(arguments) - 2; i++){
  235. const char c = arguments[i];
  236. if(c == ' '){
  237. last_space = i;
  238. }
  239. else if(c == ',' && !skip_comma){
  240. std::string name(
  241. arguments + last_space + 1, i - last_space - 1
  242. );
  243. args.push_back(name);
  244. }
  245. else if(c == '<'){
  246. skip_comma++;
  247. }
  248. else if(c == '>'){
  249. skip_comma--;
  250. }
  251. }
  252. std::string last_argument(
  253. arguments + last_space + 1, std::strlen(arguments) - last_space - 2
  254. );
  255. args.push_back(last_argument);
  256. return args;
  257. }
  258. struct signature_argument_inserter
  259. {
  260. signature_argument_inserter(std::stringstream &s_, const char *arguments, size_t last)
  261. : s(s_)
  262. {
  263. n = 0;
  264. m_last = last;
  265. m_argument_names = parse_argument_names(arguments);
  266. BOOST_ASSERT_MSG(
  267. m_argument_names.size() == last,
  268. "Wrong number of arguments"
  269. );
  270. }
  271. template<class T>
  272. void operator()(const T*)
  273. {
  274. s << type_name<T>() << " " << m_argument_names[n];
  275. if(n+1 < m_last){
  276. s << ", ";
  277. }
  278. n++;
  279. }
  280. size_t n;
  281. size_t m_last;
  282. std::stringstream &s;
  283. std::vector<std::string> m_argument_names;
  284. };
  285. template<class Signature>
  286. inline std::string make_function_declaration(const char *name, const char *arguments)
  287. {
  288. typedef typename
  289. boost::function_traits<Signature>::result_type result_type;
  290. typedef typename
  291. boost::function_types::parameter_types<Signature>::type parameter_types;
  292. typedef typename
  293. mpl::size<parameter_types>::type arity_type;
  294. std::stringstream s;
  295. s << "inline " << type_name<result_type>() << " " << name;
  296. s << "(";
  297. if(arity_type::value > 0){
  298. signature_argument_inserter i(s, arguments, arity_type::value);
  299. mpl::for_each<
  300. typename mpl::transform<parameter_types, boost::add_pointer<mpl::_1>
  301. >::type>(i);
  302. }
  303. s << ")";
  304. return s.str();
  305. }
  306. struct argument_list_inserter
  307. {
  308. argument_list_inserter(std::stringstream &s_, const char first, size_t last)
  309. : s(s_)
  310. {
  311. n = 0;
  312. m_last = last;
  313. m_name = first;
  314. }
  315. template<class T>
  316. void operator()(const T*)
  317. {
  318. s << type_name<T>() << " " << m_name++;
  319. if(n+1 < m_last){
  320. s << ", ";
  321. }
  322. n++;
  323. }
  324. size_t n;
  325. size_t m_last;
  326. char m_name;
  327. std::stringstream &s;
  328. };
  329. template<class Signature>
  330. inline std::string generate_argument_list(const char first = 'a')
  331. {
  332. typedef typename
  333. boost::function_types::parameter_types<Signature>::type parameter_types;
  334. typedef typename
  335. mpl::size<parameter_types>::type arity_type;
  336. std::stringstream s;
  337. s << '(';
  338. if(arity_type::value > 0){
  339. argument_list_inserter i(s, first, arity_type::value);
  340. mpl::for_each<
  341. typename mpl::transform<parameter_types, boost::add_pointer<mpl::_1>
  342. >::type>(i);
  343. }
  344. s << ')';
  345. return s.str();
  346. }
  347. // used by the BOOST_COMPUTE_FUNCTION() macro to create a function
  348. // with the given signature, name, arguments, and source.
  349. template<class Signature>
  350. inline function<Signature>
  351. make_function_impl(const char *name, const char *arguments, const char *source)
  352. {
  353. std::stringstream s;
  354. s << make_function_declaration<Signature>(name, arguments);
  355. s << source;
  356. return make_function_from_source<Signature>(name, s.str());
  357. }
  358. } // end detail namespace
  359. } // end compute namespace
  360. } // end boost namespace
  361. /// Creates a function object with \p name and \p source.
  362. ///
  363. /// \param return_type The return type for the function.
  364. /// \param name The name of the function.
  365. /// \param arguments A list of arguments for the function.
  366. /// \param source The OpenCL C source code for the function.
  367. ///
  368. /// The function declaration and signature are automatically created using
  369. /// the \p return_type, \p name, and \p arguments macro parameters.
  370. ///
  371. /// The source code for the function is interpreted as OpenCL C99 source code
  372. /// which is stringified and passed to the OpenCL compiler when the function
  373. /// is invoked.
  374. ///
  375. /// For example, to create a function which squares a number:
  376. /// \code
  377. /// BOOST_COMPUTE_FUNCTION(float, square, (float x),
  378. /// {
  379. /// return x * x;
  380. /// });
  381. /// \endcode
  382. ///
  383. /// And to create a function which sums two numbers:
  384. /// \code
  385. /// BOOST_COMPUTE_FUNCTION(int, sum_two, (int x, int y),
  386. /// {
  387. /// return x + y;
  388. /// });
  389. /// \endcode
  390. ///
  391. /// \see BOOST_COMPUTE_CLOSURE()
  392. #ifdef BOOST_COMPUTE_DOXYGEN_INVOKED
  393. #define BOOST_COMPUTE_FUNCTION(return_type, name, arguments, source)
  394. #else
  395. #define BOOST_COMPUTE_FUNCTION(return_type, name, arguments, ...) \
  396. ::boost::compute::function<return_type arguments> name = \
  397. ::boost::compute::detail::make_function_impl<return_type arguments>( \
  398. #name, #arguments, #__VA_ARGS__ \
  399. )
  400. #endif
  401. #endif // BOOST_COMPUTE_FUNCTION_HPP