lamppp
Loading...
Searching...
No Matches
any.hpp
1#pragma once
2
3#include <any>
4#include <memory>
5#include <type_traits>
6#include <utility>
7#include <vector>
8#include "lamppp/common/assert.hpp"
9#include "lamppp/nets/module.hpp"
10
11namespace lmp::nets {
12
13class AnyModule {
14 public:
15 template <typename Derived>
16 explicit AnyModule(Module<Derived> mod) {
18 &Derived::forward);
19 }
20
21 std::shared_ptr<ModuleImpl> getImpl();
22 std::any call(const std::vector<std::any>& args) const;
23
24 private:
25 class Placeholder;
26 template <typename MImpl, typename... Args>
27 class Holder;
28
29 std::shared_ptr<Placeholder> impl_;
30 protected:
31 template <typename Impl, typename R, typename... Args>
32 std::shared_ptr<AnyModule::Placeholder> make_holder(std::shared_ptr<Impl> m,
33 R (Impl::*fp)(Args...)
34 const) {
35 using H = typename AnyModule::Holder<Impl, Args...>;
36 return std::make_shared<H>(std::move(m), fp);
37 }
38};
39
40
42 public:
43 virtual ~Placeholder() = default;
44
45 virtual std::any call(const std::vector<std::any>& args) = 0;
46 virtual std::shared_ptr<ModuleImpl> getImpl() = 0;
47};
48
49
50template <typename MImpl, typename... Args>
51class AnyModule::Holder : public AnyModule::Placeholder {
52 using FuncPtr = std::invoke_result_t<decltype(&MImpl::forward), MImpl*,
53 Args...> (MImpl::*)(Args...) const;
54
55 public:
56 ~Holder() override = default;
57 explicit Holder(std::shared_ptr<MImpl> mod, FuncPtr forward)
58 : mod_(mod), forward_(forward) {};
59
60 std::any call(const std::vector<std::any>& args) override {
61 LMP_CHECK(args.size() == sizeof...(Args)) << "Invalid forward arguments";
62 return invoke(args, std::index_sequence_for<Args...>{});
63 }
64 std::shared_ptr<ModuleImpl> getImpl() override { return mod_; };
65
66 private:
67 template <size_t... Idx>
68 std::any invoke(const std::vector<std::any>& args,
69 std::index_sequence<Idx...> /*seq*/) {
70 return std::any((static_cast<MImpl*>(mod_.get())->*forward_)(
71 any_cast<Args>(args[Idx])...));
72 }
73
74 std::shared_ptr<MImpl> mod_;
75 FuncPtr forward_;
76};
77
78} // namespace lmp::nets
Definition any.hpp:13
Definition module.hpp:39