lamppp
Loading...
Searching...
No Matches
variable.hpp
1#pragma once
2
3#include <memory>
4#include <unordered_set>
5#include <utility>
6#include <vector>
7#include "lamppp/tensor/core.hpp"
8#include "lamppp/tensor/data_type.hpp"
9#include "lamppp/tensor/tensor.hpp"
10#include "lamppp/tensor/utils/fill_like.hpp"
11
12namespace lmp::autograd {
13
14class Function;
15class Variable;
16
18 tensor::Tensor data;
19 tensor::Tensor grad;
20 std::shared_ptr<Function> _grad_fn;
21 bool requires_grad;
22
23 explicit VariableImpl(const tensor::Tensor& data, bool requires_grad = false)
24 : data(tensor::Tensor(data)),
25 grad(requires_grad ? zeros_like(data) : tensor::Tensor()),
26 requires_grad(requires_grad),
27 _grad_fn(nullptr) {}
28 explicit VariableImpl(const tensor::Tensor& data, const tensor::Tensor& grad,
29 bool requires_grad,
30 const std::shared_ptr<Function>& grad_fn)
31 : data(tensor::Tensor(data)),
32 grad(tensor::Tensor(grad)),
33 requires_grad(requires_grad),
34 _grad_fn(std::move(grad_fn)) {}
35};
36
37class Variable {
38 public:
39 Variable() = default;
40 explicit Variable(const tensor::Tensor& data, bool requires_grad = false)
41 : impl_(std::make_shared<VariableImpl>(data, requires_grad)) {}
42
43 const tensor::Tensor& grad() const;
44 const tensor::Tensor& data() const noexcept;
45 std::weak_ptr<Function> grad_fn() const noexcept;
46 bool requires_grad() const noexcept;
47
48 void zero_grad();
49 void incr_grad(const tensor::Tensor& other_grad);
50 void set_grad_fn(std::shared_ptr<Function> grad_fn);
51
52 void copy(const Variable& other);
53 void fill(tensor::Scalar item);
54
55 void backward();
56 friend std::ostream& operator<<(std::ostream& os, const Variable& obj);
57
58 private:
59 explicit Variable(std::shared_ptr<VariableImpl> impl)
60 : impl_(std::move(impl)) {}
61 std::shared_ptr<VariableImpl> impl_;
62 std::vector<Variable> topological_sort();
63 void dfs(const Variable& v, std::unordered_set<void*>& visited,
64 std::vector<Variable>& topo) const;
65};
66
67using variable_list = std::vector<Variable>;
68
70 template <typename Op, typename... Args>
71 static variable_list apply(const variable_list& variables, Args&&... args) {
72 Op op_fn(std::forward<Args>(args)...);
73 variable_list result =
74 op_fn.template apply<Args...>(variables, std::forward<Args>(args)...);
75 return result;
76 }
77};
78
79} // namespace lmp::autograd
Definition variable.hpp:37
Main tensor object for Lamppp.
Definition tensor.hpp:29
Definition variable.hpp:17
Definition variable.hpp:69