lamppp
Loading...
Searching...
No Matches
tensor.hpp
1#pragma once
2
3#include <iostream>
4#include <memory>
5#include <span>
6#include <utility>
7#include <vector>
8#include "data_type.hpp"
9#include "device_type.hpp"
10#include "dispatch_type.hpp"
11#include "fill_like.hpp"
12#include "tensor_impl.hpp"
13#include "lamppp/tensor/native/memory_ops.hpp"
14
15namespace lmp::tensor {
16
17namespace detail {
18class UnsafeTensorAccessor;
19}
20
29class Tensor {
30 public:
31 Tensor() = default;
32
40 template <typename T>
41 explicit Tensor(const std::vector<T>& data, const std::vector<size_t>& shape,
42 DeviceType device = DeviceType::CPU,
43 DataType dtype = DataType::Float64)
44 : impl_(std::make_shared<TensorImpl>(data, shape, device, dtype)) {}
45
46 void* data() const noexcept;
47 DataType type() const noexcept;
48 DeviceType device() const noexcept;
49 const std::vector<size_t>& shape() const noexcept;
50 const std::vector<detail::stride_t>& strides() const noexcept;
51 size_t numel() const noexcept;
52
54 template <typename T>
55 std::vector<T> to_vector() const {
56 std::vector<T> converted_data(impl_->numel());
57 LMP_DISPATCH_ALL_TYPES(impl_->type(), [&] {
58 std::unique_ptr<scalar_t[]> original_data = std::make_unique<scalar_t[]>(numel());
59 ops::copy_stub()(device(), DeviceType::CPU, data(),
60 original_data.get(), numel(), type(), type());
61
62 for (size_t i = 0; i < impl_->numel(); ++i) {
63 converted_data[i] = static_cast<T>(original_data[i]);
64 }
65 });
66 return converted_data;
67 }
68 Scalar index(const std::vector<size_t>& idx) const;
69
74 Tensor reshape(std::vector<size_t> new_shape) const;
75 Tensor squeeze(size_t dim) const;
76 Tensor expand_dims(size_t dim) const;
77
82 Tensor to(DeviceType device) const;
83
85 void copy(const Tensor& other);
86 void fill(Scalar item);
87
88 friend std::ostream& operator<<(std::ostream& os, const Tensor& obj);
89 friend class TensorOpFact;
90 friend class detail::UnsafeTensorAccessor;
91
92 private:
93 explicit Tensor(std::shared_ptr<TensorImpl> ptr) : impl_(std::move(ptr)) {}
94 std::shared_ptr<TensorImpl> impl_;
95};
96
97namespace detail {
98// @internal
104struct UnsafeTensorAccessor {
105 static std::shared_ptr<TensorImpl> getImpl(const Tensor& ten) {
106 return ten.impl_;
107 }
108 static Tensor fromImpl(std::shared_ptr<TensorImpl> ptr) {
109 return Tensor(std::move(ptr));
110 }
111};
112// @endinternal
113} // namespace detail
114
115} // namespace lmp::tensor
Main implementation class for Tensor object.
Definition tensor_impl.hpp:28
Main tensor object for Lamppp.
Definition tensor.hpp:29
std::vector< T > to_vector() const
Definition tensor.hpp:55
Tensor(const std::vector< T > &data, const std::vector< size_t > &shape, DeviceType device=DeviceType::CPU, DataType dtype=DataType::Float64)
Construct a tensor from a vector.
Definition tensor.hpp:41