4#include "lamppp/tensor/device_type.hpp"
5#include "lamppp/tensor/dispatch_stub.hpp"
6#include "lamppp/tensor/native/unary_ops.hpp"
7#include "lamppp/tensor/native/matrix_ops.hpp"
8#include "lamppp/tensor/tensor_impl.hpp"
15 T operator()(T arg1, T arg2) {
return arg1 + arg2; }
19 T operator()(T arg1, T arg2) {
return arg1 - arg2; }
23 T operator()(T arg1, T arg2) {
return arg1 * arg2; }
27 T operator()(T arg1, T arg2) {
return arg1 / arg2; }
31 T operator()(T arg1, T arg2) { return ::std::pow(arg1, arg2); }
35 T operator()(T arg1, T arg2) {
return arg1 == arg2; }
39 T operator()(T arg1, T arg2) {
return arg1 != arg2; }
43 T operator()(T arg1, T arg2) {
return arg1 <= arg2; }
47 T operator()(T arg1, T arg2) {
return arg1 < arg2; }
51 T operator()(T arg1, T arg2) {
return arg1 > arg2; }
55 T operator()(T arg1, T arg2) {
return arg1 >= arg2; }
59 T operator()(T arg) {
return (-arg); }
64 return static_cast<T
>(::log(
static_cast<double>(arg)));
70 return static_cast<T
>(::exp(
static_cast<double>(arg)));
76 return static_cast<T
>(::sqrt(
static_cast<double>(arg)));
82 return static_cast<T
>(::std::abs(
static_cast<double>(arg)));
88 return static_cast<T
>(::sin(
static_cast<double>(arg)));
94 return static_cast<T
>(::cos(
static_cast<double>(arg)));
100 return static_cast<T
>(::tan(
static_cast<double>(arg)));
106 : min_val_(min_val), max_val_(max_val) {}
107 T operator()(T arg) {
108 return arg < min_val_ ? min_val_ : (arg > max_val_ ? max_val_ : arg);
112 Scalar min_val_, max_val_;
116 static constexpr T kIdentity = 0;
117 T operator()(T arg1, T arg2) {
return arg1 + arg2; }
121 static constexpr T kIdentity = std::numeric_limits<T>::lowest();
122 T operator()(T arg1, T arg2) { return ::std::max(arg1, arg2); }
126 static constexpr T kIdentity = std::numeric_limits<T>::max();
127 T operator()(T arg1, T arg2) { return ::std::min(arg1, arg2); }
131 static constexpr T kIdentity = 1;
132 T operator()(T arg1, T arg2) {
return arg1 * arg2; }
Main implementation class for Tensor object.
Definition tensor_impl.hpp:28
TensorImpl add_cpu(const TensorImpl &a, const TensorImpl &b)
Definition kernels.hpp:80
Definition kernels.hpp:14
Definition kernels.hpp:104
Definition kernels.hpp:92
Definition kernels.hpp:26
Definition kernels.hpp:34
Definition kernels.hpp:68
Definition kernels.hpp:54
Definition kernels.hpp:50
Definition kernels.hpp:42
Definition kernels.hpp:62
Definition kernels.hpp:46
Definition kernels.hpp:120
Definition kernels.hpp:125
Definition kernels.hpp:22
Definition kernels.hpp:38
Definition kernels.hpp:58
Definition kernels.hpp:30
Definition kernels.hpp:130
Definition kernels.hpp:86
Definition kernels.hpp:74
Definition kernels.hpp:18
Definition kernels.hpp:115
Definition kernels.hpp:98