Main implementation class for Tensor object.
More...
#include <tensor_impl.hpp>
|
| template<typename T > |
| | TensorImpl (const std::vector< T > &data, const std::vector< size_t > &shape, DeviceType device, DataType dtype) |
| | Construct a TensorImpl from a vector of data.
|
| |
|
| TensorImpl (Storage storage, const std::vector< size_t > &shape, DataType dtype) |
| |
|
void * | data () const noexcept |
| |
|
DataType | type () const noexcept |
| |
|
DeviceType | device () const noexcept |
| |
|
const std::vector< size_t > & | shape () const noexcept |
| |
|
const std::vector< detail::stride_t > & | strides () const noexcept |
| |
|
size_t | numel () const noexcept |
| |
|
TensorImpl | reshape (std::vector< size_t > new_shape) |
| |
|
TensorImpl | squeeze (size_t dim) |
| |
|
TensorImpl | expand_dims (size_t dim) |
| |
|
Scalar | index (const std::vector< size_t > &idx) |
| |
|
void | copy (const TensorImpl &other) const |
| |
|
void | fill (Scalar item) const |
| |
|
void | print (std::ostream &os) const |
| |
Main implementation class for Tensor object.
TensorImpl contains a few core members: type_, shape_, and data_ Note that similar to Pytorch, Tensor/TensorImpl is not responsible for the low-level data storage – note that TensorImpl has no member called device_. That is managed by Storage.
- See also
- Tensor, Storage
◆ TensorImpl()
template<typename T >
| lmp::tensor::TensorImpl::TensorImpl |
( |
const std::vector< T > & |
data, |
|
|
const std::vector< size_t > & |
shape, |
|
|
DeviceType |
device, |
|
|
DataType |
dtype |
|
) |
| |
|
inlineexplicit |
Construct a TensorImpl from a vector of data.
- Template Parameters
-
| T | The element type of the input data vector |
- Parameters
-
| data | Flat vector containing the tensor data in row-major order |
| shape | Dimensions of the tensor, e.g. {28, 28} for a 2D tensor |
| device | Target device where the tensor will be stored (CPU/GPU) |
| dtype | Data type for the tensor elements (may differ from T) |
- Exceptions
-
| std::runtime_error | if data.size() != product of shape dimensions |
This constructor allocates storage on the specified device and copies the input data.
- Note
- Note that the input data's type T does NOT have to be the same as dtype. i.e. inputting dtype = DataType::Float64, but data = std::vector<int>{...} is valid
The documentation for this class was generated from the following files: