lamppp
Loading...
Searching...
No Matches
tensor_impl.hpp
1#pragma once
2
3#include <iostream>
4#include <numeric>
5#include <vector>
6#include "data_type.hpp"
7#include "device_type.hpp"
8#include "dispatch_type.hpp"
9#include "lamppp/common/assert.hpp"
10#include "lamppp/tensor/align_utils.hpp"
11#include "lamppp/tensor/native/memory_ops.hpp"
12#include "lamppp/tensor/storage.hpp"
13#include "scalar.hpp"
14
15namespace lmp::tensor {
16
18
29 public:
47 template <typename T>
48 explicit TensorImpl(const std::vector<T>& data,
49 const std::vector<size_t>& shape, DeviceType device,
50 DataType dtype)
51 : data_(LMP_DISPATCH_ALL_TYPES(
52 dtype,
53 [&] { return Storage(data.size() * sizeof(scalar_t), device); })),
54 shape_(shape),
55 type_(dtype),
56 strides_(std::vector<detail::stride_t>(shape.size())),
57 numel_(shape.empty() ? 0
58 : std::accumulate(shape.begin(), shape.end(), 1,
59 std::multiplies<>())) {
60 LMP_CHECK(data.size() == numel_) <<
61 "Size mismatch, product of shape must equal num elements";
62 DataType src_dtype = TypeMeta<T>::value;
63 ops::copy_stub()(DeviceType::CPU, device, data.data(),
64 data_.data(), numel_, src_dtype, type_);
65 update_strides();
66 }
69 explicit TensorImpl(Storage storage, const std::vector<size_t>& shape,
70 DataType dtype);
72
73 void* data() const noexcept;
74 DataType type() const noexcept;
75 DeviceType device() const noexcept;
76 const std::vector<size_t>& shape() const noexcept;
77 const std::vector<detail::stride_t>& strides() const noexcept;
78 size_t numel() const noexcept;
79
80 TensorImpl reshape(std::vector<size_t> new_shape);
81 TensorImpl squeeze(size_t dim);
82 TensorImpl expand_dims(size_t dim);
83 Scalar index(const std::vector<size_t>& idx);
84
85 void copy(const TensorImpl& other) const;
86 void fill(Scalar item) const;
87 void print(std::ostream& os) const;
88
89 private:
90 friend class Tensor;
91
98 void update_strides();
99
100 DataType type_;
101 Storage data_;
102 size_t numel_;
103 std::vector<size_t> shape_;
104 std::vector<detail::stride_t> strides_;
105};
107
108} // namespace lmp::tensor
Low-level data manager for Tensor and TensorImpl.
Definition storage.hpp:34
Main implementation class for Tensor object.
Definition tensor_impl.hpp:28
TensorImpl(const std::vector< T > &data, const std::vector< size_t > &shape, DeviceType device, DataType dtype)
Construct a TensorImpl from a vector of data.
Definition tensor_impl.hpp:48
void * data() const noexcept
Definition tensor_impl.cpp:31
Main tensor object for Lamppp.
Definition tensor.hpp:29
simple template to convert from a concrete type (like int) to the enum type, Int32....
Definition data_type.hpp:31