Xped
Loading...
Searching...
No Matches
ADTensor.hpp
Go to the documentation of this file.
1#ifndef XPED_ADTENSOR_HPP_
2#define XPED_ADTENSOR_HPP_
3
4#include <cmath>
5
6#include "stan/math/rev.hpp"
7
9
10#include "Xped/Util/Bool.hpp"
11
12#include "Xped/Core/Tensor.hpp"
13
16
17namespace Xped {
18
19template <bool AD, typename Scalar, std::size_t Rank, std::size_t CoRank, typename Symmetry, typename AllocationPolicy = HeapPolicy>
20using XTensor = std::conditional_t<AD,
23
24template <bool AD, typename Scalar>
25using XScalar = std::conditional_t<AD, stan::math::var_value<Scalar>, Scalar>;
26
27template <typename Scalar_, std::size_t Rank, std::size_t CoRank, typename Symmetry_, typename AllocationPolicy_>
28class Tensor<Scalar_, Rank, CoRank, Symmetry_, true, AllocationPolicy_>
29{
30public:
31 using Scalar = Scalar_;
33
34 using Symmetry = Symmetry_;
35 using qType = typename Symmetry::qType;
36
37 using AllocationPolicy = AllocationPolicy_;
38
47 using value_type = Tensor<Scalar, Rank, CoRank, Symmetry, false, AllocationPolicy>; // type in vari_value -->ArenaTensor or Tensor.
48 using vari_type = stan::math::vari_value<value_type>;
49
51
52 inline bool is_uninitialized() noexcept { return (vi_ == nullptr); }
53
55 : vi_(nullptr)
56 {}
57
59 : vi_(new vari_type(x, false))
60 {}
61
63 : vi_(vi)
64 {}
65
66 Tensor(const std::array<Qbasis<Symmetry, 1, AllocationPolicy>, Rank>& basis_domain,
67 const std::array<Qbasis<Symmetry, 1, AllocationPolicy>, CoRank>& basis_codomain,
69 : vi_(new vari_type(basis_domain, basis_codomain, world))
70 {}
71
72 void setRandom() { val_op().setRandom(); }
73 void setIdentity() { val_op().setIdentity(); }
74 void setZero() { val_op().setZero(); }
75 void setConstant(Scalar c) { val_op().setConstant(c); }
76
77 inline const auto& val() const noexcept { return vi_->val(); }
78 inline auto& val_op() noexcept { return vi_->val_op(); }
79 inline const auto& detach() const noexcept { return vi_->val(); }
80
81 inline auto& adj() noexcept { return vi_->adj(); }
82 inline auto& adj() const noexcept { return vi_->adj(); }
83 inline auto& adj_op() noexcept { return vi_->adj(); }
84
85 constexpr std::size_t rank() const noexcept { return vi_->rank(); }
86 constexpr std::size_t corank() const noexcept { return vi_->corank(); }
87 const std::string name() const { return val().name(); }
88
89 inline const auto& sector() const { return val().sector(); }
90 inline const qType sector(std::size_t i) const { return val().sector(i); }
91
92 constexpr bool CONTIGUOUS_STORAGE() const { return val().CONTIGUOUS_STORAGE(); }
93 constexpr bool AD_TENSOR() { return true; }
94
95 std::size_t plainSize() const { return val().plainSize(); }
96
97 inline const std::array<Qbasis<Symmetry, 1, AllocationPolicy>, Rank>& uncoupledDomain() const { return val().uncoupledDomain(); }
98 inline const std::array<Qbasis<Symmetry, 1, AllocationPolicy>, CoRank>& uncoupledCodomain() const { return val().uncoupledCodomain(); }
99
100 inline const Qbasis<Symmetry, Rank, AllocationPolicy>& coupledDomain() const { return val().coupledDomain(); }
101 inline const Qbasis<Symmetry, CoRank, AllocationPolicy>& coupledCodomain() const { return val().coupledCodomain(); }
102
103 const mpi::XpedWorld& world() const { return val().world(); }
104
105 inline auto begin() const { return val_op().begin(); }
106 inline auto end() const { return val_op().end(); }
107
108 inline const auto cbegin() const { return val().cbegin(); }
109 inline const auto cend() const { return val().cend(); }
110
111 inline auto gradbegin() const { return adj_op().begin(); }
112 inline auto gradend() const { return adj_op().end(); }
113
114 inline const auto cgradbegin() const { return adj().cbegin(); }
115 inline const auto cgradend() const { return adj().cend(); }
116
117 inline void set_data(const Scalar* data, std::size_t size) { val_op().set_data(data, size); }
118
119 // inline vari_type& operator*() { return *vi_; }
120
121 inline vari_type* operator->() { return vi_; }
122
124
125 template <bool TRACK = true>
126 auto adjoint() const
127 {
128 if constexpr(TRACK) {
130 stan::math::reverse_pass_callback([curr = *this, out]() mutable {
131 curr.adj() += out.adj().adjoint().eval();
132 SPDLOG_WARN("reverse adjoint of {}, input adj norm={}, output adj norm={}", curr.name(), out.adj().norm(), curr.adj().norm());
133 });
134 return out;
135 } else {
136 return val().adjoint().eval();
137 }
138 }
139
140 template <int shift, std::size_t... p, bool TRACK>
142 {
143 if constexpr(TRACK) {
144 Xped::Tensor<Scalar, Rank - shift, CoRank + shift, Symmetry, true, AllocationPolicy> out(
145 val().template permute<shift, p...>(Bool<false>{}));
146 stan::math::reverse_pass_callback([curr = *this, out]() mutable {
147 using inverse = decltype(Xped::util::constFct::inverse_permutation<seq::iseq<std::size_t, p...>>());
148 curr.adj() += out.adj().template permute<-shift>(inverse{}, Bool<false>{});
149 SPDLOG_WARN("reverse permute of {}, input adj norm={}, output adj norm={}", curr.name(), out.adj().norm(), curr.adj().norm());
150 });
151 return out;
152 } else {
153 return val().template permute<shift, p...>(Bool<false>{});
154 }
155 }
156
157 template <int shift, std::size_t... p, bool TRACK>
158 Tensor<Scalar, Rank - shift, CoRank + shift, Symmetry, true, AllocationPolicy> permute(seq::iseq<std::size_t, p...>, Bool<TRACK>) const
159 {
160 return permute<shift, p...>(Bool<TRACK>{});
161 }
162
163#if XPED_HAS_NTTP
164 template <auto a1,
165 auto a2,
166 std::size_t ResRank,
167 bool TRACK = true,
168 typename OtherScalar,
169 std::size_t OtherRank,
170 std::size_t OtherCoRank,
171 bool ENABLE_AD>
173 {
174 constexpr auto perms = util::constFct::get_permutations<a1, Rank, a2, OtherRank, ResRank>();
175 constexpr auto p1 = std::get<0>(perms);
176 constexpr auto shift1 = std::get<1>(perms);
177 SPDLOG_INFO("shift1={}, p1={}", shift1, p1);
178 constexpr auto p2 = std::get<2>(perms);
179 constexpr auto shift2 = std::get<3>(perms);
180 SPDLOG_INFO("shift2={}, p2={}", shift2, p2);
181 constexpr auto pres = std::get<4>(perms);
182 constexpr auto shiftres = std::get<5>(perms);
183 SPDLOG_INFO("shiftres={}, pres={}", shiftres, pres);
184 // auto left_p = this->template permute<shift1>(util::constFct::as_sequence<p1>(), Bool<TRACK>{});
185 // auto right_p = other.template permute<shift2>(util::constFct::as_sequence<p2>(), Bool<TRACK>{});
186 // return operator*<TRACK>(left_p, right_p).template permute<shiftres>(util::constFct::as_sequence<pres>(), Bool<TRACK>{});
187 return operator*<TRACK>(this->template permute<shift1>(util::constFct::as_sequence<p1>(), Bool<TRACK>{}),
188 other.template permute<shift2>(util::constFct::as_sequence<p2>(), Bool<TRACK>{}))
189 .template permute<shiftres>(util::constFct::as_sequence<pres>(), Bool<TRACK>{});
190 }
191#endif
192
193 template <bool TRACK = true>
194 std::tuple<XTensor<TRACK, Scalar, Rank, 1, Symmetry, AllocationPolicy>,
195 XTensor<TRACK, Scalar, 1, 1, Symmetry, AllocationPolicy>,
196 XTensor<TRACK, Scalar, 1, CoRank, Symmetry, AllocationPolicy>>
197 tSVD(std::size_t maxKeep,
198 RealScalar eps_svd,
199 RealScalar& truncWeight,
200 RealScalar& entropy,
201 std::map<qarray<Symmetry::Nq>, VectorType>& SVspec,
202 bool PRESERVE_MULTIPLETS = true,
203 bool RETURN_SPEC = true) XPED_CONST
204 {
205 auto [Uval, Sval_real, Vdagval] = val().tSVD(maxKeep, eps_svd, truncWeight, entropy, SVspec, PRESERVE_MULTIPLETS, RETURN_SPEC);
206 XTensor<TRACK, Scalar, 1, 1, Symmetry, AllocationPolicy> Sval = Sval_real.template cast<Scalar>().eval();
207 if constexpr(not TRACK) {
208 return std::make_tuple(Uval, Sval, Vdagval);
209 } else {
213
214 stan::math::reverse_pass_callback([curr = *this, U, S, Vdag]() mutable {
215 for(std::size_t i = 0; i < curr.sector().size(); ++i) {
216 SPDLOG_INFO("i={}", i);
217 auto it = S.val().dict().find(curr.val().sector(i));
218 if(it == S.val().dict().end()) { continue; }
219 auto j = it->second;
220 auto U_b = U.val().block(j);
221 auto S_b = S.val().block(j);
222 auto S_b_real = S_b.real().eval();
223 auto Vdag_b = Vdag.val().block(j);
224 // fmt::print("max dU={}\n", U.adj().block(j).array().abs().matrix().maxCoeff());
225 // fmt::print("max dVdag={}\n", Vdag.adj().block(j).array().abs().matrix().maxCoeff());
226 // fmt::print("dS={}\n", S.adj().block(j).imag().norm());
227 SPDLOG_INFO("i={}:\tU.val: {}x{}, U.adj: {}x{}, Vdag.val: {}x{}, Vdag.adj: {}x{}, S.val: {}x{}",
228 i,
229 U_b.rows(),
230 U_b.cols(),
231 U.adj().block(i).rows(),
232 U.adj().block(i).cols(),
233 Vdag_b.rows(),
234 Vdag_b.cols(),
235 Vdag.adj().block(i).rows(),
236 Vdag.adj().block(i).cols(),
237 S_b.rows(),
238 S_b.cols());
239
240 auto F_inv = PlainInterface::construct<RealScalar>(PlainInterface::rows(S_b), PlainInterface::cols(S_b), S.val().world());
241 PlainInterface::vec_diff(S_b_real.array().square().matrix().diagonal().eval(), F_inv);
242 auto F = PlainInterface::unaryFunc<RealScalar>(
243 F_inv, [](RealScalar d) { return (std::abs(d) < 1.e-12) ? d / (d * d + 1.e-12) : 1. / d; });
244
245 // auto Udag_dU = U_b.adjoint() * U.adj().block(j);
246 auto Vdag_dV = (Vdag_b * Vdag.adj().block(j).adjoint()).eval();
247
248 auto J = (F.array() * (U_b.adjoint() * U.adj().block(j)).array()).matrix().eval();
249 auto K = (F.array() * Vdag_dV.array()).matrix().eval();
250 if constexpr(ScalarTraits<Scalar>::IS_COMPLEX()) {
251 using namespace std::complex_literals;
252 auto L = (1i * Eigen::MatrixXcd(Vdag_dV.diagonal().imag().asDiagonal())).eval();
253 // fmt::print("imag term={}\n", L.norm());
254 curr.adj().block(i) +=
255 U_b * (S.adj().block(j) + (J + J.adjoint()) * S_b + S_b * (K + K.adjoint()) - S_b.diagonal().asDiagonal().inverse() * L) *
256 Vdag_b;
257 } else {
258 curr.adj().block(i) += U_b * (S.adj().block(j) + (J + J.adjoint()) * S_b + S_b * (K + K.adjoint())) * Vdag_b;
259 }
260 if(U_b.rows() > S_b.rows()) {
261 curr.adj().block(i) += (PlainInterface::Identity<Scalar>(U_b.rows(), U_b.rows(), U.val().world()) - U_b * U_b.adjoint()) *
262 U.adj().block(j) * S_b.diagonal().asDiagonal().inverse() * Vdag_b;
263 }
264 if(Vdag_b.cols() > S_b.rows()) {
265 curr.adj().block(i) +=
266 U_b * S_b.diagonal().asDiagonal().inverse() * Vdag.adj().block(j) *
267 (PlainInterface::Identity<Scalar>(Vdag_b.cols(), Vdag_b.cols(), Vdag.val().world()) - Vdag_b.adjoint() * Vdag_b);
268 }
269 // fmt::print("dA=\n{}\n", fmt::streamed(curr.adj().block(i)));
270 // fmt::print("max dA={}\n", curr.adj().block(i).array().abs().matrix().maxCoeff());
271 }
272 SPDLOG_WARN("reverse svd. U.adj.norm()={}, S.adj.norm()={}, Vdag.adj.norm()={}, output adj norm={}",
273 U.adj().norm(),
274 S.adj().norm(),
275 Vdag.adj().norm(),
276 curr.adj().norm());
277 });
278
279 return std::make_tuple(U, S, Vdag);
280 }
281 }
282
283 template <bool TRACK = true>
284 std::tuple<XTensor<TRACK, Scalar, Rank, 1, Symmetry, AllocationPolicy>,
287 tSVD(std::size_t maxKeep, RealScalar eps_svd, RealScalar& truncWeight, bool PRESERVE_MULTIPLETS = true) XPED_CONST
288 {
289 RealScalar S_dumb;
290 std::map<qarray<Symmetry::Nq>, VectorType> SVspec_dumb;
291 return tSVD<TRACK>(
292 maxKeep, eps_svd, truncWeight, S_dumb, SVspec_dumb, PRESERVE_MULTIPLETS, false); // false: Dont return singular value spectrum
293 }
294
295 template <bool TRACK = true>
297 {
298 if constexpr(TRACK) {
299 Scalar tmp = val().norm();
300 stan::math::var_value<Scalar> res(tmp);
301 stan::math::reverse_pass_callback([curr = *this, res]() mutable {
302 curr.adj() += (curr.val() * (res.adj() / res.val())).eval();
303 SPDLOG_WARN("reverse norm of {}, input adj norm={}, output adj norm={}", curr.name(), res.adj(), curr.adj().norm());
304 });
305 return res;
306 } else {
307 return val().norm();
308 }
309 }
310
311 template <bool TRACK = true>
313 {
314 if constexpr(TRACK) {
315 std::size_t max_block;
318 Scalar tmp = val().abs().maxCoeff(max_block, max_row, max_col);
319 stan::math::var_value<Scalar> res(tmp);
320 stan::math::reverse_pass_callback([curr = *this, res, max_block, max_row, max_col]() mutable {
321 Tensor<Scalar, Rank, CoRank, Symmetry, false> Zero(curr.uncoupledDomain(), curr.uncoupledCodomain(), curr.adj().world());
322 Zero.setZero();
323 if constexpr(not ScalarTraits<Scalar>::IS_COMPLEX()) {
324 Zero.block(max_block)(max_row, max_col) = std::signbit(curr.val().block(max_block)(max_row, max_col)) ? -1. : 1.;
325 } else {
326 using namespace std::complex_literals;
327 Zero.block(max_block)(max_row, max_col) =
328 (std::real(curr.val().block(max_block)(max_row, max_col)) + 1i * std::imag(curr.val().block(max_block)(max_row, max_col))) /
329 std::abs(curr.val().block(max_block)(max_row, max_col));
330 }
331 curr.adj() += Zero * res.adj();
332 SPDLOG_WARN("reverse norm of {}, input adj norm={}, output adj norm={}", curr.name(), res.adj(), curr.adj().norm());
333 });
334 return res;
335 } else {
336 return val().maxNorm();
337 }
338 }
339
340 template <bool TRACK = true>
342 {
343 if constexpr(TRACK) {
344 Scalar tmp = val().trace();
345 stan::math::var_value<Scalar> res(tmp);
346 stan::math::reverse_pass_callback([curr = *this, res]() mutable {
347 auto Id =
348 Tensor<Scalar, Rank, CoRank, Symmetry, false>::Identity(curr.uncoupledDomain(), curr.uncoupledCodomain(), curr.adj().world());
349 curr.adj() += Id * res.adj();
350 SPDLOG_WARN("reverse trace of {}, input adj norm={}, output adj norm={}", curr.name(), res.adj(), curr.adj().norm());
351 });
352 return res;
353 } else {
354 return val().trace();
355 }
356 }
357
358 // Tensor<Scalar, Rank, CoRank, Symmetry, true, AllocationPolicy> sqrt() const
359 // {
360 // Tensor<Scalar, Rank, CoRank, Symmetry, true, AllocationPolicy> res(val().sqrt().eval());
361 // stan::math::reverse_pass_callback([curr = *this, res]() mutable {
362 // SPDLOG_WARN("reverse sqrt, in adj norm={}", res.adj().norm());
363 // curr.adj() += res.adj().binaryExpr((res.val().inv()) / 2., [](Scalar d1, Scalar d2) { return d1 * d2; });
364 // });
365 // return res;
366 // }
367
368 template <bool TRACK = true>
370 {
371 if constexpr(TRACK) {
373 stan::math::reverse_pass_callback([curr = *this, res]() mutable {
374 curr.adj() += res.adj().binaryExpr(res.val().diag_inv().eval(), [](Scalar d1, Scalar d2) { return d1 * d2 * 0.5; });
375 SPDLOG_WARN("reverse sqrt of {}, input adj norm={}, output adj norm={}", curr.name(), res.adj().norm(), curr.adj().norm());
376 });
377 return res;
378 } else {
379 return val().diag_sqrt().eval();
380 }
381 }
382
383 template <bool TRACK = true>
385 {
386 if constexpr(TRACK) {
388 stan::math::reverse_pass_callback([curr = *this, res]() mutable {
389 curr.adj() += res.adj().diagBinaryExpr((res.val().square()) * -1., [](Scalar d1, Scalar d2) { return d1 * d2; }).eval();
390 SPDLOG_WARN("reverse diag_inv of {}, input adj norm={}, output adj norm={}", curr.name(), res.adj().norm(), curr.adj().norm());
391 });
392 return res;
393 } else {
394 return val().diag_inv().eval();
395 }
396 }
397
398 template <bool TRACK = true>
400 {
401 if constexpr(TRACK) {
403 stan::math::reverse_pass_callback([curr = *this, res, leg]() mutable {
404 curr.adj() += res.adj().twist(leg);
405 SPDLOG_WARN("reverse twist of {}, input adj norm={}, output adj norm={}", curr.name(), res.adj().norm(), curr.adj().norm());
406 });
407 return res;
408 } else {
409 return val().twist(leg);
410 }
411 }
412 // Tensor<Scalar, Rank, CoRank, Symmetry, true, AllocationPolicy>& operator-=(const Scalar s)
413 // {
414 // val_op() = val() - s;
415 // return *this;
416 // }
417
418 void print(std::ostream& o, bool PRINT_MATRICES = true) const { val().print(o, PRINT_MATRICES); }
419
420 friend std::ostream& operator<<(std::ostream& os, const Tensor<Scalar, Rank, CoRank, Symmetry, true, AllocationPolicy>& v)
421 {
422 if(v.vi_ == nullptr) { return os << "uninitialized"; }
423 return os << v.val();
424 }
425};
426
427template <bool TRACK = true, typename Scalar, std::size_t Rank, std::size_t CoRank, typename Symmetry>
429{
430 if constexpr(TRACK) {
431 SPDLOG_CRITICAL("BLOCKER");
433 return out;
434 } else {
435 return (t.val() - s).eval();
436 }
437}
438
439template <bool TRACK = true, typename Scalar, std::size_t Rank, std::size_t CoRank, typename Symmetry>
441{
442 if constexpr(TRACK) {
443 SPDLOG_CRITICAL("BLOCKER");
445 return out;
446 } else {
447 return (t.val() + s).eval();
448 }
449}
450
451template <bool TRACK = true, typename Scalar, std::size_t Rank, std::size_t CoRank, typename Symmetry>
453{
454 if constexpr(TRACK) {
455 Tensor<Scalar, Rank, CoRank, Symmetry, true> res((t.val() * s).eval());
456 stan::math::reverse_pass_callback([res, t, s]() mutable {
457 t.adj() += (res.adj() * s).eval();
458 SPDLOG_WARN("reverse vt*d with vt={}, input adj norm={}, output adj norm={}", t.name(), res.adj().norm(), t.adj().norm());
459 });
460 return res;
461 } else {
462 return (t.val() * s).eval();
463 }
464}
465
466template <bool TRACK = true, typename Scalar, std::size_t Rank, std::size_t CoRank, typename Symmetry>
468{
469 if constexpr(TRACK) {
470 Tensor<Scalar, Rank, CoRank, Symmetry, true> res((t.val() * s.val()).eval());
471 stan::math::reverse_pass_callback([res, t, s]() mutable {
472 t.adj() += (res.adj() * s.val()).eval();
473 s.adj() += (res.adj() * t.val().adjoint()).trace();
474 SPDLOG_WARN(
475 "reverse vt*v with vt={}, input adj norm={}, vt adj norm={}, v adj norm={}", t.name(), res.adj().norm(), t.adj().norm(), s.adj());
476 });
477 return res;
478 } else {
479 return (t.val() * s.val()).eval();
480 }
481}
482
483template <bool TRACK = true, typename Scalar, typename OtherScalar, std::size_t Rank, std::size_t MiddleRank, std::size_t CoRank, typename Symmetry>
484XTensor<TRACK, std::common_type_t<Scalar, OtherScalar>, Rank, CoRank, Symmetry>
486{
487 if constexpr(TRACK) {
488 Tensor<std::common_type_t<Scalar, OtherScalar>, Rank, CoRank, Symmetry, true> res(left * right.val());
489 Xped::reverse_pass_callback_alloc([res, left, right]() mutable {
490 right.adj() += (left.adjoint() * res.adj());
491 SPDLOG_WARN("reverse t*vt with t={} and vt={}, input adj norm={}, output adj norm={}",
492 left.name(),
493 right.name(),
494 res.adj().norm(),
495 right.adj().norm());
496 });
497 return res;
498 } else {
499 return left * right.val();
500 }
501}
502
503template <bool TRACK = true, typename Scalar, typename OtherScalar, std::size_t Rank, std::size_t MiddleRank, std::size_t CoRank, typename Symmetry>
504XTensor<TRACK, std::common_type_t<Scalar, OtherScalar>, Rank, CoRank, Symmetry>
506{
507 if constexpr(TRACK) {
508 Tensor<std::common_type_t<Scalar, OtherScalar>, Rank, CoRank, Symmetry, true> res(left.val() * right);
509 Xped::reverse_pass_callback_alloc([res, left, right]() mutable {
510 left.adj() += res.adj() * right.adjoint();
511 SPDLOG_WARN("reverse vt*t with vt={} and t={}, input adj norm={}, output adj norm={}",
512 left.name(),
513 right.name(),
514 res.adj().norm(),
515 left.adj().norm());
516 });
517 return res;
518 } else {
519 return left.val() * right;
520 }
521}
522
523template <bool TRACK = true, typename Scalar, typename OtherScalar, std::size_t Rank, std::size_t MiddleRank, std::size_t CoRank, typename Symmetry>
524XTensor<TRACK, std::common_type_t<Scalar, OtherScalar>, Rank, CoRank, Symmetry>
526{
527 if constexpr(TRACK) {
528 Tensor<std::common_type_t<Scalar, OtherScalar>, Rank, CoRank, Symmetry, true> res(left.val() * right.val());
529 stan::math::reverse_pass_callback([res, left, right]() mutable {
530 right.adj() += (left.val().adjoint() * res.adj());
531 left.adj() += res.adj() * right.val().adjoint();
532 SPDLOG_WARN("reverse vt*vt with vtl={} and vtr={}, in adj norm={}, vtl adj norm={}, vtr adj norm={}",
533 left.name(),
534 right.name(),
535 res.adj().norm(),
536 left.adj().norm(),
537 right.adj().norm());
538 });
539 return res;
540 } else {
541 return left.val() * right.val();
542 }
543}
544
545// template <bool TRACK = true, typename Scalar, std::size_t Rank, typename Symmetry>
546// stan::math::var operator*(const Tensor<Scalar, 0, Rank, Symmetry, true>& left, const Tensor<Scalar, Rank, 0, Symmetry, true>& right)
547// {
548// stan::math::var res = (left.val() * right.val()).block(0)(0, 0);
549// stan::math::reverse_pass_callback([res, left, right]() mutable {
550// right.adj() += left.val().adjoint() * res.adj();
551// left.adj() += res.adj() * right.val().adjoint();
552// });
553// return res;
554// }
555
556// template <bool TRACK = true, typename Scalar, std::size_t Rank, typename Symmetry>
557// stan::math::var operator*(const Tensor<Scalar, 0, Rank, Symmetry, false>& left, const Tensor<Scalar, Rank, 0, Symmetry, true>& right)
558// {
559// stan::math::var res = (left * right.val()).block(0)(0, 0);
560// Xped::reverse_pass_callback_alloc([res, left, right]() mutable { right.adj() += left.adjoint() * res.adj(); });
561// return res;
562// }
563
564// template <typename Scalar, std::size_t Rank, std::size_t CoRank, typename Symmetry>
565// stan::math::var_value<Xped::Tensor<Scalar, CoRank, Rank, Symmetry>>
566// adjoint(const stan::math::var_value<Xped::Tensor<Scalar, Rank, CoRank, Symmetry>>& t)
567// {
568// stan::math::var_value<Xped::Tensor<Scalar, CoRank, Rank, Symmetry>> res(t.val().adjoint().eval());
569// stan::math::reverse_pass_callback([res, t]() mutable { t.adj() += res.adj().adjoint().eval(); });
570// return res;
571// }
572
573} // namespace Xped
574
575namespace std {
576
577stan::math::var_value<double> real(const stan::math::var_value<double>& z) { return z; }
578
579} // namespace std
580
581#endif
Definition: Qbasis.hpp:39
void print(std::ostream &o, bool PRINT_MATRICES=true) const
Definition: ADTensor.hpp:418
XTensor< TRACK, Scalar, Rank, CoRank, Symmetry, AllocationPolicy > twist(std::size_t leg) const
Definition: ADTensor.hpp:399
constexpr std::size_t rank() const noexcept
Definition: ADTensor.hpp:85
PlainInterface::cMapMType< Scalar > MatrixcMapType
Definition: ADTensor.hpp:43
PlainInterface::Indextype IndexType
Definition: ADTensor.hpp:39
std::size_t plainSize() const
Definition: ADTensor.hpp:95
PlainInterface::TType< Scalar, Rank+CoRank > TensorType
Definition: ADTensor.hpp:44
Tensor(const Xped::Tensor< Scalar, Rank, CoRank, Symmetry, false, AllocationPolicy > &x)
Definition: ADTensor.hpp:58
const Qbasis< Symmetry, CoRank, AllocationPolicy > & coupledCodomain() const
Definition: ADTensor.hpp:101
const auto cgradbegin() const
Definition: ADTensor.hpp:114
typename Symmetry::qType qType
Definition: ADTensor.hpp:35
XTensor< TRACK, Scalar, Rank, CoRank, Symmetry, AllocationPolicy > diag_inv() const
Definition: ADTensor.hpp:384
friend std::ostream & operator<<(std::ostream &os, const Tensor< Scalar, Rank, CoRank, Symmetry, true, AllocationPolicy > &v)
Definition: ADTensor.hpp:420
Tensor< Scalar, Rank, CoRank, Symmetry, true, AllocationPolicy > eval() const
Definition: ADTensor.hpp:123
PlainInterface::MapMType< Scalar > MatrixMapType
Definition: ADTensor.hpp:42
const auto & sector() const
Definition: ADTensor.hpp:89
auto & adj() const noexcept
Definition: ADTensor.hpp:82
stan::math::vari_value< value_type > vari_type
Definition: ADTensor.hpp:48
PlainInterface::MapTType< Scalar, Rank+CoRank > TensorMapType
Definition: ADTensor.hpp:45
const std::array< Qbasis< Symmetry, 1, AllocationPolicy >, Rank > & uncoupledDomain() const
Definition: ADTensor.hpp:97
const std::string name() const
Definition: ADTensor.hpp:87
std::tuple< XTensor< TRACK, Scalar, Rank, 1, Symmetry, AllocationPolicy >, XTensor< TRACK, Scalar, 1, 1, Symmetry, AllocationPolicy >, XTensor< TRACK, Scalar, 1, CoRank, Symmetry, AllocationPolicy > > tSVD(std::size_t maxKeep, RealScalar eps_svd, RealScalar &truncWeight, bool PRESERVE_MULTIPLETS=true) XPED_CONST
Definition: ADTensor.hpp:287
const auto & detach() const noexcept
Definition: ADTensor.hpp:79
XTensor< TRACK, Scalar, Rank, CoRank, Symmetry, AllocationPolicy > diag_sqrt() const
Definition: ADTensor.hpp:369
auto permute(Bool< TRACK >) const
Definition: ADTensor.hpp:141
void set_data(const Scalar *data, std::size_t size)
Definition: ADTensor.hpp:117
const qType sector(std::size_t i) const
Definition: ADTensor.hpp:90
const auto & val() const noexcept
Definition: ADTensor.hpp:77
const Qbasis< Symmetry, Rank, AllocationPolicy > & coupledDomain() const
Definition: ADTensor.hpp:100
XScalar< TRACK, Scalar > trace() const
Definition: ADTensor.hpp:341
Tensor(const std::array< Qbasis< Symmetry, 1, AllocationPolicy >, Rank > &basis_domain, const std::array< Qbasis< Symmetry, 1, AllocationPolicy >, CoRank > &basis_codomain, mpi::XpedWorld &world=mpi::getUniverse())
Definition: ADTensor.hpp:66
Tensor< Scalar, Rank - shift, CoRank+shift, Symmetry, true, AllocationPolicy > permute(seq::iseq< std::size_t, p... >, Bool< TRACK >) const
Definition: ADTensor.hpp:158
PlainInterface::MType< Scalar > MatrixType
Definition: ADTensor.hpp:41
const std::array< Qbasis< Symmetry, 1, AllocationPolicy >, CoRank > & uncoupledCodomain() const
Definition: ADTensor.hpp:98
XScalar< TRACK, Scalar > maxNorm() const
Definition: ADTensor.hpp:312
constexpr std::size_t corank() const noexcept
Definition: ADTensor.hpp:86
XScalar< TRACK, Scalar > norm() const
Definition: ADTensor.hpp:296
const mpi::XpedWorld & world() const
Definition: ADTensor.hpp:103
typename ScalarTraits< Scalar >::Real RealScalar
Definition: ADTensor.hpp:32
PlainInterface::cMapTType< Scalar, Rank+CoRank > TensorcMapType
Definition: ADTensor.hpp:46
std::tuple< XTensor< TRACK, Scalar, Rank, 1, Symmetry, AllocationPolicy >, XTensor< TRACK, Scalar, 1, 1, Symmetry, AllocationPolicy >, XTensor< TRACK, Scalar, 1, CoRank, Symmetry, AllocationPolicy > > tSVD(std::size_t maxKeep, RealScalar eps_svd, RealScalar &truncWeight, RealScalar &entropy, std::map< qarray< Symmetry::Nq >, VectorType > &SVspec, bool PRESERVE_MULTIPLETS=true, bool RETURN_SPEC=true) XPED_CONST
Definition: ADTensor.hpp:197
PlainInterface::VType< Scalar > VectorType
Definition: ADTensor.hpp:40
constexpr bool CONTIGUOUS_STORAGE() const
Definition: ADTensor.hpp:92
AllocationPolicy_ AllocationPolicy
Definition: ADTensor.hpp:37
Definition: Tensor.hpp:40
XpedWorld & getUniverse()
Definition: Mpi.hpp:49
constexpr auto inverse_permutation()
Definition: Constfct.hpp:40
Definition: bench.cpp:62
std::conditional_t< AD, Tensor< Scalar, Rank, CoRank, Symmetry, true, AllocationPolicy >, Tensor< Scalar, Rank, CoRank, Symmetry, false, AllocationPolicy > > XTensor
Definition: ADTensor.hpp:22
XTensor< TRACK, Scalar, Rank, CoRank, Symmetry > operator-(const Tensor< Scalar, Rank, CoRank, Symmetry, true > &t, Scalar s)
Definition: ADTensor.hpp:428
XTensor< TRACK, Scalar, Rank, CoRank, Symmetry > operator+(const Tensor< Scalar, Rank, CoRank, Symmetry, true > &t, Scalar s)
Definition: ADTensor.hpp:440
void reverse_pass_callback_alloc(F &&functor)
Definition: reverse_pass_callback_alloc.hpp:40
XTensor< TRACK, Scalar, Rank, CoRank, Symmetry > operator*(const Tensor< Scalar, Rank, CoRank, Symmetry, true > &t, Scalar s)
Definition: ADTensor.hpp:452
std::conditional_t< AD, stan::math::var_value< Scalar >, Scalar > XScalar
Definition: ADTensor.hpp:25
Definition: Bool.hpp:8
CTF::Matrix< Scalar > MapMType
Definition: MatrixInterface_Cyclops_impl.hpp:45
static MIndextype cols(const MType< Scalar > &M)
Definition: MatrixInterface_Cyclops_impl.cpp:90
static MIndextype rows(const MType< Scalar > &M)
Definition: MatrixInterface_Cyclops_impl.cpp:84
const CTF::Matrix< Scalar > cMapMType
Definition: MatrixInterface_Cyclops_impl.hpp:47
int MIndextype
Definition: MatrixInterface_Cyclops_impl.hpp:49
CTF::Matrix< Scalar > MType
Definition: MatrixInterface_Cyclops_impl.hpp:40
static void vec_diff(VT &&vec, MType< typename ctf_traits< VT >::Scalar > &res)
int Indextype
Definition: PlainInterface_Cyclops_impl.hpp:11
Definition: ScalarTraits.hpp:10
nda::dense_array< Scalar, Rank > TType
Definition: TensorInterface_Array_impl.hpp:44
nda::dense_array_ref< Scalar, Rank > MapTType
Definition: TensorInterface_Array_impl.hpp:49
nda::const_dense_array_ref< Scalar, Rank > cMapTType
Definition: TensorInterface_Array_impl.hpp:51
CTF::Vector< Scalar > VType
Definition: VectorInterface_Cyclops_impl.hpp:12
Definition: Mpi.hpp:34
Definition: qarray.hpp:30