3#include "lamppp/autograd/forward_function.hpp"
4#include "lamppp/autograd/function.hpp"
5#include "lamppp/autograd/functions/overloads.hpp"
6#include "lamppp/tensor/device_type.hpp"
8namespace lmp::autograd::ops {
12 std::vector<size_t> shape;
14 : shape(std::move(shape)) {}
15 variable_list apply(
const variable_list& gradOutputs)
override;
19 std::vector<size_t> shape;
20 explicit Reshape(std::vector<size_t> shape) : shape(std::move(shape)) {}
27 variable_list apply(
const variable_list& gradOutputs)
override;
32 explicit Squeeze(
size_t axis) : axis(axis) {}
39 variable_list apply(
const variable_list& gradOutputs)
override;
44 explicit ExpandDims(
size_t axis) : axis(axis) {}
49 tensor::DeviceType device;
50 explicit ToBackward(tensor::DeviceType device) : device(device) {}
51 variable_list apply(
const variable_list& gradOutputs)
override;
55 tensor::DeviceType device;
56 explicit To(tensor::DeviceType device) : device(device) {}
67inline Variable reshape(
const Variable& a,
const std::vector<size_t>& shape) {
68 return VariableOpFact::apply<Reshape>({a}, shape)[0];
78 return VariableOpFact::apply<Squeeze>({a}, axis)[0];
87inline Variable expand_dims(
const Variable& a,
size_t axis) {
88 return VariableOpFact::apply<ExpandDims>({a}, axis)[0];
97inline Variable to(
const Variable& a, tensor::DeviceType device) {
98 return VariableOpFact::apply<To>({a}, device)[0];
Definition variable.hpp:48
Main tensor object for Lamppp.
Definition tensor.hpp:29
Definition forward_function.hpp:11
Definition function.hpp:12
Definition view_ops.hpp:36
Definition view_ops.hpp:41
Definition view_ops.hpp:11
Definition view_ops.hpp:17
Definition view_ops.hpp:24
Definition view_ops.hpp:29
Definition view_ops.hpp:48
Definition view_ops.hpp:53