lamppp
Loading...
Searching...
No Matches
meta_handler.hpp
1#pragma once
2
3#include <memory>
4#include <vector>
5#include "lamppp/tensor/cpu/offset_util.hpp"
6#include "lamppp/tensor/data_type.hpp"
7#include "lamppp/tensor/tensor_impl.hpp"
8
9namespace lmp::tensor::detail {
10
11using tensor_list = std::vector<const lmp::tensor::TensorImpl*>;
12
13template <typename... Args>
15 public:
16 static constexpr std::size_t kNumElem =
17 (0 + ... + std::size_t{std::is_same_v<const TensorImpl*, Args>});
18 explicit TensorMetaHandler(Args... args);
19
20 TensorImpl& out() noexcept { return *outTen_; }
21 tensor_list& in() noexcept { return inTens_; }
22 const OffsetUtil* offset() const noexcept {
23 LMP_INTERNAL_ASSERT(expand_) << "Must have expand = True to get offset";
24 return outOffset_.get();
25 }
26 bool expand() const noexcept { return expand_; }
27
28 private:
29 DataType outDtype_;
30 size_t outSize_;
31 std::vector<size_t> outShape_;
32
33 bool expand_;
34 std::unique_ptr<OffsetUtil> outOffset_;
35 std::unique_ptr<TensorImpl> outTen_;
36 tensor_list inTens_;
37};
38
43
44} // namespace lmp::tensor::detail
Main implementation class for Tensor object.
Definition tensor_impl.hpp:28
Offset utility for CPU.
Definition offset_util.hpp:15
Definition meta_handler.hpp:14