32 std::shared_ptr<Function> _grad_fn;
38 requires_grad(requires_grad),
41 bool requires_grad, std::shared_ptr<Function> grad_fn)
44 requires_grad(requires_grad),
45 _grad_fn(std::move(grad_fn)) {}
52 : impl_(std::make_shared<VariableImpl>(data, requires_grad)) {}
56 const std::shared_ptr<Function>& grad_fn()
const noexcept;
57 bool requires_grad()
const noexcept;
61 void set_grad_fn(std::shared_ptr<Function> grad_fn);
64 void fill(tensor::Scalar item);
67 friend std::ostream& operator<<(std::ostream& os,
const Variable& obj);
70 explicit Variable(std::shared_ptr<VariableImpl> impl)
71 : impl_(std::move(impl)) {}
72 std::shared_ptr<VariableImpl> impl_;
73 std::vector<Variable> topological_sort();
74 void dfs(
const Variable& v, std::unordered_set<void*>& visited,
75 std::vector<Variable>& topo)
const;