lamppp
Loading...
Searching...
No Matches
overloads.hpp
1#pragma once
2
3#include "lamppp/autograd/functions/expand_ops.hpp"
4#include "lamppp/autograd/functions/unary_ops.hpp"
5#include "lamppp/autograd/variable.hpp"
6#include "lamppp/tensor/fill_like.hpp"
7#include "lamppp/tensor/scalar.hpp"
8#include "lamppp/tensor/tensor.hpp"
9
10namespace lmp::autograd {
11
12template <Variable (*OpTag)(const Variable&, const Variable&)>
13inline Variable binary_op(const Variable& a, const Variable& b) {
14 return (*OpTag)(a, b);
15}
16template <Variable (*OpTag)(const Variable&, const Variable&)>
17inline Variable binary_op(const Variable& v, tensor::Scalar s) {
18 tensor::Tensor scalar_tensor(std::vector<tensor::Scalar>(1, s), {1}, v.data().device(),
19 v.data().type()); // rely on broadcasting
20 return binary_op<OpTag>(v, Variable(scalar_tensor));
21}
22template <Variable (*OpTag)(const Variable&, const Variable&)>
23inline Variable binary_op(tensor::Scalar s, const Variable& v) {
24 tensor::Tensor scalar_tensor(std::vector<tensor::Scalar>(1, s), {1}, v.data().device(),
25 v.data().type()); // rely on broadcasting
26 return binary_op<OpTag>(Variable(scalar_tensor), v);
27}
28
29#define DECL_BINARY_OP(op, tag) \
30 inline Variable operator op(const Variable& a, const Variable& b) { \
31 return binary_op<&tag>(a, b); \
32 } \
33 inline Variable operator op(const Variable& v, tensor::Scalar s) { \
34 return binary_op<&tag>(v, s); \
35 } \
36 inline Variable operator op(tensor::Scalar s, const Variable& v) { \
37 return binary_op<&tag>(s, v); \
38 }
39
40#define FORALL_BINARY_OPS(_) \
41 _(+, ops::add) \
42 _(-, ops::sub) \
43 _(*, ops::mul) \
44 _(/, ops::div) \
45 _(==, ops::eq) \
46 _(!=, ops::ne) \
47 _(>=, ops::ge) \
48 _(<=, ops::le) \
49 _(>, ops::gt) \
50 _(<, ops::lt)
51
52FORALL_BINARY_OPS(DECL_BINARY_OP)
53
54#undef FORALL_BINARY_OPS
55#undef DECL_BINARY_OP
56
57inline Variable operator-(const Variable& a) { return ops::neg(a); }
58
59} // namespace lmp::autograd