lamppp
Loading...
Searching...
No Matches
conv_ops.hpp
1#pragma once
2
3#include "lamppp/tensor/device_type.hpp"
4#include "lamppp/tensor/dispatch_stub.hpp"
5#include "lamppp/tensor/tensor.hpp"
6#include "lamppp/tensor/tensor_impl.hpp"
7
8namespace lmp::tensor::ops {
9
10using conv1d_fn = TensorImpl (*)(const TensorImpl&, const TensorImpl&, size_t, size_t, size_t);
11using conv2d_fn = TensorImpl (*)(const TensorImpl&, const TensorImpl&, size_t, size_t, size_t);
12using conv3d_fn = TensorImpl (*)(const TensorImpl&, const TensorImpl&, size_t, size_t, size_t);
13
14LMP_DECLARE_DISPATCH(conv1d_fn, conv1d_stub);
15LMP_DECLARE_DISPATCH(conv2d_fn, conv2d_stub);
16LMP_DECLARE_DISPATCH(conv3d_fn, conv3d_stub);
17
18Tensor conv1d(const Tensor& input, const Tensor& kernel, size_t stride, size_t padding, size_t dilation);
19Tensor conv2d(const Tensor& input, const Tensor& kernel, size_t stride, size_t padding, size_t dilation);
20Tensor conv3d(const Tensor& input, const Tensor& kernel, size_t stride, size_t padding, size_t dilation);
21
22} // namespace lmp::tensor::ops