lamppp
Loading...
Searching...
No Matches
csrc
include
lamppp
nets
layers
linear.hpp
1
#pragma once
2
3
#include <cstddef>
4
#include "lamppp/autograd/variable.hpp"
5
#include "lamppp/nets/module.hpp"
6
#include "lamppp/nets/parameter.hpp"
7
8
namespace
lmp::nets {
9
10
using
ssize_t = ptrdiff_t;
// signed size_t
11
12
class
LinearImpl
:
public
ModuleImpl
{
13
public
:
14
explicit
LinearImpl
(
size_t
in_features,
size_t
out_features,
bool
bias =
true
,
15
tensor::DeviceType device = DEFAULT_DEVICE,
16
tensor::DataType dtype = DEFAULT_DTYPE);
17
autograd::Variable
forward(
const
autograd::Variable
& x)
const
;
18
19
private
:
20
Parameter
weights_;
21
Parameter
bias_;
22
bool
requires_bias_;
23
};
24
LMP_DEFINE_MODULE(Linear);
25
26
class
FlattenImpl
:
public
ModuleImpl
{
27
public
:
28
explicit
FlattenImpl
(ssize_t start_dim = 1, ssize_t end_dim = -1);
29
autograd::Variable
forward(
const
autograd::Variable
& x)
const
;
30
31
private
:
32
ssize_t start_dim_;
33
ssize_t end_dim_;
34
};
35
LMP_DEFINE_MODULE(Flatten);
36
37
}
lmp::autograd::Variable
Definition
variable.hpp:37
lmp::nets::FlattenImpl
Definition
linear.hpp:26
lmp::nets::LinearImpl
Definition
linear.hpp:12
lmp::nets::ModuleImpl
Definition
module.hpp:19
lmp::nets::Parameter
Definition
parameter.hpp:7
Generated by
1.9.8