3#include "lamppp/tensor/device_type.hpp"
4#include "lamppp/tensor/dispatch_stub.hpp"
5#include "lamppp/tensor/tensor.hpp"
6#include "lamppp/tensor/tensor_impl.hpp"
8namespace lmp::tensor::ops {
11using matmul_fn = TensorImpl (*)(
const TensorImpl&,
const TensorImpl&);
12using transpose_fn = TensorImpl (*)(
const TensorImpl&);
14LMP_DECLARE_DISPATCH(matmul_fn, matmul_stub);
15LMP_DECLARE_DISPATCH(transpose_fn, transpose_stub);
24Tensor matmul(
const Tensor& a,
const Tensor& b);
32Tensor transpose(
const Tensor& a);