lamppp
Loading...
Searching...
No Matches
overloads.hpp
1#pragma once
2
3#include "lamppp/tensor/native/expand_ops.hpp"
4#include "lamppp/tensor/native/unary_ops.hpp"
5#include "lamppp/tensor/tensor.hpp"
6
7namespace lmp::tensor {
8
11 template <Tensor (*OpTag)(const Tensor&, const Tensor&)>
12 static inline Tensor binary_tensor_op(const Tensor& a, const Tensor& b) {
13 return (*OpTag)(a, b);
14 }
15
16 template <Tensor (*OpTag)(const Tensor&, const Tensor&)>
17 static inline Tensor binary_tensor_op(const Tensor& tensor, Scalar scalar) {
18 Tensor scalar_tensor(std::vector<Scalar>(1, scalar), {1}, tensor.device(),
19 tensor.type()); // rely on broadcasting
20 return binary_tensor_op<OpTag>(tensor, scalar_tensor);
21 }
22
23 template <Tensor (*OpTag)(const Tensor&, const Tensor&)>
24 static inline Tensor binary_tensor_op(Scalar scalar, const Tensor& tensor) {
25 Tensor scalar_tensor(std::vector<Scalar>(1, scalar), {1}, tensor.device(),
26 tensor.type());
27 return binary_tensor_op<OpTag>(scalar_tensor, tensor);
28 }
29};
30
31#define DECL_BINARY_OP(op, tag) \
32 inline Tensor operator op(const Tensor& a, const Tensor& b) { \
33 return TensorOpFact::binary_tensor_op<&tag>(a, b); \
34 } \
35 inline Tensor operator op(const Tensor& tensor, Scalar scalar) { \
36 return TensorOpFact::binary_tensor_op<&tag>(tensor, scalar); \
37 } \
38 inline Tensor operator op(Scalar scalar, const Tensor& tensor) { \
39 return TensorOpFact::binary_tensor_op<&tag>(scalar, tensor); \
40 }
41
42#define FORALL_BINARY_OPS(_) \
43 _(+, ops::add) \
44 _(-, ops::sub) \
45 _(*, ops::mul) \
46 _(/, ops::div) \
47 _(==, ops::eq) \
48 _(!=, ops::ne) \
49 _(>=, ops::ge) \
50 _(<=, ops::le) \
51 _(>, ops::gt) \
52 _(<, ops::lt)
53
54FORALL_BINARY_OPS(DECL_BINARY_OP)
55
56#undef FORALL_BINARY_OPS
57#undef DECL_BINARY_OP
58
64inline Tensor operator-(const Tensor& a) { return ops::neg(a); }
65
67
68} // namespace lmp::tensor
Main tensor object for Lamppp.
Definition tensor.hpp:29
Definition overloads.hpp:10