6#include "data_type.hpp"
7#include "device_type.hpp"
8#include "dispatch_type.hpp"
9#include "lamppp/common/assert.hpp"
10#include "lamppp/tensor/native/memory_ops.hpp"
11#include "lamppp/tensor/storage.hpp"
12#include "lamppp/tensor/utils/align_utils.hpp"
14namespace lmp::tensor {
48 const std::vector<size_t>& shape, DeviceType device,
50 : data_(LMP_DISPATCH_ALL_TYPES(
52 [&] {
return Storage(
data.size() *
sizeof(scalar_t), device); })),
55 strides_(std::vector<detail::stride_t>(shape.size())),
56 numel_(shape.empty() ? 0
57 : std::accumulate(shape.begin(), shape.end(), 1,
58 std::multiplies<>())) {
59 LMP_CHECK(
data.size() == numel_)
60 <<
"Size mismatch, product of shape must equal num elements";
62 ops::copy_stub()(DeviceType::CPU, device,
data.data(), data_.data(), numel_,
68 explicit TensorImpl(Storage storage,
const std::vector<size_t>& shape,
72 void*
data() const noexcept;
73 DataType type() const noexcept;
74 DeviceType device() const noexcept;
75 const std::vector<
size_t>& shape() const noexcept;
76 const std::vector<detail::stride_t>& strides() const noexcept;
77 size_t numel() const noexcept;
79 TensorImpl reshape(std::vector<
size_t> new_shape);
80 TensorImpl squeeze(
size_t dim);
81 TensorImpl expand_dims(
size_t dim);
82 Scalar index(const std::vector<
size_t>& idx);
84 void copy(const TensorImpl& other) const;
85 void fill(Scalar item) const;
86 void print(std::ostream& os) const;
97 void update_strides();
102 std::vector<
size_t> shape_;
103 std::vector<detail::stride_t> strides_;
Low-level data manager for Tensor and TensorImpl.
Definition storage.hpp:34
Main implementation class for Tensor object.
Definition tensor_impl.hpp:27
TensorImpl(const std::vector< T > &data, const std::vector< size_t > &shape, DeviceType device, DataType dtype)
Construct a TensorImpl from a vector of data.
Definition tensor_impl.hpp:47
void * data() const noexcept
Definition tensor_impl.cpp:30
Main tensor object for Lamppp.
Definition tensor.hpp:29