Xped
Loading...
Searching...
No Matches
MatrixMultiplication.hpp
Go to the documentation of this file.
1#ifndef MATRIX_MULTIPLICATION_H_
2#define MATRIX_MULTIPLICATION_H_
3
4namespace internal {
6template <typename IndexTypeA, typename IndexTypeB>
7static std::size_t mult_cost(const std::array<IndexTypeA, 2>& dimsA, const std::array<IndexTypeB, 2>& dimsB)
8{
9 return dimsA[0] * dimsA[1] * dimsB[1];
10}
11
13template <typename IndexTypeA, typename IndexTypeB, typename IndexTypeC>
14static std::vector<std::size_t>
15mult_cost(const std::array<IndexTypeA, 2>& dimsA, const std::array<IndexTypeB, 2>& dimsB, const std::array<IndexTypeC, 2>& dimsC)
16{
17 std::vector<std::size_t> out(2);
18 // (AB)C
19 out[0] = mult_cost(dimsA, dimsB) + dimsA[0] * dimsC[0] * dimsC[1];
20
21 // A(BC)
22 out[1] = mult_cost(dimsB, dimsC) + dimsA[0] * dimsA[1] * dimsC[1];
23
24 return out;
25}
26
27// /**Cost to multiply 4 matrices in 5 possible ways.*/
28// template <typename MatrixTypeA, typename MatrixTypeB, typename MatrixTypeC, typename MatrixTypeD>
29// std::vector<size_t> mult_cost(const MatrixTypeA& A, const MatrixTypeB& B, const MatrixTypeC& C, const MatrixTypeD& D)
30// {
31// std::vector<size_t> out(5);
32// // (AB)(CD)
33// out[0] = mult_cost(A, B) + mult_cost(C, D) + A.rows() * B.cols() * C.cols();
34
35// // ((AB)C)D
36// out[1] = mult_cost(A, B) + A.rows() * C.rows() * C.cols() + A.rows() * D.rows() * D.cols();
37
38// // (A(BC))D
39// out[2] = mult_cost(B, C) + A.rows() * A.cols() * C.cols() + A.rows() * D.rows() * D.cols();
40
41// // A((BC)D)
42// out[3] = mult_cost(B, C) + B.rows() * D.rows() * D.cols() + A.rows() * A.cols() * D.cols();
43
44// // A(B(CD))
45// out[4] = mult_cost(C, D) + B.rows() * B.cols() * D.cols() + A.rows() * A.cols() * D.cols();
46// return out;
47// }
48} // namespace internal
49#endif
Definition: MatrixMultiplication.hpp:4