lamppp
Loading...
Searching...
No Matches
matrix_ops.hpp
1#pragma once
2
3#include "lamppp/tensor/device_type.hpp"
4#include "lamppp/tensor/dispatch_stub.hpp"
5#include "lamppp/tensor/tensor.hpp"
6#include "lamppp/tensor/tensor_impl.hpp"
7
8namespace lmp::tensor::ops {
9
11using matmul_fn = TensorImpl (*)(const TensorImpl&, const TensorImpl&);
12using transpose_fn = TensorImpl (*)(const TensorImpl&);
13
14LMP_DECLARE_DISPATCH(matmul_fn, matmul_stub);
15LMP_DECLARE_DISPATCH(transpose_fn, transpose_stub);
17
24Tensor matmul(const Tensor& a, const Tensor& b);
25
32Tensor transpose(const Tensor& a);
33
34} // namespace lmp::tensor::ops