lamppp
Loading...
Searching...
No Matches
data_type.hpp
1#pragma once
2
3#include <cstdint>
4#include <ostream>
5
6namespace lmp::tensor {
7
13using Scalar = double;
14
23enum class DataType : uint8_t {
24 Bool = 0,
25 Int16 = 1,
26 Int32 = 2,
27 Int64 = 3,
28 Float32 = 4,
29 Float64 = 5
30};
31
33
37template <typename T>
38struct TypeMeta;
39
40template <>
41struct TypeMeta<bool> {
42 static constexpr DataType kValue = DataType::Bool;
43};
44template <>
45struct TypeMeta<int16_t> {
46 static constexpr DataType kValue = DataType::Int16;
47};
48template <>
49struct TypeMeta<int> {
50 static constexpr DataType kValue = DataType::Int32;
51};
52template <>
53struct TypeMeta<int64_t> {
54 static constexpr DataType kValue = DataType::Int64;
55};
56template <>
57struct TypeMeta<float> {
58 static constexpr DataType kValue = DataType::Float32;
59};
60template <>
61struct TypeMeta<double> {
62 static constexpr DataType kValue = DataType::Float64;
63};
64
66
75inline DataType type_upcast(DataType a_type, DataType b_type) {
76 return static_cast<DataType>(
77 std::max(static_cast<uint8_t>(a_type), static_cast<uint8_t>(b_type)));
78}
79
80inline std::ostream& operator<<(std::ostream& os, DataType dtype) {
81 switch (dtype) {
82 case DataType::Bool:
83 os << "Bool";
84 break;
85 case DataType::Int16:
86 os << "Int16";
87 break;
88 case DataType::Int32:
89 os << "Int32";
90 break;
91 case DataType::Int64:
92 os << "Int64";
93 break;
94 case DataType::Float32:
95 os << "Float32";
96 break;
97 case DataType::Float64:
98 os << "Float64";
99 break;
100 default:
101 os << "Unknown DataType";
102 break;
103 }
104 return os;
105}
106
107} // namespace lmp::tensor
108
109#define LMP_X_TYPES(_) \
110 _(bool) \
111 _(int16_t) \
112 _(int) \
113 _(int64_t) \
114 _(float) \
115 _(double)
116
117#define LMP_LIST_TYPES (bool, int16_t, int, int64_t, float, double)
simple template to convert from a concrete type (like int) to the enum type, Int32....
Definition data_type.hpp:38