lamppp
Loading...
Searching...
No Matches
grad_utils.hpp
1#include "lamppp/tensor/tensor.hpp"
2
3namespace lmp::autograd::detail {
4
6
12tensor::Tensor sum_broadcast_axis(const tensor::Tensor& grad,
13 const std::vector<size_t>& orig_shape);
14
16
17} // namespace lmp::autograd::detail