lamppp
Loading...
Searching...
No Matches
unary_ops.hpp
1#pragma once
2
3#include "lamppp/autograd/forward_function.hpp"
4#include "lamppp/autograd/function.hpp"
5
6namespace lmp::autograd::ops {
7
9struct NegationBackward : public Function {
10 variable_list apply(const variable_list& gradOutputs) override;
11};
12struct Negation : public ForwardFunction<Negation> {
14 tensor::Tensor execute(const variable_list& inputs) const;
15};
16
18 variable_list apply(const variable_list& gradOutputs) override;
19};
20struct Exponential : public ForwardFunction<Exponential> {
22 tensor::Tensor execute(const variable_list& inputs) const;
23};
24
25struct LogarithmBackward : public Function {
26 variable_list apply(const variable_list& gradOutputs) override;
27};
28struct Logarithm : public ForwardFunction<Logarithm> {
30 tensor::Tensor execute(const variable_list& inputs) const;
31};
32
34 variable_list apply(const variable_list& gradOutputs) override;
35};
36struct SquareRoot : public ForwardFunction<SquareRoot> {
38 tensor::Tensor execute(const variable_list& inputs) const;
39};
40
42 variable_list apply(const variable_list& gradOutputs) override;
43};
44struct AbsoluteValue : public ForwardFunction<AbsoluteValue> {
46 tensor::Tensor execute(const variable_list& inputs) const;
47};
48
49struct SineBackward : public Function {
50 variable_list apply(const variable_list& gradOutputs) override;
51};
52struct Sine : public ForwardFunction<Sine> {
54 tensor::Tensor execute(const variable_list& inputs) const;
55};
56
57struct CosineBackward : public Function {
58 variable_list apply(const variable_list& gradOutputs) override;
59};
60struct Cosine : public ForwardFunction<Cosine> {
62 tensor::Tensor execute(const variable_list& inputs) const;
63};
64
65struct TangentBackward : public Function {
66 variable_list apply(const variable_list& gradOutputs) override;
67};
68struct Tangent : public ForwardFunction<Tangent> {
70 tensor::Tensor execute(const variable_list& inputs) const;
71};
72
73struct ClampBackward : public Function {
74 tensor::Scalar min_val_, max_val_;
75 explicit ClampBackward(tensor::Scalar min_val, tensor::Scalar max_val)
76 : min_val_(min_val), max_val_(max_val) {}
77 variable_list apply(const variable_list& gradOutputs) override;
78};
79struct Clamp : public ForwardFunction<Clamp> {
81 tensor::Scalar min_val_, max_val_;
82 explicit Clamp(tensor::Scalar min_val, tensor::Scalar max_val)
83 : min_val_(min_val), max_val_(max_val) {}
84 tensor::Tensor execute(const variable_list& inputs) const;
85};
87
93inline Variable neg(const Variable& a) {
94 return VariableOpFact::apply<Negation>({a})[0];
95}
96
102inline Variable exp(const Variable& a) {
103 return VariableOpFact::apply<Exponential>({a})[0];
104}
105
111inline Variable log(const Variable& a) {
112 return VariableOpFact::apply<Logarithm>({a})[0];
113}
114
120inline Variable sqrt(const Variable& a) {
121 return VariableOpFact::apply<SquareRoot>({a})[0];
122}
123
129inline Variable abs(const Variable& a) {
130 return VariableOpFact::apply<AbsoluteValue>({a})[0];
131}
132
138inline Variable sin(const Variable& a) {
139 return VariableOpFact::apply<Sine>({a})[0];
140}
141
147inline Variable cos(const Variable& a) {
148 return VariableOpFact::apply<Cosine>({a})[0];
149}
150
156inline Variable tan(const Variable& a) {
157 return VariableOpFact::apply<Tangent>({a})[0];
158}
159
167inline Variable clamp(const Variable& a, tensor::Scalar min_val,
168 tensor::Scalar max_val) {
169 return VariableOpFact::apply<Clamp>({a}, min_val, max_val)[0];
170}
171
172} // namespace lmp::autograd::ops
Definition variable.hpp:48
Main tensor object for Lamppp.
Definition tensor.hpp:29
Definition forward_function.hpp:11
Definition function.hpp:12
Definition unary_ops.hpp:44
Definition unary_ops.hpp:73
Definition unary_ops.hpp:79
Definition unary_ops.hpp:57
Definition unary_ops.hpp:60
Definition unary_ops.hpp:17
Definition unary_ops.hpp:20
Definition unary_ops.hpp:25
Definition unary_ops.hpp:28
Definition unary_ops.hpp:9
Definition unary_ops.hpp:12
Definition unary_ops.hpp:49
Definition unary_ops.hpp:52
Definition unary_ops.hpp:33
Definition unary_ops.hpp:36
Definition unary_ops.hpp:65
Definition unary_ops.hpp:68