6#include "data_type.hpp"
7#include "device_type.hpp"
8#include "dispatch_type.hpp"
9#include "lamppp/common/assert.hpp"
10#include "lamppp/tensor/align_utils.hpp"
11#include "lamppp/tensor/native/memory_ops.hpp"
12#include "lamppp/tensor/storage.hpp"
15namespace lmp::tensor {
49 const std::vector<size_t>& shape, DeviceType device,
51 : data_(LMP_DISPATCH_ALL_TYPES(
53 [&] {
return Storage(
data.size() *
sizeof(scalar_t), device); })),
56 strides_(std::vector<detail::stride_t>(shape.size())),
57 numel_(shape.empty() ? 0
58 : std::accumulate(shape.begin(), shape.end(), 1,
59 std::multiplies<>())) {
60 LMP_CHECK(
data.size() == numel_) <<
61 "Size mismatch, product of shape must equal num elements";
63 ops::copy_stub()(DeviceType::CPU, device,
data.data(),
64 data_.data(), numel_, src_dtype, type_);
69 explicit TensorImpl(Storage storage,
const std::vector<size_t>& shape,
73 void*
data() const noexcept;
74 DataType type() const noexcept;
75 DeviceType device() const noexcept;
76 const std::vector<
size_t>& shape() const noexcept;
77 const std::vector<detail::stride_t>& strides() const noexcept;
78 size_t numel() const noexcept;
80 TensorImpl reshape(std::vector<
size_t> new_shape);
81 TensorImpl squeeze(
size_t dim);
82 TensorImpl expand_dims(
size_t dim);
83 Scalar index(const std::vector<
size_t>& idx);
85 void copy(const TensorImpl& other) const;
86 void fill(Scalar item) const;
87 void print(std::ostream& os) const;
98 void update_strides();
103 std::vector<
size_t> shape_;
104 std::vector<detail::stride_t> strides_;
Low-level data manager for Tensor and TensorImpl.
Definition storage.hpp:34
Main implementation class for Tensor object.
Definition tensor_impl.hpp:28
TensorImpl(const std::vector< T > &data, const std::vector< size_t > &shape, DeviceType device, DataType dtype)
Construct a TensorImpl from a vector of data.
Definition tensor_impl.hpp:48
void * data() const noexcept
Definition tensor_impl.cpp:31
Main tensor object for Lamppp.
Definition tensor.hpp:29