3#include "lamppp/tensor/device_type.hpp"
4#include "lamppp/tensor/dispatch_stub.hpp"
5#include "lamppp/tensor/tensor.hpp"
6#include "lamppp/tensor/tensor_impl.hpp"
8namespace lmp::tensor::ops {
11using sum_fn = TensorImpl (*)(
const TensorImpl&,
size_t axis);
12using max_fn = TensorImpl (*)(
const TensorImpl&,
size_t axis);
13using min_fn = TensorImpl (*)(
const TensorImpl&,
size_t axis);
14using prod_fn = TensorImpl (*)(
const TensorImpl&,
size_t axis);
16LMP_DECLARE_DISPATCH(sum_fn, sum_stub);
17LMP_DECLARE_DISPATCH(max_fn, max_stub);
18LMP_DECLARE_DISPATCH(min_fn, min_stub);
19LMP_DECLARE_DISPATCH(prod_fn, prod_stub);
30Tensor sum(
const Tensor& a,
size_t axis);
40Tensor max(
const Tensor& a,
size_t axis);
50Tensor min(
const Tensor& a,
size_t axis);
60Tensor prod(
const Tensor& a,
size_t axis);