lamppp
Loading...
Searching...
No Matches
forward_function.hpp
1#pragma once
2
3#include <numeric>
4#include "function.hpp"
5#include "variable.hpp"
6
7namespace lmp::autograd {
8
10template <typename Derived>
11struct ForwardFunction : public Function {
12 static bool requires_grad(const variable_list& variables) {
13 return std::accumulate(variables.begin(), variables.end(), false,
14 [](bool accumulated, const Variable& b) {
15 return accumulated || b.requires_grad();
16 });
17 }
18
19 variable_list apply(const variable_list& inputs) override {
20 throw std::runtime_error(
21 "Forward function should not be called without template args");
22 return {};
23 }
24
25 template <typename... Args>
26 variable_list apply(const variable_list& inputs, Args&&... args) {
27 bool requires_grad_ = requires_grad(inputs);
28 Variable result =
29 Variable(static_cast<Derived*>(this)->execute(inputs), requires_grad_);
30 if (requires_grad_) {
31 auto backward_fn = std::make_shared<typename Derived::DefaultBackward>(
32 std::forward<Args>(args)...);
33 backward_fn->saved_inputs = std::make_unique<variable_list>(inputs);
34 result.set_grad_fn(backward_fn);
35 }
36
37 return {result};
38 }
39};
41
42} // namespace lmp::autograd
Definition variable.hpp:48
Definition forward_function.hpp:11
Definition function.hpp:12