autodiff_example.cpp 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852
  1. // Copyright (C) 2016-2018 T. Zachary Laine
  2. //
  3. // Distributed under the Boost Software License, Version 1.0. (See
  4. // accompanying file LICENSE_1_0.txt or copy at
  5. // http://www.boost.org/LICENSE_1_0.txt)
  6. #include "autodiff.h"
  7. #include <iostream>
  8. #include <boost/yap/algorithm.hpp>
  9. #include <boost/polymorphic_cast.hpp>
  10. #include <boost/hana/for_each.hpp>
  11. #define BOOST_TEST_MODULE autodiff_test
  12. #include <boost/test/included/unit_test.hpp>
  13. double const Epsilon = 10.0e-6;
  14. #define CHECK_CLOSE(A,B) do { BOOST_CHECK_CLOSE(A,B,Epsilon); } while(0)
  15. using namespace AutoDiff;
  16. //[ autodiff_expr_template_decl
  17. template <boost::yap::expr_kind Kind, typename Tuple>
  18. struct autodiff_expr
  19. {
  20. static boost::yap::expr_kind const kind = Kind;
  21. Tuple elements;
  22. };
  23. BOOST_YAP_USER_UNARY_OPERATOR(negate, autodiff_expr, autodiff_expr)
  24. BOOST_YAP_USER_BINARY_OPERATOR(plus, autodiff_expr, autodiff_expr)
  25. BOOST_YAP_USER_BINARY_OPERATOR(minus, autodiff_expr, autodiff_expr)
  26. BOOST_YAP_USER_BINARY_OPERATOR(multiplies, autodiff_expr, autodiff_expr)
  27. BOOST_YAP_USER_BINARY_OPERATOR(divides, autodiff_expr, autodiff_expr)
  28. //]
  29. //[ autodiff_expr_literals_decl
  30. namespace autodiff_placeholders {
  31. // This defines a placeholder literal operator that creates autodiff_expr
  32. // placeholders.
  33. BOOST_YAP_USER_LITERAL_PLACEHOLDER_OPERATOR(autodiff_expr)
  34. }
  35. //]
  36. //[ autodiff_function_terminals
  37. template <OPCODE Opcode>
  38. struct autodiff_fn_expr :
  39. autodiff_expr<boost::yap::expr_kind::terminal, boost::hana::tuple<OPCODE>>
  40. {
  41. autodiff_fn_expr () :
  42. autodiff_expr {boost::hana::tuple<OPCODE>{Opcode}}
  43. {}
  44. BOOST_YAP_USER_CALL_OPERATOR_N(::autodiff_expr, 1);
  45. };
  46. // Someone included <math.h>, so we have to add trailing underscores.
  47. autodiff_fn_expr<OP_SIN> const sin_;
  48. autodiff_fn_expr<OP_COS> const cos_;
  49. autodiff_fn_expr<OP_SQRT> const sqrt_;
  50. //]
  51. //[ autodiff_xform
  52. struct xform
  53. {
  54. // Create a var-node for each placeholder when we see it for the first
  55. // time.
  56. template <long long I>
  57. Node * operator() (boost::yap::expr_tag<boost::yap::expr_kind::terminal>,
  58. boost::yap::placeholder<I>)
  59. {
  60. if (list_.size() < I)
  61. list_.resize(I);
  62. auto & retval = list_[I - 1];
  63. if (retval == nullptr)
  64. retval = create_var_node();
  65. return retval;
  66. }
  67. // Create a param-node for every numeric terminal in the expression.
  68. Node * operator() (boost::yap::expr_tag<boost::yap::expr_kind::terminal>, double x)
  69. { return create_param_node(x); }
  70. // Create a "uary" node for each call expression, using its OPCODE.
  71. template <typename Expr>
  72. Node * operator() (boost::yap::expr_tag<boost::yap::expr_kind::call>,
  73. OPCODE opcode, Expr const & expr)
  74. {
  75. return create_uary_op_node(
  76. opcode,
  77. boost::yap::transform(boost::yap::as_expr<autodiff_expr>(expr), *this)
  78. );
  79. }
  80. template <typename Expr>
  81. Node * operator() (boost::yap::expr_tag<boost::yap::expr_kind::negate>,
  82. Expr const & expr)
  83. {
  84. return create_uary_op_node(
  85. OP_NEG,
  86. boost::yap::transform(boost::yap::as_expr<autodiff_expr>(expr), *this)
  87. );
  88. }
  89. // Define a mapping from binary arithmetic expr_kind to OPCODE...
  90. static OPCODE op_for_kind (boost::yap::expr_kind kind)
  91. {
  92. switch (kind) {
  93. case boost::yap::expr_kind::plus: return OP_PLUS;
  94. case boost::yap::expr_kind::minus: return OP_MINUS;
  95. case boost::yap::expr_kind::multiplies: return OP_TIMES;
  96. case boost::yap::expr_kind::divides: return OP_DIVID;
  97. default: assert(!"This should never execute"); return OPCODE{};
  98. }
  99. assert(!"This should never execute");
  100. return OPCODE{};
  101. }
  102. // ... and use it to handle all the binary arithmetic operators.
  103. template <boost::yap::expr_kind Kind, typename Expr1, typename Expr2>
  104. Node * operator() (boost::yap::expr_tag<Kind>, Expr1 const & expr1, Expr2 const & expr2)
  105. {
  106. return create_binary_op_node(
  107. op_for_kind(Kind),
  108. boost::yap::transform(boost::yap::as_expr<autodiff_expr>(expr1), *this),
  109. boost::yap::transform(boost::yap::as_expr<autodiff_expr>(expr2), *this)
  110. );
  111. }
  112. vector<Node *> & list_;
  113. };
  114. //]
  115. //[ autodiff_to_node
  116. template <typename Expr, typename ...T>
  117. Node * to_auto_diff_node (Expr const & expr, vector<Node *> & list, T ... args)
  118. {
  119. Node * retval = nullptr;
  120. // This fills in list as a side effect.
  121. retval = boost::yap::transform(expr, xform{list});
  122. assert(list.size() == sizeof...(args));
  123. // Fill in the values of the value-nodes in list with the "args"
  124. // parameter pack.
  125. auto it = list.begin();
  126. boost::hana::for_each(
  127. boost::hana::make_tuple(args ...),
  128. [&it](auto x) {
  129. Node * n = *it;
  130. VNode * v = boost::polymorphic_downcast<VNode *>(n);
  131. v->val = x;
  132. ++it;
  133. }
  134. );
  135. return retval;
  136. }
  137. //]
  138. struct F{
  139. F() { AutoDiff::autodiff_setup(); }
  140. ~F(){ AutoDiff::autodiff_cleanup(); }
  141. };
  142. BOOST_FIXTURE_TEST_SUITE(all, F)
  143. //[ autodiff_original_node_builder
  144. Node* build_linear_fun1_manually(vector<Node*>& list)
  145. {
  146. //f(x1,x2,x3) = -5*x1+sin(10)*x1+10*x2-x3/6
  147. PNode* v5 = create_param_node(-5);
  148. PNode* v10 = create_param_node(10);
  149. PNode* v6 = create_param_node(6);
  150. VNode* x1 = create_var_node();
  151. VNode* x2 = create_var_node();
  152. VNode* x3 = create_var_node();
  153. OPNode* op1 = create_binary_op_node(OP_TIMES,v5,x1); //op1 = v5*x1
  154. OPNode* op2 = create_uary_op_node(OP_SIN,v10); //op2 = sin(v10)
  155. OPNode* op3 = create_binary_op_node(OP_TIMES,op2,x1); //op3 = op2*x1
  156. OPNode* op4 = create_binary_op_node(OP_PLUS,op1,op3); //op4 = op1 + op3
  157. OPNode* op5 = create_binary_op_node(OP_TIMES,v10,x2); //op5 = v10*x2
  158. OPNode* op6 = create_binary_op_node(OP_PLUS,op4,op5); //op6 = op4+op5
  159. OPNode* op7 = create_binary_op_node(OP_DIVID,x3,v6); //op7 = x3/v6
  160. OPNode* op8 = create_binary_op_node(OP_MINUS,op6,op7); //op8 = op6 - op7
  161. x1->val = -1.9;
  162. x2->val = 2;
  163. x3->val = 5./6.;
  164. list.push_back(x1);
  165. list.push_back(x2);
  166. list.push_back(x3);
  167. return op8;
  168. }
  169. //]
  170. //[ autodiff_yap_node_builder
  171. Node* build_linear_fun1(vector<Node*>& list)
  172. {
  173. //f(x1,x2,x3) = -5*x1+sin(10)*x1+10*x2-x3/6
  174. using namespace autodiff_placeholders;
  175. return to_auto_diff_node(
  176. -5 * 1_p + sin_(10) * 1_p + 10 * 2_p - 3_p / 6,
  177. list,
  178. -1.9,
  179. 2,
  180. 5./6.
  181. );
  182. }
  183. //]
  184. Node* build_linear_function2_manually(vector<Node*>& list)
  185. {
  186. //f(x1,x2,x3) = -5*x1+-10*x1+10*x2-x3/6
  187. PNode* v5 = create_param_node(-5);
  188. PNode* v10 = create_param_node(10);
  189. PNode* v6 = create_param_node(6);
  190. VNode* x1 = create_var_node();
  191. VNode* x2 = create_var_node();
  192. VNode* x3 = create_var_node();
  193. list.push_back(x1);
  194. list.push_back(x2);
  195. list.push_back(x3);
  196. OPNode* op1 = create_binary_op_node(OP_TIMES,v5,x1); //op1 = v5*x1
  197. OPNode* op2 = create_uary_op_node(OP_NEG,v10); //op2 = -v10
  198. OPNode* op3 = create_binary_op_node(OP_TIMES,op2,x1);//op3 = op2*x1
  199. OPNode* op4 = create_binary_op_node(OP_PLUS,op1,op3);//op4 = op1 + op3
  200. OPNode* op5 = create_binary_op_node(OP_TIMES,v10,x2);//op5 = v10*x2
  201. OPNode* op6 = create_binary_op_node(OP_PLUS,op4,op5);//op6 = op4+op5
  202. OPNode* op7 = create_binary_op_node(OP_DIVID,x3,v6); //op7 = x3/v6
  203. OPNode* op8 = create_binary_op_node(OP_MINUS,op6,op7);//op8 = op6 - op7
  204. x1->val = -1.9;
  205. x2->val = 2;
  206. x3->val = 5./6.;
  207. return op8;
  208. }
  209. Node* build_linear_function2(vector<Node*>& list)
  210. {
  211. //f(x1,x2,x3) = -5*x1+-10*x1+10*x2-x3/6
  212. using namespace autodiff_placeholders;
  213. auto ten = boost::yap::make_terminal<autodiff_expr>(10);
  214. return to_auto_diff_node(
  215. -5 * 1_p + -ten * 1_p + 10 * 2_p - 3_p / 6,
  216. list,
  217. -1.9,
  218. 2,
  219. 5./6.
  220. );
  221. }
  222. Node* build_nl_function1_manually(vector<Node*>& list)
  223. {
  224. // (x1*x2 * sin(x1))/x3 + x2*x4 - x1/x2
  225. VNode* x1 = create_var_node();
  226. VNode* x2 = create_var_node();
  227. VNode* x3 = create_var_node();
  228. VNode* x4 = create_var_node();
  229. x1->val = -1.23;
  230. x2->val = 7.1231;
  231. x3->val = 2;
  232. x4->val = -10;
  233. list.push_back(x1);
  234. list.push_back(x2);
  235. list.push_back(x3);
  236. list.push_back(x4);
  237. OPNode* op1 = create_binary_op_node(OP_TIMES,x2,x1);
  238. OPNode* op2 = create_uary_op_node(OP_SIN,x1);
  239. OPNode* op3 = create_binary_op_node(OP_TIMES,op1,op2);
  240. OPNode* op4 = create_binary_op_node(OP_DIVID,op3,x3);
  241. OPNode* op5 = create_binary_op_node(OP_TIMES,x2,x4);
  242. OPNode* op6 = create_binary_op_node(OP_PLUS,op4,op5);
  243. OPNode* op7 = create_binary_op_node(OP_DIVID,x1,x2);
  244. OPNode* op8 = create_binary_op_node(OP_MINUS,op6,op7);
  245. return op8;
  246. }
  247. Node* build_nl_function1(vector<Node*>& list)
  248. {
  249. // (x1*x2 * sin(x1))/x3 + x2*x4 - x1/x2
  250. using namespace autodiff_placeholders;
  251. return to_auto_diff_node(
  252. (1_p * 2_p * sin_(1_p)) / 3_p + 2_p * 4_p - 1_p / 2_p,
  253. list,
  254. -1.23,
  255. 7.1231,
  256. 2,
  257. -10
  258. );
  259. }
  260. BOOST_AUTO_TEST_CASE( test_linear_fun1 )
  261. {
  262. BOOST_TEST_MESSAGE("test_linear_fun1");
  263. vector<Node*> list;
  264. Node* root = build_linear_fun1(list);
  265. vector<double> grad;
  266. double val1 = grad_reverse(root,list,grad);
  267. double val2 = eval_function(root);
  268. double x1g[] = {-5.5440211108893697744548489936278,10.0,-0.16666666666666666666666666666667};
  269. for(unsigned int i=0;i<3;i++){
  270. CHECK_CLOSE(grad[i],x1g[i]);
  271. }
  272. double eval = 30.394751221800913;
  273. CHECK_CLOSE(val1,eval);
  274. CHECK_CLOSE(val2,eval);
  275. EdgeSet s;
  276. nonlinearEdges(root,s);
  277. unsigned int n = nzHess(s);
  278. BOOST_CHECK_EQUAL(n,0);
  279. }
  280. BOOST_AUTO_TEST_CASE( test_grad_sin )
  281. {
  282. BOOST_TEST_MESSAGE("test_grad_sin");
  283. VNode* x1 = create_var_node();
  284. x1->val = 10;
  285. OPNode* root = create_uary_op_node(OP_SIN,x1);
  286. vector<Node*> nodes;
  287. nodes.push_back(x1);
  288. vector<double> grad;
  289. grad_reverse(root,nodes,grad);
  290. double x1g = -0.83907152907645244;
  291. //the matlab give cos(10) = -0.839071529076452
  292. CHECK_CLOSE(grad[0],x1g);
  293. BOOST_CHECK_EQUAL(nodes.size(),1);
  294. EdgeSet s;
  295. nonlinearEdges(root,s);
  296. unsigned int n = nzHess(s);
  297. BOOST_CHECK_EQUAL(n,1);
  298. }
  299. BOOST_AUTO_TEST_CASE(test_grad_single_node)
  300. {
  301. VNode* x1 = create_var_node();
  302. x1->val = -2;
  303. vector<Node*> nodes;
  304. nodes.push_back(x1);
  305. vector<double> grad;
  306. double val = grad_reverse(x1,nodes,grad);
  307. CHECK_CLOSE(grad[0],1);
  308. CHECK_CLOSE(val,-2);
  309. EdgeSet s;
  310. unsigned int n = 0;
  311. nonlinearEdges(x1,s);
  312. n = nzHess(s);
  313. BOOST_CHECK_EQUAL(n,0);
  314. grad.clear();
  315. nodes.clear();
  316. PNode* p = create_param_node(-10);
  317. //OPNode* op = create_binary_op_node(TIMES,p,create_param_node(2));
  318. val = grad_reverse(p,nodes,grad);
  319. BOOST_CHECK_EQUAL(grad.size(),0);
  320. CHECK_CLOSE(val,-10);
  321. s.clear();
  322. nonlinearEdges(p,s);
  323. n = nzHess(s);
  324. BOOST_CHECK_EQUAL(n,0);
  325. }
  326. BOOST_AUTO_TEST_CASE(test_grad_neg)
  327. {
  328. VNode* x1 = create_var_node();
  329. x1->val = 10;
  330. PNode* p2 = create_param_node(-1);
  331. vector<Node*> nodes;
  332. vector<double> grad;
  333. nodes.push_back(x1);
  334. Node* root = create_binary_op_node(OP_TIMES,x1,p2);
  335. grad_reverse(root,nodes,grad);
  336. CHECK_CLOSE(grad[0],-1);
  337. BOOST_CHECK_EQUAL(nodes.size(),1);
  338. nodes.clear();
  339. grad.clear();
  340. nodes.push_back(x1);
  341. root = create_uary_op_node(OP_NEG,x1);
  342. grad_reverse(root,nodes,grad);
  343. CHECK_CLOSE(grad[0],-1);
  344. EdgeSet s;
  345. unsigned int n = 0;
  346. nonlinearEdges(root,s);
  347. n = nzHess(s);
  348. BOOST_CHECK_EQUAL(n,0);
  349. }
  350. BOOST_AUTO_TEST_CASE( test_nl_function)
  351. {
  352. vector<Node*> list;
  353. Node* root = build_nl_function1(list);
  354. double val = eval_function(root);
  355. vector<double> grad;
  356. grad_reverse(root,list,grad);
  357. double eval =-66.929555552886214;
  358. double gx[] = {-4.961306690356109,-9.444611307649055,-2.064383410399700,7.123100000000000};
  359. CHECK_CLOSE(val,eval);
  360. for(unsigned int i=0;i<4;i++)
  361. {
  362. CHECK_CLOSE(grad[i],gx[i]);
  363. }
  364. unsigned int nzgrad = nzGrad(root);
  365. unsigned int tol = numTotalNodes(root);
  366. BOOST_CHECK_EQUAL(nzgrad,4);
  367. BOOST_CHECK_EQUAL(tol,16);
  368. EdgeSet s;
  369. nonlinearEdges(root,s);
  370. unsigned int n = nzHess(s);
  371. BOOST_CHECK_EQUAL(n,11);
  372. }
  373. BOOST_AUTO_TEST_CASE( test_hess_reverse_1)
  374. {
  375. vector<Node*> nodes;
  376. Node* root = build_linear_fun1(nodes);
  377. vector<double> grad;
  378. double val = grad_reverse(root,nodes,grad);
  379. double eval = eval_function(root);
  380. // cout<<eval<<"\t"<<grad[0]<<"\t"<<grad[1]<<"\t"<<grad[2]<<"\t"<<endl;
  381. CHECK_CLOSE(val,eval);
  382. for(unsigned int i=0;i<nodes.size();i++)
  383. {
  384. static_cast<VNode*>(nodes[i])->u = 0;
  385. }
  386. static_cast<VNode*>(nodes[0])->u = 1;
  387. double hval = 0;
  388. vector<double> dhess;
  389. hval = hess_reverse(root,nodes,dhess);
  390. CHECK_CLOSE(hval,eval);
  391. for(unsigned int i=0;i<dhess.size();i++)
  392. {
  393. CHECK_CLOSE(dhess[i],0);
  394. }
  395. }
  396. BOOST_AUTO_TEST_CASE( test_hess_reverse_2)
  397. {
  398. vector<Node*> nodes;
  399. Node* root = build_linear_function2(nodes);
  400. vector<double> grad;
  401. double val = grad_reverse(root,nodes,grad);
  402. double eval = eval_function(root);
  403. CHECK_CLOSE(val,eval);
  404. for(unsigned int i=0;i<nodes.size();i++)
  405. {
  406. static_cast<VNode*>(nodes[i])->u = 0;
  407. }
  408. static_cast<VNode*>(nodes[0])->u = 1;
  409. double hval = 0;
  410. vector<double> dhess;
  411. hval = hess_reverse(root,nodes,dhess);
  412. CHECK_CLOSE(hval,eval);
  413. for(unsigned int i=0;i<dhess.size();i++)
  414. {
  415. CHECK_CLOSE(dhess[i],0);
  416. }
  417. EdgeSet s;
  418. nonlinearEdges(root,s);
  419. unsigned int n = nzHess(s);
  420. BOOST_CHECK_EQUAL(n,0);
  421. }
  422. BOOST_AUTO_TEST_CASE( test_hess_reverse_4)
  423. {
  424. vector<Node*> nodes;
  425. // Node* root = build_nl_function1(nodes);
  426. VNode* x1 = create_var_node();
  427. nodes.push_back(x1);
  428. x1->val = 1;
  429. x1->u =1;
  430. Node* op = create_uary_op_node(OP_SIN,x1);
  431. Node* root = create_uary_op_node(OP_SIN,op);
  432. vector<double> grad;
  433. double eval = eval_function(root);
  434. vector<double> dhess;
  435. double hval = hess_reverse(root,nodes,dhess);
  436. CHECK_CLOSE(hval,eval);
  437. BOOST_CHECK_EQUAL(dhess.size(),1);
  438. CHECK_CLOSE(dhess[0], -0.778395788418109);
  439. EdgeSet s;
  440. nonlinearEdges(root,s);
  441. unsigned int n = nzHess(s);
  442. BOOST_CHECK_EQUAL(n,1);
  443. }
  444. BOOST_AUTO_TEST_CASE( test_hess_reverse_3)
  445. {
  446. vector<Node*> nodes;
  447. VNode* x1 = create_var_node();
  448. VNode* x2 = create_var_node();
  449. nodes.push_back(x1);
  450. nodes.push_back(x2);
  451. x1->val = 2.5;
  452. x2->val = -9;
  453. Node* op1 = create_binary_op_node(OP_TIMES,x1,x2);
  454. Node* root = create_binary_op_node(OP_TIMES,x1,op1);
  455. double eval = eval_function(root);
  456. for(unsigned int i=0;i<nodes.size();i++)
  457. {
  458. static_cast<VNode*>(nodes[i])->u = 0;
  459. }
  460. static_cast<VNode*>(nodes[0])->u = 1;
  461. vector<double> dhess;
  462. double hval = hess_reverse(root,nodes,dhess);
  463. BOOST_CHECK_EQUAL(dhess.size(),2);
  464. CHECK_CLOSE(hval,eval);
  465. double hx[]={-18,5};
  466. for(unsigned int i=0;i<dhess.size();i++)
  467. {
  468. //Print("\t["<<i<<"]="<<dhess[i]);
  469. CHECK_CLOSE(dhess[i],hx[i]);
  470. }
  471. EdgeSet s;
  472. nonlinearEdges(root,s);
  473. unsigned int n = nzHess(s);
  474. BOOST_CHECK_EQUAL(n,3);
  475. }
  476. BOOST_AUTO_TEST_CASE( test_hess_reverse_5)
  477. {
  478. vector<Node*> nodes;
  479. VNode* x1 = create_var_node();
  480. VNode* x2 = create_var_node();
  481. nodes.push_back(x1);
  482. nodes.push_back(x2);
  483. x1->val = 2.5;
  484. x2->val = -9;
  485. Node* op1 = create_binary_op_node(OP_TIMES,x1,x1);
  486. Node* op2 = create_binary_op_node(OP_TIMES,x2,x2);
  487. Node* op3 = create_binary_op_node(OP_MINUS,op1,op2);
  488. Node* op4 = create_binary_op_node(OP_PLUS,op1,op2);
  489. Node* root = create_binary_op_node(OP_TIMES,op3,op4);
  490. double eval = eval_function(root);
  491. for(unsigned int i=0;i<nodes.size();i++)
  492. {
  493. static_cast<VNode*>(nodes[i])->u = 0;
  494. }
  495. static_cast<VNode*>(nodes[0])->u = 1;
  496. vector<double> dhess;
  497. double hval = hess_reverse(root,nodes,dhess);
  498. CHECK_CLOSE(hval,eval);
  499. double hx[] ={75,0};
  500. for(unsigned int i=0;i<dhess.size();i++)
  501. {
  502. CHECK_CLOSE(dhess[i],hx[i]);
  503. }
  504. for(unsigned int i=0;i<nodes.size();i++)
  505. {
  506. static_cast<VNode*>(nodes[i])->u = 0;
  507. }
  508. static_cast<VNode*>(nodes[1])->u = 1;
  509. double hx2[] = {0, -972};
  510. hval = hess_reverse(root,nodes,dhess);
  511. for(unsigned int i=0;i<dhess.size();i++)
  512. {
  513. CHECK_CLOSE(dhess[i],hx2[i]);
  514. }
  515. EdgeSet s;
  516. nonlinearEdges(root,s);
  517. unsigned int n = nzHess(s);
  518. BOOST_CHECK_EQUAL(n,4);
  519. }
  520. BOOST_AUTO_TEST_CASE( test_hess_reverse_6)
  521. {
  522. vector<Node*> nodes;
  523. // Node* root = build_nl_function1(nodes);
  524. VNode* x1 = create_var_node();
  525. VNode* x2 = create_var_node();
  526. nodes.push_back(x1);
  527. nodes.push_back(x2);
  528. x1->val = 2.5;
  529. x2->val = -9;
  530. Node* root = create_binary_op_node(OP_POW,x1,x2);
  531. double eval = eval_function(root);
  532. static_cast<VNode*>(nodes[0])->u=1;static_cast<VNode*>(nodes[1])->u=0;
  533. vector<double> dhess;
  534. double hval = hess_reverse(root,nodes,dhess);
  535. CHECK_CLOSE(hval,eval);
  536. double hx1[] ={0.003774873600000 , -0.000759862823419};
  537. double hx2[] ={-0.000759862823419, 0.000220093141567};
  538. for(unsigned int i=0;i<dhess.size();i++)
  539. {
  540. CHECK_CLOSE(dhess[i],hx1[i]);
  541. }
  542. static_cast<VNode*>(nodes[0])->u=0;static_cast<VNode*>(nodes[1])->u=1;
  543. hess_reverse(root,nodes,dhess);
  544. for(unsigned int i=0;i<dhess.size();i++)
  545. {
  546. CHECK_CLOSE(dhess[i],hx2[i]);
  547. }
  548. EdgeSet s;
  549. nonlinearEdges(root,s);
  550. unsigned int n = nzHess(s);
  551. BOOST_CHECK_EQUAL(n,4);
  552. }
  553. BOOST_AUTO_TEST_CASE( test_hess_reverse_7)
  554. {
  555. vector<Node*> nodes;
  556. Node* root = build_nl_function1(nodes);
  557. double eval = eval_function(root);
  558. vector<double> dhess;
  559. double hx0[] ={-1.747958066718855,
  560. -0.657091724418110,
  561. 2.410459188139686,
  562. 0};
  563. double hx1[] ={ -0.657091724418110,
  564. 0.006806564792590,
  565. -0.289815306593997,
  566. 1.000000000000000};
  567. double hx2[] ={ 2.410459188139686,
  568. -0.289815306593997,
  569. 2.064383410399700,
  570. 0};
  571. double hx3[] ={0,1,0,0};
  572. for(unsigned int i=0;i<nodes.size();i++)
  573. {
  574. static_cast<VNode*>(nodes[i])->u = 0;
  575. }
  576. static_cast<VNode*>(nodes[0])->u = 1;
  577. double hval = hess_reverse(root,nodes,dhess);
  578. CHECK_CLOSE(hval,eval);
  579. for(unsigned int i=0;i<dhess.size();i++)
  580. {
  581. CHECK_CLOSE(dhess[i],hx0[i]);
  582. }
  583. for (unsigned int i = 0; i < nodes.size(); i++) {
  584. static_cast<VNode*>(nodes[i])->u = 0;
  585. }
  586. static_cast<VNode*>(nodes[1])->u = 1;
  587. hess_reverse(root, nodes, dhess);
  588. for (unsigned int i = 0; i < dhess.size(); i++) {
  589. CHECK_CLOSE(dhess[i], hx1[i]);
  590. }
  591. for (unsigned int i = 0; i < nodes.size(); i++) {
  592. static_cast<VNode*>(nodes[i])->u = 0;
  593. }
  594. static_cast<VNode*>(nodes[2])->u = 1;
  595. hess_reverse(root, nodes, dhess);
  596. for (unsigned int i = 0; i < dhess.size(); i++) {
  597. CHECK_CLOSE(dhess[i], hx2[i]);
  598. }
  599. for (unsigned int i = 0; i < nodes.size(); i++) {
  600. static_cast<VNode*>(nodes[i])->u = 0;
  601. }
  602. static_cast<VNode*>(nodes[3])->u = 1;
  603. hess_reverse(root, nodes, dhess);
  604. for (unsigned i = 0; i < dhess.size(); i++) {
  605. CHECK_CLOSE(dhess[i], hx3[i]);
  606. }
  607. }
  608. #if FORWARD_ENABLED
  609. void test_hess_forward(Node* root, unsigned int& nvar)
  610. {
  611. AutoDiff::num_var = nvar;
  612. unsigned int len = (nvar+3)*nvar/2;
  613. double* hess = new double[len];
  614. hess_forward(root,nvar,&hess);
  615. for(unsigned int i=0;i<len;i++){
  616. cout<<"hess["<<i<<"]="<<hess[i]<<endl;
  617. }
  618. delete[] hess;
  619. }
  620. #endif
  621. BOOST_AUTO_TEST_CASE( test_hess_reverse_8)
  622. {
  623. vector<Node*> list;
  624. vector<double> dhess;
  625. VNode* x1 = create_var_node();
  626. list.push_back(x1);
  627. static_cast<VNode*>(list[0])->val = -10.5;
  628. static_cast<VNode*>(list[0])->u = 1;
  629. double deval = hess_reverse(x1,list,dhess);
  630. CHECK_CLOSE(deval,-10.5);
  631. BOOST_CHECK_EQUAL(dhess.size(),1);
  632. BOOST_CHECK(isnan(dhess[0]));
  633. EdgeSet s;
  634. nonlinearEdges(x1,s);
  635. unsigned int n = nzHess(s);
  636. BOOST_CHECK_EQUAL(n,0);
  637. PNode* p1 = create_param_node(-1.5);
  638. list.clear();
  639. deval = hess_reverse(p1,list,dhess);
  640. CHECK_CLOSE(deval,-1.5);
  641. BOOST_CHECK_EQUAL(dhess.size(),0);
  642. s.clear();
  643. nonlinearEdges(p1,s);
  644. n = nzHess(s);
  645. BOOST_CHECK_EQUAL(n,0);
  646. }
  647. BOOST_AUTO_TEST_CASE( test_hess_revers9)
  648. {
  649. vector<Node*> list;
  650. vector<double> dhess;
  651. VNode* x1 = create_var_node();
  652. list.push_back(x1);
  653. static_cast<VNode*>(list[0])->val = 2.5;
  654. static_cast<VNode*>(list[0])->u =1;
  655. Node* op1 = create_binary_op_node(OP_TIMES,x1,x1);
  656. Node* root = create_binary_op_node(OP_TIMES,op1,op1);
  657. double deval = hess_reverse(root,list,dhess);
  658. double eval = eval_function(root);
  659. CHECK_CLOSE(eval,deval);
  660. BOOST_CHECK_EQUAL(dhess.size(),1);
  661. CHECK_CLOSE(dhess[0],75);
  662. EdgeSet s;
  663. nonlinearEdges(root,s);
  664. unsigned int n = nzHess(s);
  665. BOOST_CHECK_EQUAL(n,1);
  666. }
  667. BOOST_AUTO_TEST_CASE( test_hess_revers10)
  668. {
  669. vector<Node*> list;
  670. vector<double> dhess;
  671. VNode* x1 = create_var_node();
  672. VNode* x2 = create_var_node();
  673. list.push_back(x1);
  674. list.push_back(x2);
  675. Node* op1 = create_binary_op_node(OP_TIMES, x1,x2);
  676. Node* op2 = create_uary_op_node(OP_SIN,op1);
  677. Node* op3 = create_uary_op_node(OP_COS,op1);
  678. Node* root = create_binary_op_node(OP_TIMES, op2, op3);
  679. static_cast<VNode*>(list[0])->val = 2.1;
  680. static_cast<VNode*>(list[1])->val = 1.8;
  681. double eval = eval_function(root);
  682. //second column
  683. static_cast<VNode*>(list[0])->u = 0;
  684. static_cast<VNode*>(list[1])->u = 1;
  685. double deval = hess_reverse(root,list,dhess);
  686. CHECK_CLOSE(eval,deval);
  687. BOOST_CHECK_EQUAL(dhess.size(),2);
  688. CHECK_CLOSE(dhess[0], -6.945893481707861);
  689. CHECK_CLOSE(dhess[1], -8.441601940854081);
  690. //first column
  691. static_cast<VNode*>(list[0])->u = 1;
  692. static_cast<VNode*>(list[1])->u = 0;
  693. deval = hess_reverse(root,list,dhess);
  694. CHECK_CLOSE(eval,deval);
  695. BOOST_CHECK_EQUAL(dhess.size(),2);
  696. CHECK_CLOSE(dhess[0], -6.201993262668304);
  697. CHECK_CLOSE(dhess[1], -6.945893481707861);
  698. }
  699. BOOST_AUTO_TEST_CASE( test_grad_reverse11)
  700. {
  701. vector<Node*> list;
  702. VNode* x1 = create_var_node();
  703. Node* p2 = create_param_node(2);
  704. list.push_back(x1);
  705. Node* op1 = create_binary_op_node(OP_POW,x1,p2);
  706. static_cast<VNode*>(x1)->val = 0;
  707. vector<double> grad;
  708. grad_reverse(op1,list,grad);
  709. BOOST_CHECK_EQUAL(grad.size(),1);
  710. CHECK_CLOSE(grad[0],0);
  711. }
  712. BOOST_AUTO_TEST_CASE( test_hess_reverse12)
  713. {
  714. vector<Node*> list;
  715. VNode* x1 = create_var_node();
  716. Node* p2 = create_param_node(2);
  717. list.push_back(x1);
  718. Node* op1 = create_binary_op_node(OP_POW,x1,p2);
  719. x1->val = 0;
  720. x1->u = 1;
  721. vector<double> hess;
  722. hess_reverse(op1,list,hess);
  723. BOOST_CHECK_EQUAL(hess.size(),1);
  724. CHECK_CLOSE(hess[0],2);
  725. }
  726. BOOST_AUTO_TEST_CASE( test_grad_reverse13)
  727. {
  728. vector<Node*> list;
  729. VNode* x1 = create_var_node();
  730. PNode* p1 = create_param_node(0.090901);
  731. VNode* x2 = create_var_node();
  732. PNode* p2 = create_param_node(0.090901);
  733. list.push_back(x1);
  734. list.push_back(x2);
  735. Node* op1 = create_binary_op_node(OP_TIMES,x1,p1);
  736. Node* op2 = create_binary_op_node(OP_TIMES,x2,p2);
  737. Node* root = create_binary_op_node(OP_PLUS,op1,op2);
  738. x1->val = 1;
  739. x2->val = 1;
  740. vector<double> grad;
  741. grad_reverse(root,list,grad);
  742. BOOST_CHECK_EQUAL(grad.size(),2);
  743. CHECK_CLOSE(grad[0],0.090901);
  744. CHECK_CLOSE(grad[1],0.090901);
  745. }
  746. BOOST_AUTO_TEST_SUITE_END()