lamppp
Loading...
Searching...
No Matches
constructor.hpp
1#pragma once
2
3#include "lamppp/tensor/data_type.hpp"
4#include "lamppp/common/config.hpp"
5#include "lamppp/autograd/variable.hpp"
6
7namespace lmp::autograd {
8
9using std::multiplies;
10
19Variable zeros(const std::vector<size_t>& shape, bool requires_grad = false,
20 tensor::DeviceType device = DEFAULT_DEVICE,
21 tensor::DataType dtype = DEFAULT_DTYPE);
22
31Variable ones(const std::vector<size_t>& shape, bool requires_grad = false, tensor::DeviceType device = DEFAULT_DEVICE,
32 tensor::DataType dtype = DEFAULT_DTYPE);
33
42Variable rand(const std::vector<size_t>& shape, bool requires_grad = false,
43 tensor::DeviceType device = DEFAULT_DEVICE,
44 tensor::DataType dtype = DEFAULT_DTYPE);
45
56Variable randn(tensor::Scalar mean, tensor::Scalar var,
57 const std::vector<size_t>& shape, bool requires_grad = false,
58 tensor::DeviceType device = DEFAULT_DEVICE,
59 tensor::DataType dtype = DEFAULT_DTYPE);
60
62template <typename>
63struct IsVector : std::false_type {};
64template <typename U, typename Alloc>
65struct IsVector<std::vector<U, Alloc>> : std::true_type {};
67
70 std::vector<tensor::Scalar> data;
71 std::vector<size_t> shape;
72 template <typename T>
73 void unroll(const std::vector<T>& tensor, size_t depth = 0) {
74 if (depth >= shape.size()) {
75 shape.push_back(tensor.size());
76 }
77 LMP_CHECK(tensor.size() == shape[depth])
78 << "Dimensions along axis must be consistent.";
79 if constexpr (IsVector<T>::value) {
80 for (const T& t : tensor) {
81 unroll(t, depth + 1);
82 }
83 } else {
84 data.insert(data.end(), tensor.begin(), tensor.end());
85 }
86 }
87};
89
91
101template <typename V>
102Variable tensor(const std::vector<V>& data, bool requires_grad = false,
103 tensor::DeviceType device = DEFAULT_DEVICE,
104 tensor::DataType dtype = DEFAULT_DTYPE) {
105 TensorHelper constr;
106 constr.unroll(data);
107 return Variable(tensor::Tensor(constr.data, constr.shape, device, dtype),
108 requires_grad);
109}
110
111} // namespace lmp::autograd
Definition variable.hpp:37
Main tensor object for Lamppp.
Definition tensor.hpp:29
Definition constructor.hpp:63
Definition constructor.hpp:69