lamppp
Loading...
Searching...
No Matches
expand_ops.hpp
1#pragma once
2
3#include "lamppp/autograd/forward_function.hpp"
4#include "lamppp/autograd/function.hpp"
5
6namespace lmp::autograd::ops {
7
9struct AddBackward : public Function {
10 variable_list apply(const variable_list& gradOutputs) override;
11};
12
13struct SubtractBackward : public Function {
14 variable_list apply(const variable_list& gradOutputs) override;
15};
16
17struct MultiplyBackward : public Function {
18 variable_list apply(const variable_list& gradOutputs) override;
19};
20
21struct DivideBackward : public Function {
22 variable_list apply(const variable_list& gradOutputs) override;
23};
24
25struct PowerBackward : public Function {
26 variable_list apply(const variable_list& gradOutputs) override;
27};
28
29struct Add : public ForwardFunction<Add> {
31 tensor::Tensor execute(const variable_list& inputs) const;
32};
33
34struct Subtract : public ForwardFunction<Subtract> {
36 tensor::Tensor execute(const variable_list& inputs) const;
37};
38
39struct Multiply : public ForwardFunction<Multiply> {
41 tensor::Tensor execute(const variable_list& inputs) const;
42};
43
44struct Divide : public ForwardFunction<Divide> {
46 tensor::Tensor execute(const variable_list& inputs) const;
47};
48
49struct Power : public ForwardFunction<Power> {
51 tensor::Tensor execute(const variable_list& inputs) const;
52};
53
54struct EqualBackward : public Function {
55 variable_list apply(const variable_list& gradOutputs) override;
56};
57
58struct LessBackward : public Function {
59 variable_list apply(const variable_list& gradOutputs) override;
60};
61
62struct LessEqualBackward : public Function {
63 variable_list apply(const variable_list& gradOutputs) override;
64};
65
66struct NotEqualBackward : public Function {
67 variable_list apply(const variable_list& gradOutputs) override;
68};
69
70struct GreaterBackward : public Function {
71 variable_list apply(const variable_list& gradOutputs) override;
72};
73
75 variable_list apply(const variable_list& gradOutputs) override;
76};
77
78struct Equal : public ForwardFunction<Equal> {
80 tensor::Tensor execute(const variable_list& inputs) const;
81};
82
83struct Less : public ForwardFunction<Less> {
85 tensor::Tensor execute(const variable_list& inputs) const;
86};
87
88struct LessEqual : public ForwardFunction<LessEqual> {
90 tensor::Tensor execute(const variable_list& inputs) const;
91};
92
93struct NotEqual : public ForwardFunction<NotEqual> {
95 tensor::Tensor execute(const variable_list& inputs) const;
96};
97
98struct Greater : public ForwardFunction<Greater> {
100 tensor::Tensor execute(const variable_list& inputs) const;
101};
102
103struct GreaterEqual : public ForwardFunction<GreaterEqual> {
105 tensor::Tensor execute(const variable_list& inputs) const;
106};
108
115inline Variable add(const Variable& a, const Variable& b) {
116 return VariableOpFact::apply<Add>({a, b})[0];
117}
118
125inline Variable sub(const Variable& a, const Variable& b) {
126 return VariableOpFact::apply<Subtract>({a, b})[0];
127}
128
135inline Variable mul(const Variable& a, const Variable& b) {
136 return VariableOpFact::apply<Multiply>({a, b})[0];
137}
138
145inline Variable div(const Variable& a, const Variable& b) {
146 return VariableOpFact::apply<Divide>({a, b})[0];
147}
148
155inline Variable pow(const Variable& a, const Variable& b) {
156 return VariableOpFact::apply<Power>({a, b})[0];
157}
158
165inline Variable eq(const Variable& a, const Variable& b) {
166 return VariableOpFact::apply<Equal>({a, b})[0];
167}
168
175inline Variable ne(const Variable& a, const Variable& b) {
176 return VariableOpFact::apply<NotEqual>({a, b})[0];
177}
178
185inline Variable ge(const Variable& a, const Variable& b) {
186 return VariableOpFact::apply<GreaterEqual>({a, b})[0];
187}
188
195inline Variable le(const Variable& a, const Variable& b) {
196 return VariableOpFact::apply<LessEqual>({a, b})[0];
197}
198
205inline Variable gt(const Variable& a, const Variable& b) {
206 return VariableOpFact::apply<Greater>({a, b})[0];
207}
208
215inline Variable lt(const Variable& a, const Variable& b) {
216 return VariableOpFact::apply<Less>({a, b})[0];
217}
218
219} // namespace lmp::autograd::ops
Definition variable.hpp:48
Main tensor object for Lamppp.
Definition tensor.hpp:29
Definition forward_function.hpp:11
Definition function.hpp:12
Definition expand_ops.hpp:9
Definition expand_ops.hpp:29
Definition expand_ops.hpp:21
Definition expand_ops.hpp:44
Definition expand_ops.hpp:54
Definition expand_ops.hpp:78
Definition expand_ops.hpp:70
Definition expand_ops.hpp:74
Definition expand_ops.hpp:103
Definition expand_ops.hpp:98
Definition expand_ops.hpp:58
Definition expand_ops.hpp:62
Definition expand_ops.hpp:88
Definition expand_ops.hpp:83
Definition expand_ops.hpp:17
Definition expand_ops.hpp:39
Definition expand_ops.hpp:66
Definition expand_ops.hpp:93
Definition expand_ops.hpp:25
Definition expand_ops.hpp:49
Definition expand_ops.hpp:13
Definition expand_ops.hpp:34