3#include "lamppp/autograd/forward_function.hpp"
4#include "lamppp/autograd/function.hpp"
6namespace lmp::autograd::ops {
12 variable_list apply(
const variable_list& gradOutputs)
override;
18 variable_list apply(
const variable_list& gradOutputs)
override;
24 variable_list apply(
const variable_list& gradOutputs)
override;
30 variable_list apply(
const variable_list& gradOutputs)
override;
36 explicit Summation(
size_t axis) : axis_(axis) {}
43 explicit Maximum(
size_t axis) : axis_(axis) {}
50 explicit Minimum(
size_t axis) : axis_(axis) {}
57 explicit Product(
size_t axis) : axis_(axis) {}
69 return VariableOpFact::apply<Summation>({a}, axis)[0];
79 return VariableOpFact::apply<Maximum>({a}, axis)[0];
88inline Variable min(
const Variable& a,
size_t axis) {
89 return VariableOpFact::apply<Minimum>({a}, axis)[0];
98inline Variable prod(
const Variable& a,
size_t axis) {
99 return VariableOpFact::apply<Product>({a}, axis)[0];
Definition variable.hpp:48
Main tensor object for Lamppp.
Definition tensor.hpp:29
Definition forward_function.hpp:11
Definition function.hpp:12
Definition reduct_ops.hpp:15
Definition reduct_ops.hpp:40
Definition reduct_ops.hpp:21
Definition reduct_ops.hpp:47
Definition reduct_ops.hpp:27
Definition reduct_ops.hpp:54
Definition reduct_ops.hpp:9
Definition reduct_ops.hpp:33