3#include "lamppp/tensor/cpu/kernels.hpp"
4#include "lamppp/tensor/cpu/meta_handler.hpp"
5#include "lamppp/tensor/cpu/ptr_pack.hpp"
6#include "lamppp/tensor/tensor_impl.hpp"
11template <
typename PtrList,
typename OpFn>
12void vectorized_reduct_kernel(PtrList ptr_, OpFn fn_,
size_t i,
13 size_t axis,
const size_t* shape,
14 const stride_t* strides);
16template <
typename PtrList,
typename OpFn>
17void reduct_kernel_launcher(PtrList ptr_, OpFn fn_,
size_t size,
size_t axis,
18 const size_t* shape,
const stride_t* strides,
21template <
template <
typename>
class OpFunctor,
typename... Args>
22void reduct_dispatch_handler(ReductMetaHandler& meta,
size_t axis,
25extern template void reduct_dispatch_handler<SumFunctor>(ReductMetaHandler&,
size_t);
26extern template void reduct_dispatch_handler<MaxFunctor>(ReductMetaHandler&,
size_t);
27extern template void reduct_dispatch_handler<MinFunctor>(ReductMetaHandler&,
size_t);
28extern template void reduct_dispatch_handler<ProdFunctor>(ReductMetaHandler&,
size_t);