|
lamppp
|
simple static registration class More...
#include <dispatch_stub.hpp>
Public Types | |
| using | fn_type = Fn |
Public Member Functions | |
| void | register_kernel (DeviceType dev, fn_type f) |
| template<typename... Args> | |
| decltype(auto) | operator() (DeviceType dev, Args &&... args) const |
Public Attributes | |
| std::array< fn_type, static_cast< size_t >(DeviceType::Count)> | table_ {} |
simple static registration class
this class keeps a dispatch table (table_), which directs to the correct overload during runtime. e.g. calling add_stub()(DeviceType::CUDA, a, b) calls add_cuda on the tensors a and b