lamppp
Loading...
Searching...
No Matches
Public Member Functions | Friends | List of all members
lmp::tensor::TensorImpl Class Reference

Main implementation class for Tensor object. More...

#include <tensor_impl.hpp>

Public Member Functions

template<typename T >
 TensorImpl (const std::vector< T > &data, const std::vector< size_t > &shape, DeviceType device, DataType dtype)
 Construct a TensorImpl from a vector of data.
 
 TensorImpl (Storage storage, const std::vector< size_t > &shape, DataType dtype)
 
void * data () const noexcept
 
DataType type () const noexcept
 
DeviceType device () const noexcept
 
const std::vector< size_t > & shape () const noexcept
 
const std::vector< detail::stride_t > & strides () const noexcept
 
size_t numel () const noexcept
 
TensorImpl reshape (std::vector< size_t > new_shape)
 
TensorImpl squeeze (size_t dim)
 
TensorImpl expand_dims (size_t dim)
 
Scalar index (const std::vector< size_t > &idx)
 
void copy (const TensorImpl &other) const
 
void fill (Scalar item) const
 
void print (std::ostream &os) const
 

Friends

class Tensor
 

Detailed Description

Main implementation class for Tensor object.

TensorImpl contains a few core members: type_, shape_, and data_ Note that similar to Pytorch, Tensor/TensorImpl is not responsible for the low-level data storage – note that TensorImpl has no member called device_. That is managed by Storage.

See also
Tensor, Storage

Constructor & Destructor Documentation

◆ TensorImpl()

template<typename T >
lmp::tensor::TensorImpl::TensorImpl ( const std::vector< T > &  data,
const std::vector< size_t > &  shape,
DeviceType  device,
DataType  dtype 
)
inlineexplicit

Construct a TensorImpl from a vector of data.

Template Parameters
TThe element type of the input data vector
Parameters
dataFlat vector containing the tensor data in row-major order
shapeDimensions of the tensor, e.g. {28, 28} for a 2D tensor
deviceTarget device where the tensor will be stored (CPU/GPU)
dtypeData type for the tensor elements (may differ from T)
Exceptions
std::runtime_errorif data.size() != product of shape dimensions

This constructor allocates storage on the specified device and copies the input data.

Note
Note that the input data's type T does NOT have to be the same as dtype. i.e. inputting dtype = DataType::Float64, but data = std::vector<int>{...} is valid

The documentation for this class was generated from the following files: