lamppp
Loading...
Searching...
No Matches
offset_util.hpp
1#pragma once
2
3#include <array>
4#include <vector>
5#include "lamppp/tensor/align_utils.hpp"
6#include "lamppp/tensor/tensor_impl.hpp"
7
8namespace lmp::tensor::detail {
9
11
16 public:
17 explicit OffsetUtil(size_t ndim) : ndim(ndim) {};
18 size_t ndim;
19
20 protected:
21 std::vector<stride_t> init_padded_strides(
22 const std::vector<size_t>& shape, const std::vector<stride_t>& stride) const;
23};
25
26namespace cpu {
27
29template <size_t NArgs>
30class CPUOffsetUtil : public OffsetUtil {
31 public:
32 explicit CPUOffsetUtil(::std::array<const TensorImpl*, NArgs> ins,
33 const TensorImpl& outs);
34 ::std::array<stride_t, NArgs + 1> get(size_t idx) const;
35
36 ::std::array<std::vector<stride_t>, NArgs + 1> arg_strides_;
37};
39
40template <size_t NArgs>
41std::unique_ptr<OffsetUtil> offset_util_cpu(::std::array<const TensorImpl*, NArgs> ins,
42 const TensorImpl& out) {
43 return std::make_unique<cpu::CPUOffsetUtil<NArgs>>(ins, out);
44}
45
46} // namespace cpu
47
48template <size_t NArgs>
49using offset_util_fn = std::unique_ptr<OffsetUtil> (*)(::std::array<const TensorImpl*, NArgs>, const TensorImpl&);
50
51LMP_DECLARE_DISPATCH(offset_util_fn<2>, offset_util_stub_2);
52LMP_DECLARE_DISPATCH(offset_util_fn<3>, offset_util_stub_3);
53
54}; // namespace lmp::tensor::detail
Main implementation class for Tensor object.
Definition tensor_impl.hpp:28
Offset utility for CPU.
Definition offset_util.hpp:15
Definition offset_util.hpp:30
std::unique_ptr< OffsetUtil > offset_util_cpu(::std::array< const TensorImpl *, NArgs > ins, const TensorImpl &out)
Definition offset_util.hpp:41