lamppp
Loading...
Searching...
No Matches
constructor.hpp
1#pragma once
2
3#include "lamppp/tensor/scalar.hpp" // TODO : maybe move scalar somewhere ?
4#include "variable.hpp"
5
6namespace lmp::autograd {
7
8using std::multiplies;
9
18Variable zeros(const std::vector<size_t>& shape, tensor::DeviceType device,
19 tensor::DataType dtype, bool requires_grad);
20
29Variable ones(const std::vector<size_t>& shape, tensor::DeviceType device,
30 tensor::DataType dtype, bool requires_grad);
31
40Variable rand(const std::vector<size_t>& shape, tensor::DeviceType device,
41 tensor::DataType dtype, bool requires_grad);
42
53Variable randn(tensor::Scalar mean, tensor::Scalar var, const std::vector<size_t>& shape, tensor::DeviceType device,
54 tensor::DataType dtype, bool requires_grad);
55
57template <typename>
58struct IsVector : std::false_type {};
59template <typename U, typename Alloc>
60struct IsVector<std::vector<U, Alloc>> : std::true_type {};
62
65 std::vector<tensor::Scalar> data;
66 std::vector<size_t> shape;
67 template <typename T>
68 void unroll(const std::vector<T>& tensor, size_t depth = 0) {
69 if (depth >= shape.size()) {
70 shape.push_back(tensor.size());
71 }
72 LMP_CHECK(tensor.size() == shape[depth]) <<
73 "Dimensions along axis must be consistent.";
74 if constexpr (IsVector<T>::value) {
75 for (const T& t : tensor) {
76 unroll(t, depth + 1);
77 }
78 } else {
79 data.insert(data.end(), tensor.begin(), tensor.end());
80 }
81 }
82};
84
86
96template <typename V>
97Variable tensor(const std::vector<V>& data, tensor::DeviceType device,
98 tensor::DataType dtype, bool requires_grad) {
99 TensorHelper constr;
100 constr.unroll(data);
101 return Variable(tensor::Tensor(constr.data, constr.shape, device, dtype),
102 requires_grad);
103}
104
105} // namespace lmp::autograd
Definition variable.hpp:48
Main tensor object for Lamppp.
Definition tensor.hpp:29
Definition constructor.hpp:58
Definition constructor.hpp:64