lamppp
|
Utility class for aligning shapes of tensors. More...
#include <align_utils.hpp>
Public Member Functions | |
AlignUtil (const std::vector< size_t > &a_shape, const std::vector< size_t > &b_shape) | |
Public Attributes | |
std::vector< size_t > | aligned_shape_ |
std::vector< stride_t > | aligned_stride_ |
size_t | aligned_size_ |
Utility class for aligning shapes of tensors.
This class is used to align the shapes of two tensors so that they can be broadcasted together. Alignment is done by NumPy's broadcasting rules.