lamppp
Loading...
Searching...
No Matches
binary_decl.hpp
1#pragma once
2
3#include "lamppp/autograd/functions/reduct_ops.hpp"
4#include "lamppp/autograd/variable.hpp"
5#include "lamppp/common/assert.hpp"
6
7#define LMP_AUTOGRAD_FN_BINARY_DECL(args) \
8 LMP_AUTOGRAD_FN_BINARY_DECL_HELPER args
9
10#define LMP_AUTOGRAD_FN_BINARY_DECL_HELPER(grad_fn, self_grad, other_grad) \
11variable_list grad_fn::apply(const variable_list& gradOutputs) { \
12 LMP_INTERNAL_ASSERT(gradOutputs.size() == 1) << "Output size mismatch."; \
13 const Variable& grad = gradOutputs[0]; \
14 Variable& self = (*saved_inputs)[0]; \
15 Variable& other = (*saved_inputs)[1]; \
16 \
17 self.incr_grad(self_grad); \
18 other.incr_grad(other_grad); \
19 return {}; \
20}
21
22
23#define LMP_AUTOGRAD_FFN_BINARY_DECL(args) \
24 LMP_AUTOGRAD_FFN_BINARY_DECL_HELPER args
25
26#define LMP_AUTOGRAD_FFN_BINARY_DECL_HELPER(grad_fn, ten_fn, ...) \
27tensor::Tensor grad_fn::execute(const variable_list& inputs) const { \
28 LMP_INTERNAL_ASSERT(inputs.size() == 2) << "Function must take one input"; \
29 const Variable& self = inputs[0]; \
30 const Variable& other = inputs[1]; \
31 \
32 return ten_fn(self.data(), other.data() __VA_OPT__(, __VA_ARGS__)); \
33}