lamppp
Loading...
Searching...
No Matches
expand_ops.hpp
1#pragma once
2
3#include "lamppp/tensor/device_type.hpp"
4#include "lamppp/tensor/dispatch_stub.hpp"
5#include "lamppp/tensor/tensor.hpp"
6#include "lamppp/tensor/tensor_impl.hpp"
7
8namespace lmp::tensor::ops {
9
11using add_fn = TensorImpl (*)(const TensorImpl&, const TensorImpl&);
12using sub_fn = TensorImpl (*)(const TensorImpl&, const TensorImpl&);
13using mul_fn = TensorImpl (*)(const TensorImpl&, const TensorImpl&);
14using div_fn = TensorImpl (*)(const TensorImpl&, const TensorImpl&);
15using pow_fn = TensorImpl (*)(const TensorImpl&, const TensorImpl&);
16using eq_fn = TensorImpl (*)(const TensorImpl&, const TensorImpl&);
17using ne_fn = TensorImpl (*)(const TensorImpl&, const TensorImpl&);
18using ge_fn = TensorImpl (*)(const TensorImpl&, const TensorImpl&);
19using le_fn = TensorImpl (*)(const TensorImpl&, const TensorImpl&);
20using gt_fn = TensorImpl (*)(const TensorImpl&, const TensorImpl&);
21using lt_fn = TensorImpl (*)(const TensorImpl&, const TensorImpl&);
22
23LMP_DECLARE_DISPATCH(add_fn, add_stub);
24LMP_DECLARE_DISPATCH(sub_fn, sub_stub);
25LMP_DECLARE_DISPATCH(mul_fn, mul_stub);
26LMP_DECLARE_DISPATCH(div_fn, div_stub);
27LMP_DECLARE_DISPATCH(pow_fn, pow_stub);
28LMP_DECLARE_DISPATCH(eq_fn, eq_stub);
29LMP_DECLARE_DISPATCH(ne_fn, ne_stub);
30LMP_DECLARE_DISPATCH(ge_fn, ge_stub);
31LMP_DECLARE_DISPATCH(le_fn, le_stub);
32LMP_DECLARE_DISPATCH(gt_fn, gt_stub);
33LMP_DECLARE_DISPATCH(lt_fn, lt_stub);
35
42Tensor add(const Tensor& a, const Tensor& b);
43
50Tensor sub(const Tensor& a, const Tensor& b);
51
58Tensor mul(const Tensor& a, const Tensor& b);
59
67Tensor div(const Tensor& a, const Tensor& b);
68
75Tensor pow(const Tensor& a, const Tensor& b);
76
83Tensor eq(const Tensor& a, const Tensor& b);
84
91Tensor ne(const Tensor& a, const Tensor& b);
92
99Tensor ge(const Tensor& a, const Tensor& b);
100
107Tensor gt(const Tensor& a, const Tensor& b);
108
115Tensor le(const Tensor& a, const Tensor& b);
116
123Tensor lt(const Tensor& a, const Tensor& b);
124
125} // namespace lmp::tensor::ops