lamppp
Loading...
Searching...
No Matches
reduct.hpp
1#pragma once
2
3#include "lamppp/tensor/cpu/kernels.hpp"
4#include "lamppp/tensor/cpu/meta_handler.hpp"
5#include "lamppp/tensor/cpu/ptr_pack.hpp"
6#include "lamppp/tensor/tensor_impl.hpp"
7
9
11template <typename PtrList, typename OpFn>
12void vectorized_reduct_kernel(PtrList ptr_, OpFn fn_, size_t i,
13 size_t axis, const size_t* shape,
14 const stride_t* strides);
15
16template <typename PtrList, typename OpFn>
17void reduct_kernel_launcher(PtrList ptr_, OpFn fn_, size_t size, size_t axis,
18 const size_t* shape, const stride_t* strides,
19 size_t ndims);
20
21template <template <typename> class OpFunctor, typename... Args>
22void reduct_dispatch_handler(ReductMetaHandler& meta, size_t axis,
23 Args&&... args);
24
25extern template void reduct_dispatch_handler<SumFunctor>(ReductMetaHandler&, size_t);
26extern template void reduct_dispatch_handler<MaxFunctor>(ReductMetaHandler&, size_t);
27extern template void reduct_dispatch_handler<MinFunctor>(ReductMetaHandler&, size_t);
28extern template void reduct_dispatch_handler<ProdFunctor>(ReductMetaHandler&, size_t);
29
31
32} // namespace lmp::tensor::detail::cpu
Definition binary.cpp:4