5#include "lamppp/tensor/align_utils.hpp"
6#include "lamppp/tensor/tensor_impl.hpp"
8namespace lmp::tensor::detail {
17 explicit OffsetUtil(
size_t ndim) : ndim(ndim) {};
21 std::vector<stride_t> init_padded_strides(
22 const std::vector<size_t>& shape,
const std::vector<stride_t>& stride)
const;
29template <
size_t NArgs>
32 explicit CPUOffsetUtil(::std::array<const TensorImpl*, NArgs> ins,
34 ::std::array<stride_t, NArgs + 1> get(
size_t idx)
const;
36 ::std::array<std::vector<stride_t>, NArgs + 1> arg_strides_;
40template <
size_t NArgs>
41std::unique_ptr<OffsetUtil>
offset_util_cpu(::std::array<const TensorImpl*, NArgs> ins,
43 return std::make_unique<cpu::CPUOffsetUtil<NArgs>>(ins, out);
48template <
size_t NArgs>
49using offset_util_fn = std::unique_ptr<OffsetUtil> (*)(::std::array<const TensorImpl*, NArgs>,
const TensorImpl&);
51LMP_DECLARE_DISPATCH(offset_util_fn<2>, offset_util_stub_2);
52LMP_DECLARE_DISPATCH(offset_util_fn<3>, offset_util_stub_3);
Main implementation class for Tensor object.
Definition tensor_impl.hpp:28
Offset utility for CPU.
Definition offset_util.hpp:15
Definition offset_util.hpp:30
std::unique_ptr< OffsetUtil > offset_util_cpu(::std::array< const TensorImpl *, NArgs > ins, const TensorImpl &out)
Definition offset_util.hpp:41