lamppp
Loading...
Searching...
No Matches
linear.hpp
1#pragma once
2
3#include <cstddef>
4#include "lamppp/autograd/variable.hpp"
5#include "lamppp/nets/module.hpp"
6#include "lamppp/nets/parameter.hpp"
7
8namespace lmp::nets {
9
10using ssize_t = ptrdiff_t; // signed size_t
11
12class LinearImpl : public ModuleImpl {
13 public:
14 explicit LinearImpl(size_t in_features, size_t out_features, bool bias = true,
15 tensor::DeviceType device = DEFAULT_DEVICE,
16 tensor::DataType dtype = DEFAULT_DTYPE);
17 autograd::Variable forward(const autograd::Variable& x) const;
18
19 private:
20 Parameter weights_;
21 Parameter bias_;
22 bool requires_bias_;
23};
24LMP_DEFINE_MODULE(Linear);
25
26class FlattenImpl : public ModuleImpl {
27 public:
28 explicit FlattenImpl(ssize_t start_dim = 1, ssize_t end_dim = -1);
29 autograd::Variable forward(const autograd::Variable& x) const;
30
31 private:
32 ssize_t start_dim_;
33 ssize_t end_dim_;
34};
35LMP_DEFINE_MODULE(Flatten);
36
37}
Definition variable.hpp:37
Definition linear.hpp:26
Definition linear.hpp:12
Definition module.hpp:19
Definition parameter.hpp:7