3#include "lamppp/autograd/functions/expand_ops.hpp"
4#include "lamppp/autograd/functions/unary_ops.hpp"
5#include "lamppp/autograd/variable.hpp"
6#include "lamppp/tensor/data_type.hpp"
7#include "lamppp/tensor/tensor.hpp"
8#include "lamppp/tensor/utils/fill_like.hpp"
10namespace lmp::autograd {
12template <Variable (*OpTag)(const Variable&, const Variable&)>
13inline Variable binary_op(
const Variable& a,
const Variable& b) {
14 return (*OpTag)(a, b);
16template <Variable (*OpTag)(const Variable&, const Variable&)>
17inline Variable binary_op(
const Variable& v, tensor::Scalar s) {
18 tensor::Tensor scalar_tensor(std::vector<tensor::Scalar>(1, s), {1},
21 return binary_op<OpTag>(v, Variable(scalar_tensor));
23template <Variable (*OpTag)(const Variable&, const Variable&)>
24inline Variable binary_op(tensor::Scalar s,
const Variable& v) {
25 tensor::Tensor scalar_tensor(std::vector<tensor::Scalar>(1, s), {1},
28 return binary_op<OpTag>(Variable(scalar_tensor), v);
31#define DECL_BINARY_OP(op, tag) \
32 inline Variable operator op(const Variable& a, const Variable& b) { \
33 return binary_op<&(tag)>(a, b); \
35 inline Variable operator op(const Variable& v, tensor::Scalar s) { \
36 return binary_op<&(tag)>(v, s); \
38 inline Variable operator op(tensor::Scalar s, const Variable& v) { \
39 return binary_op<&(tag)>(s, v); \
42#define FORALL_BINARY_OPS(_) \
54FORALL_BINARY_OPS(DECL_BINARY_OP)
56#undef FORALL_BINARY_OPS
59inline Variable operator-(
const Variable& a) {