Xped
Loading...
Searching...
No Matches
vari_value.hpp
Go to the documentation of this file.
1#ifndef XPED_VARI_VALUE_HPP_
2#define XPED_VARI_VALUE_HPP_
3
4#include "stan/math/rev/core/vari.hpp"
5
7
8template <typename T>
10{
11 static const bool value = false;
12};
13
14template <typename Scalar, std::size_t Rank, std::size_t CoRank, typename Symmetry, typename AllocationPolicy>
15struct is_tensor<Xped::Tensor<Scalar, Rank, CoRank, Symmetry, false, AllocationPolicy>>
16{
17 static const bool value = true;
18};
19
20template <typename T>
22{
23 static const bool value = false;
24};
25
26template <typename Scalar, std::size_t Rank, std::size_t CoRank, typename Symmetry>
27struct is_tensor_var<Xped::Tensor<Scalar, Rank, CoRank, Symmetry, true>>
28{
29 static const bool value = true;
30};
31
32template <typename T>
34{
35 static const bool value = false;
36};
37
38template <typename Scalar, std::size_t Rank, std::size_t CoRank, typename Symmetry>
39struct is_arena_tensor<Xped::ArenaTensor<Scalar, Rank, CoRank, Symmetry, false>>
40{
41 static const bool value = true;
42};
43
44template <typename T>
46{
47 static const bool value = false;
48};
49
50template <typename Scalar, std::size_t Rank, std::size_t CoRank, typename Symmetry>
51struct is_arena_tensor_var<Xped::ArenaTensor<Scalar, Rank, CoRank, Symmetry, true>>
52{
53 static const bool value = true;
54};
55
56template <typename T>
57using require_tensor_v = stan::require_t<stan::bool_constant<is_tensor<T>::value>>;
58
59template <typename T>
60using require_tensor_t = stan::bool_constant<is_tensor<T>::value>;
61
62template <typename T>
63using require_tensor_var_t = stan::bool_constant<is_tensor_var<T>::value>;
64
65template <typename T>
66using require_tensor_var_v = stan::require_t<stan::bool_constant<is_tensor_var<T>::value>>;
67
68template <typename T>
69using require_arena_tensor_v = stan::require_t<stan::bool_constant<is_arena_tensor<T>::value>>;
70
71template <typename T>
72using require_arena_tensor_t = stan::bool_constant<is_arena_tensor<T>::value>;
73
74template <typename T>
75using require_arena_tensor_var_t = stan::bool_constant<is_arena_tensor_var<T>::value>;
76
77template <typename T>
78using require_arena_tensor_var_v = stan::require_t<stan::bool_constant<is_arena_tensor_var<T>::value>>;
79
80class Empty
81{};
82
83// template <typename Derived>
84// class vari_view_base
85// {
86// vari_view_base() = default;
87// friend Derived;
88
89// inline Derived& derived() { return static_cast<Derived&>(*this); }
90// inline const Derived& derived() const { return static_cast<const Derived&>(*this); }
91
92// public:
93// };
94
95namespace stan::math {
96template <typename T>
97class vari_value<T, require_tensor_v<T>> : public vari_base, public std::conditional_t<is_arena_tensor<T>::value, Empty, chainable_alloc>
98{
99 using value_type = T; // The underlying type for this class
100 using Scalar = typename T::Scalar; // A floating point type
101 using Symmetry = typename T::Symmetry;
102 using AllocationPolicy = typename T::AllocationPolicy;
103 // using Storage = typename T::Storage; // A floating point type
104 using vari_type = vari_value<T>;
105 static constexpr std::size_t Rank = T::Rank;
106 static constexpr std::size_t CoRank = T::CoRank;
107
108 T val_;
109
110 T adj_;
111
112 template <typename S>
113 explicit vari_value(const S& x)
114 : val_(x)
115 , adj_(x.uncoupledDomain(), x.uncoupledCodomain(), x.world())
116 {
117 adj_.setZero();
118 stan::math::ChainableStack::instance_->var_stack_.push_back(this);
119 }
120
121 template <typename S>
122 explicit vari_value(const S& x, bool stacked)
123 : val_(x)
124 , adj_(x.uncoupledDomain(), x.uncoupledCodomain(), x.world())
125 {
126 adj_.setZero();
127 if(stacked) {
128 stan::math::ChainableStack::instance_->var_stack_.push_back(this);
129 } else {
130 stan::math::ChainableStack::instance_->var_nochain_stack_.push_back(this);
131 }
132 }
133
134 vari_value(const std::array<Xped::Qbasis<Symmetry, 1, AllocationPolicy>, Rank>& basis_domain,
135 const std::array<Xped::Qbasis<Symmetry, 1, AllocationPolicy>, CoRank>& basis_codomain,
137 : val_(basis_domain, basis_codomain, world)
138 , adj_(basis_domain, basis_codomain, world)
139 {
140 adj_.setZero();
141 stan::math::ChainableStack::instance_->var_nochain_stack_.push_back(this);
142 }
143
144 inline const auto& val() const noexcept { return val_; }
145 inline auto& val_op() noexcept { return val_; }
146
147 inline auto& adj() noexcept { return adj_; }
148 inline auto& adj() const noexcept { return adj_; }
149 inline auto& adj_op() noexcept { return adj_; }
150
151 constexpr std::size_t rank() const { return val_.rank(); }
152 constexpr std::size_t corank() const { return val_.corank(); }
153
154 virtual void chain() {}
155
156 inline void init_dependent() { adj_.setOnes(); }
157
158 inline void set_zero_adjoint() final { adj_.setZero(); }
159
160 friend std::ostream& operator<<(std::ostream& os, const vari_value<T>* v) { return os << "val: \n" << v->val_ << " \nadj: \n" << v->adj_; }
161
162private:
163 template <typename, std::size_t, std::size_t, typename, bool, typename>
164 friend class Xped::Tensor;
165};
166} // namespace stan::math
167#endif
Definition: vari_value.hpp:81
Definition: Qbasis.hpp:39
Definition: Tensor.hpp:40
XpedWorld & getUniverse()
Definition: Mpi.hpp:49
Definition: bench.cpp:62
std::ostream & operator<<(std::ostream &os, const FusionTree< depth, Symmetry > &tree)
Definition: FusionTree.hpp:93
Definition: Mpi.hpp:34
Definition: vari_value.hpp:46
static const bool value
Definition: vari_value.hpp:47
Definition: vari_value.hpp:34
static const bool value
Definition: vari_value.hpp:35
Definition: vari_value.hpp:22
static const bool value
Definition: vari_value.hpp:23
Definition: vari_value.hpp:10
static const bool value
Definition: vari_value.hpp:11
stan::require_t< stan::bool_constant< is_tensor_var< T >::value > > require_tensor_var_v
Definition: vari_value.hpp:66
stan::require_t< stan::bool_constant< is_tensor< T >::value > > require_tensor_v
Definition: vari_value.hpp:57
stan::bool_constant< is_arena_tensor_var< T >::value > require_arena_tensor_var_t
Definition: vari_value.hpp:75
stan::require_t< stan::bool_constant< is_arena_tensor< T >::value > > require_arena_tensor_v
Definition: vari_value.hpp:69
stan::bool_constant< is_arena_tensor< T >::value > require_arena_tensor_t
Definition: vari_value.hpp:72
stan::bool_constant< is_tensor_var< T >::value > require_tensor_var_t
Definition: vari_value.hpp:63
stan::require_t< stan::bool_constant< is_arena_tensor_var< T >::value > > require_arena_tensor_var_v
Definition: vari_value.hpp:78
stan::bool_constant< is_tensor< T >::value > require_tensor_t
Definition: vari_value.hpp:60