lamppp
|
simple static registration class More...
#include <dispatch_stub.hpp>
Public Types | |
using | fn_type = Fn |
Public Member Functions | |
void | register_kernel (DeviceType dev, fn_type f) |
template<typename... Args> | |
decltype(auto) | operator() (DeviceType dev, Args &&... args) const |
Public Attributes | |
std::array< fn_type, static_cast< size_t >(DeviceType::Count)> | table_ {} |
simple static registration class
this class keeps a dispatch table (table_
), which directs to the correct overload during runtime. e.g. calling add_stub()(DeviceType::CUDA, a, b)
calls add_cuda
on the tensors a
and b