3#include "lamppp/tensor/device_type.hpp"
4#include "lamppp/tensor/dispatch_stub.hpp"
5#include "lamppp/tensor/tensor.hpp"
6#include "lamppp/tensor/tensor_impl.hpp"
8namespace lmp::tensor::ops {
10using conv1d_fn = TensorImpl (*)(
const TensorImpl&,
const TensorImpl&, size_t, size_t, size_t);
11using conv2d_fn = TensorImpl (*)(
const TensorImpl&,
const TensorImpl&, size_t, size_t, size_t);
12using conv3d_fn = TensorImpl (*)(
const TensorImpl&,
const TensorImpl&, size_t, size_t, size_t);
14LMP_DECLARE_DISPATCH(conv1d_fn, conv1d_stub);
15LMP_DECLARE_DISPATCH(conv2d_fn, conv2d_stub);
16LMP_DECLARE_DISPATCH(conv3d_fn, conv3d_stub);
18Tensor conv1d(
const Tensor& input,
const Tensor& kernel,
size_t stride,
size_t padding,
size_t dilation);
19Tensor conv2d(
const Tensor& input,
const Tensor& kernel,
size_t stride,
size_t padding,
size_t dilation);
20Tensor conv3d(
const Tensor& input,
const Tensor& kernel,
size_t stride,
size_t padding,
size_t dilation);