lamppp
Loading...
Searching...
No Matches
matrix_ops.hpp
1#pragma once
2
3#include "lamppp/autograd/forward_function.hpp"
4#include "lamppp/autograd/function.hpp"
5
6namespace lmp::autograd::ops {
7
10 variable_list apply(const variable_list& gradOutputs) override;
11};
12
13struct TransposeBackward : public Function {
14 variable_list apply(const variable_list& gradOutputs) override;
15};
16
17struct MatrixMultiplication : public ForwardFunction<MatrixMultiplication> {
19 tensor::Tensor execute(const variable_list& inputs);
20};
21
22struct Transpose : public ForwardFunction<Transpose> {
24 tensor::Tensor execute(const variable_list& inputs);
25};
27
34inline Variable matmul(const Variable& a, const Variable& b) {
35 return VariableOpFact::apply<MatrixMultiplication>({a, b})[0];
36}
37
44inline Variable transpose(const Variable& a) {
45 return VariableOpFact::apply<Transpose>({a})[0];
46}
47
48} // namespace lmp::autograd::ops
Definition variable.hpp:48
Main tensor object for Lamppp.
Definition tensor.hpp:29
Definition forward_function.hpp:11
Definition function.hpp:12
Definition matrix_ops.hpp:17
Definition matrix_ops.hpp:13
Definition matrix_ops.hpp:22