lamppp
Loading...
Searching...
No Matches
kernels.hpp
1#pragma once
2
3#include <cmath>
4#include "lamppp/tensor/device_type.hpp"
5#include "lamppp/tensor/dispatch_stub.hpp"
6#include "lamppp/tensor/native/unary_ops.hpp"
7#include "lamppp/tensor/native/matrix_ops.hpp"
8#include "lamppp/tensor/tensor_impl.hpp"
9
11
13template <typename T>
14struct AddFunctor {
15 T operator()(T arg1, T arg2) { return arg1 + arg2; }
16};
17template <typename T>
18struct SubFunctor {
19 T operator()(T arg1, T arg2) { return arg1 - arg2; }
20};
21template <typename T>
22struct MulFunctor {
23 T operator()(T arg1, T arg2) { return arg1 * arg2; }
24};
25template <typename T>
26struct DivFunctor {
27 T operator()(T arg1, T arg2) { return arg1 / arg2; }
28};
29template <typename T>
30struct PowFunctor {
31 T operator()(T arg1, T arg2) { return ::std::pow(arg1, arg2); }
32};
33template <typename T>
34struct EqFunctor {
35 T operator()(T arg1, T arg2) { return arg1 == arg2; }
36};
37template <typename T>
38struct NeFunctor {
39 T operator()(T arg1, T arg2) { return arg1 != arg2; }
40};
41template <typename T>
42struct LeFunctor {
43 T operator()(T arg1, T arg2) { return arg1 <= arg2; }
44};
45template <typename T>
46struct LtFunctor {
47 T operator()(T arg1, T arg2) { return arg1 < arg2; }
48};
49template <typename T>
50struct GtFunctor {
51 T operator()(T arg1, T arg2) { return arg1 > arg2; }
52};
53template <typename T>
54struct GeFunctor {
55 T operator()(T arg1, T arg2) { return arg1 >= arg2; }
56};
57template <typename T>
58struct NegFunctor {
59 T operator()(T arg) { return (-arg); }
60};
61template <typename T>
62struct LogFunctor {
63 T operator()(T arg) {
64 return static_cast<T>(::log(static_cast<double>(arg)));
65 }
66};
67template <typename T>
68struct ExpFunctor {
69 T operator()(T arg) {
70 return static_cast<T>(::exp(static_cast<double>(arg)));
71 }
72};
73template <typename T>
75 T operator()(T arg) {
76 return static_cast<T>(::sqrt(static_cast<double>(arg)));
77 }
78};
79template <typename T>
80struct AbsFunctor {
81 T operator()(T arg) {
82 return static_cast<T>(::std::abs(static_cast<double>(arg)));
83 }
84};
85template <typename T>
86struct SinFunctor {
87 T operator()(T arg) {
88 return static_cast<T>(::sin(static_cast<double>(arg)));
89 }
90};
91template <typename T>
92struct CosFunctor {
93 T operator()(T arg) {
94 return static_cast<T>(::cos(static_cast<double>(arg)));
95 }
96};
97template <typename T>
98struct TanFunctor {
99 T operator()(T arg) {
100 return static_cast<T>(::tan(static_cast<double>(arg)));
101 }
102};
103template <typename T>
105 explicit ClampFunctor(Scalar min_val, Scalar max_val)
106 : min_val_(min_val), max_val_(max_val) {}
107 T operator()(T arg) {
108 return arg < min_val_ ? min_val_ : (arg > max_val_ ? max_val_ : arg);
109 }
110
111 private:
112 Scalar min_val_, max_val_;
113};
114template <typename T>
116 static constexpr T kIdentity = 0;
117 T operator()(T arg1, T arg2) { return arg1 + arg2; }
118};
119template <typename T>
121 static constexpr T kIdentity = std::numeric_limits<T>::lowest();
122 T operator()(T arg1, T arg2) { return ::std::max(arg1, arg2); }
123};
124template <typename T>
126 static constexpr T kIdentity = std::numeric_limits<T>::max();
127 T operator()(T arg1, T arg2) { return ::std::min(arg1, arg2); }
128};
129template <typename T>
131 static constexpr T kIdentity = 1;
132 T operator()(T arg1, T arg2) { return arg1 * arg2; }
133};
135
138TensorImpl sub_cpu(const TensorImpl& a, const TensorImpl& b);
139TensorImpl mul_cpu(const TensorImpl& a, const TensorImpl& b);
140TensorImpl div_cpu(const TensorImpl& a, const TensorImpl& b);
141TensorImpl pow_cpu(const TensorImpl& a, const TensorImpl& b);
142TensorImpl eq_cpu(const TensorImpl& a, const TensorImpl& b);
143TensorImpl ne_cpu(const TensorImpl& a, const TensorImpl& b);
144TensorImpl le_cpu(const TensorImpl& a, const TensorImpl& b);
145TensorImpl lt_cpu(const TensorImpl& a, const TensorImpl& b);
146TensorImpl ge_cpu(const TensorImpl& a, const TensorImpl& b);
147TensorImpl gt_cpu(const TensorImpl& a, const TensorImpl& b);
148
149TensorImpl neg_cpu(const TensorImpl& a);
150TensorImpl log_cpu(const TensorImpl& a);
151TensorImpl exp_cpu(const TensorImpl& a);
152TensorImpl sqrt_cpu(const TensorImpl& a);
153TensorImpl abs_cpu(const TensorImpl& a);
154TensorImpl sin_cpu(const TensorImpl& a);
155TensorImpl cos_cpu(const TensorImpl& a);
156TensorImpl tan_cpu(const TensorImpl& a);
157TensorImpl clamp_cpu(const TensorImpl& a, Scalar min_val, Scalar max_val);
158
159TensorImpl matmul_cpu(const TensorImpl& a, const TensorImpl& b);
160TensorImpl transpose_cpu(const TensorImpl& a);
161
162TensorImpl sum_cpu(const TensorImpl& a, size_t axis);
163TensorImpl max_cpu(const TensorImpl& a, size_t axis);
164TensorImpl min_cpu(const TensorImpl& a, size_t axis);
165TensorImpl prod_cpu(const TensorImpl& a, size_t axis);
167
168} // namespace lmp::tensor::detail::cpu
Main implementation class for Tensor object.
Definition tensor_impl.hpp:28
Definition binary.cpp:4
TensorImpl add_cpu(const TensorImpl &a, const TensorImpl &b)
Definition kernels.hpp:80
Definition kernels.hpp:14
Definition kernels.hpp:104
Definition kernels.hpp:92
Definition kernels.hpp:26
Definition kernels.hpp:34
Definition kernels.hpp:68
Definition kernels.hpp:54
Definition kernels.hpp:50
Definition kernels.hpp:42
Definition kernels.hpp:62
Definition kernels.hpp:46
Definition kernels.hpp:120
Definition kernels.hpp:125
Definition kernels.hpp:22
Definition kernels.hpp:38
Definition kernels.hpp:58
Definition kernels.hpp:30
Definition kernels.hpp:130
Definition kernels.hpp:86
Definition kernels.hpp:74
Definition kernels.hpp:18
Definition kernels.hpp:115
Definition kernels.hpp:98