1#ifndef XPED_VARI_VALUE_HPP_
2#define XPED_VARI_VALUE_HPP_
4#include "stan/math/rev/core/vari.hpp"
11 static const bool value =
false;
14template <
typename Scalar, std::
size_t Rank, std::
size_t CoRank,
typename Symmetry,
typename AllocationPolicy>
15struct is_tensor<
Xped::Tensor<Scalar, Rank, CoRank, Symmetry, false, AllocationPolicy>>
23 static const bool value =
false;
26template <
typename Scalar, std::
size_t Rank, std::
size_t CoRank,
typename Symmetry>
35 static const bool value =
false;
38template <
typename Scalar, std::
size_t Rank, std::
size_t CoRank,
typename Symmetry>
47 static const bool value =
false;
50template <
typename Scalar, std::
size_t Rank, std::
size_t CoRank,
typename Symmetry>
97class vari_value<T,
require_tensor_v<T>> :
public vari_base,
public std::conditional_t<is_arena_tensor<T>::value, Empty, chainable_alloc>
100 using Scalar =
typename T::Scalar;
101 using Symmetry =
typename T::Symmetry;
102 using AllocationPolicy =
typename T::AllocationPolicy;
104 using vari_type = vari_value<T>;
105 static constexpr std::size_t Rank = T::Rank;
106 static constexpr std::size_t CoRank = T::CoRank;
112 template <
typename S>
113 explicit vari_value(
const S& x)
115 , adj_(x.uncoupledDomain(), x.uncoupledCodomain(), x.world())
118 stan::math::ChainableStack::instance_->var_stack_.push_back(
this);
121 template <
typename S>
122 explicit vari_value(
const S& x,
bool stacked)
124 , adj_(x.uncoupledDomain(), x.uncoupledCodomain(), x.world())
128 stan::math::ChainableStack::instance_->var_stack_.push_back(
this);
130 stan::math::ChainableStack::instance_->var_nochain_stack_.push_back(
this);
137 : val_(basis_domain, basis_codomain, world)
138 , adj_(basis_domain, basis_codomain, world)
141 stan::math::ChainableStack::instance_->var_nochain_stack_.push_back(
this);
144 inline const auto& val() const noexcept {
return val_; }
145 inline auto& val_op() noexcept {
return val_; }
147 inline auto& adj() noexcept {
return adj_; }
148 inline auto& adj() const noexcept {
return adj_; }
149 inline auto& adj_op() noexcept {
return adj_; }
151 constexpr std::size_t rank()
const {
return val_.rank(); }
152 constexpr std::size_t corank()
const {
return val_.corank(); }
154 virtual void chain() {}
156 inline void init_dependent() { adj_.setOnes(); }
158 inline void set_zero_adjoint() final { adj_.setZero(); }
160 friend std::ostream&
operator<<(std::ostream& os,
const vari_value<T>* v) {
return os <<
"val: \n" << v->val_ <<
" \nadj: \n" << v->adj_; }
163 template <
typename, std::
size_t, std::
size_t,
typename,
bool,
typename>
Definition: vari_value.hpp:81
Definition: Qbasis.hpp:39
Definition: Tensor.hpp:40
XpedWorld & getUniverse()
Definition: Mpi.hpp:49
std::ostream & operator<<(std::ostream &os, const FusionTree< depth, Symmetry > &tree)
Definition: FusionTree.hpp:93
Definition: vari_value.hpp:46
static const bool value
Definition: vari_value.hpp:47
Definition: vari_value.hpp:34
static const bool value
Definition: vari_value.hpp:35
Definition: vari_value.hpp:22
static const bool value
Definition: vari_value.hpp:23
Definition: vari_value.hpp:10
static const bool value
Definition: vari_value.hpp:11
stan::require_t< stan::bool_constant< is_tensor_var< T >::value > > require_tensor_var_v
Definition: vari_value.hpp:66
stan::require_t< stan::bool_constant< is_tensor< T >::value > > require_tensor_v
Definition: vari_value.hpp:57
stan::bool_constant< is_arena_tensor_var< T >::value > require_arena_tensor_var_t
Definition: vari_value.hpp:75
stan::require_t< stan::bool_constant< is_arena_tensor< T >::value > > require_arena_tensor_v
Definition: vari_value.hpp:69
stan::bool_constant< is_arena_tensor< T >::value > require_arena_tensor_t
Definition: vari_value.hpp:72
stan::bool_constant< is_tensor_var< T >::value > require_tensor_var_t
Definition: vari_value.hpp:63
stan::require_t< stan::bool_constant< is_arena_tensor_var< T >::value > > require_arena_tensor_var_v
Definition: vari_value.hpp:78
stan::bool_constant< is_tensor< T >::value > require_tensor_t
Definition: vari_value.hpp:60