lamppp
Loading...
Searching...
No Matches
reduct_ops.hpp
1#pragma once
2
3#include "lamppp/autograd/forward_function.hpp"
4#include "lamppp/autograd/function.hpp"
5
6namespace lmp::autograd::ops {
7
9struct SummationBackward : public Function {
10 size_t axis_;
11 explicit SummationBackward(size_t axis) : axis_(axis) {}
12 variable_list apply(const variable_list& gradOutputs) override;
13};
14
15struct MaximumBackward : public Function {
16 size_t axis_;
17 explicit MaximumBackward(size_t axis) : axis_(axis) {}
18 variable_list apply(const variable_list& gradOutputs) override;
19};
20
21struct MinimumBackward : public Function {
22 size_t axis_;
23 explicit MinimumBackward(size_t axis) : axis_(axis) {}
24 variable_list apply(const variable_list& gradOutputs) override;
25};
26
27struct ProductBackward : public Function {
28 size_t axis_;
29 explicit ProductBackward(size_t axis) : axis_(axis) {}
30 variable_list apply(const variable_list& gradOutputs) override;
31};
32
33struct Summation : public ForwardFunction<Summation> {
35 size_t axis_;
36 explicit Summation(size_t axis) : axis_(axis) {}
37 tensor::Tensor execute(const variable_list& inputs) const;
38};
39
40struct Maximum : public ForwardFunction<Maximum> {
42 size_t axis_;
43 explicit Maximum(size_t axis) : axis_(axis) {}
44 tensor::Tensor execute(const variable_list& inputs) const;
45};
46
47struct Minimum : public ForwardFunction<Minimum> {
49 size_t axis_;
50 explicit Minimum(size_t axis) : axis_(axis) {}
51 tensor::Tensor execute(const variable_list& inputs) const;
52};
53
54struct Product : public ForwardFunction<Product> {
56 size_t axis_;
57 explicit Product(size_t axis) : axis_(axis) {}
58 tensor::Tensor execute(const variable_list& inputs) const;
59};
61
68inline Variable sum(const Variable& a, size_t axis) {
69 return VariableOpFact::apply<Summation>({a}, axis)[0];
70}
71
78inline Variable max(const Variable& a, size_t axis) {
79 return VariableOpFact::apply<Maximum>({a}, axis)[0];
80}
81
88inline Variable min(const Variable& a, size_t axis) {
89 return VariableOpFact::apply<Minimum>({a}, axis)[0];
90}
91
98inline Variable prod(const Variable& a, size_t axis) {
99 return VariableOpFact::apply<Product>({a}, axis)[0];
100}
101
102} // namespace lmp::autograd::ops
Definition variable.hpp:48
Main tensor object for Lamppp.
Definition tensor.hpp:29
Definition forward_function.hpp:11
Definition function.hpp:12
Definition reduct_ops.hpp:15
Definition reduct_ops.hpp:40
Definition reduct_ops.hpp:21
Definition reduct_ops.hpp:47
Definition reduct_ops.hpp:27
Definition reduct_ops.hpp:54
Definition reduct_ops.hpp:9
Definition reduct_ops.hpp:33