lamppp
Loading...
Searching...
No Matches
view_ops.hpp
1#pragma once
2
3#include "lamppp/autograd/forward_function.hpp"
4#include "lamppp/autograd/function.hpp"
5#include "lamppp/autograd/functions/overloads.hpp"
6#include "lamppp/tensor/device_type.hpp"
7
8namespace lmp::autograd::ops {
9
11struct ReshapeBackward : public Function {
12 std::vector<size_t> shape;
13 explicit ReshapeBackward(std::vector<size_t> shape)
14 : shape(std::move(shape)) {}
15 variable_list apply(const variable_list& gradOutputs) override;
16};
17struct Reshape : public ForwardFunction<Reshape> {
19 std::vector<size_t> shape;
20 explicit Reshape(std::vector<size_t> shape) : shape(std::move(shape)) {}
21 tensor::Tensor execute(const variable_list& inputs) const;
22};
23
24struct SqueezeBackward : public Function {
25 size_t axis;
26 explicit SqueezeBackward(size_t axis) : axis(axis) {}
27 variable_list apply(const variable_list& gradOutputs) override;
28};
29struct Squeeze : public ForwardFunction<Squeeze> {
31 size_t axis;
32 explicit Squeeze(size_t axis) : axis(axis) {}
33 tensor::Tensor execute(const variable_list& inputs) const;
34};
35
37 size_t axis;
38 explicit ExpandDimsBackward(size_t axis) : axis(axis) {}
39 variable_list apply(const variable_list& gradOutputs) override;
40};
41struct ExpandDims : public ForwardFunction<ExpandDims> {
43 size_t axis;
44 explicit ExpandDims(size_t axis) : axis(axis) {}
45 tensor::Tensor execute(const variable_list& inputs) const;
46};
47
48struct ToBackward : public Function {
49 tensor::DeviceType device;
50 explicit ToBackward(tensor::DeviceType device) : device(device) {}
51 variable_list apply(const variable_list& gradOutputs) override;
52};
53struct To : public ForwardFunction<To> {
55 tensor::DeviceType device;
56 explicit To(tensor::DeviceType device) : device(device) {}
57 tensor::Tensor execute(const variable_list& inputs) const;
58};
60
67inline Variable reshape(const Variable& a, const std::vector<size_t>& shape) {
68 return VariableOpFact::apply<Reshape>({a}, shape)[0];
69}
70
77inline Variable squeeze(const Variable& a, size_t axis) {
78 return VariableOpFact::apply<Squeeze>({a}, axis)[0];
79}
80
87inline Variable expand_dims(const Variable& a, size_t axis) {
88 return VariableOpFact::apply<ExpandDims>({a}, axis)[0];
89}
90
97inline Variable to(const Variable& a, tensor::DeviceType device) {
98 return VariableOpFact::apply<To>({a}, device)[0];
99}
100
101} // namespace lmp::autograd::ops
Definition variable.hpp:48
Main tensor object for Lamppp.
Definition tensor.hpp:29
Definition forward_function.hpp:11
Definition function.hpp:12
Definition view_ops.hpp:36
Definition view_ops.hpp:41
Definition view_ops.hpp:11
Definition view_ops.hpp:17
Definition view_ops.hpp:24
Definition view_ops.hpp:29
Definition view_ops.hpp:48
Definition view_ops.hpp:53