Xped
Loading...
Searching...
No Matches
TensorBase.hpp
Go to the documentation of this file.
1#ifndef XPED_BASE_H_
2#define XPED_BASE_H_
3
5
6namespace Xped {
7
8template <typename Derived>
10{};
11
12// forward declarations
13template <typename>
14class AdjointOp;
15
16template <typename, typename>
17class CoeffUnaryOp;
18
19template <typename>
20class BlockUnaryOp;
21
22template <typename>
24
25template <typename, typename>
26class CoeffBinaryOp;
27
28template <typename, typename>
30
31template <typename Scalar, std::size_t Rank, std::size_t CoRank, typename Symmetry, bool ENABLE_AD, typename AllocationPolicy>
32class Tensor;
33
34template <typename Derived>
36{
37public:
41
42 static constexpr std::size_t Rank = TensorTraits<Derived>::Rank;
43 static constexpr std::size_t CoRank = TensorTraits<Derived>::CoRank;
44
46
47 XPED_CONST AdjointOp<Derived> adjoint() XPED_CONST;
48
49 // Unary operations
50 template <typename ReturnScalar>
51 XPED_CONST CoeffUnaryOp<Derived, ReturnScalar> unaryExpr(const std::function<ReturnScalar(Scalar)>& coeff_func) XPED_CONST;
52
53 XPED_CONST BlockUnaryOp<Derived> unaryExpr(const std::function<MatrixType(const MatrixType&)>& coeff_func) XPED_CONST;
54
55 XPED_CONST CoeffUnaryOp<Derived, Scalar> sqrt() XPED_CONST;
56 XPED_CONST CoeffUnaryOp<Derived, Scalar> inv() XPED_CONST;
57 XPED_CONST CoeffUnaryOp<Derived, Scalar> square() XPED_CONST;
58 XPED_CONST CoeffUnaryOp<Derived, typename ScalarTraits<Scalar>::Real> abs() XPED_CONST;
59 template <typename OtherScalar>
60 XPED_CONST CoeffUnaryOp<Derived, OtherScalar> cast() XPED_CONST;
61
62 XPED_CONST BlockUnaryOp<Derived> msqrt() XPED_CONST;
63 XPED_CONST BlockUnaryOp<Derived> mexp(Scalar factor) XPED_CONST;
64
65 Derived& operator+=(const Scalar offset);
66 Derived& operator-=(const Scalar offset);
67 Derived& operator*=(const Scalar factor);
68 Derived& operator/=(const Scalar divisor);
69
70 XPED_CONST DiagCoeffUnaryOp<Derived> diagUnaryExpr(const std::function<Scalar(Scalar)>& coeff_func) XPED_CONST;
71
72 XPED_CONST DiagCoeffUnaryOp<Derived> diag_inv() XPED_CONST;
73 XPED_CONST DiagCoeffUnaryOp<Derived> diag_sqrt() XPED_CONST;
74
75 // Binary operations
76 template <typename OtherDerived>
77 XPED_CONST DiagCoeffBinaryOp<Derived, OtherDerived> diagBinaryExpr(XPED_CONST TensorBase<OtherDerived>& other,
78 const std::function<Scalar(Scalar, Scalar)>& coeff_func) XPED_CONST;
79
80 template <typename OtherDerived>
81 XPED_CONST CoeffBinaryOp<Derived, OtherDerived> binaryExpr(XPED_CONST TensorBase<OtherDerived>& other,
82 const std::function<Scalar(Scalar, Scalar)>& coeff_func) XPED_CONST;
83
84 template <typename OtherDerived>
85 Derived& operator+=(XPED_CONST TensorBase<OtherDerived>& other);
86 template <typename OtherDerived>
87 Derived& operator-=(XPED_CONST TensorBase<OtherDerived>& other);
88
89 template <bool = false, typename OtherDerived>
90 Tensor<std::common_type_t<Scalar, typename TensorTraits<OtherDerived>::Scalar>,
91 TensorTraits<Derived>::Rank,
92 TensorTraits<typename std::remove_const<std::remove_reference_t<OtherDerived>>::type>::CoRank,
94 false,
96 operator*(XPED_CONST TensorBase<OtherDerived>& other) XPED_CONST;
97
98 template <bool TRACK = false, typename OtherDerived>
99 Tensor<std::common_type_t<Scalar, typename TensorTraits<OtherDerived>::Scalar>,
100 Rank,
101 TensorTraits<typename std::remove_const<std::remove_reference_t<OtherDerived>>::type>::CoRank,
102 Symmetry,
103 false,
105 operator*(TensorBase<OtherDerived>&& other) XPED_CONST
106 {
107 TensorBase<OtherDerived>& tmp = other;
108 return this->operator*<TRACK>(tmp);
109 }
110
111 template <bool = false>
112 Scalar trace() XPED_CONST;
113
114 typename ScalarTraits<Scalar>::Real maxNorm() XPED_CONST;
115
116 typename ScalarTraits<Scalar>::Real squaredNorm() XPED_CONST;
117
118 inline typename ScalarTraits<Scalar>::Real norm() XPED_CONST { return std::sqrt(squaredNorm()); }
119
121 maxCoeff(std::size_t& max_block, PlainInterface::MIndextype& max_row, PlainInterface::MIndextype& max_col) XPED_CONST;
122
124 {
126 };
127
128 inline const Derived& derived() const { return *static_cast<const Derived*>(this); }
129 inline Derived& derived() { return *static_cast<Derived*>(this); }
130
131protected:
132 template <typename, std::size_t, std::size_t, typename, bool, typename>
133 friend class Tensor;
134 template <typename OtherDerived>
135 friend class TensorBase;
136
137 // inline const Derived& derived() const { return *static_cast<const Derived*>(this); }
138 // inline Derived& derived() { return *static_cast<Derived*>(this); }
139};
140
141template <typename DerivedLeft, typename DerivedRight>
143{
144 return left.binaryExpr(right, [](const typename DerivedLeft::Scalar s1, const typename DerivedRight::Scalar s2) { return s1 + s2; });
145}
146
147template <typename DerivedLeft, typename DerivedRight>
149{
150 TensorBase<DerivedLeft>& tmp_left = left;
151 TensorBase<DerivedRight>& tmp_right = right;
152 return tmp_left.binaryExpr(tmp_right, [](const typename DerivedLeft::Scalar s1, const typename DerivedRight::Scalar s2) { return s1 + s2; });
153}
154
155template <typename DerivedLeft, typename DerivedRight>
157{
158 return left.binaryExpr(right, [](const typename DerivedLeft::Scalar s1, const typename DerivedRight::Scalar s2) { return s1 - s2; });
159}
160
161template <typename Derived, typename Scalar>
163{
164 return left.template unaryExpr<std::common_type_t<typename Derived::Scalar, Scalar>>(
165 [offset](const typename Derived::Scalar s) { return offset + s; });
166}
167
168template <typename Derived, typename Scalar>
170{
171 return right.template unaryExpr<std::common_type_t<typename Derived::Scalar, Scalar>>(
172 [offset](const typename Derived::Scalar s) { return offset + s; });
173};
174
175template <typename Derived, typename Scalar>
177{
178 return left.template unaryExpr<std::common_type_t<typename Derived::Scalar, Scalar>>(
179 [offset](const typename Derived::Scalar s) { return s - offset; });
180}
181
182template <bool = false, typename Derived, typename Scalar>
184{
185 return left.template unaryExpr<std::common_type_t<typename Derived::Scalar, Scalar>>(
186 [factor](const typename Derived::Scalar s) { return s * factor; });
187}
188
189template <bool = false, typename Derived, typename Scalar>
191{
192 return right.template unaryExpr<std::common_type_t<typename Derived::Scalar, Scalar>>(
193 [factor](const typename Derived::Scalar s) { return s * factor; });
194}
195
196template <bool = false, typename Derived, typename Scalar>
198{
199 TensorBase<Derived>& tmp_right = right;
200 return tmp_right.template unaryExpr<std::common_type_t<typename Derived::Scalar, Scalar>>(
201 [factor](const typename Derived::Scalar s) { return s * factor; });
202}
203
204template <typename Derived, typename Scalar>
206{
207 return left.template unaryExpr<std::common_type_t<typename Derived::Scalar, Scalar>>(
208 [divisor](const typename Derived::Scalar s) { return s / divisor; });
209}
210
211} // namespace Xped
212
213#ifndef XPED_COMPILED_LIB
214# include "Core/TensorBase.cpp"
215#endif
216
217#endif
Definition: AdjointOp.hpp:25
Definition: BlockUnaryOp.hpp:25
Definition: CoeffBinaryOp.hpp:27
Definition: CoeffUnaryOp.hpp:27
Definition: DiagCoeffBinaryOp.hpp:27
Definition: DiagCoeffUnaryOp.hpp:25
Definition: TensorBase.hpp:36
ScalarTraits< Scalar >::Real norm() XPED_CONST
Definition: TensorBase.hpp:118
const Derived & derived() const
Definition: TensorBase.hpp:128
ScalarTraits< Scalar >::Real maxNorm() XPED_CONST
Definition: TensorBase.cpp:37
static constexpr std::size_t CoRank
Definition: TensorBase.hpp:43
XPED_CONST BlockUnaryOp< Derived > msqrt() XPED_CONST
Definition: TensorBase.cpp:149
XPED_CONST CoeffUnaryOp< Derived, Scalar > square() XPED_CONST
Definition: TensorBase.cpp:130
Derived & derived()
Definition: TensorBase.hpp:129
PlainInterface::MType< Scalar > MatrixType
Definition: TensorBase.hpp:45
XPED_CONST BlockUnaryOp< Derived > mexp(Scalar factor) XPED_CONST
Definition: TensorBase.cpp:155
XPED_CONST CoeffUnaryOp< Derived, ReturnScalar > unaryExpr(const std::function< ReturnScalar(Scalar)> &coeff_func) XPED_CONST
Definition: TensorBase.cpp:78
Scalar trace() XPED_CONST
TensorTraits< Derived >::Scalar Scalar
Definition: TensorBase.hpp:38
typename TensorTraits< Derived >::AllocationPolicy AllocationPolicy
Definition: TensorBase.hpp:40
XPED_CONST DiagCoeffBinaryOp< Derived, OtherDerived > diagBinaryExpr(XPED_CONST TensorBase< OtherDerived > &other, const std::function< Scalar(Scalar, Scalar)> &coeff_func) XPED_CONST
Definition: TensorBase.cpp:181
TensorTraits< Derived >::Symmetry Symmetry
Definition: TensorBase.hpp:39
XPED_CONST CoeffUnaryOp< Derived, Scalar > sqrt() XPED_CONST
Definition: TensorBase.cpp:118
ScalarTraits< Scalar >::Real maxCoeff(std::size_t &max_block, PlainInterface::MIndextype &max_row, PlainInterface::MIndextype &max_col) XPED_CONST
Definition: TensorBase.cpp:56
XPED_CONST CoeffUnaryOp< Derived, Scalar > inv() XPED_CONST
Definition: TensorBase.cpp:124
XPED_CONST CoeffUnaryOp< Derived, OtherScalar > cast() XPED_CONST
Definition: TensorBase.cpp:143
XPED_CONST DiagCoeffUnaryOp< Derived > diag_inv() XPED_CONST
Definition: TensorBase.cpp:167
Tensor< Scalar, Rank, CoRank, Symmetry, false, AllocationPolicy > eval() const
Definition: TensorBase.hpp:123
XPED_CONST CoeffBinaryOp< Derived, OtherDerived > binaryExpr(XPED_CONST TensorBase< OtherDerived > &other, const std::function< Scalar(Scalar, Scalar)> &coeff_func) XPED_CONST
Definition: TensorBase.cpp:188
XPED_CONST CoeffUnaryOp< Derived, typename ScalarTraits< Scalar >::Real > abs() XPED_CONST
Definition: TensorBase.cpp:136
XPED_CONST DiagCoeffUnaryOp< Derived > diag_sqrt() XPED_CONST
Definition: TensorBase.cpp:173
static constexpr std::size_t Rank
Definition: TensorBase.hpp:42
ScalarTraits< Scalar >::Real squaredNorm() XPED_CONST
Definition: TensorBase.cpp:47
XPED_CONST AdjointOp< Derived > adjoint() XPED_CONST
Definition: TensorBase.cpp:71
XPED_CONST DiagCoeffUnaryOp< Derived > diagUnaryExpr(const std::function< Scalar(Scalar)> &coeff_func) XPED_CONST
Definition: TensorBase.cpp:161
Definition: Tensor.hpp:40
Definition: bench.cpp:62
XTensor< TRACK, Scalar, Rank, CoRank, Symmetry > operator-(const Tensor< Scalar, Rank, CoRank, Symmetry, true > &t, Scalar s)
Definition: ADTensor.hpp:428
XTensor< TRACK, Scalar, Rank, CoRank, Symmetry > operator+(const Tensor< Scalar, Rank, CoRank, Symmetry, true > &t, Scalar s)
Definition: ADTensor.hpp:440
XTensor< TRACK, Scalar, Rank, CoRank, Symmetry > operator*(const Tensor< Scalar, Rank, CoRank, Symmetry, true > &t, Scalar s)
Definition: ADTensor.hpp:452
XPED_CONST CoeffUnaryOp< Derived, std::common_type_t< typename Derived::Scalar, Scalar > > operator/(XPED_CONST TensorBase< Derived > &left, Scalar divisor)
Definition: TensorBase.hpp:205
int MIndextype
Definition: MatrixInterface_Cyclops_impl.hpp:49
CTF::Matrix< Scalar > MType
Definition: MatrixInterface_Cyclops_impl.hpp:40
Definition: ScalarTraits.hpp:10
Definition: TensorBase.hpp:10