Main implementation class for Tensor object.
More...
#include <tensor_impl.hpp>
|
template<typename T > |
| TensorImpl (const std::vector< T > &data, const std::vector< size_t > &shape, DeviceType device, DataType dtype) |
| Construct a TensorImpl from a vector of data.
|
|
| TensorImpl (Storage storage, const std::vector< size_t > &shape, DataType dtype) |
|
void * | data () const noexcept |
|
DataType | type () const noexcept |
|
DeviceType | device () const noexcept |
|
const std::vector< size_t > & | shape () const noexcept |
|
const std::vector< detail::stride_t > & | strides () const noexcept |
|
size_t | numel () const noexcept |
|
TensorImpl | reshape (std::vector< size_t > new_shape) |
|
TensorImpl | squeeze (size_t dim) |
|
TensorImpl | expand_dims (size_t dim) |
|
Scalar | index (const std::vector< size_t > &idx) |
|
void | copy (const TensorImpl &other) const |
|
void | fill (Scalar item) const |
|
void | print (std::ostream &os) const |
|
Main implementation class for Tensor object.
TensorImpl
contains a few core members: type_
, shape_
, and data_
Note that similar to Pytorch, Tensor/TensorImpl is not responsible for the low-level data storage – note that TensorImpl
has no member called device_
. That is managed by Storage
.
- See also
- Tensor, Storage
◆ TensorImpl()
template<typename T >
lmp::tensor::TensorImpl::TensorImpl |
( |
const std::vector< T > & |
data, |
|
|
const std::vector< size_t > & |
shape, |
|
|
DeviceType |
device, |
|
|
DataType |
dtype |
|
) |
| |
|
inlineexplicit |
Construct a TensorImpl from a vector of data.
- Template Parameters
-
T | The element type of the input data vector |
- Parameters
-
data | Flat vector containing the tensor data in row-major order |
shape | Dimensions of the tensor, e.g. {28, 28} for a 2D tensor |
device | Target device where the tensor will be stored (CPU/GPU) |
dtype | Data type for the tensor elements (may differ from T) |
- Exceptions
-
std::runtime_error | if data.size() != product of shape dimensions |
This constructor allocates storage on the specified device and copies the input data.
- Note
- Note that the input data's type T does NOT have to be the same as dtype. i.e. inputting dtype = DataType::Float64, but data = std::vector<int>{...} is valid
The documentation for this class was generated from the following files: