21 std::vector<Parameter> parameters()
const;
22 std::multimap<std::string, Parameter> named_parameters()
const;
24 std::vector<std::reference_wrapper<Parameter>> parameters();
25 std::multimap<std::string, std::reference_wrapper<Parameter>> named_parameters();
30 void register_parameter(
const std::string& name,
Parameter& param);
31 void register_module(
const std::string& name, std::shared_ptr<ModuleImpl> module);
33 bool trainable_ =
true;
34 std::unordered_map<std::string, std::shared_ptr<ModuleImpl>> modules_;
35 std::unordered_map<std::string, std::reference_wrapper<Parameter>> params_;
41 template <
typename... Args>
42 explicit Module(Args&&... args)
43 : impl_(std::make_shared<Derived>(std::forward<Args>(args)...)) {}
45 std::vector<Parameter> parameters()
const;
46 std::multimap<std::string, Parameter> named_parameters()
const;
48 std::vector<std::reference_wrapper<Parameter>> parameters();
49 std::multimap<std::string, std::reference_wrapper<Parameter>> named_parameters();
53 template <
typename... Args>
54 auto operator()(Args&&... args)
55 -> std::invoke_result_t<
decltype(&Derived::forward), Derived, Args...> {
56 return static_cast<Derived*
>(impl_.get())
57 ->forward(std::forward<Args>(args)...);
61 std::shared_ptr<Derived> impl_;