lamppp
Loading...
Searching...
No Matches
dispatch_stub.hpp
1#pragma once
2
3#include <array>
4#include <type_traits>
5#include "device_type.hpp"
6#include "lamppp/common/assert.hpp"
7#include "lamppp/tensor/data_type.hpp"
8
9namespace lmp::tensor::detail {
10
12
21template <typename Fn>
23 using fn_type = Fn;
24 std::array<fn_type, static_cast<size_t>(DeviceType::Count)> table_{};
25 constexpr DispatchStub() noexcept = default;
26
27 void register_kernel(DeviceType dev, fn_type f) {
28 table_[static_cast<size_t>(dev)] = f;
29 }
30
31 template <typename... Args>
32 decltype(auto) operator()(DeviceType dev, Args&&... args) const {
33 fn_type f = table_[static_cast<size_t>(dev)];
34 LMP_CHECK(f) << "Kernel for this backend not registered";
35 return f(std::forward<Args>(args)...);
36 }
37};
38
39} // namespace lmp::tensor::detail
40
48#define LMP_DECLARE_DISPATCH(fn_type, stub_name) \
49 ::lmp::tensor::detail::DispatchStub<fn_type>& stub_name();
50
51#define LMP_DEFINE_DISPATCH(fn_type, stub_name) \
52 ::lmp::tensor::detail::DispatchStub<fn_type>& stub_name() { \
53 static ::lmp::tensor::detail::DispatchStub<fn_type> s; \
54 return s; \
55 };
56
57#define LMP_REGISTER_DISPATCH(stub_name, dev, kernel_fn) \
58 namespace { \
59 struct _Reg##kernel_fn { \
60 _Reg##kernel_fn() { \
61 stub_name().register_kernel(dev, kernel_fn); \
62 } \
63 } _auto_reg_##kernel_fn; \
64 }
65
66
simple static registration class
Definition dispatch_stub.hpp:22