multiplication.hpp 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945
  1. //
  2. // Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com
  3. //
  4. // Distributed under the Boost Software License, Version 1.0. (See
  5. // accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt)
  7. //
  8. // The authors gratefully acknowledge the support of
  9. // Fraunhofer IOSB, Ettlingen, Germany
  10. //
  11. #ifndef BOOST_UBLAS_TENSOR_MULTIPLICATION
  12. #define BOOST_UBLAS_TENSOR_MULTIPLICATION
  13. #include <cassert>
  14. namespace boost {
  15. namespace numeric {
  16. namespace ublas {
  17. namespace detail {
  18. namespace recursive {
  19. /** @brief Computes the tensor-times-tensor product for q contraction modes
  20. *
  21. * Implements C[i1,...,ir,j1,...,js] = sum( A[i1,...,ir+q] * B[j1,...,js+q] )
  22. *
  23. * nc[x] = na[phia[x] ] for 1 <= x <= r
  24. * nc[r+x] = nb[phib[x] ] for 1 <= x <= s
  25. * na[phia[r+x]] = nb[phib[s+x]] for 1 <= x <= q
  26. *
  27. * @note is used in function ttt
  28. *
  29. * @param k zero-based recursion level starting with 0
  30. * @param r number of non-contraction indices of A
  31. * @param s number of non-contraction indices of B
  32. * @param q number of contraction indices with q > 0
  33. * @param phia pointer to the permutation tuple of length q+r for A
  34. * @param phib pointer to the permutation tuple of length q+s for B
  35. * @param c pointer to the output tensor C with rank(A)=r+s
  36. * @param nc pointer to the extents of tensor C
  37. * @param wc pointer to the strides of tensor C
  38. * @param a pointer to the first input tensor with rank(A)=r+q
  39. * @param na pointer to the extents of the first input tensor A
  40. * @param wa pointer to the strides of the first input tensor A
  41. * @param b pointer to the second input tensor B with rank(B)=s+q
  42. * @param nb pointer to the extents of the second input tensor B
  43. * @param wb pointer to the strides of the second input tensor B
  44. */
  45. template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
  46. void ttt(SizeType const k,
  47. SizeType const r, SizeType const s, SizeType const q,
  48. SizeType const*const phia, SizeType const*const phib,
  49. PointerOut c, SizeType const*const nc, SizeType const*const wc,
  50. PointerIn1 a, SizeType const*const na, SizeType const*const wa,
  51. PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
  52. {
  53. if(k < r)
  54. {
  55. assert(nc[k] == na[phia[k]-1]);
  56. for(size_t ic = 0u; ic < nc[k]; a += wa[phia[k]-1], c += wc[k], ++ic)
  57. ttt(k+1, r, s, q, phia,phib, c, nc, wc, a, na, wa, b, nb, wb);
  58. }
  59. else if(k < r+s)
  60. {
  61. assert(nc[k] == nb[phib[k-r]-1]);
  62. for(size_t ic = 0u; ic < nc[k]; b += wb[phib[k-r]-1], c += wc[k], ++ic)
  63. ttt(k+1, r, s, q, phia, phib, c, nc, wc, a, na, wa, b, nb, wb);
  64. }
  65. else if(k < r+s+q-1)
  66. {
  67. assert(na[phia[k-s]-1] == nb[phib[k-r]-1]);
  68. for(size_t ia = 0u; ia < na[phia[k-s]-1]; a += wa[phia[k-s]-1], b += wb[phib[k-r]-1], ++ia)
  69. ttt(k+1, r, s, q, phia, phib, c, nc, wc, a, na, wa, b, nb, wb);
  70. }
  71. else
  72. {
  73. assert(na[phia[k-s]-1] == nb[phib[k-r]-1]);
  74. for(size_t ia = 0u; ia < na[phia[k-s]-1]; a += wa[phia[k-s]-1], b += wb[phib[k-r]-1], ++ia)
  75. *c += *a * *b;
  76. }
  77. }
  78. /** @brief Computes the tensor-times-tensor product for q contraction modes
  79. *
  80. * Implements C[i1,...,ir,j1,...,js] = sum( A[i1,...,ir+q] * B[j1,...,js+q] )
  81. *
  82. * @note no permutation tuple is used
  83. *
  84. * nc[x] = na[x ] for 1 <= x <= r
  85. * nc[r+x] = nb[x ] for 1 <= x <= s
  86. * na[r+x] = nb[s+x] for 1 <= x <= q
  87. *
  88. * @note is used in function ttt
  89. *
  90. * @param k zero-based recursion level starting with 0
  91. * @param r number of non-contraction indices of A
  92. * @param s number of non-contraction indices of B
  93. * @param q number of contraction indices with q > 0
  94. * @param c pointer to the output tensor C with rank(A)=r+s
  95. * @param nc pointer to the extents of tensor C
  96. * @param wc pointer to the strides of tensor C
  97. * @param a pointer to the first input tensor with rank(A)=r+q
  98. * @param na pointer to the extents of the first input tensor A
  99. * @param wa pointer to the strides of the first input tensor A
  100. * @param b pointer to the second input tensor B with rank(B)=s+q
  101. * @param nb pointer to the extents of the second input tensor B
  102. * @param wb pointer to the strides of the second input tensor B
  103. */
  104. template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
  105. void ttt(SizeType const k,
  106. SizeType const r, SizeType const s, SizeType const q,
  107. PointerOut c, SizeType const*const nc, SizeType const*const wc,
  108. PointerIn1 a, SizeType const*const na, SizeType const*const wa,
  109. PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
  110. {
  111. if(k < r)
  112. {
  113. assert(nc[k] == na[k]);
  114. for(size_t ic = 0u; ic < nc[k]; a += wa[k], c += wc[k], ++ic)
  115. ttt(k+1, r, s, q, c, nc, wc, a, na, wa, b, nb, wb);
  116. }
  117. else if(k < r+s)
  118. {
  119. assert(nc[k] == nb[k-r]);
  120. for(size_t ic = 0u; ic < nc[k]; b += wb[k-r], c += wc[k], ++ic)
  121. ttt(k+1, r, s, q, c, nc, wc, a, na, wa, b, nb, wb);
  122. }
  123. else if(k < r+s+q-1)
  124. {
  125. assert(na[k-s] == nb[k-r]);
  126. for(size_t ia = 0u; ia < na[k-s]; a += wa[k-s], b += wb[k-r], ++ia)
  127. ttt(k+1, r, s, q, c, nc, wc, a, na, wa, b, nb, wb);
  128. }
  129. else
  130. {
  131. assert(na[k-s] == nb[k-r]);
  132. for(size_t ia = 0u; ia < na[k-s]; a += wa[k-s], b += wb[k-r], ++ia)
  133. *c += *a * *b;
  134. }
  135. }
  136. /** @brief Computes the tensor-times-matrix product for the contraction mode m > 0
  137. *
  138. * Implements C[i1,i2,...,im-1,j,im+1,...,ip] = sum(A[i1,i2,...,im,...,ip] * B[j,im])
  139. *
  140. * @note is used in function ttm
  141. *
  142. * @param m zero-based contraction mode with 0<m<p
  143. * @param r zero-based recursion level starting with p-1
  144. * @param c pointer to the output tensor
  145. * @param nc pointer to the extents of tensor c
  146. * @param wc pointer to the strides of tensor c
  147. * @param a pointer to the first input tensor
  148. * @param na pointer to the extents of input tensor a
  149. * @param wa pointer to the strides of input tensor a
  150. * @param b pointer to the second input tensor
  151. * @param nb pointer to the extents of input tensor b
  152. * @param wb pointer to the strides of input tensor b
  153. */
  154. template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
  155. void ttm(SizeType const m, SizeType const r,
  156. PointerOut c, SizeType const*const nc, SizeType const*const wc,
  157. PointerIn1 a, SizeType const*const na, SizeType const*const wa,
  158. PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
  159. {
  160. if(r == m) {
  161. ttm(m, r-1, c, nc, wc, a, na, wa, b, nb, wb);
  162. }
  163. else if(r == 0){
  164. for(auto i0 = 0ul; i0 < nc[0]; c += wc[0], a += wa[0], ++i0) {
  165. auto cm = c;
  166. auto b0 = b;
  167. for(auto i0 = 0ul; i0 < nc[m]; cm += wc[m], b0 += wb[0], ++i0){
  168. auto am = a;
  169. auto b1 = b0;
  170. for(auto i1 = 0ul; i1 < nb[1]; am += wa[m], b1 += wb[1], ++i1)
  171. *cm += *am * *b1;
  172. }
  173. }
  174. }
  175. else{
  176. for(auto i = 0ul; i < na[r]; c += wc[r], a += wa[r], ++i)
  177. ttm(m, r-1, c, nc, wc, a, na, wa, b, nb, wb);
  178. }
  179. }
  180. /** @brief Computes the tensor-times-matrix product for the contraction mode m = 0
  181. *
  182. * Implements C[j,i2,...,ip] = sum(A[i1,i2,...,ip] * B[j,i1])
  183. *
  184. * @note is used in function ttm
  185. *
  186. * @param m zero-based contraction mode with 0<m<p
  187. * @param r zero-based recursion level starting with p-1
  188. * @param c pointer to the output tensor
  189. * @param nc pointer to the extents of tensor c
  190. * @param wc pointer to the strides of tensor c
  191. * @param a pointer to the first input tensor
  192. * @param na pointer to the extents of input tensor a
  193. * @param wa pointer to the strides of input tensor a
  194. * @param b pointer to the second input tensor
  195. * @param nb pointer to the extents of input tensor b
  196. * @param wb pointer to the strides of input tensor b
  197. */
  198. template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
  199. void ttm0( SizeType const r,
  200. PointerOut c, SizeType const*const nc, SizeType const*const wc,
  201. PointerIn1 a, SizeType const*const na, SizeType const*const wa,
  202. PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
  203. {
  204. if(r > 1){
  205. for(auto i = 0ul; i < na[r]; c += wc[r], a += wa[r], ++i)
  206. ttm0(r-1, c, nc, wc, a, na, wa, b, nb, wb);
  207. }
  208. else{
  209. for(auto i1 = 0ul; i1 < nc[1]; c += wc[1], a += wa[1], ++i1) {
  210. auto cm = c;
  211. auto b0 = b;
  212. // r == m == 0
  213. for(auto i0 = 0ul; i0 < nc[0]; cm += wc[0], b0 += wb[0], ++i0){
  214. auto am = a;
  215. auto b1 = b0;
  216. for(auto i1 = 0u; i1 < nb[1]; am += wa[0], b1 += wb[1], ++i1){
  217. *cm += *am * *b1;
  218. }
  219. }
  220. }
  221. }
  222. }
  223. //////////////////////////////////////////////////////////////////////////////////////////
  224. //////////////////////////////////////////////////////////////////////////////////////////
  225. //////////////////////////////////////////////////////////////////////////////////////////
  226. //////////////////////////////////////////////////////////////////////////////////////////
  227. /** @brief Computes the tensor-times-vector product for the contraction mode m > 0
  228. *
  229. * Implements C[i1,i2,...,im-1,im+1,...,ip] = sum(A[i1,i2,...,im,...,ip] * b[im])
  230. *
  231. * @note is used in function ttv
  232. *
  233. * @param m zero-based contraction mode with 0<m<p
  234. * @param r zero-based recursion level starting with p-1 for tensor A
  235. * @param q zero-based recursion level starting with p-1 for tensor C
  236. * @param c pointer to the output tensor
  237. * @param nc pointer to the extents of tensor c
  238. * @param wc pointer to the strides of tensor c
  239. * @param a pointer to the first input tensor
  240. * @param na pointer to the extents of input tensor a
  241. * @param wa pointer to the strides of input tensor a
  242. * @param b pointer to the second input tensor
  243. */
  244. template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
  245. void ttv( SizeType const m, SizeType const r, SizeType const q,
  246. PointerOut c, SizeType const*const nc, SizeType const*const wc,
  247. PointerIn1 a, SizeType const*const na, SizeType const*const wa,
  248. PointerIn2 b)
  249. {
  250. if(r == m) {
  251. ttv(m, r-1, q, c, nc, wc, a, na, wa, b);
  252. }
  253. else if(r == 0){
  254. for(auto i0 = 0u; i0 < na[0]; c += wc[0], a += wa[0], ++i0) {
  255. auto c1 = c; auto a1 = a; auto b1 = b;
  256. for(auto im = 0u; im < na[m]; a1 += wa[m], ++b1, ++im)
  257. *c1 += *a1 * *b1;
  258. }
  259. }
  260. else{
  261. for(auto i = 0u; i < na[r]; c += wc[q], a += wa[r], ++i)
  262. ttv(m, r-1, q-1, c, nc, wc, a, na, wa, b);
  263. }
  264. }
  265. /** @brief Computes the tensor-times-vector product for the contraction mode m = 0
  266. *
  267. * Implements C[i2,...,ip] = sum(A[i1,...,ip] * b[i1])
  268. *
  269. * @note is used in function ttv
  270. *
  271. * @param m zero-based contraction mode with m=0
  272. * @param r zero-based recursion level starting with p-1
  273. * @param c pointer to the output tensor
  274. * @param nc pointer to the extents of tensor c
  275. * @param wc pointer to the strides of tensor c
  276. * @param a pointer to the first input tensor
  277. * @param na pointer to the extents of input tensor a
  278. * @param wa pointer to the strides of input tensor a
  279. * @param b pointer to the second input tensor
  280. */
  281. template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
  282. void ttv0(SizeType const r,
  283. PointerOut c, SizeType const*const nc, SizeType const*const wc,
  284. PointerIn1 a, SizeType const*const na, SizeType const*const wa,
  285. PointerIn2 b)
  286. {
  287. if(r > 1){
  288. for(auto i = 0u; i < na[r]; c += wc[r-1], a += wa[r], ++i)
  289. ttv0(r-1, c, nc, wc, a, na, wa, b);
  290. }
  291. else{
  292. for(auto i1 = 0u; i1 < na[1]; c += wc[0], a += wa[1], ++i1)
  293. {
  294. auto c1 = c; auto a1 = a; auto b1 = b;
  295. for(auto i0 = 0u; i0 < na[0]; a1 += wa[0], ++b1, ++i0)
  296. *c1 += *a1 * *b1;
  297. }
  298. }
  299. }
  300. /** @brief Computes the matrix-times-vector product
  301. *
  302. * Implements C[i1] = sum(A[i1,i2] * b[i2]) or C[i2] = sum(A[i1,i2] * b[i1])
  303. *
  304. * @note is used in function ttv
  305. *
  306. * @param[in] m zero-based contraction mode with m=0 or m=1
  307. * @param[out] c pointer to the output tensor C
  308. * @param[in] nc pointer to the extents of tensor C
  309. * @param[in] wc pointer to the strides of tensor C
  310. * @param[in] a pointer to the first input tensor A
  311. * @param[in] na pointer to the extents of input tensor A
  312. * @param[in] wa pointer to the strides of input tensor A
  313. * @param[in] b pointer to the second input tensor B
  314. */
  315. template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
  316. void mtv(SizeType const m,
  317. PointerOut c, SizeType const*const , SizeType const*const wc,
  318. PointerIn1 a, SizeType const*const na, SizeType const*const wa,
  319. PointerIn2 b)
  320. {
  321. // decides whether matrix multiplied with vector or vector multiplied with matrix
  322. const auto o = (m == 0) ? 1 : 0;
  323. for(auto io = 0u; io < na[o]; c += wc[o], a += wa[o], ++io) {
  324. auto c1 = c; auto a1 = a; auto b1 = b;
  325. for(auto im = 0u; im < na[m]; a1 += wa[m], ++b1, ++im)
  326. *c1 += *a1 * *b1;
  327. }
  328. }
  329. /** @brief Computes the matrix-times-matrix product
  330. *
  331. * Implements C[i1,i3] = sum(A[i1,i2] * B[i2,i3])
  332. *
  333. * @note is used in function ttm
  334. *
  335. * @param[out] c pointer to the output tensor C
  336. * @param[in] nc pointer to the extents of tensor C
  337. * @param[in] wc pointer to the strides of tensor C
  338. * @param[in] a pointer to the first input tensor A
  339. * @param[in] na pointer to the extents of input tensor A
  340. * @param[in] wa pointer to the strides of input tensor A
  341. * @param[in] b pointer to the second input tensor B
  342. * @param[in] nb pointer to the extents of input tensor B
  343. * @param[in] wb pointer to the strides of input tensor B
  344. */
  345. template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
  346. void mtm(PointerOut c, SizeType const*const nc, SizeType const*const wc,
  347. PointerIn1 a, SizeType const*const na, SizeType const*const wa,
  348. PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
  349. {
  350. // C(i,j) = A(i,k) * B(k,j)
  351. assert(nc[0] == na[0]);
  352. assert(nc[1] == nb[1]);
  353. assert(na[1] == nb[0]);
  354. auto cj = c; auto bj = b;
  355. for(auto j = 0u; j < nc[1]; cj += wc[1], bj += wb[1], ++j) {
  356. auto bk = bj; auto ak = a;
  357. for(auto k = 0u; k < na[1]; ak += wa[1], bk += wb[0], ++k) {
  358. auto ci = cj; auto ai = ak;
  359. for(auto i = 0u; i < na[0]; ai += wa[0], ci += wc[0], ++i){
  360. *ci += *ai * *bk;
  361. }
  362. }
  363. }
  364. }
  365. /** @brief Computes the inner product of two tensors
  366. *
  367. * Implements c = sum(A[i1,i2,...,ip] * B[i1,i2,...,ip])
  368. *
  369. * @note is used in function inner
  370. *
  371. * @param r zero-based recursion level starting with p-1
  372. * @param n pointer to the extents of input or output tensor
  373. * @param a pointer to the first input tensor
  374. * @param wa pointer to the strides of input tensor a
  375. * @param b pointer to the second input tensor
  376. * @param wb pointer to the strides of tensor b
  377. * @param v previously computed value (start with v = 0).
  378. * @return inner product of two tensors.
  379. */
  380. template <class PointerIn1, class PointerIn2, class value_t, class SizeType>
  381. value_t inner(SizeType const r, SizeType const*const n,
  382. PointerIn1 a, SizeType const*const wa,
  383. PointerIn2 b, SizeType const*const wb,
  384. value_t v)
  385. {
  386. if(r == 0)
  387. for(auto i0 = 0u; i0 < n[0]; a += wa[0], b += wb[0], ++i0)
  388. v += *a * *b;
  389. else
  390. for(auto ir = 0u; ir < n[r]; a += wa[r], b += wb[r], ++ir)
  391. v = inner(r-1, n, a, wa, b, wb, v);
  392. return v;
  393. }
  394. template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
  395. void outer_2x2(SizeType const pa,
  396. PointerOut c, SizeType const*const , SizeType const*const wc,
  397. PointerIn1 a, SizeType const*const na, SizeType const*const wa,
  398. PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
  399. {
  400. // assert(rc == 3);
  401. // assert(ra == 1);
  402. // assert(rb == 1);
  403. for(auto ib1 = 0u; ib1 < nb[1]; b += wb[1], c += wc[pa+1], ++ib1) {
  404. auto c2 = c;
  405. auto b0 = b;
  406. for(auto ib0 = 0u; ib0 < nb[0]; b0 += wb[0], c2 += wc[pa], ++ib0) {
  407. const auto b = *b0;
  408. auto c1 = c2;
  409. auto a1 = a;
  410. for(auto ia1 = 0u; ia1 < na[1]; a1 += wa[1], c1 += wc[1], ++ia1) {
  411. auto a0 = a1;
  412. auto c0 = c1;
  413. for(SizeType ia0 = 0u; ia0 < na[0]; a0 += wa[0], c0 += wc[0], ++ia0)
  414. *c0 = *a0 * b;
  415. }
  416. }
  417. }
  418. }
  419. /** @brief Computes the outer product of two tensors
  420. *
  421. * Implements C[i1,...,ip,j1,...,jq] = A[i1,i2,...,ip] * B[j1,j2,...,jq]
  422. *
  423. * @note called by outer
  424. *
  425. *
  426. * @param[in] pa number of dimensions (rank) of the first input tensor A with pa > 0
  427. *
  428. * @param[in] rc recursion level for C that starts with pc-1
  429. * @param[out] c pointer to the output tensor
  430. * @param[in] nc pointer to the extents of output tensor c
  431. * @param[in] wc pointer to the strides of output tensor c
  432. *
  433. * @param[in] ra recursion level for A that starts with pa-1
  434. * @param[in] a pointer to the first input tensor
  435. * @param[in] na pointer to the extents of the first input tensor a
  436. * @param[in] wa pointer to the strides of the first input tensor a
  437. *
  438. * @param[in] rb recursion level for B that starts with pb-1
  439. * @param[in] b pointer to the second input tensor
  440. * @param[in] nb pointer to the extents of the second input tensor b
  441. * @param[in] wb pointer to the strides of the second input tensor b
  442. */
  443. template<class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
  444. void outer(SizeType const pa,
  445. SizeType const rc, PointerOut c, SizeType const*const nc, SizeType const*const wc,
  446. SizeType const ra, PointerIn1 a, SizeType const*const na, SizeType const*const wa,
  447. SizeType const rb, PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
  448. {
  449. if(rb > 1)
  450. for(auto ib = 0u; ib < nb[rb]; b += wb[rb], c += wc[rc], ++ib)
  451. outer(pa, rc-1, c, nc, wc, ra, a, na, wa, rb-1, b, nb, wb);
  452. else if(ra > 1)
  453. for(auto ia = 0u; ia < na[ra]; a += wa[ra], c += wc[ra], ++ia)
  454. outer(pa, rc-1, c, nc, wc, ra-1, a, na, wa, rb, b, nb, wb);
  455. else
  456. outer_2x2(pa, c, nc, wc, a, na, wa, b, nb, wb); //assert(ra==1 && rb==1 && rc==3);
  457. }
  458. /** @brief Computes the outer product with permutation tuples
  459. *
  460. * Implements C[i1,...,ir,j1,...,js] = sum( A[i1,...,ir] * B[j1,...,js] )
  461. *
  462. * nc[x] = na[phia[x]] for 1 <= x <= r
  463. * nc[r+x] = nb[phib[x]] for 1 <= x <= s
  464. *
  465. * @note maybe called by ttt function
  466. *
  467. * @param k zero-based recursion level starting with 0
  468. * @param r number of non-contraction indices of A
  469. * @param s number of non-contraction indices of B
  470. * @param phia pointer to the permutation tuple of length r for A
  471. * @param phib pointer to the permutation tuple of length s for B
  472. * @param c pointer to the output tensor C with rank(A)=r+s
  473. * @param nc pointer to the extents of tensor C
  474. * @param wc pointer to the strides of tensor C
  475. * @param a pointer to the first input tensor with rank(A)=r
  476. * @param na pointer to the extents of the first input tensor A
  477. * @param wa pointer to the strides of the first input tensor A
  478. * @param b pointer to the second input tensor B with rank(B)=s
  479. * @param nb pointer to the extents of the second input tensor B
  480. * @param wb pointer to the strides of the second input tensor B
  481. */
  482. template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
  483. void outer(SizeType const k,
  484. SizeType const r, SizeType const s,
  485. SizeType const*const phia, SizeType const*const phib,
  486. PointerOut c, SizeType const*const nc, SizeType const*const wc,
  487. PointerIn1 a, SizeType const*const na, SizeType const*const wa,
  488. PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
  489. {
  490. if(k < r)
  491. {
  492. assert(nc[k] == na[phia[k]-1]);
  493. for(size_t ic = 0u; ic < nc[k]; a += wa[phia[k]-1], c += wc[k], ++ic)
  494. outer(k+1, r, s, phia,phib, c, nc, wc, a, na, wa, b, nb, wb);
  495. }
  496. else if(k < r+s-1)
  497. {
  498. assert(nc[k] == nb[phib[k-r]-1]);
  499. for(size_t ic = 0u; ic < nc[k]; b += wb[phib[k-r]-1], c += wc[k], ++ic)
  500. outer(k+1, r, s, phia, phib, c, nc, wc, a, na, wa, b, nb, wb);
  501. }
  502. else
  503. {
  504. assert(nc[k] == nb[phib[k-r]-1]);
  505. for(size_t ic = 0u; ic < nc[k]; b += wb[phib[k-r]-1], c += wc[k], ++ic)
  506. *c = *a * *b;
  507. }
  508. }
  509. } // namespace recursive
  510. } // namespace detail
  511. } // namespace ublas
  512. } // namespace numeric
  513. } // namespace boost
  514. //////////////////////////////////////////////////////////////////////////////////////////
  515. //////////////////////////////////////////////////////////////////////////////////////////
  516. //////////////////////////////////////////////////////////////////////////////////////////
  517. //////////////////////////////////////////////////////////////////////////////////////////
  518. //////////////////////////////////////////////////////////////////////////////////////////
  519. //////////////////////////////////////////////////////////////////////////////////////////
  520. //////////////////////////////////////////////////////////////////////////////////////////
  521. //////////////////////////////////////////////////////////////////////////////////////////
  522. #include <stdexcept>
  523. namespace boost {
  524. namespace numeric {
  525. namespace ublas {
  526. /** @brief Computes the tensor-times-vector product
  527. *
  528. * Implements
  529. * C[i1,i2,...,im-1,im+1,...,ip] = sum(A[i1,i2,...,im,...,ip] * b[im]) for m>1 and
  530. * C[i2,...,ip] = sum(A[i1,...,ip] * b[i1]) for m=1
  531. *
  532. * @note calls detail::ttv, detail::ttv0 or detail::mtv
  533. *
  534. * @param[in] m contraction mode with 0 < m <= p
  535. * @param[in] p number of dimensions (rank) of the first input tensor with p > 0
  536. * @param[out] c pointer to the output tensor with rank p-1
  537. * @param[in] nc pointer to the extents of tensor c
  538. * @param[in] wc pointer to the strides of tensor c
  539. * @param[in] a pointer to the first input tensor
  540. * @param[in] na pointer to the extents of input tensor a
  541. * @param[in] wa pointer to the strides of input tensor a
  542. * @param[in] b pointer to the second input tensor
  543. * @param[in] nb pointer to the extents of input tensor b
  544. * @param[in] wb pointer to the strides of input tensor b
  545. */
  546. template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
  547. void ttv(SizeType const m, SizeType const p,
  548. PointerOut c, SizeType const*const nc, SizeType const*const wc,
  549. const PointerIn1 a, SizeType const*const na, SizeType const*const wa,
  550. const PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
  551. {
  552. static_assert( std::is_pointer<PointerOut>::value & std::is_pointer<PointerIn1>::value & std::is_pointer<PointerIn2>::value,
  553. "Static error in boost::numeric::ublas::ttv: Argument types for pointers are not pointer types.");
  554. if( m == 0)
  555. throw std::length_error("Error in boost::numeric::ublas::ttv: Contraction mode must be greater than zero.");
  556. if( p < m )
  557. throw std::length_error("Error in boost::numeric::ublas::ttv: Rank must be greater equal the modus.");
  558. if( p == 0)
  559. throw std::length_error("Error in boost::numeric::ublas::ttv: Rank must be greater than zero.");
  560. if(c == nullptr || a == nullptr || b == nullptr)
  561. throw std::length_error("Error in boost::numeric::ublas::ttv: Pointers shall not be null pointers.");
  562. for(auto i = 0u; i < m-1; ++i)
  563. if(na[i] != nc[i])
  564. throw std::length_error("Error in boost::numeric::ublas::ttv: Extents (except of dimension mode) of A and C must be equal.");
  565. for(auto i = m; i < p; ++i)
  566. if(na[i] != nc[i-1])
  567. throw std::length_error("Error in boost::numeric::ublas::ttv: Extents (except of dimension mode) of A and C must be equal.");
  568. const auto max = std::max(nb[0], nb[1]);
  569. if( na[m-1] != max)
  570. throw std::length_error("Error in boost::numeric::ublas::ttv: Extent of dimension mode of A and b must be equal.");
  571. if((m != 1) && (p > 2))
  572. detail::recursive::ttv(m-1, p-1, p-2, c, nc, wc, a, na, wa, b);
  573. else if ((m == 1) && (p > 2))
  574. detail::recursive::ttv0(p-1, c, nc, wc, a, na, wa, b);
  575. else if( p == 2 )
  576. detail::recursive::mtv(m-1, c, nc, wc, a, na, wa, b);
  577. else /*if( p == 1 )*/{
  578. auto v = std::remove_pointer_t<std::remove_cv_t<PointerOut>>{};
  579. *c = detail::recursive::inner(SizeType(0), na, a, wa, b, wb, v);
  580. }
  581. }
  582. /** @brief Computes the tensor-times-matrix product
  583. *
  584. * Implements
  585. * C[i1,i2,...,im-1,j,im+1,...,ip] = sum(A[i1,i2,...,im,...,ip] * B[j,im]) for m>1 and
  586. * C[j,i2,...,ip] = sum(A[i1,i2,...,ip] * B[j,i1]) for m=1
  587. *
  588. * @note calls detail::ttm or detail::ttm0
  589. *
  590. * @param[in] m contraction mode with 0 < m <= p
  591. * @param[in] p number of dimensions (rank) of the first input tensor with p > 0
  592. * @param[out] c pointer to the output tensor with rank p-1
  593. * @param[in] nc pointer to the extents of tensor c
  594. * @param[in] wc pointer to the strides of tensor c
  595. * @param[in] a pointer to the first input tensor
  596. * @param[in] na pointer to the extents of input tensor a
  597. * @param[in] wa pointer to the strides of input tensor a
  598. * @param[in] b pointer to the second input tensor
  599. * @param[in] nb pointer to the extents of input tensor b
  600. * @param[in] wb pointer to the strides of input tensor b
  601. */
  602. template <class PointerIn1, class PointerIn2, class PointerOut, class SizeType>
  603. void ttm(SizeType const m, SizeType const p,
  604. PointerOut c, SizeType const*const nc, SizeType const*const wc,
  605. const PointerIn1 a, SizeType const*const na, SizeType const*const wa,
  606. const PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
  607. {
  608. static_assert( std::is_pointer<PointerOut>::value & std::is_pointer<PointerIn1>::value & std::is_pointer<PointerIn2>::value,
  609. "Static error in boost::numeric::ublas::ttm: Argument types for pointers are not pointer types.");
  610. if( m == 0 )
  611. throw std::length_error("Error in boost::numeric::ublas::ttm: Contraction mode must be greater than zero.");
  612. if( p < m )
  613. throw std::length_error("Error in boost::numeric::ublas::ttm: Rank must be greater equal than the specified mode.");
  614. if( p == 0)
  615. throw std::length_error("Error in boost::numeric::ublas::ttm:Rank must be greater than zero.");
  616. if(c == nullptr || a == nullptr || b == nullptr)
  617. throw std::length_error("Error in boost::numeric::ublas::ttm: Pointers shall not be null pointers.");
  618. for(auto i = 0u; i < m-1; ++i)
  619. if(na[i] != nc[i])
  620. throw std::length_error("Error in boost::numeric::ublas::ttm: Extents (except of dimension mode) of A and C must be equal.");
  621. for(auto i = m; i < p; ++i)
  622. if(na[i] != nc[i])
  623. throw std::length_error("Error in boost::numeric::ublas::ttm: Extents (except of dimension mode) of A and C must be equal.");
  624. if(na[m-1] != nb[1])
  625. throw std::length_error("Error in boost::numeric::ublas::ttm: 2nd Extent of B and M-th Extent of A must be the equal.");
  626. if(nc[m-1] != nb[0])
  627. throw std::length_error("Error in boost::numeric::ublas::ttm: 1nd Extent of B and M-th Extent of C must be the equal.");
  628. if ( m != 1 )
  629. detail::recursive::ttm (m-1, p-1, c, nc, wc, a, na, wa, b, nb, wb);
  630. else /*if (m == 1 && p > 2)*/
  631. detail::recursive::ttm0( p-1, c, nc, wc, a, na, wa, b, nb, wb);
  632. }
  633. /** @brief Computes the tensor-times-tensor product
  634. *
  635. * Implements C[i1,...,ir,j1,...,js] = sum( A[i1,...,ir+q] * B[j1,...,js+q] )
  636. *
  637. * @note calls detail::recursive::ttt or ttm or ttv or inner or outer
  638. *
  639. * nc[x] = na[phia[x] ] for 1 <= x <= r
  640. * nc[r+x] = nb[phib[x] ] for 1 <= x <= s
  641. * na[phia[r+x]] = nb[phib[s+x]] for 1 <= x <= q
  642. *
  643. * @param[in] pa number of dimensions (rank) of the first input tensor a with pa > 0
  644. * @param[in] pb number of dimensions (rank) of the second input tensor b with pb > 0
  645. * @param[in] q number of contraction dimensions with pa >= q and pb >= q and q >= 0
  646. * @param[in] phia pointer to a permutation tuple for the first input tensor a
  647. * @param[in] phib pointer to a permutation tuple for the second input tensor b
  648. * @param[out] c pointer to the output tensor with rank p-1
  649. * @param[in] nc pointer to the extents of tensor c
  650. * @param[in] wc pointer to the strides of tensor c
  651. * @param[in] a pointer to the first input tensor
  652. * @param[in] na pointer to the extents of input tensor a
  653. * @param[in] wa pointer to the strides of input tensor a
  654. * @param[in] b pointer to the second input tensor
  655. * @param[in] nb pointer to the extents of input tensor b
  656. * @param[in] wb pointer to the strides of input tensor b
  657. */
  658. template <class PointerIn1, class PointerIn2, class PointerOut, class SizeType>
  659. void ttt(SizeType const pa, SizeType const pb, SizeType const q,
  660. SizeType const*const phia, SizeType const*const phib,
  661. PointerOut c, SizeType const*const nc, SizeType const*const wc,
  662. PointerIn1 a, SizeType const*const na, SizeType const*const wa,
  663. PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
  664. {
  665. static_assert( std::is_pointer<PointerOut>::value & std::is_pointer<PointerIn1>::value & std::is_pointer<PointerIn2>::value,
  666. "Static error in boost::numeric::ublas::ttm: Argument types for pointers are not pointer types.");
  667. if( pa == 0 || pb == 0)
  668. throw std::length_error("Error in boost::numeric::ublas::ttt: tensor order must be greater zero.");
  669. if( q > pa && q > pb)
  670. throw std::length_error("Error in boost::numeric::ublas::ttt: number of contraction must be smaller than or equal to the tensor order.");
  671. SizeType const r = pa - q;
  672. SizeType const s = pb - q;
  673. if(c == nullptr || a == nullptr || b == nullptr)
  674. throw std::length_error("Error in boost::numeric::ublas::ttm: Pointers shall not be null pointers.");
  675. for(auto i = 0ul; i < r; ++i)
  676. if( na[phia[i]-1] != nc[i] )
  677. throw std::length_error("Error in boost::numeric::ublas::ttt: dimensions of lhs and res tensor not correct.");
  678. for(auto i = 0ul; i < s; ++i)
  679. if( nb[phib[i]-1] != nc[r+i] )
  680. throw std::length_error("Error in boost::numeric::ublas::ttt: dimensions of rhs and res not correct.");
  681. for(auto i = 0ul; i < q; ++i)
  682. if( nb[phib[s+i]-1] != na[phia[r+i]-1] )
  683. throw std::length_error("Error in boost::numeric::ublas::ttt: dimensions of lhs and rhs not correct.");
  684. if(q == 0ul)
  685. detail::recursive::outer(SizeType{0},r,s, phia,phib, c,nc,wc, a,na,wa, b,nb,wb);
  686. else
  687. detail::recursive::ttt(SizeType{0},r,s,q, phia,phib, c,nc,wc, a,na,wa, b,nb,wb);
  688. }
  689. /** @brief Computes the tensor-times-tensor product
  690. *
  691. * Implements C[i1,...,ir,j1,...,js] = sum( A[i1,...,ir+q] * B[j1,...,js+q] )
  692. *
  693. * @note calls detail::recursive::ttt or ttm or ttv or inner or outer
  694. *
  695. * nc[x] = na[x ] for 1 <= x <= r
  696. * nc[r+x] = nb[x ] for 1 <= x <= s
  697. * na[r+x] = nb[s+x] for 1 <= x <= q
  698. *
  699. * @param[in] pa number of dimensions (rank) of the first input tensor a with pa > 0
  700. * @param[in] pb number of dimensions (rank) of the second input tensor b with pb > 0
  701. * @param[in] q number of contraction dimensions with pa >= q and pb >= q and q >= 0
  702. * @param[out] c pointer to the output tensor with rank p-1
  703. * @param[in] nc pointer to the extents of tensor c
  704. * @param[in] wc pointer to the strides of tensor c
  705. * @param[in] a pointer to the first input tensor
  706. * @param[in] na pointer to the extents of input tensor a
  707. * @param[in] wa pointer to the strides of input tensor a
  708. * @param[in] b pointer to the second input tensor
  709. * @param[in] nb pointer to the extents of input tensor b
  710. * @param[in] wb pointer to the strides of input tensor b
  711. */
  712. template <class PointerIn1, class PointerIn2, class PointerOut, class SizeType>
  713. void ttt(SizeType const pa, SizeType const pb, SizeType const q,
  714. PointerOut c, SizeType const*const nc, SizeType const*const wc,
  715. PointerIn1 a, SizeType const*const na, SizeType const*const wa,
  716. PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
  717. {
  718. static_assert( std::is_pointer<PointerOut>::value & std::is_pointer<PointerIn1>::value & std::is_pointer<PointerIn2>::value,
  719. "Static error in boost::numeric::ublas::ttm: Argument types for pointers are not pointer types.");
  720. if( pa == 0 || pb == 0)
  721. throw std::length_error("Error in boost::numeric::ublas::ttt: tensor order must be greater zero.");
  722. if( q > pa && q > pb)
  723. throw std::length_error("Error in boost::numeric::ublas::ttt: number of contraction must be smaller than or equal to the tensor order.");
  724. SizeType const r = pa - q;
  725. SizeType const s = pb - q;
  726. SizeType const pc = r+s;
  727. if(c == nullptr || a == nullptr || b == nullptr)
  728. throw std::length_error("Error in boost::numeric::ublas::ttm: Pointers shall not be null pointers.");
  729. for(auto i = 0ul; i < r; ++i)
  730. if( na[i] != nc[i] )
  731. throw std::length_error("Error in boost::numeric::ublas::ttt: dimensions of lhs and res tensor not correct.");
  732. for(auto i = 0ul; i < s; ++i)
  733. if( nb[i] != nc[r+i] )
  734. throw std::length_error("Error in boost::numeric::ublas::ttt: dimensions of rhs and res not correct.");
  735. for(auto i = 0ul; i < q; ++i)
  736. if( nb[s+i] != na[r+i] )
  737. throw std::length_error("Error in boost::numeric::ublas::ttt: dimensions of lhs and rhs not correct.");
  738. using value_type = std::decay_t<decltype(*c)>;
  739. if(q == 0ul)
  740. detail::recursive::outer(pa, pc-1, c,nc,wc, pa-1, a,na,wa, pb-1, b,nb,wb);
  741. else if(r == 0ul && s == 0ul)
  742. *c = detail::recursive::inner(q-1, na, a,wa, b,wb, value_type(0) );
  743. else
  744. detail::recursive::ttt(SizeType{0},r,s,q, c,nc,wc, a,na,wa, b,nb,wb);
  745. }
  746. /** @brief Computes the inner product of two tensors
  747. *
  748. * Implements c = sum(A[i1,i2,...,ip] * B[i1,i2,...,ip])
  749. *
  750. * @note calls detail::inner
  751. *
  752. * @param[in] p number of dimensions (rank) of the first input tensor with p > 0
  753. * @param[in] n pointer to the extents of input or output tensor
  754. * @param[in] a pointer to the first input tensor
  755. * @param[in] wa pointer to the strides of input tensor a
  756. * @param[in] b pointer to the second input tensor
  757. * @param[in] wb pointer to the strides of input tensor b
  758. * @param[in] v inital value
  759. *
  760. * @return inner product of two tensors.
  761. */
  762. template <class PointerIn1, class PointerIn2, class value_t, class SizeType>
  763. auto inner(const SizeType p, SizeType const*const n,
  764. const PointerIn1 a, SizeType const*const wa,
  765. const PointerIn2 b, SizeType const*const wb,
  766. value_t v)
  767. {
  768. static_assert( std::is_pointer<PointerIn1>::value && std::is_pointer<PointerIn2>::value,
  769. "Static error in boost::numeric::ublas::inner: Argument types for pointers must be pointer types.");
  770. if(p<2)
  771. throw std::length_error("Error in boost::numeric::ublas::inner: Rank must be greater than zero.");
  772. if(a == nullptr || b == nullptr)
  773. throw std::length_error("Error in boost::numeric::ublas::inner: Pointers shall not be null pointers.");
  774. return detail::recursive::inner(p-1, n, a, wa, b, wb, v);
  775. }
  776. /** @brief Computes the outer product of two tensors
  777. *
  778. * Implements C[i1,...,ip,j1,...,jq] = A[i1,i2,...,ip] * B[j1,j2,...,jq]
  779. *
  780. * @note calls detail::outer
  781. *
  782. * @param[out] c pointer to the output tensor
  783. * @param[in] pc number of dimensions (rank) of the output tensor c with pc > 0
  784. * @param[in] nc pointer to the extents of output tensor c
  785. * @param[in] wc pointer to the strides of output tensor c
  786. * @param[in] a pointer to the first input tensor
  787. * @param[in] pa number of dimensions (rank) of the first input tensor a with pa > 0
  788. * @param[in] na pointer to the extents of the first input tensor a
  789. * @param[in] wa pointer to the strides of the first input tensor a
  790. * @param[in] b pointer to the second input tensor
  791. * @param[in] pb number of dimensions (rank) of the second input tensor b with pb > 0
  792. * @param[in] nb pointer to the extents of the second input tensor b
  793. * @param[in] wb pointer to the strides of the second input tensor b
  794. */
  795. template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
  796. void outer(PointerOut c, SizeType const pc, SizeType const*const nc, SizeType const*const wc,
  797. const PointerIn1 a, SizeType const pa, SizeType const*const na, SizeType const*const wa,
  798. const PointerIn2 b, SizeType const pb, SizeType const*const nb, SizeType const*const wb)
  799. {
  800. static_assert( std::is_pointer<PointerIn1>::value & std::is_pointer<PointerIn2>::value & std::is_pointer<PointerOut>::value,
  801. "Static error in boost::numeric::ublas::outer: argument types for pointers must be pointer types.");
  802. if(pa < 2u || pb < 2u)
  803. throw std::length_error("Error in boost::numeric::ublas::outer: number of extents of lhs and rhs tensor must be equal or greater than two.");
  804. if((pa + pb) != pc)
  805. throw std::length_error("Error in boost::numeric::ublas::outer: number of extents of lhs plus rhs tensor must be equal to the number of extents of C.");
  806. if(a == nullptr || b == nullptr || c == nullptr)
  807. throw std::length_error("Error in boost::numeric::ublas::outer: pointers shall not be null pointers.");
  808. detail::recursive::outer(pa, pc-1, c, nc, wc, pa-1, a, na, wa, pb-1, b, nb, wb);
  809. }
  810. }
  811. }
  812. }
  813. #endif