lamppp
Loading...
Searching...
No Matches
unary_decl.hpp
1#pragma once
2
3#include "lamppp/autograd/functions/reduct_ops.hpp"
4#include "lamppp/autograd/variable.hpp"
5#include "lamppp/common/assert.hpp"
6
7#define LMP_AUTOGRAD_FN_UNARY_DECL(args) \
8 LMP_AUTOGRAD_FN_UNARY_DECL_HELPER args
9
10#define LMP_AUTOGRAD_FN_UNARY_DECL_HELPER(grad_fn, self_grad) \
11variable_list grad_fn::apply(const variable_list& gradOutputs) { \
12 LMP_INTERNAL_ASSERT(gradOutputs.size() == 1) << "Output size mismatch."; \
13 const Variable& grad = gradOutputs[0]; \
14 Variable& self = (*saved_inputs)[0]; \
15 \
16 self.incr_grad(self_grad); \
17 return {}; \
18}
19
20
21#define LMP_AUTOGRAD_FFN_UNARY_DECL(args) \
22 LMP_AUTOGRAD_FFN_UNARY_DECL_HELPER args
23
24#define LMP_AUTOGRAD_FFN_UNARY_DECL_HELPER(grad_fn, ten_fn, ...) \
25tensor::Tensor grad_fn::execute(const variable_list& inputs) const { \
26 LMP_INTERNAL_ASSERT(inputs.size() == 1) << "Function must take one input"; \
27 const Variable& self = inputs[0]; \
28 return ten_fn(self.data() __VA_OPT__(, __VA_ARGS__)); \
29}