lamppp
Loading...
Searching...
No Matches
overloads.hpp
1#pragma once
2
3#include "lamppp/autograd/functions/expand_ops.hpp"
4#include "lamppp/autograd/functions/unary_ops.hpp"
5#include "lamppp/autograd/variable.hpp"
6#include "lamppp/tensor/data_type.hpp"
7#include "lamppp/tensor/tensor.hpp"
8#include "lamppp/tensor/utils/fill_like.hpp"
9
10namespace lmp::autograd {
11
12template <Variable (*OpTag)(const Variable&, const Variable&)>
13inline Variable binary_op(const Variable& a, const Variable& b) {
14 return (*OpTag)(a, b);
15}
16template <Variable (*OpTag)(const Variable&, const Variable&)>
17inline Variable binary_op(const Variable& v, tensor::Scalar s) {
18 tensor::Tensor scalar_tensor(std::vector<tensor::Scalar>(1, s), {1},
19 v.data().device(),
20 v.data().type()); // rely on broadcasting
21 return binary_op<OpTag>(v, Variable(scalar_tensor));
22}
23template <Variable (*OpTag)(const Variable&, const Variable&)>
24inline Variable binary_op(tensor::Scalar s, const Variable& v) {
25 tensor::Tensor scalar_tensor(std::vector<tensor::Scalar>(1, s), {1},
26 v.data().device(),
27 v.data().type()); // rely on broadcasting
28 return binary_op<OpTag>(Variable(scalar_tensor), v);
29}
30
31#define DECL_BINARY_OP(op, tag) \
32 inline Variable operator op(const Variable& a, const Variable& b) { \
33 return binary_op<&(tag)>(a, b); \
34 } \
35 inline Variable operator op(const Variable& v, tensor::Scalar s) { \
36 return binary_op<&(tag)>(v, s); \
37 } \
38 inline Variable operator op(tensor::Scalar s, const Variable& v) { \
39 return binary_op<&(tag)>(s, v); \
40 }
41
42#define FORALL_BINARY_OPS(_) \
43 _(+, ops::add) \
44 _(-, ops::sub) \
45 _(*, ops::mul) \
46 _(/, ops::div) \
47 _(==, ops::eq) \
48 _(!=, ops::ne) \
49 _(>=, ops::ge) \
50 _(<=, ops::le) \
51 _(>, ops::gt) \
52 _(<, ops::lt)
53
54FORALL_BINARY_OPS(DECL_BINARY_OP)
55
56#undef FORALL_BINARY_OPS
57#undef DECL_BINARY_OP
58
59inline Variable operator-(const Variable& a) {
60 return ops::neg(a);
61}
62
63} // namespace lmp::autograd