lamppp
Loading...
Searching...
No Matches
reduct_ops.hpp
1#pragma once
2
3#include "lamppp/tensor/device_type.hpp"
4#include "lamppp/tensor/dispatch_stub.hpp"
5#include "lamppp/tensor/tensor.hpp"
6#include "lamppp/tensor/tensor_impl.hpp"
7
8namespace lmp::tensor::ops {
9
11using sum_fn = TensorImpl (*)(const TensorImpl&, size_t axis);
12using max_fn = TensorImpl (*)(const TensorImpl&, size_t axis);
13using min_fn = TensorImpl (*)(const TensorImpl&, size_t axis);
14using prod_fn = TensorImpl (*)(const TensorImpl&, size_t axis);
15
16LMP_DECLARE_DISPATCH(sum_fn, sum_stub);
17LMP_DECLARE_DISPATCH(max_fn, max_stub);
18LMP_DECLARE_DISPATCH(min_fn, min_stub);
19LMP_DECLARE_DISPATCH(prod_fn, prod_stub);
21
30Tensor sum(const Tensor& a, size_t axis);
31
40Tensor max(const Tensor& a, size_t axis);
41
50Tensor min(const Tensor& a, size_t axis);
51
60Tensor prod(const Tensor& a, size_t axis);
61
62} // namespace lmp::tensor::ops