lamppp
Loading...
Searching...
No Matches
variable.hpp
1#pragma once
2
3#include <memory>
4#include <unordered_set>
5#include <utility>
6#include <vector>
7#include "lamppp/tensor/core.hpp"
8#include "lamppp/tensor/data_type.hpp"
9#include "lamppp/tensor/fill_like.hpp"
10#include "lamppp/tensor/tensor.hpp"
11
12namespace lmp::autograd {
13
14class Function;
15class Variable;
16
30 tensor::Tensor data;
31 tensor::Tensor grad;
32 std::shared_ptr<Function> _grad_fn;
33 bool requires_grad;
34
35 explicit VariableImpl(const tensor::Tensor& data, bool requires_grad = false)
36 : data(tensor::Tensor(data)),
37 grad(requires_grad ? zeros_like(data) : tensor::Tensor()),
38 requires_grad(requires_grad),
39 _grad_fn(nullptr) {}
40 explicit VariableImpl(const tensor::Tensor& data, const tensor::Tensor& grad,
41 bool requires_grad, std::shared_ptr<Function> grad_fn)
42 : data(tensor::Tensor(data)),
43 grad(tensor::Tensor(grad)),
44 requires_grad(requires_grad),
45 _grad_fn(std::move(grad_fn)) {}
46};
47
48class Variable {
49 public:
50 Variable() = default;
51 explicit Variable(const tensor::Tensor& data, bool requires_grad = false)
52 : impl_(std::make_shared<VariableImpl>(data, requires_grad)) {}
53
54 const tensor::Tensor& grad() const noexcept;
55 const tensor::Tensor& data() const noexcept;
56 const std::shared_ptr<Function>& grad_fn() const noexcept;
57 bool requires_grad() const noexcept;
58
59 void zero_grad();
60 void incr_grad(const tensor::Tensor& other_grad);
61 void set_grad_fn(std::shared_ptr<Function> grad_fn);
62
63 void copy(const Variable& other);
64 void fill(tensor::Scalar item);
65
66 void backward();
67 friend std::ostream& operator<<(std::ostream& os, const Variable& obj);
68
69 private:
70 explicit Variable(std::shared_ptr<VariableImpl> impl)
71 : impl_(std::move(impl)) {}
72 std::shared_ptr<VariableImpl> impl_;
73 std::vector<Variable> topological_sort();
74 void dfs(const Variable& v, std::unordered_set<void*>& visited,
75 std::vector<Variable>& topo) const;
76};
77
78using variable_list = std::vector<Variable>;
79
81 template <typename Op, typename... Args>
82 static variable_list apply(variable_list variables, Args&&... args) {
83 Op op_fn(std::forward<Args>(args)...);
84 variable_list result =
85 op_fn.template apply<Args...>(variables, std::forward<Args>(args)...);
86 return result;
87 }
88};
89
90} // namespace lmp::autograd
Definition variable.hpp:48
Main tensor object for Lamppp.
Definition tensor.hpp:29
Main autograd object for Lamppp.
Definition variable.hpp:29
Definition variable.hpp:80