lamppp
Loading...
Searching...
No Matches
unary_ops.hpp
1#pragma once
2
3#include "lamppp/tensor/device_type.hpp"
4#include "lamppp/tensor/dispatch_stub.hpp"
5#include "lamppp/tensor/tensor.hpp"
6#include "lamppp/tensor/tensor_impl.hpp"
7
8namespace lmp::tensor::ops {
9
11using neg_fn = TensorImpl (*)(const TensorImpl&);
12using exp_fn = TensorImpl (*)(const TensorImpl&);
13using log_fn = TensorImpl (*)(const TensorImpl&);
14using sqrt_fn = TensorImpl (*)(const TensorImpl&);
15using abs_fn = TensorImpl (*)(const TensorImpl&);
16using sin_fn = TensorImpl (*)(const TensorImpl&);
17using cos_fn = TensorImpl (*)(const TensorImpl&);
18using tan_fn = TensorImpl (*)(const TensorImpl&);
19using clamp_fn = TensorImpl (*)(const TensorImpl&, Scalar, Scalar);
20
21LMP_DECLARE_DISPATCH(neg_fn, neg_stub);
22LMP_DECLARE_DISPATCH(exp_fn, exp_stub);
23LMP_DECLARE_DISPATCH(log_fn, log_stub);
24LMP_DECLARE_DISPATCH(sqrt_fn, sqrt_stub);
25LMP_DECLARE_DISPATCH(abs_fn, abs_stub);
26LMP_DECLARE_DISPATCH(sin_fn, sin_stub);
27LMP_DECLARE_DISPATCH(cos_fn, cos_stub);
28LMP_DECLARE_DISPATCH(tan_fn, tan_stub);
29LMP_DECLARE_DISPATCH(clamp_fn, clamp_stub);
31
37Tensor neg(const Tensor& a);
38
44Tensor exp(const Tensor& a);
45
51Tensor log(const Tensor& a);
52
58Tensor sqrt(const Tensor& a);
59
65Tensor abs(const Tensor& a);
66
72Tensor sin(const Tensor& a);
73
79Tensor cos(const Tensor& a);
80
86Tensor tan(const Tensor& a);
87
95Tensor clamp(const Tensor& a, Scalar min_val, Scalar max_val);
96
97} // namespace lmp::tensor::ops