1#ifndef MATRIX_MULTIPLICATION_H_
2#define MATRIX_MULTIPLICATION_H_
6template <
typename IndexTypeA,
typename IndexTypeB>
7static std::size_t mult_cost(
const std::array<IndexTypeA, 2>& dimsA,
const std::array<IndexTypeB, 2>& dimsB)
9 return dimsA[0] * dimsA[1] * dimsB[1];
13template <
typename IndexTypeA,
typename IndexTypeB,
typename IndexTypeC>
14static std::vector<std::size_t>
15mult_cost(
const std::array<IndexTypeA, 2>& dimsA,
const std::array<IndexTypeB, 2>& dimsB,
const std::array<IndexTypeC, 2>& dimsC)
17 std::vector<std::size_t> out(2);
19 out[0] = mult_cost(dimsA, dimsB) + dimsA[0] * dimsC[0] * dimsC[1];
22 out[1] = mult_cost(dimsB, dimsC) + dimsA[0] * dimsA[1] * dimsC[1];
Definition: MatrixMultiplication.hpp:4