lamppp
Loading...
Searching...
No Matches
view_decl.hpp
1#pragma once
2
3#include "lamppp/autograd/functions/view_ops.hpp"
4#include "lamppp/autograd/variable.hpp"
5#include "lamppp/common/assert.hpp"
6
7#define LMP_AUTOGRAD_FN_VIEW_DECL(args) \
8 LMP_AUTOGRAD_FN_VIEW_DECL_HELPER args
9
10#define LMP_AUTOGRAD_FN_VIEW_DECL_HELPER(grad_fn, grad_expr) \
11variable_list grad_fn::apply(const variable_list& gradOutputs) { \
12 LMP_INTERNAL_ASSERT(gradOutputs.size() == 1) << "Output size mismatch."; \
13 const Variable& grad = gradOutputs[0]; \
14 Variable& self = (*saved_inputs)[0]; \
15 \
16 self.incr_grad(grad_expr); \
17 \
18 variable_list grad_inputs = {}; \
19 return grad_inputs; \
20}
21
22#define LMP_AUTOGRAD_FFN_VIEW_DECL(args) \
23 LMP_AUTOGRAD_FFN_VIEW_DECL_HELPER args
24
25#define LMP_AUTOGRAD_FFN_VIEW_DECL_HELPER(fn_name, ten_fn, ...) \
26tensor::Tensor fn_name::execute(const variable_list& inputs) { \
27 LMP_INTERNAL_ASSERT(inputs.size() == 1) << "Function must take one input"; \
28 const Variable& self = inputs[0]; \
29 \
30 return self.data().ten_fn(__VA_ARGS__); \
31}