3#include "lamppp/autograd/functions/expand_ops.hpp"
4#include "lamppp/autograd/functions/unary_ops.hpp"
5#include "lamppp/autograd/variable.hpp"
6#include "lamppp/tensor/fill_like.hpp"
7#include "lamppp/tensor/scalar.hpp"
8#include "lamppp/tensor/tensor.hpp"
10namespace lmp::autograd {
12template <Variable (*OpTag)(const Variable&, const Variable&)>
13inline Variable binary_op(
const Variable& a,
const Variable& b) {
14 return (*OpTag)(a, b);
16template <Variable (*OpTag)(const Variable&, const Variable&)>
17inline Variable binary_op(
const Variable& v, tensor::Scalar s) {
18 tensor::Tensor scalar_tensor(std::vector<tensor::Scalar>(1, s), {1}, v.data().device(),
20 return binary_op<OpTag>(v, Variable(scalar_tensor));
22template <Variable (*OpTag)(const Variable&, const Variable&)>
23inline Variable binary_op(tensor::Scalar s,
const Variable& v) {
24 tensor::Tensor scalar_tensor(std::vector<tensor::Scalar>(1, s), {1}, v.data().device(),
26 return binary_op<OpTag>(Variable(scalar_tensor), v);
29#define DECL_BINARY_OP(op, tag) \
30 inline Variable operator op(const Variable& a, const Variable& b) { \
31 return binary_op<&tag>(a, b); \
33 inline Variable operator op(const Variable& v, tensor::Scalar s) { \
34 return binary_op<&tag>(v, s); \
36 inline Variable operator op(tensor::Scalar s, const Variable& v) { \
37 return binary_op<&tag>(s, v); \
40#define FORALL_BINARY_OPS(_) \
52FORALL_BINARY_OPS(DECL_BINARY_OP)
54#undef FORALL_BINARY_OPS
57inline Variable operator-(
const Variable& a) {
return ops::neg(a); }