lamppp
Loading...
Searching...
No Matches
align_utils.hpp
1#pragma once
2
3#include <vector>
4#include "lamppp/tensor/data_type.hpp"
5
7#define LMP_MAX_DIMS 16
8
10namespace lmp::tensor::detail {
11
12using stride_t = int64_t;
13using stride_list = std::vector<stride_t>;
14using shape_list = std::vector<size_t>;
15
24class AlignUtil {
25 public:
26 explicit AlignUtil(const std::vector<size_t>& a_shape,
27 const std::vector<size_t>& b_shape);
28
29 std::vector<size_t> aligned_shape_;
30 std::vector<stride_t> aligned_stride_;
31 size_t aligned_size_;
32
33 private:
34 static std::vector<size_t> calc_aligned_shape(const std::vector<size_t>& a_shape,
35 const std::vector<size_t>& b_shape);
36 std::vector<stride_t> calc_aligned_stride();
37};
39
40} // namespace lmp::tensor::detail
Utility class for aligning shapes of tensors.
Definition align_utils.hpp:24