lamppp
Loading...
Searching...
No Matches
module.hpp
1#pragma once
2
3#include <functional>
4#include <map>
5#include <memory>
6#include <string>
7#include <unordered_map>
8#include <utility>
9#include <vector>
10#include "parameter.hpp"
11
12namespace lmp::nets {
13
14namespace detail {
15template <typename T>
16class UnsafeModuleAccessor;
17}
18
20 public:
21 std::vector<Parameter> parameters() const;
22 std::multimap<std::string, Parameter> named_parameters() const;
23
24 std::vector<std::reference_wrapper<Parameter>> parameters();
25 std::multimap<std::string, std::reference_wrapper<Parameter>> named_parameters();
26 void eval();
27 void train();
28
29 protected:
30 void register_parameter(const std::string& name, Parameter& param);
31 void register_module(const std::string& name, std::shared_ptr<ModuleImpl> module);
32
33 bool trainable_ = true;
34 std::unordered_map<std::string, std::shared_ptr<ModuleImpl>> modules_; // problem, this is not type-specific, no operator()
35 std::unordered_map<std::string, std::reference_wrapper<Parameter>> params_;
36};
37
38template <typename Derived>
39class Module {
40 public:
41 template <typename... Args>
42 explicit Module(Args&&... args)
43 : impl_(std::make_shared<Derived>(std::forward<Args>(args)...)) {}
44
45 std::vector<Parameter> parameters() const;
46 std::multimap<std::string, Parameter> named_parameters() const;
47
48 std::vector<std::reference_wrapper<Parameter>> parameters();
49 std::multimap<std::string, std::reference_wrapper<Parameter>> named_parameters();
50 void eval();
51 void train();
52
53 template <typename... Args>
54 auto operator()(Args&&... args)
55 -> std::invoke_result_t<decltype(&Derived::forward), Derived, Args...> {
56 return static_cast<Derived*>(impl_.get())
57 ->forward(std::forward<Args>(args)...);
58 }
59
60 protected:
61 std::shared_ptr<Derived> impl_;
62
63 template <typename T>
65};
66
67namespace detail {
68// @internal
69template <typename T>
71 static std::shared_ptr<T> getImpl(const Module<T>& mod) {
72 return mod.impl_;
73 }
74};
75// @endinternal
76} // namespace detail
77
78template <typename Derived>
79std::vector<std::reference_wrapper<Parameter>> Module<Derived>::parameters() {
80 return impl_->parameters();
81}
82
83template <typename Derived>
84std::multimap<std::string, std::reference_wrapper<Parameter>> Module<Derived>::named_parameters() {
85 return impl_->named_parameters();
86}
87
88template <typename Derived>
89std::vector<Parameter> Module<Derived>::parameters() const {
90 return impl_->parameters();
91}
92
93template <typename Derived>
94std::multimap<std::string, Parameter> Module<Derived>::named_parameters() const {
95 return impl_->named_parameters();
96}
97
98template <typename Derived>
99void Module<Derived>::eval() {
100 impl_->eval();
101}
102
103template <typename Derived>
104void Module<Derived>::train() {
105 impl_->train();
106}
107
108} // namespace lmp::nets
109
110#define LMP_DEFINE_MODULE_IMPL(module, impl) \
111 struct module : public Module<impl> { \
112 using Module<impl>::Module; \
113 };
114#define LMP_DEFINE_MODULE(module) LMP_DEFINE_MODULE_IMPL(module, module##Impl)
Definition module.hpp:19
Definition module.hpp:39
Definition parameter.hpp:7