lamppp
Loading...
Searching...
No Matches
module.hpp
1#pragma once
2
3#include <memory>
4#include <string>
5#include <unordered_map>
6#include <vector>
7#include "lamppp/common/assert.hpp"
8#include "parameter.hpp"
9
10class Module {
11public:
12 Module() = default;
13
14private:
15 class ModuleImpl;
16 std::unique_ptr<ModuleImpl> impl_;
17};
18
20public:
21 std::vector<Parameter> parameters();
22 void eval();
23 void train();
24
25 template<typename Ret, typename... Args>
26 Ret forward(Args&&... /*args*/) {
27 LMP_INTERNAL_ASSERT(false) << "Not Implemented";
28 }
29
30 template<typename Ret, typename... Args>
31 Ret operator()(Args&&... args) {
32 return forward<Ret>(std::forward<Args>(args)...);
33 }
34
35protected:
36 template<typename T>
37 T& register_parameter(const std::string& name, T&& param) {
38 params_[name] = std::forward<T>(param);
39 return params_[name];
40 }
41
42 template<typename T>
43 T& register_module(const std::string& name, T&& module) {
44 modules_[name] = std::make_unique<T>(std::forward<T>(module));
45 return *static_cast<T*>(modules_[name].get());
46 }
47
48private:
49 bool trainable_ = true;
50 std::unordered_map<std::string, std::unique_ptr<ModuleImpl>> modules_;
51 std::unordered_map<std::string, Parameter> params_;
52};
Definition module.hpp:19
Definition module.hpp:10