first push
This commit is contained in:
54
RAFT/alt_cuda_corr/correlation.cpp
Executable file
54
RAFT/alt_cuda_corr/correlation.cpp
Executable file
@ -0,0 +1,54 @@
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
// CUDA forward declarations
|
||||
std::vector<torch::Tensor> corr_cuda_forward(
|
||||
torch::Tensor fmap1,
|
||||
torch::Tensor fmap2,
|
||||
torch::Tensor coords,
|
||||
int radius);
|
||||
|
||||
std::vector<torch::Tensor> corr_cuda_backward(
|
||||
torch::Tensor fmap1,
|
||||
torch::Tensor fmap2,
|
||||
torch::Tensor coords,
|
||||
torch::Tensor corr_grad,
|
||||
int radius);
|
||||
|
||||
// C++ interface
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::vector<torch::Tensor> corr_forward(
|
||||
torch::Tensor fmap1,
|
||||
torch::Tensor fmap2,
|
||||
torch::Tensor coords,
|
||||
int radius) {
|
||||
CHECK_INPUT(fmap1);
|
||||
CHECK_INPUT(fmap2);
|
||||
CHECK_INPUT(coords);
|
||||
|
||||
return corr_cuda_forward(fmap1, fmap2, coords, radius);
|
||||
}
|
||||
|
||||
|
||||
std::vector<torch::Tensor> corr_backward(
|
||||
torch::Tensor fmap1,
|
||||
torch::Tensor fmap2,
|
||||
torch::Tensor coords,
|
||||
torch::Tensor corr_grad,
|
||||
int radius) {
|
||||
CHECK_INPUT(fmap1);
|
||||
CHECK_INPUT(fmap2);
|
||||
CHECK_INPUT(coords);
|
||||
CHECK_INPUT(corr_grad);
|
||||
|
||||
return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &corr_forward, "CORR forward");
|
||||
m.def("backward", &corr_backward, "CORR backward");
|
||||
}
|
Reference in New Issue
Block a user