lamppp
Loading...
Searching...
No Matches
ptr_pack.hpp
1#pragma once
2
3#include <array>
4#include <tuple>
5
6namespace lmp::tensor::detail::cpu::internal {
7
9template <typename U, typename V>
11 U operator()(void* p, std::size_t i) const {
12 return static_cast<U>(static_cast<V*>(p)[i]);
13 }
14};
16
18
23template <class OutT, class... SrcTs>
24class PtrPack {
25 public:
26 static constexpr std::size_t kN = sizeof...(SrcTs);
27
28 ::std::array<void*, kN + 1> data;
29 ::std::tuple<TransformFunctor<OutT, OutT>, TransformFunctor<OutT, SrcTs>...>
30 fns;
31
32 constexpr explicit PtrPack(OutT* out, SrcTs*... in)
33 : data{static_cast<void*>(out), static_cast<void*>(in)...},
36
37 void set_Out(std::size_t idx, OutT value) {
38 static_cast<OutT*>(data[0])[idx] = value;
39 }
40};
42
43} // namespace lmp::tensor::detail::cpu::internal