first push
This commit is contained in:
128
.gitignore
vendored
Normal file
128
.gitignore
vendored
Normal file
@ -0,0 +1,128 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
#log
|
||||
log/
|
||||
#data
|
||||
*.pth
|
||||
*.jpg
|
||||
*.png
|
||||
*.tar
|
||||
*.h5
|
||||
*.mp4
|
||||
*.avi
|
||||
utils/goodsCategory.json
|
||||
utils/monitor.json
|
||||
tmp.txt
|
29
RAFT/LICENSE
Executable file
29
RAFT/LICENSE
Executable file
@ -0,0 +1,29 @@
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2020, princeton-vl
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
0
RAFT/__init__.py
Executable file
0
RAFT/__init__.py
Executable file
54
RAFT/alt_cuda_corr/correlation.cpp
Executable file
54
RAFT/alt_cuda_corr/correlation.cpp
Executable file
@ -0,0 +1,54 @@
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
// CUDA forward declarations
|
||||
std::vector<torch::Tensor> corr_cuda_forward(
|
||||
torch::Tensor fmap1,
|
||||
torch::Tensor fmap2,
|
||||
torch::Tensor coords,
|
||||
int radius);
|
||||
|
||||
std::vector<torch::Tensor> corr_cuda_backward(
|
||||
torch::Tensor fmap1,
|
||||
torch::Tensor fmap2,
|
||||
torch::Tensor coords,
|
||||
torch::Tensor corr_grad,
|
||||
int radius);
|
||||
|
||||
// C++ interface
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::vector<torch::Tensor> corr_forward(
|
||||
torch::Tensor fmap1,
|
||||
torch::Tensor fmap2,
|
||||
torch::Tensor coords,
|
||||
int radius) {
|
||||
CHECK_INPUT(fmap1);
|
||||
CHECK_INPUT(fmap2);
|
||||
CHECK_INPUT(coords);
|
||||
|
||||
return corr_cuda_forward(fmap1, fmap2, coords, radius);
|
||||
}
|
||||
|
||||
|
||||
std::vector<torch::Tensor> corr_backward(
|
||||
torch::Tensor fmap1,
|
||||
torch::Tensor fmap2,
|
||||
torch::Tensor coords,
|
||||
torch::Tensor corr_grad,
|
||||
int radius) {
|
||||
CHECK_INPUT(fmap1);
|
||||
CHECK_INPUT(fmap2);
|
||||
CHECK_INPUT(coords);
|
||||
CHECK_INPUT(corr_grad);
|
||||
|
||||
return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &corr_forward, "CORR forward");
|
||||
m.def("backward", &corr_backward, "CORR backward");
|
||||
}
|
324
RAFT/alt_cuda_corr/correlation_kernel.cu
Executable file
324
RAFT/alt_cuda_corr/correlation_kernel.cu
Executable file
@ -0,0 +1,324 @@
|
||||
#include <torch/extension.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <vector>
|
||||
|
||||
|
||||
#define BLOCK_H 4
|
||||
#define BLOCK_W 8
|
||||
#define BLOCK_HW BLOCK_H * BLOCK_W
|
||||
#define CHANNEL_STRIDE 32
|
||||
|
||||
|
||||
__forceinline__ __device__
|
||||
bool within_bounds(int h, int w, int H, int W) {
|
||||
return h >= 0 && h < H && w >= 0 && w < W;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void corr_forward_kernel(
|
||||
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
|
||||
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
|
||||
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,
|
||||
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr,
|
||||
int r)
|
||||
{
|
||||
const int b = blockIdx.x;
|
||||
const int h0 = blockIdx.y * blockDim.x;
|
||||
const int w0 = blockIdx.z * blockDim.y;
|
||||
const int tid = threadIdx.x * blockDim.y + threadIdx.y;
|
||||
|
||||
const int H1 = fmap1.size(1);
|
||||
const int W1 = fmap1.size(2);
|
||||
const int H2 = fmap2.size(1);
|
||||
const int W2 = fmap2.size(2);
|
||||
const int N = coords.size(1);
|
||||
const int C = fmap1.size(3);
|
||||
|
||||
__shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
|
||||
__shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
|
||||
__shared__ scalar_t x2s[BLOCK_HW];
|
||||
__shared__ scalar_t y2s[BLOCK_HW];
|
||||
|
||||
for (int c=0; c<C; c+=CHANNEL_STRIDE) {
|
||||
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
||||
int k1 = k + tid / CHANNEL_STRIDE;
|
||||
int h1 = h0 + k1 / BLOCK_W;
|
||||
int w1 = w0 + k1 % BLOCK_W;
|
||||
int c1 = tid % CHANNEL_STRIDE;
|
||||
|
||||
auto fptr = fmap1[b][h1][w1];
|
||||
if (within_bounds(h1, w1, H1, W1))
|
||||
f1[c1][k1] = fptr[c+c1];
|
||||
else
|
||||
f1[c1][k1] = 0.0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int n=0; n<N; n++) {
|
||||
int h1 = h0 + threadIdx.x;
|
||||
int w1 = w0 + threadIdx.y;
|
||||
if (within_bounds(h1, w1, H1, W1)) {
|
||||
x2s[tid] = coords[b][n][h1][w1][0];
|
||||
y2s[tid] = coords[b][n][h1][w1][1];
|
||||
}
|
||||
|
||||
scalar_t dx = x2s[tid] - floor(x2s[tid]);
|
||||
scalar_t dy = y2s[tid] - floor(y2s[tid]);
|
||||
|
||||
int rd = 2*r + 1;
|
||||
for (int iy=0; iy<rd+1; iy++) {
|
||||
for (int ix=0; ix<rd+1; ix++) {
|
||||
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
||||
int k1 = k + tid / CHANNEL_STRIDE;
|
||||
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
|
||||
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
|
||||
int c2 = tid % CHANNEL_STRIDE;
|
||||
|
||||
auto fptr = fmap2[b][h2][w2];
|
||||
if (within_bounds(h2, w2, H2, W2))
|
||||
f2[c2][k1] = fptr[c+c2];
|
||||
else
|
||||
f2[c2][k1] = 0.0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
scalar_t s = 0.0;
|
||||
for (int k=0; k<CHANNEL_STRIDE; k++)
|
||||
s += f1[k][tid] * f2[k][tid];
|
||||
|
||||
int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
|
||||
int ix_ne = H1*W1*((iy-1) + rd*ix);
|
||||
int ix_sw = H1*W1*(iy + rd*(ix-1));
|
||||
int ix_se = H1*W1*(iy + rd*ix);
|
||||
|
||||
scalar_t nw = s * (dy) * (dx);
|
||||
scalar_t ne = s * (dy) * (1-dx);
|
||||
scalar_t sw = s * (1-dy) * (dx);
|
||||
scalar_t se = s * (1-dy) * (1-dx);
|
||||
|
||||
scalar_t* corr_ptr = &corr[b][n][0][h1][w1];
|
||||
|
||||
if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
|
||||
*(corr_ptr + ix_nw) += nw;
|
||||
|
||||
if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
|
||||
*(corr_ptr + ix_ne) += ne;
|
||||
|
||||
if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
|
||||
*(corr_ptr + ix_sw) += sw;
|
||||
|
||||
if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
|
||||
*(corr_ptr + ix_se) += se;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void corr_backward_kernel(
|
||||
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
|
||||
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
|
||||
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,
|
||||
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr_grad,
|
||||
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1_grad,
|
||||
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2_grad,
|
||||
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords_grad,
|
||||
int r)
|
||||
{
|
||||
|
||||
const int b = blockIdx.x;
|
||||
const int h0 = blockIdx.y * blockDim.x;
|
||||
const int w0 = blockIdx.z * blockDim.y;
|
||||
const int tid = threadIdx.x * blockDim.y + threadIdx.y;
|
||||
|
||||
const int H1 = fmap1.size(1);
|
||||
const int W1 = fmap1.size(2);
|
||||
const int H2 = fmap2.size(1);
|
||||
const int W2 = fmap2.size(2);
|
||||
const int N = coords.size(1);
|
||||
const int C = fmap1.size(3);
|
||||
|
||||
__shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
|
||||
__shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
|
||||
|
||||
__shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1];
|
||||
__shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1];
|
||||
|
||||
__shared__ scalar_t x2s[BLOCK_HW];
|
||||
__shared__ scalar_t y2s[BLOCK_HW];
|
||||
|
||||
for (int c=0; c<C; c+=CHANNEL_STRIDE) {
|
||||
|
||||
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
||||
int k1 = k + tid / CHANNEL_STRIDE;
|
||||
int h1 = h0 + k1 / BLOCK_W;
|
||||
int w1 = w0 + k1 % BLOCK_W;
|
||||
int c1 = tid % CHANNEL_STRIDE;
|
||||
|
||||
auto fptr = fmap1[b][h1][w1];
|
||||
if (within_bounds(h1, w1, H1, W1))
|
||||
f1[c1][k1] = fptr[c+c1];
|
||||
else
|
||||
f1[c1][k1] = 0.0;
|
||||
|
||||
f1_grad[c1][k1] = 0.0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int h1 = h0 + threadIdx.x;
|
||||
int w1 = w0 + threadIdx.y;
|
||||
|
||||
for (int n=0; n<N; n++) {
|
||||
x2s[tid] = coords[b][n][h1][w1][0];
|
||||
y2s[tid] = coords[b][n][h1][w1][1];
|
||||
|
||||
scalar_t dx = x2s[tid] - floor(x2s[tid]);
|
||||
scalar_t dy = y2s[tid] - floor(y2s[tid]);
|
||||
|
||||
int rd = 2*r + 1;
|
||||
for (int iy=0; iy<rd+1; iy++) {
|
||||
for (int ix=0; ix<rd+1; ix++) {
|
||||
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
||||
int k1 = k + tid / CHANNEL_STRIDE;
|
||||
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
|
||||
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
|
||||
int c2 = tid % CHANNEL_STRIDE;
|
||||
|
||||
auto fptr = fmap2[b][h2][w2];
|
||||
if (within_bounds(h2, w2, H2, W2))
|
||||
f2[c2][k1] = fptr[c+c2];
|
||||
else
|
||||
f2[c2][k1] = 0.0;
|
||||
|
||||
f2_grad[c2][k1] = 0.0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1];
|
||||
scalar_t g = 0.0;
|
||||
|
||||
int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
|
||||
int ix_ne = H1*W1*((iy-1) + rd*ix);
|
||||
int ix_sw = H1*W1*(iy + rd*(ix-1));
|
||||
int ix_se = H1*W1*(iy + rd*ix);
|
||||
|
||||
if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
|
||||
g += *(grad_ptr + ix_nw) * dy * dx;
|
||||
|
||||
if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
|
||||
g += *(grad_ptr + ix_ne) * dy * (1-dx);
|
||||
|
||||
if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
|
||||
g += *(grad_ptr + ix_sw) * (1-dy) * dx;
|
||||
|
||||
if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
|
||||
g += *(grad_ptr + ix_se) * (1-dy) * (1-dx);
|
||||
|
||||
for (int k=0; k<CHANNEL_STRIDE; k++) {
|
||||
f1_grad[k][tid] += g * f2[k][tid];
|
||||
f2_grad[k][tid] += g * f1[k][tid];
|
||||
}
|
||||
|
||||
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
||||
int k1 = k + tid / CHANNEL_STRIDE;
|
||||
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
|
||||
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
|
||||
int c2 = tid % CHANNEL_STRIDE;
|
||||
|
||||
scalar_t* fptr = &fmap2_grad[b][h2][w2][0];
|
||||
if (within_bounds(h2, w2, H2, W2))
|
||||
atomicAdd(fptr+c+c2, f2_grad[c2][k1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
||||
int k1 = k + tid / CHANNEL_STRIDE;
|
||||
int h1 = h0 + k1 / BLOCK_W;
|
||||
int w1 = w0 + k1 % BLOCK_W;
|
||||
int c1 = tid % CHANNEL_STRIDE;
|
||||
|
||||
scalar_t* fptr = &fmap1_grad[b][h1][w1][0];
|
||||
if (within_bounds(h1, w1, H1, W1))
|
||||
fptr[c+c1] += f1_grad[c1][k1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::vector<torch::Tensor> corr_cuda_forward(
|
||||
torch::Tensor fmap1,
|
||||
torch::Tensor fmap2,
|
||||
torch::Tensor coords,
|
||||
int radius)
|
||||
{
|
||||
const auto B = coords.size(0);
|
||||
const auto N = coords.size(1);
|
||||
const auto H = coords.size(2);
|
||||
const auto W = coords.size(3);
|
||||
|
||||
const auto rd = 2 * radius + 1;
|
||||
auto opts = fmap1.options();
|
||||
auto corr = torch::zeros({B, N, rd*rd, H, W}, opts);
|
||||
|
||||
const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W);
|
||||
const dim3 threads(BLOCK_H, BLOCK_W);
|
||||
|
||||
corr_forward_kernel<float><<<blocks, threads>>>(
|
||||
fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
||||
fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
||||
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
||||
corr.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
||||
radius);
|
||||
|
||||
return {corr};
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> corr_cuda_backward(
|
||||
torch::Tensor fmap1,
|
||||
torch::Tensor fmap2,
|
||||
torch::Tensor coords,
|
||||
torch::Tensor corr_grad,
|
||||
int radius)
|
||||
{
|
||||
const auto B = coords.size(0);
|
||||
const auto N = coords.size(1);
|
||||
|
||||
const auto H1 = fmap1.size(1);
|
||||
const auto W1 = fmap1.size(2);
|
||||
const auto H2 = fmap2.size(1);
|
||||
const auto W2 = fmap2.size(2);
|
||||
const auto C = fmap1.size(3);
|
||||
|
||||
auto opts = fmap1.options();
|
||||
auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts);
|
||||
auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts);
|
||||
auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts);
|
||||
|
||||
const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W);
|
||||
const dim3 threads(BLOCK_H, BLOCK_W);
|
||||
|
||||
|
||||
corr_backward_kernel<float><<<blocks, threads>>>(
|
||||
fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
||||
fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
||||
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
||||
corr_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
||||
fmap1_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
||||
fmap2_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
||||
coords_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
||||
radius);
|
||||
|
||||
return {fmap1_grad, fmap2_grad, coords_grad};
|
||||
}
|
15
RAFT/alt_cuda_corr/setup.py
Executable file
15
RAFT/alt_cuda_corr/setup.py
Executable file
@ -0,0 +1,15 @@
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
|
||||
setup(
|
||||
name='correlation',
|
||||
ext_modules=[
|
||||
CUDAExtension('alt_cuda_corr',
|
||||
sources=['correlation.cpp', 'correlation_kernel.cu'],
|
||||
extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
||||
|
233
RAFT/analysis_video.py
Executable file
233
RAFT/analysis_video.py
Executable file
@ -0,0 +1,233 @@
|
||||
import sys
|
||||
sys.path.append('RAFT/core')
|
||||
import argparse
|
||||
import glob, cv2, os, pdb, time
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
import time
|
||||
from raft import RAFT
|
||||
from RAFT.core.utils import flow_viz
|
||||
from RAFT.core.utils.utils import InputPadder
|
||||
from utils.tools import EvaluteMap,ManagingFeature
|
||||
from utils.config import cfg
|
||||
from utils.updateObs import Addimg_content
|
||||
from utils.retrieval_feature import AntiFraudFeatureDataset
|
||||
DEVICE = 'cuda'
|
||||
global Result
|
||||
pre_area = 0
|
||||
def load_image(imfile):
|
||||
#img = np.array(Image.open(imfile)).astype(np.uint8)
|
||||
img = np.array(imfile).astype(np.uint8)
|
||||
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
||||
return img[None].to(DEVICE)
|
||||
|
||||
def viz(img, flo):
|
||||
img = img[0].permute(1,2,0).cpu().numpy()
|
||||
flo = flo[0].permute(1,2,0).cpu().numpy()
|
||||
flo = flow_viz.flow_to_image(flo)
|
||||
return flo
|
||||
|
||||
def raft_init_model(args):
|
||||
model = torch.nn.DataParallel(RAFT(args))
|
||||
model.load_state_dict(torch.load(args.model))
|
||||
model = model.module
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def analysis_video(model, video_path, result_path, uuid_barcode, maskpath=None, net=None, transform=None, ms=None, match=True):
|
||||
imfile1, imfile2 = None,None
|
||||
affd = AntiFraudFeatureDataset()
|
||||
barcode = uuid_barcode.split('_')[-1]
|
||||
search_r = ManagingFeature().getfeature(barcode)
|
||||
#print('search_r>>>>>>>>', len(search_r))
|
||||
ori_mask = cv2.imread(maskpath, 0)
|
||||
nn, nu = 0, 1
|
||||
Result = '03'
|
||||
img_dic = {}
|
||||
ex_ocrList, resultList = [],[]
|
||||
flag = True
|
||||
fgbg = cv2.createBackgroundSubtractorMOG2(history=500, varThreshold=20, detectShadows = False)
|
||||
#oriimg = cv2.imread(cfg.fgbgmask)
|
||||
with torch.no_grad():
|
||||
capture = cv2.VideoCapture(video_path)
|
||||
ret,frame = capture.read(0)
|
||||
if frame.shape[0]<frame.shape[1]:#1024x1280
|
||||
oriimg = cv2.imread(cfg.fgbgmask)
|
||||
MASKIMG = cfg.MASKIMG
|
||||
else:#1280x1024
|
||||
oriimg = cv2.imread(cfg.fgbgmask_old)
|
||||
MASKIMG = cfg.MASKIMG_old
|
||||
while True:
|
||||
ret,frame = capture.read()
|
||||
if not ret:break
|
||||
frame_show = frame
|
||||
nn += 1
|
||||
if flag: #MOG2过滤为入侵画面
|
||||
if nn%2==0 or nn%3==0:continue #跳帧
|
||||
flag = img_filter(frame, oriimg, fgbg, nn)
|
||||
if flag: continue
|
||||
else: #RAFT定位
|
||||
if nn%2==0:continue
|
||||
height, width = frame.shape[:2]
|
||||
frame = cv2.GaussianBlur(frame,(5,5),0)
|
||||
frame = cv2.resize(frame, (int(width/2), int(height/2)), interpolation=cv2.INTER_CUBIC)
|
||||
if nu == 1:
|
||||
imfile1 = frame
|
||||
nu += 1
|
||||
continue
|
||||
else:
|
||||
imfile2 = frame
|
||||
image1 = load_image(imfile1)
|
||||
image2 = load_image(imfile2)
|
||||
|
||||
padder = InputPadder(image1.shape)
|
||||
image1, image2 = padder.pad(image1, image2)
|
||||
|
||||
flow_low, flow_up = model(image1, image2, iters=2, test_mode=True)
|
||||
flo = viz(image1, flow_up)
|
||||
result = get_target(result_path, flo, imfile1, nu, ori_mask, uuid_barcode, MASKIMG)
|
||||
imfile1 = imfile2
|
||||
flag, nu, Result = detect(match, affd, net, result, transform, ms, search_r, nn, nu, Result)
|
||||
if flag: break
|
||||
Addimg_content(uuid_barcode, frame_show)#图片上传
|
||||
#if not Result=='03':
|
||||
if result is not None:
|
||||
cv2.imwrite(os.sep.join([cfg.Ocrimg, uuid_barcode+'_'+str(nu)+'.jpg']), result) #give ocr img
|
||||
else:
|
||||
cv2.imwrite(os.sep.join([cfg.Ocrimg, uuid_barcode+'_'+str(nu)+'.jpg']), frame_show) #give ocr img
|
||||
return Result
|
||||
|
||||
def detect(match, affd, net, result, transform, ms, search_r, nn, nu, Result):
|
||||
flag = False
|
||||
if match:
|
||||
if result is not None:
|
||||
feature = affd.extractFeature_o(net, result, transform, ms)
|
||||
res = EvaluteMap().match_images(feature, search_r, mod ='single')
|
||||
if Result=='03':
|
||||
Result = ''
|
||||
Result += str(res)+','
|
||||
else:
|
||||
Result += str(res)+','
|
||||
|
||||
if res<cfg.THRESHOLD:
|
||||
#Result ='04'
|
||||
flag = True
|
||||
nu += 1
|
||||
if nu>cfg.NUM_RAFT:
|
||||
#Result = '02'
|
||||
flag = True
|
||||
if nn>100 and nu>2:
|
||||
flag = True
|
||||
else:
|
||||
if result is not None:
|
||||
nu += 1
|
||||
if nu>cfg.NUM_RAFT:
|
||||
flag = True
|
||||
if nn>100 and nu>2:
|
||||
flag = True
|
||||
return flag, nu, Result
|
||||
|
||||
def get_target(path, img, ori_img, nu, ori_mask, uuid_barcode, MASKIMG):
|
||||
global pre_area
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
ret, mask = cv2.threshold(img, 249, 255, cv2.THRESH_BINARY)
|
||||
mask_max_area, mask_max_contour = 0, 0
|
||||
mask = cv2.bitwise_not(mask)
|
||||
mask_image = np.zeros((ori_img.shape[0], ori_img.shape[1], 1), np.uint8)
|
||||
if (cv2.__version__).split('.')[0] == '3':
|
||||
_, contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
||||
else:
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
|
||||
if len(contours)>100:
|
||||
return None
|
||||
for contour in contours:
|
||||
mask_area_now = cv2.contourArea(contour)
|
||||
if mask_area_now > mask_max_area:
|
||||
mask_max_area = mask_area_now
|
||||
mask_max_contour = contour
|
||||
if mask_max_area == 0 :return None #mask_max_area 目标位的面积
|
||||
(x, y, w, h) = cv2.boundingRect(mask_max_contour)
|
||||
if (w*h)/(img.shape[0]*img.shape[1])>0.80:
|
||||
return None
|
||||
if min(w,h) <100 or max(w,h)>1000:
|
||||
return None
|
||||
coordination = [x, y, x + w, y + h]
|
||||
mask_image = cv2.fillPoly(mask_image, [mask_max_contour], (255))
|
||||
if pre_area==0:
|
||||
pre_area = mask_max_area
|
||||
return None
|
||||
else:
|
||||
if abs(mask_max_area-pre_area)/pre_area > 0.4:
|
||||
pre_area = mask_max_area
|
||||
#print('abs:',abs(mask_max_area-pre_area)/pre_area)
|
||||
return None
|
||||
else:
|
||||
pre_area = mask_max_area
|
||||
A,B,C = mask_image, mask_image, mask_image
|
||||
mask_image = cv2.merge([A,B,C])
|
||||
|
||||
#该方法去除框外干扰
|
||||
if not get_iou_ratio(mask_image, MASKIMG):
|
||||
return None
|
||||
|
||||
show = cv2.bitwise_and(ori_img, mask_image)
|
||||
#show = show[coordination[1]:coordination[3], coordination[0]:coordination[2]]
|
||||
show = ori_img[coordination[1]:coordination[3], coordination[0]:coordination[2]]
|
||||
#cv2.imwrite(os.sep.join([cfg.Ocrimg, str(nu-1)+'_'+uuid_barcode+'.jpg']), show)
|
||||
return show
|
||||
|
||||
def get_iou_ratio(oimg, MASKIMG):
|
||||
mimg = cv2.imread(MASKIMG)
|
||||
iimg = cv2.bitwise_and(oimg, mimg)
|
||||
iimgarea = get_area(iimg)
|
||||
oimgarea = get_area(oimg)
|
||||
if iimgarea/oimgarea < 0.1:
|
||||
return False
|
||||
else: return True
|
||||
|
||||
def get_area(img):
|
||||
kernel = np.ones((3, 3), dtype=np.uint8)
|
||||
img = cv2.dilate(img, kernel, 1)
|
||||
img = cv2.erode(img, kernel, 1)
|
||||
maxcontour, nu = 0,0
|
||||
contours, _ = cv2.findContours(img[:,:,1] ,cv2.RETR_TREE , cv2.CHAIN_APPROX_NONE)
|
||||
if len(contours) == 0:
|
||||
return 0
|
||||
for i in range(len(contours)):
|
||||
if maxcontour < len(contours[i]):
|
||||
maxcontour = len(contours[i])
|
||||
nu = i
|
||||
area = cv2.contourArea(contours[nu])
|
||||
return area
|
||||
|
||||
def img_filter(frame, oriimg, fgbg, nn):
|
||||
dic,dics = {},{}
|
||||
iouArea = 0
|
||||
frame = cv2.GaussianBlur(frame, (5, 5), 0)
|
||||
height, width = frame.shape[:2]
|
||||
frame = cv2.resize(frame, (int(width/2), int(height/2)), interpolation=cv2.INTER_CUBIC)
|
||||
# 计算前景掩码
|
||||
fgmask = fgbg.apply(frame)
|
||||
draw1 = cv2.threshold(fgmask, 25, 255, cv2.THRESH_BINARY)[1]
|
||||
if nn==2: return True
|
||||
draw1 = cv2.bitwise_and(oriimg[:, :, 0], draw1)
|
||||
contours_m, hierarchy_m = cv2.findContours(draw1.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
for contour in contours_m:
|
||||
dics[len(contour)] = contour
|
||||
if len(dics.keys())>0:
|
||||
cc = sorted(dics.keys())
|
||||
iouArea = cv2.contourArea(dics[cc[-1]])
|
||||
if iouArea>10000 and iouArea<40000:
|
||||
return False
|
||||
return True
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = raft_init_model()
|
||||
from utils.tools import createNet
|
||||
net, transform, ms = createNet()
|
||||
video_path = '../data/videos/20220625-094651_37dd99b0-520d-457b-8615-efdb7f53b5b4_6907992825762.mp4'#video_path
|
||||
uuid_barcode = '6907992825762'
|
||||
analysis = analysis_video(model=model, video_path=video_path, result_path='', uuid_barcode=uuid_barcode, maskpath=None, net=net, transform=transform, ms=ms)
|
||||
# analysis_video(model, video_path, result_path)
|
22872
RAFT/chairs_split.txt
Executable file
22872
RAFT/chairs_split.txt
Executable file
File diff suppressed because it is too large
Load Diff
0
RAFT/core/__init__.py
Executable file
0
RAFT/core/__init__.py
Executable file
94
RAFT/core/corr.py
Executable file
94
RAFT/core/corr.py
Executable file
@ -0,0 +1,94 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import sys
|
||||
sys.path.append('utils')
|
||||
#from utils.utils import bilinear_sampler, coords_grid
|
||||
from RAFT.core.utils.utils import bilinear_sampler, coords_grid
|
||||
|
||||
try:
|
||||
import alt_cuda_corr
|
||||
except:
|
||||
# alt_cuda_corr is not compiled
|
||||
pass
|
||||
|
||||
|
||||
class CorrBlock:
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
self.corr_pyramid = []
|
||||
|
||||
# all pairs correlation
|
||||
corr = CorrBlock.corr(fmap1, fmap2)
|
||||
|
||||
batch, h1, w1, dim, h2, w2 = corr.shape
|
||||
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
|
||||
|
||||
self.corr_pyramid.append(corr)
|
||||
for i in range(self.num_levels-1):
|
||||
corr = F.avg_pool2d(corr, 2, stride=2)
|
||||
self.corr_pyramid.append(corr)
|
||||
|
||||
def __call__(self, coords):
|
||||
r = self.radius
|
||||
coords = coords.permute(0, 2, 3, 1)
|
||||
batch, h1, w1, _ = coords.shape
|
||||
|
||||
out_pyramid = []
|
||||
for i in range(self.num_levels):
|
||||
corr = self.corr_pyramid[i]
|
||||
dx = torch.linspace(-r, r, 2*r+1)
|
||||
dy = torch.linspace(-r, r, 2*r+1)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
|
||||
|
||||
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
||||
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
|
||||
coords_lvl = centroid_lvl + delta_lvl
|
||||
|
||||
corr = bilinear_sampler(corr, coords_lvl)
|
||||
corr = corr.view(batch, h1, w1, -1)
|
||||
out_pyramid.append(corr)
|
||||
|
||||
out = torch.cat(out_pyramid, dim=-1)
|
||||
return out.permute(0, 3, 1, 2).contiguous().float()
|
||||
|
||||
@staticmethod
|
||||
def corr(fmap1, fmap2):
|
||||
batch, dim, ht, wd = fmap1.shape
|
||||
fmap1 = fmap1.view(batch, dim, ht*wd)
|
||||
fmap2 = fmap2.view(batch, dim, ht*wd)
|
||||
|
||||
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
|
||||
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||
|
||||
|
||||
class AlternateCorrBlock:
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
|
||||
self.pyramid = [(fmap1, fmap2)]
|
||||
for i in range(self.num_levels):
|
||||
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
|
||||
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
|
||||
self.pyramid.append((fmap1, fmap2))
|
||||
|
||||
def __call__(self, coords):
|
||||
coords = coords.permute(0, 2, 3, 1)
|
||||
B, H, W, _ = coords.shape
|
||||
dim = self.pyramid[0][0].shape[1]
|
||||
|
||||
corr_list = []
|
||||
for i in range(self.num_levels):
|
||||
r = self.radius
|
||||
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
|
||||
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
||||
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
|
||||
corr_list.append(corr.squeeze(1))
|
||||
|
||||
corr = torch.stack(corr_list, dim=1)
|
||||
corr = corr.reshape(B, -1, H, W)
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
235
RAFT/core/datasets.py
Executable file
235
RAFT/core/datasets.py
Executable file
@ -0,0 +1,235 @@
|
||||
# Data loading based on https://github.com/NVIDIA/flownet2-pytorch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import torch.nn.functional as F
|
||||
|
||||
import os
|
||||
import math
|
||||
import random
|
||||
from glob import glob
|
||||
import os.path as osp
|
||||
|
||||
from utils import frame_utils
|
||||
from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
|
||||
|
||||
|
||||
class FlowDataset(data.Dataset):
|
||||
def __init__(self, aug_params=None, sparse=False):
|
||||
self.augmentor = None
|
||||
self.sparse = sparse
|
||||
if aug_params is not None:
|
||||
if sparse:
|
||||
self.augmentor = SparseFlowAugmentor(**aug_params)
|
||||
else:
|
||||
self.augmentor = FlowAugmentor(**aug_params)
|
||||
|
||||
self.is_test = False
|
||||
self.init_seed = False
|
||||
self.flow_list = []
|
||||
self.image_list = []
|
||||
self.extra_info = []
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
if self.is_test:
|
||||
img1 = frame_utils.read_gen(self.image_list[index][0])
|
||||
img2 = frame_utils.read_gen(self.image_list[index][1])
|
||||
img1 = np.array(img1).astype(np.uint8)[..., :3]
|
||||
img2 = np.array(img2).astype(np.uint8)[..., :3]
|
||||
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
||||
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
||||
return img1, img2, self.extra_info[index]
|
||||
|
||||
if not self.init_seed:
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is not None:
|
||||
torch.manual_seed(worker_info.id)
|
||||
np.random.seed(worker_info.id)
|
||||
random.seed(worker_info.id)
|
||||
self.init_seed = True
|
||||
|
||||
index = index % len(self.image_list)
|
||||
valid = None
|
||||
if self.sparse:
|
||||
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
|
||||
else:
|
||||
flow = frame_utils.read_gen(self.flow_list[index])
|
||||
|
||||
img1 = frame_utils.read_gen(self.image_list[index][0])
|
||||
img2 = frame_utils.read_gen(self.image_list[index][1])
|
||||
|
||||
flow = np.array(flow).astype(np.float32)
|
||||
img1 = np.array(img1).astype(np.uint8)
|
||||
img2 = np.array(img2).astype(np.uint8)
|
||||
|
||||
# grayscale images
|
||||
if len(img1.shape) == 2:
|
||||
img1 = np.tile(img1[...,None], (1, 1, 3))
|
||||
img2 = np.tile(img2[...,None], (1, 1, 3))
|
||||
else:
|
||||
img1 = img1[..., :3]
|
||||
img2 = img2[..., :3]
|
||||
|
||||
if self.augmentor is not None:
|
||||
if self.sparse:
|
||||
img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
|
||||
else:
|
||||
img1, img2, flow = self.augmentor(img1, img2, flow)
|
||||
|
||||
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
||||
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
||||
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
|
||||
|
||||
if valid is not None:
|
||||
valid = torch.from_numpy(valid)
|
||||
else:
|
||||
valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
|
||||
|
||||
return img1, img2, flow, valid.float()
|
||||
|
||||
|
||||
def __rmul__(self, v):
|
||||
self.flow_list = v * self.flow_list
|
||||
self.image_list = v * self.image_list
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_list)
|
||||
|
||||
|
||||
class MpiSintel(FlowDataset):
|
||||
def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
|
||||
super(MpiSintel, self).__init__(aug_params)
|
||||
flow_root = osp.join(root, split, 'flow')
|
||||
image_root = osp.join(root, split, dstype)
|
||||
|
||||
if split == 'test':
|
||||
self.is_test = True
|
||||
|
||||
for scene in os.listdir(image_root):
|
||||
image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
|
||||
for i in range(len(image_list)-1):
|
||||
self.image_list += [ [image_list[i], image_list[i+1]] ]
|
||||
self.extra_info += [ (scene, i) ] # scene and frame_id
|
||||
|
||||
if split != 'test':
|
||||
self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
|
||||
|
||||
|
||||
class FlyingChairs(FlowDataset):
|
||||
def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
|
||||
super(FlyingChairs, self).__init__(aug_params)
|
||||
|
||||
images = sorted(glob(osp.join(root, '*.ppm')))
|
||||
flows = sorted(glob(osp.join(root, '*.flo')))
|
||||
assert (len(images)//2 == len(flows))
|
||||
|
||||
split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
|
||||
for i in range(len(flows)):
|
||||
xid = split_list[i]
|
||||
if (split=='training' and xid==1) or (split=='validation' and xid==2):
|
||||
self.flow_list += [ flows[i] ]
|
||||
self.image_list += [ [images[2*i], images[2*i+1]] ]
|
||||
|
||||
|
||||
class FlyingThings3D(FlowDataset):
|
||||
def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
|
||||
super(FlyingThings3D, self).__init__(aug_params)
|
||||
|
||||
for cam in ['left']:
|
||||
for direction in ['into_future', 'into_past']:
|
||||
image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
|
||||
image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
|
||||
|
||||
flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
|
||||
flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
|
||||
|
||||
for idir, fdir in zip(image_dirs, flow_dirs):
|
||||
images = sorted(glob(osp.join(idir, '*.png')) )
|
||||
flows = sorted(glob(osp.join(fdir, '*.pfm')) )
|
||||
for i in range(len(flows)-1):
|
||||
if direction == 'into_future':
|
||||
self.image_list += [ [images[i], images[i+1]] ]
|
||||
self.flow_list += [ flows[i] ]
|
||||
elif direction == 'into_past':
|
||||
self.image_list += [ [images[i+1], images[i]] ]
|
||||
self.flow_list += [ flows[i+1] ]
|
||||
|
||||
|
||||
class KITTI(FlowDataset):
|
||||
def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
|
||||
super(KITTI, self).__init__(aug_params, sparse=True)
|
||||
if split == 'testing':
|
||||
self.is_test = True
|
||||
|
||||
root = osp.join(root, split)
|
||||
images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
|
||||
images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
|
||||
|
||||
for img1, img2 in zip(images1, images2):
|
||||
frame_id = img1.split('/')[-1]
|
||||
self.extra_info += [ [frame_id] ]
|
||||
self.image_list += [ [img1, img2] ]
|
||||
|
||||
if split == 'training':
|
||||
self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
|
||||
|
||||
|
||||
class HD1K(FlowDataset):
|
||||
def __init__(self, aug_params=None, root='datasets/HD1k'):
|
||||
super(HD1K, self).__init__(aug_params, sparse=True)
|
||||
|
||||
seq_ix = 0
|
||||
while 1:
|
||||
flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
|
||||
images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
|
||||
|
||||
if len(flows) == 0:
|
||||
break
|
||||
|
||||
for i in range(len(flows)-1):
|
||||
self.flow_list += [flows[i]]
|
||||
self.image_list += [ [images[i], images[i+1]] ]
|
||||
|
||||
seq_ix += 1
|
||||
|
||||
|
||||
def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
|
||||
""" Create the data loader for the corresponding trainign set """
|
||||
|
||||
if args.stage == 'chairs':
|
||||
aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
|
||||
train_dataset = FlyingChairs(aug_params, split='training')
|
||||
|
||||
elif args.stage == 'things':
|
||||
aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
|
||||
clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
|
||||
final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
|
||||
train_dataset = clean_dataset + final_dataset
|
||||
|
||||
elif args.stage == 'sintel':
|
||||
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
|
||||
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
|
||||
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
|
||||
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
|
||||
|
||||
if TRAIN_DS == 'C+T+K+S+H':
|
||||
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
|
||||
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
|
||||
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
|
||||
|
||||
elif TRAIN_DS == 'C+T+K/S':
|
||||
train_dataset = 100*sintel_clean + 100*sintel_final + things
|
||||
|
||||
elif args.stage == 'kitti':
|
||||
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
|
||||
train_dataset = KITTI(aug_params, split='training')
|
||||
|
||||
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
||||
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
|
||||
|
||||
print('Training with %d image pairs' % len(train_dataset))
|
||||
return train_loader
|
||||
|
267
RAFT/core/extractor.py
Executable file
267
RAFT/core/extractor.py
Executable file
@ -0,0 +1,267 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(planes)
|
||||
self.norm2 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(planes)
|
||||
self.norm2 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x+y)
|
||||
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
|
||||
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
|
||||
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(planes//4)
|
||||
self.norm2 = nn.BatchNorm2d(planes//4)
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(planes//4)
|
||||
self.norm2 = nn.InstanceNorm2d(planes//4)
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
self.norm3 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
y = self.relu(self.norm3(self.conv3(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x+y)
|
||||
|
||||
class BasicEncoder(nn.Module):
|
||||
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
||||
super(BasicEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
||||
|
||||
elif self.norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = 64
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer2 = self._make_layer(96, stride=2)
|
||||
self.layer3 = self._make_layer(128, stride=2)
|
||||
|
||||
# output convolution
|
||||
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# if input is list, combine batch dimension
|
||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||
if is_list:
|
||||
batch_dim = x[0].shape[0]
|
||||
x = torch.cat(x, dim=0)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SmallEncoder(nn.Module):
|
||||
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
||||
super(SmallEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
||||
|
||||
elif self.norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(32)
|
||||
|
||||
elif self.norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(32)
|
||||
|
||||
elif self.norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = 32
|
||||
self.layer1 = self._make_layer(32, stride=1)
|
||||
self.layer2 = self._make_layer(64, stride=2)
|
||||
self.layer3 = self._make_layer(96, stride=2)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
|
||||
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# if input is list, combine batch dimension
|
||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||
if is_list:
|
||||
batch_dim = x[0].shape[0]
|
||||
x = torch.cat(x, dim=0)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||
|
||||
return x
|
145
RAFT/core/raft.py
Executable file
145
RAFT/core/raft.py
Executable file
@ -0,0 +1,145 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from update import BasicUpdateBlock, SmallUpdateBlock
|
||||
from extractor import BasicEncoder, SmallEncoder
|
||||
from corr import CorrBlock, AlternateCorrBlock
|
||||
from RAFT.core.utils.utils import bilinear_sampler, coords_grid, upflow8
|
||||
|
||||
try:
|
||||
autocast = torch.cuda.amp.autocast
|
||||
except:
|
||||
# dummy autocast for PyTorch < 1.6
|
||||
class autocast:
|
||||
def __init__(self, enabled):
|
||||
pass
|
||||
def __enter__(self):
|
||||
pass
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
class RAFT(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(RAFT, self).__init__()
|
||||
self.args = args
|
||||
#args.small = True
|
||||
|
||||
if args.small:
|
||||
self.hidden_dim = hdim = 96
|
||||
self.context_dim = cdim = 64
|
||||
args.corr_levels = 4
|
||||
args.corr_radius = 3
|
||||
|
||||
else:
|
||||
self.hidden_dim = hdim = 128
|
||||
self.context_dim = cdim = 128
|
||||
args.corr_levels = 4
|
||||
args.corr_radius = 4
|
||||
|
||||
if 'dropout' not in self.args:
|
||||
self.args.dropout = 0
|
||||
|
||||
if 'alternate_corr' not in self.args:
|
||||
self.args.alternate_corr = False
|
||||
|
||||
# feature network, context network, and update block
|
||||
if args.small:
|
||||
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
|
||||
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
|
||||
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
|
||||
|
||||
else:
|
||||
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
|
||||
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
|
||||
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
||||
|
||||
def freeze_bn(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
m.eval()
|
||||
|
||||
def initialize_flow(self, img):
|
||||
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
|
||||
N, C, H, W = img.shape
|
||||
coords0 = coords_grid(N, H//8, W//8).to(img.device)
|
||||
coords1 = coords_grid(N, H//8, W//8).to(img.device)
|
||||
|
||||
# optical flow computed as difference: flow = coords1 - coords0
|
||||
return coords0, coords1
|
||||
|
||||
def upsample_flow(self, flow, mask):
|
||||
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
|
||||
N, _, H, W = flow.shape
|
||||
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
||||
mask = torch.softmax(mask, dim=2)
|
||||
|
||||
up_flow = F.unfold(8 * flow, [3,3], padding=1)
|
||||
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
||||
|
||||
up_flow = torch.sum(mask * up_flow, dim=2)
|
||||
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
||||
return up_flow.reshape(N, 2, 8*H, 8*W)
|
||||
|
||||
|
||||
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
|
||||
""" Estimate optical flow between pair of frames """
|
||||
|
||||
image1 = 2 * (image1 / 255.0) - 1.0
|
||||
image2 = 2 * (image2 / 255.0) - 1.0
|
||||
|
||||
image1 = image1.contiguous()
|
||||
image2 = image2.contiguous()
|
||||
|
||||
hdim = self.hidden_dim
|
||||
cdim = self.context_dim
|
||||
|
||||
# run the feature network
|
||||
with autocast(enabled=self.args.mixed_precision):
|
||||
fmap1, fmap2 = self.fnet([image1, image2])
|
||||
|
||||
fmap1 = fmap1.float()
|
||||
fmap2 = fmap2.float()
|
||||
if self.args.alternate_corr:
|
||||
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
else:
|
||||
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
|
||||
# run the context network
|
||||
with autocast(enabled=self.args.mixed_precision):
|
||||
cnet = self.cnet(image1)
|
||||
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
|
||||
net = torch.tanh(net)
|
||||
inp = torch.relu(inp)
|
||||
|
||||
coords0, coords1 = self.initialize_flow(image1)
|
||||
|
||||
if flow_init is not None:
|
||||
coords1 = coords1 + flow_init
|
||||
|
||||
flow_predictions = []
|
||||
for itr in range(iters):
|
||||
coords1 = coords1.detach()
|
||||
corr = corr_fn(coords1) # index correlation volume
|
||||
|
||||
flow = coords1 - coords0
|
||||
with autocast(enabled=self.args.mixed_precision):
|
||||
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
|
||||
|
||||
# F(t+1) = F(t) + \Delta(t)
|
||||
coords1 = coords1 + delta_flow
|
||||
|
||||
# upsample predictions
|
||||
if up_mask is None:
|
||||
flow_up = upflow8(coords1 - coords0)
|
||||
else:
|
||||
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
|
||||
|
||||
flow_predictions.append(flow_up)
|
||||
|
||||
if test_mode:
|
||||
return coords1 - coords0, flow_up
|
||||
|
||||
return flow_predictions
|
139
RAFT/core/update.py
Executable file
139
RAFT/core/update.py
Executable file
@ -0,0 +1,139 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FlowHead(nn.Module):
|
||||
def __init__(self, input_dim=128, hidden_dim=256):
|
||||
super(FlowHead, self).__init__()
|
||||
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.relu(self.conv1(x)))
|
||||
|
||||
class ConvGRU(nn.Module):
|
||||
def __init__(self, hidden_dim=128, input_dim=192+128):
|
||||
super(ConvGRU, self).__init__()
|
||||
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
||||
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
||||
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
||||
|
||||
def forward(self, h, x):
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
|
||||
z = torch.sigmoid(self.convz(hx))
|
||||
r = torch.sigmoid(self.convr(hx))
|
||||
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
|
||||
|
||||
h = (1-z) * h + z * q
|
||||
return h
|
||||
|
||||
class SepConvGRU(nn.Module):
|
||||
def __init__(self, hidden_dim=128, input_dim=192+128):
|
||||
super(SepConvGRU, self).__init__()
|
||||
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
||||
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
||||
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
||||
|
||||
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
||||
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
||||
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
||||
|
||||
|
||||
def forward(self, h, x):
|
||||
# horizontal
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz1(hx))
|
||||
r = torch.sigmoid(self.convr1(hx))
|
||||
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
|
||||
h = (1-z) * h + z * q
|
||||
|
||||
# vertical
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz2(hx))
|
||||
r = torch.sigmoid(self.convr2(hx))
|
||||
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
|
||||
h = (1-z) * h + z * q
|
||||
|
||||
return h
|
||||
|
||||
class SmallMotionEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(SmallMotionEncoder, self).__init__()
|
||||
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
|
||||
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
|
||||
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
|
||||
self.conv = nn.Conv2d(128, 80, 3, padding=1)
|
||||
|
||||
def forward(self, flow, corr):
|
||||
cor = F.relu(self.convc1(corr))
|
||||
flo = F.relu(self.convf1(flow))
|
||||
flo = F.relu(self.convf2(flo))
|
||||
cor_flo = torch.cat([cor, flo], dim=1)
|
||||
out = F.relu(self.conv(cor_flo))
|
||||
return torch.cat([out, flow], dim=1)
|
||||
|
||||
class BasicMotionEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(BasicMotionEncoder, self).__init__()
|
||||
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
|
||||
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
||||
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
||||
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
||||
self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
|
||||
|
||||
def forward(self, flow, corr):
|
||||
cor = F.relu(self.convc1(corr))
|
||||
cor = F.relu(self.convc2(cor))
|
||||
flo = F.relu(self.convf1(flow))
|
||||
flo = F.relu(self.convf2(flo))
|
||||
|
||||
cor_flo = torch.cat([cor, flo], dim=1)
|
||||
out = F.relu(self.conv(cor_flo))
|
||||
return torch.cat([out, flow], dim=1)
|
||||
|
||||
class SmallUpdateBlock(nn.Module):
|
||||
def __init__(self, args, hidden_dim=96):
|
||||
super(SmallUpdateBlock, self).__init__()
|
||||
self.encoder = SmallMotionEncoder(args)
|
||||
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
|
||||
|
||||
def forward(self, net, inp, corr, flow):
|
||||
motion_features = self.encoder(flow, corr)
|
||||
inp = torch.cat([inp, motion_features], dim=1)
|
||||
net = self.gru(net, inp)
|
||||
delta_flow = self.flow_head(net)
|
||||
|
||||
return net, None, delta_flow
|
||||
|
||||
class BasicUpdateBlock(nn.Module):
|
||||
def __init__(self, args, hidden_dim=128, input_dim=128):
|
||||
super(BasicUpdateBlock, self).__init__()
|
||||
self.args = args
|
||||
self.encoder = BasicMotionEncoder(args)
|
||||
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||
|
||||
self.mask = nn.Sequential(
|
||||
nn.Conv2d(128, 256, 3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 64*9, 1, padding=0))
|
||||
|
||||
def forward(self, net, inp, corr, flow, upsample=True):
|
||||
motion_features = self.encoder(flow, corr)
|
||||
inp = torch.cat([inp, motion_features], dim=1)
|
||||
|
||||
net = self.gru(net, inp)
|
||||
delta_flow = self.flow_head(net)
|
||||
|
||||
# scale mask to balence gradients
|
||||
mask = .25 * self.mask(net)
|
||||
return net, mask, delta_flow
|
||||
|
||||
|
||||
|
0
RAFT/core/utils/__init__.py
Executable file
0
RAFT/core/utils/__init__.py
Executable file
246
RAFT/core/utils/augmentor.py
Executable file
246
RAFT/core/utils/augmentor.py
Executable file
@ -0,0 +1,246 @@
|
||||
import numpy as np
|
||||
import random
|
||||
import math
|
||||
from PIL import Image
|
||||
|
||||
import cv2
|
||||
cv2.setNumThreads(0)
|
||||
cv2.ocl.setUseOpenCL(False)
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import ColorJitter
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FlowAugmentor:
|
||||
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
|
||||
|
||||
# spatial augmentation params
|
||||
self.crop_size = crop_size
|
||||
self.min_scale = min_scale
|
||||
self.max_scale = max_scale
|
||||
self.spatial_aug_prob = 0.8
|
||||
self.stretch_prob = 0.8
|
||||
self.max_stretch = 0.2
|
||||
|
||||
# flip augmentation params
|
||||
self.do_flip = do_flip
|
||||
self.h_flip_prob = 0.5
|
||||
self.v_flip_prob = 0.1
|
||||
|
||||
# photometric augmentation params
|
||||
self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
|
||||
self.asymmetric_color_aug_prob = 0.2
|
||||
self.eraser_aug_prob = 0.5
|
||||
|
||||
def color_transform(self, img1, img2):
|
||||
""" Photometric augmentation """
|
||||
|
||||
# asymmetric
|
||||
if np.random.rand() < self.asymmetric_color_aug_prob:
|
||||
img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
|
||||
img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
|
||||
|
||||
# symmetric
|
||||
else:
|
||||
image_stack = np.concatenate([img1, img2], axis=0)
|
||||
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
||||
img1, img2 = np.split(image_stack, 2, axis=0)
|
||||
|
||||
return img1, img2
|
||||
|
||||
def eraser_transform(self, img1, img2, bounds=[50, 100]):
|
||||
""" Occlusion augmentation """
|
||||
|
||||
ht, wd = img1.shape[:2]
|
||||
if np.random.rand() < self.eraser_aug_prob:
|
||||
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
||||
for _ in range(np.random.randint(1, 3)):
|
||||
x0 = np.random.randint(0, wd)
|
||||
y0 = np.random.randint(0, ht)
|
||||
dx = np.random.randint(bounds[0], bounds[1])
|
||||
dy = np.random.randint(bounds[0], bounds[1])
|
||||
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
|
||||
|
||||
return img1, img2
|
||||
|
||||
def spatial_transform(self, img1, img2, flow):
|
||||
# randomly sample scale
|
||||
ht, wd = img1.shape[:2]
|
||||
min_scale = np.maximum(
|
||||
(self.crop_size[0] + 8) / float(ht),
|
||||
(self.crop_size[1] + 8) / float(wd))
|
||||
|
||||
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
||||
scale_x = scale
|
||||
scale_y = scale
|
||||
if np.random.rand() < self.stretch_prob:
|
||||
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
||||
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
||||
|
||||
scale_x = np.clip(scale_x, min_scale, None)
|
||||
scale_y = np.clip(scale_y, min_scale, None)
|
||||
|
||||
if np.random.rand() < self.spatial_aug_prob:
|
||||
# rescale the images
|
||||
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||
flow = flow * [scale_x, scale_y]
|
||||
|
||||
if self.do_flip:
|
||||
if np.random.rand() < self.h_flip_prob: # h-flip
|
||||
img1 = img1[:, ::-1]
|
||||
img2 = img2[:, ::-1]
|
||||
flow = flow[:, ::-1] * [-1.0, 1.0]
|
||||
|
||||
if np.random.rand() < self.v_flip_prob: # v-flip
|
||||
img1 = img1[::-1, :]
|
||||
img2 = img2[::-1, :]
|
||||
flow = flow[::-1, :] * [1.0, -1.0]
|
||||
|
||||
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
|
||||
x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
|
||||
|
||||
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
|
||||
return img1, img2, flow
|
||||
|
||||
def __call__(self, img1, img2, flow):
|
||||
img1, img2 = self.color_transform(img1, img2)
|
||||
img1, img2 = self.eraser_transform(img1, img2)
|
||||
img1, img2, flow = self.spatial_transform(img1, img2, flow)
|
||||
|
||||
img1 = np.ascontiguousarray(img1)
|
||||
img2 = np.ascontiguousarray(img2)
|
||||
flow = np.ascontiguousarray(flow)
|
||||
|
||||
return img1, img2, flow
|
||||
|
||||
class SparseFlowAugmentor:
|
||||
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
|
||||
# spatial augmentation params
|
||||
self.crop_size = crop_size
|
||||
self.min_scale = min_scale
|
||||
self.max_scale = max_scale
|
||||
self.spatial_aug_prob = 0.8
|
||||
self.stretch_prob = 0.8
|
||||
self.max_stretch = 0.2
|
||||
|
||||
# flip augmentation params
|
||||
self.do_flip = do_flip
|
||||
self.h_flip_prob = 0.5
|
||||
self.v_flip_prob = 0.1
|
||||
|
||||
# photometric augmentation params
|
||||
self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
|
||||
self.asymmetric_color_aug_prob = 0.2
|
||||
self.eraser_aug_prob = 0.5
|
||||
|
||||
def color_transform(self, img1, img2):
|
||||
image_stack = np.concatenate([img1, img2], axis=0)
|
||||
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
||||
img1, img2 = np.split(image_stack, 2, axis=0)
|
||||
return img1, img2
|
||||
|
||||
def eraser_transform(self, img1, img2):
|
||||
ht, wd = img1.shape[:2]
|
||||
if np.random.rand() < self.eraser_aug_prob:
|
||||
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
||||
for _ in range(np.random.randint(1, 3)):
|
||||
x0 = np.random.randint(0, wd)
|
||||
y0 = np.random.randint(0, ht)
|
||||
dx = np.random.randint(50, 100)
|
||||
dy = np.random.randint(50, 100)
|
||||
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
|
||||
|
||||
return img1, img2
|
||||
|
||||
def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
|
||||
ht, wd = flow.shape[:2]
|
||||
coords = np.meshgrid(np.arange(wd), np.arange(ht))
|
||||
coords = np.stack(coords, axis=-1)
|
||||
|
||||
coords = coords.reshape(-1, 2).astype(np.float32)
|
||||
flow = flow.reshape(-1, 2).astype(np.float32)
|
||||
valid = valid.reshape(-1).astype(np.float32)
|
||||
|
||||
coords0 = coords[valid>=1]
|
||||
flow0 = flow[valid>=1]
|
||||
|
||||
ht1 = int(round(ht * fy))
|
||||
wd1 = int(round(wd * fx))
|
||||
|
||||
coords1 = coords0 * [fx, fy]
|
||||
flow1 = flow0 * [fx, fy]
|
||||
|
||||
xx = np.round(coords1[:,0]).astype(np.int32)
|
||||
yy = np.round(coords1[:,1]).astype(np.int32)
|
||||
|
||||
v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
|
||||
xx = xx[v]
|
||||
yy = yy[v]
|
||||
flow1 = flow1[v]
|
||||
|
||||
flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
|
||||
valid_img = np.zeros([ht1, wd1], dtype=np.int32)
|
||||
|
||||
flow_img[yy, xx] = flow1
|
||||
valid_img[yy, xx] = 1
|
||||
|
||||
return flow_img, valid_img
|
||||
|
||||
def spatial_transform(self, img1, img2, flow, valid):
|
||||
# randomly sample scale
|
||||
|
||||
ht, wd = img1.shape[:2]
|
||||
min_scale = np.maximum(
|
||||
(self.crop_size[0] + 1) / float(ht),
|
||||
(self.crop_size[1] + 1) / float(wd))
|
||||
|
||||
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
||||
scale_x = np.clip(scale, min_scale, None)
|
||||
scale_y = np.clip(scale, min_scale, None)
|
||||
|
||||
if np.random.rand() < self.spatial_aug_prob:
|
||||
# rescale the images
|
||||
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
||||
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
|
||||
|
||||
if self.do_flip:
|
||||
if np.random.rand() < 0.5: # h-flip
|
||||
img1 = img1[:, ::-1]
|
||||
img2 = img2[:, ::-1]
|
||||
flow = flow[:, ::-1] * [-1.0, 1.0]
|
||||
valid = valid[:, ::-1]
|
||||
|
||||
margin_y = 20
|
||||
margin_x = 50
|
||||
|
||||
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
|
||||
x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
|
||||
|
||||
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
|
||||
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
|
||||
|
||||
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
||||
return img1, img2, flow, valid
|
||||
|
||||
|
||||
def __call__(self, img1, img2, flow, valid):
|
||||
img1, img2 = self.color_transform(img1, img2)
|
||||
img1, img2 = self.eraser_transform(img1, img2)
|
||||
img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
|
||||
|
||||
img1 = np.ascontiguousarray(img1)
|
||||
img2 = np.ascontiguousarray(img2)
|
||||
flow = np.ascontiguousarray(flow)
|
||||
valid = np.ascontiguousarray(valid)
|
||||
|
||||
return img1, img2, flow, valid
|
132
RAFT/core/utils/flow_viz.py
Executable file
132
RAFT/core/utils/flow_viz.py
Executable file
@ -0,0 +1,132 @@
|
||||
# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
|
||||
|
||||
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2018 Tom Runia
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to conditions.
|
||||
#
|
||||
# Author: Tom Runia
|
||||
# Date Created: 2018-08-03
|
||||
|
||||
import numpy as np
|
||||
|
||||
def make_colorwheel():
|
||||
"""
|
||||
Generates a color wheel for optical flow visualization as presented in:
|
||||
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
||||
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
||||
|
||||
Code follows the original C++ source code of Daniel Scharstein.
|
||||
Code follows the the Matlab source code of Deqing Sun.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Color wheel
|
||||
"""
|
||||
|
||||
RY = 15
|
||||
YG = 6
|
||||
GC = 4
|
||||
CB = 11
|
||||
BM = 13
|
||||
MR = 6
|
||||
|
||||
ncols = RY + YG + GC + CB + BM + MR
|
||||
colorwheel = np.zeros((ncols, 3))
|
||||
col = 0
|
||||
|
||||
# RY
|
||||
colorwheel[0:RY, 0] = 255
|
||||
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
|
||||
col = col+RY
|
||||
# YG
|
||||
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
|
||||
colorwheel[col:col+YG, 1] = 255
|
||||
col = col+YG
|
||||
# GC
|
||||
colorwheel[col:col+GC, 1] = 255
|
||||
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
|
||||
col = col+GC
|
||||
# CB
|
||||
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
|
||||
colorwheel[col:col+CB, 2] = 255
|
||||
col = col+CB
|
||||
# BM
|
||||
colorwheel[col:col+BM, 2] = 255
|
||||
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
|
||||
col = col+BM
|
||||
# MR
|
||||
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
|
||||
colorwheel[col:col+MR, 0] = 255
|
||||
return colorwheel
|
||||
|
||||
|
||||
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
||||
"""
|
||||
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
||||
|
||||
According to the C++ source code of Daniel Scharstein
|
||||
According to the Matlab source code of Deqing Sun
|
||||
|
||||
Args:
|
||||
u (np.ndarray): Input horizontal flow of shape [H,W]
|
||||
v (np.ndarray): Input vertical flow of shape [H,W]
|
||||
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Flow visualization image of shape [H,W,3]
|
||||
"""
|
||||
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
||||
colorwheel = make_colorwheel() # shape [55x3]
|
||||
ncols = colorwheel.shape[0]
|
||||
rad = np.sqrt(np.square(u) + np.square(v))
|
||||
a = np.arctan2(-v, -u)/np.pi
|
||||
fk = (a+1) / 2*(ncols-1)
|
||||
k0 = np.floor(fk).astype(np.int32)
|
||||
k1 = k0 + 1
|
||||
k1[k1 == ncols] = 0
|
||||
f = fk - k0
|
||||
for i in range(colorwheel.shape[1]):
|
||||
tmp = colorwheel[:,i]
|
||||
col0 = tmp[k0] / 255.0
|
||||
col1 = tmp[k1] / 255.0
|
||||
col = (1-f)*col0 + f*col1
|
||||
idx = (rad <= 1)
|
||||
col[idx] = 1 - rad[idx] * (1-col[idx])
|
||||
col[~idx] = col[~idx] * 0.75 # out of range
|
||||
# Note the 2-i => BGR instead of RGB
|
||||
ch_idx = 2-i if convert_to_bgr else i
|
||||
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
||||
return flow_image
|
||||
|
||||
|
||||
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
|
||||
"""
|
||||
Expects a two dimensional flow image of shape.
|
||||
|
||||
Args:
|
||||
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
||||
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
||||
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Flow visualization image of shape [H,W,3]
|
||||
"""
|
||||
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
||||
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
||||
if clip_flow is not None:
|
||||
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
||||
u = flow_uv[:,:,0]
|
||||
v = flow_uv[:,:,1]
|
||||
rad = np.sqrt(np.square(u) + np.square(v))
|
||||
rad_max = np.max(rad)
|
||||
epsilon = 1e-5
|
||||
u = u / (rad_max + epsilon)
|
||||
v = v / (rad_max + epsilon)
|
||||
return flow_uv_to_colors(u, v, convert_to_bgr)
|
137
RAFT/core/utils/frame_utils.py
Executable file
137
RAFT/core/utils/frame_utils.py
Executable file
@ -0,0 +1,137 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from os.path import *
|
||||
import re
|
||||
|
||||
import cv2
|
||||
cv2.setNumThreads(0)
|
||||
cv2.ocl.setUseOpenCL(False)
|
||||
|
||||
TAG_CHAR = np.array([202021.25], np.float32)
|
||||
|
||||
def readFlow(fn):
|
||||
""" Read .flo file in Middlebury format"""
|
||||
# Code adapted from:
|
||||
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
|
||||
|
||||
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
|
||||
# print 'fn = %s'%(fn)
|
||||
with open(fn, 'rb') as f:
|
||||
magic = np.fromfile(f, np.float32, count=1)
|
||||
if 202021.25 != magic:
|
||||
print('Magic number incorrect. Invalid .flo file')
|
||||
return None
|
||||
else:
|
||||
w = np.fromfile(f, np.int32, count=1)
|
||||
h = np.fromfile(f, np.int32, count=1)
|
||||
# print 'Reading %d x %d flo file\n' % (w, h)
|
||||
data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
|
||||
# Reshape data into 3D array (columns, rows, bands)
|
||||
# The reshape here is for visualization, the original code is (w,h,2)
|
||||
return np.resize(data, (int(h), int(w), 2))
|
||||
|
||||
def readPFM(file):
|
||||
file = open(file, 'rb')
|
||||
|
||||
color = None
|
||||
width = None
|
||||
height = None
|
||||
scale = None
|
||||
endian = None
|
||||
|
||||
header = file.readline().rstrip()
|
||||
if header == b'PF':
|
||||
color = True
|
||||
elif header == b'Pf':
|
||||
color = False
|
||||
else:
|
||||
raise Exception('Not a PFM file.')
|
||||
|
||||
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
|
||||
if dim_match:
|
||||
width, height = map(int, dim_match.groups())
|
||||
else:
|
||||
raise Exception('Malformed PFM header.')
|
||||
|
||||
scale = float(file.readline().rstrip())
|
||||
if scale < 0: # little-endian
|
||||
endian = '<'
|
||||
scale = -scale
|
||||
else:
|
||||
endian = '>' # big-endian
|
||||
|
||||
data = np.fromfile(file, endian + 'f')
|
||||
shape = (height, width, 3) if color else (height, width)
|
||||
|
||||
data = np.reshape(data, shape)
|
||||
data = np.flipud(data)
|
||||
return data
|
||||
|
||||
def writeFlow(filename,uv,v=None):
|
||||
""" Write optical flow to file.
|
||||
|
||||
If v is None, uv is assumed to contain both u and v channels,
|
||||
stacked in depth.
|
||||
Original code by Deqing Sun, adapted from Daniel Scharstein.
|
||||
"""
|
||||
nBands = 2
|
||||
|
||||
if v is None:
|
||||
assert(uv.ndim == 3)
|
||||
assert(uv.shape[2] == 2)
|
||||
u = uv[:,:,0]
|
||||
v = uv[:,:,1]
|
||||
else:
|
||||
u = uv
|
||||
|
||||
assert(u.shape == v.shape)
|
||||
height,width = u.shape
|
||||
f = open(filename,'wb')
|
||||
# write the header
|
||||
f.write(TAG_CHAR)
|
||||
np.array(width).astype(np.int32).tofile(f)
|
||||
np.array(height).astype(np.int32).tofile(f)
|
||||
# arrange into matrix form
|
||||
tmp = np.zeros((height, width*nBands))
|
||||
tmp[:,np.arange(width)*2] = u
|
||||
tmp[:,np.arange(width)*2 + 1] = v
|
||||
tmp.astype(np.float32).tofile(f)
|
||||
f.close()
|
||||
|
||||
|
||||
def readFlowKITTI(filename):
|
||||
flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
|
||||
flow = flow[:,:,::-1].astype(np.float32)
|
||||
flow, valid = flow[:, :, :2], flow[:, :, 2]
|
||||
flow = (flow - 2**15) / 64.0
|
||||
return flow, valid
|
||||
|
||||
def readDispKITTI(filename):
|
||||
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
|
||||
valid = disp > 0.0
|
||||
flow = np.stack([-disp, np.zeros_like(disp)], -1)
|
||||
return flow, valid
|
||||
|
||||
|
||||
def writeFlowKITTI(filename, uv):
|
||||
uv = 64.0 * uv + 2**15
|
||||
valid = np.ones([uv.shape[0], uv.shape[1], 1])
|
||||
uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
|
||||
cv2.imwrite(filename, uv[..., ::-1])
|
||||
|
||||
|
||||
def read_gen(file_name, pil=False):
|
||||
ext = splitext(file_name)[-1]
|
||||
if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
|
||||
return Image.open(file_name)
|
||||
elif ext == '.bin' or ext == '.raw':
|
||||
return np.load(file_name)
|
||||
elif ext == '.flo':
|
||||
return readFlow(file_name).astype(np.float32)
|
||||
elif ext == '.pfm':
|
||||
flow = readPFM(file_name).astype(np.float32)
|
||||
if len(flow.shape) == 2:
|
||||
return flow
|
||||
else:
|
||||
return flow[:, :, :-1]
|
||||
return []
|
82
RAFT/core/utils/utils.py
Executable file
82
RAFT/core/utils/utils.py
Executable file
@ -0,0 +1,82 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from scipy import interpolate
|
||||
|
||||
|
||||
class InputPadder:
|
||||
""" Pads images such that dimensions are divisible by 8 """
|
||||
def __init__(self, dims, mode='sintel'):
|
||||
self.ht, self.wd = dims[-2:]
|
||||
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
||||
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
||||
if mode == 'sintel':
|
||||
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
||||
else:
|
||||
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
||||
|
||||
def pad(self, *inputs):
|
||||
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
||||
|
||||
def unpad(self,x):
|
||||
ht, wd = x.shape[-2:]
|
||||
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
||||
return x[..., c[0]:c[1], c[2]:c[3]]
|
||||
|
||||
def forward_interpolate(flow):
|
||||
flow = flow.detach().cpu().numpy()
|
||||
dx, dy = flow[0], flow[1]
|
||||
|
||||
ht, wd = dx.shape
|
||||
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
|
||||
|
||||
x1 = x0 + dx
|
||||
y1 = y0 + dy
|
||||
|
||||
x1 = x1.reshape(-1)
|
||||
y1 = y1.reshape(-1)
|
||||
dx = dx.reshape(-1)
|
||||
dy = dy.reshape(-1)
|
||||
|
||||
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
|
||||
x1 = x1[valid]
|
||||
y1 = y1[valid]
|
||||
dx = dx[valid]
|
||||
dy = dy[valid]
|
||||
|
||||
flow_x = interpolate.griddata(
|
||||
(x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
|
||||
|
||||
flow_y = interpolate.griddata(
|
||||
(x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
|
||||
|
||||
flow = np.stack([flow_x, flow_y], axis=0)
|
||||
return torch.from_numpy(flow).float()
|
||||
|
||||
|
||||
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
||||
""" Wrapper for grid_sample, uses pixel coordinates """
|
||||
H, W = img.shape[-2:]
|
||||
xgrid, ygrid = coords.split([1,1], dim=-1)
|
||||
xgrid = 2*xgrid/(W-1) - 1
|
||||
ygrid = 2*ygrid/(H-1) - 1
|
||||
|
||||
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||
img = F.grid_sample(img, grid, align_corners=True)
|
||||
|
||||
if mask:
|
||||
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||
return img, mask.float()
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def coords_grid(batch, ht, wd):
|
||||
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
||||
coords = torch.stack(coords[::-1], dim=0).float()
|
||||
return coords[None].repeat(batch, 1, 1, 1)
|
||||
|
||||
|
||||
def upflow8(flow, mode='bilinear'):
|
||||
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
|
||||
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
75
RAFT/demo.py
Executable file
75
RAFT/demo.py
Executable file
@ -0,0 +1,75 @@
|
||||
import sys
|
||||
sys.path.append('core')
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import cv2
|
||||
import glob
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from raft import RAFT
|
||||
from utils import flow_viz
|
||||
from utils.utils import InputPadder
|
||||
|
||||
|
||||
|
||||
DEVICE = 'cuda'
|
||||
|
||||
def load_image(imfile):
|
||||
img = np.array(Image.open(imfile)).astype(np.uint8)
|
||||
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
||||
return img[None].to(DEVICE)
|
||||
|
||||
|
||||
def viz(img, flo):
|
||||
img = img[0].permute(1,2,0).cpu().numpy()
|
||||
flo = flo[0].permute(1,2,0).cpu().numpy()
|
||||
|
||||
# map flow to rgb image
|
||||
flo = flow_viz.flow_to_image(flo)
|
||||
img_flo = np.concatenate([img, flo], axis=0)
|
||||
|
||||
# import matplotlib.pyplot as plt
|
||||
# plt.imshow(img_flo / 255.0)
|
||||
# plt.show()
|
||||
|
||||
cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
|
||||
cv2.waitKey()
|
||||
|
||||
|
||||
def demo(args):
|
||||
model = torch.nn.DataParallel(RAFT(args))
|
||||
model.load_state_dict(torch.load(args.model))
|
||||
|
||||
model = model.module
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
images = glob.glob(os.path.join(args.path, '*.png')) + \
|
||||
glob.glob(os.path.join(args.path, '*.jpg'))
|
||||
|
||||
images = sorted(images)
|
||||
for imfile1, imfile2 in zip(images[:-1], images[1:]):
|
||||
image1 = load_image(imfile1)
|
||||
image2 = load_image(imfile2)
|
||||
|
||||
padder = InputPadder(image1.shape)
|
||||
image1, image2 = padder.pad(image1, image2)
|
||||
|
||||
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
|
||||
viz(image1, flow_up)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', help="restore checkpoint")
|
||||
parser.add_argument('--path', help="dataset for evaluation")
|
||||
parser.add_argument('--small', action='store_true', help='use small model')
|
||||
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
|
||||
args = parser.parse_args()
|
||||
|
||||
demo(args)
|
3
RAFT/download_models.sh
Executable file
3
RAFT/download_models.sh
Executable file
@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip
|
||||
unzip models.zip
|
197
RAFT/evaluate.py
Executable file
197
RAFT/evaluate.py
Executable file
@ -0,0 +1,197 @@
|
||||
import sys
|
||||
sys.path.append('core')
|
||||
|
||||
from PIL import Image
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import datasets
|
||||
from utils import flow_viz
|
||||
from utils import frame_utils
|
||||
|
||||
from raft import RAFT
|
||||
from utils.utils import InputPadder, forward_interpolate
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'):
|
||||
""" Create submission for the Sintel leaderboard """
|
||||
model.eval()
|
||||
for dstype in ['clean', 'final']:
|
||||
test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype)
|
||||
|
||||
flow_prev, sequence_prev = None, None
|
||||
for test_id in range(len(test_dataset)):
|
||||
image1, image2, (sequence, frame) = test_dataset[test_id]
|
||||
if sequence != sequence_prev:
|
||||
flow_prev = None
|
||||
|
||||
padder = InputPadder(image1.shape)
|
||||
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
|
||||
|
||||
flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True)
|
||||
flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
|
||||
|
||||
if warm_start:
|
||||
flow_prev = forward_interpolate(flow_low[0])[None].cuda()
|
||||
|
||||
output_dir = os.path.join(output_path, dstype, sequence)
|
||||
output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1))
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
frame_utils.writeFlow(output_file, flow)
|
||||
sequence_prev = sequence
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
|
||||
""" Create submission for the Sintel leaderboard """
|
||||
model.eval()
|
||||
test_dataset = datasets.KITTI(split='testing', aug_params=None)
|
||||
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
|
||||
for test_id in range(len(test_dataset)):
|
||||
image1, image2, (frame_id, ) = test_dataset[test_id]
|
||||
padder = InputPadder(image1.shape, mode='kitti')
|
||||
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
|
||||
|
||||
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||
flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
|
||||
|
||||
output_filename = os.path.join(output_path, frame_id)
|
||||
frame_utils.writeFlowKITTI(output_filename, flow)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_chairs(model, iters=24):
|
||||
""" Perform evaluation on the FlyingChairs (test) split """
|
||||
model.eval()
|
||||
epe_list = []
|
||||
|
||||
val_dataset = datasets.FlyingChairs(split='validation')
|
||||
for val_id in range(len(val_dataset)):
|
||||
image1, image2, flow_gt, _ = val_dataset[val_id]
|
||||
image1 = image1[None].cuda()
|
||||
image2 = image2[None].cuda()
|
||||
|
||||
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||
epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt()
|
||||
epe_list.append(epe.view(-1).numpy())
|
||||
|
||||
epe = np.mean(np.concatenate(epe_list))
|
||||
print("Validation Chairs EPE: %f" % epe)
|
||||
return {'chairs': epe}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_sintel(model, iters=32):
|
||||
""" Peform validation using the Sintel (train) split """
|
||||
model.eval()
|
||||
results = {}
|
||||
for dstype in ['clean', 'final']:
|
||||
val_dataset = datasets.MpiSintel(split='training', dstype=dstype)
|
||||
epe_list = []
|
||||
|
||||
for val_id in range(len(val_dataset)):
|
||||
image1, image2, flow_gt, _ = val_dataset[val_id]
|
||||
image1 = image1[None].cuda()
|
||||
image2 = image2[None].cuda()
|
||||
|
||||
padder = InputPadder(image1.shape)
|
||||
image1, image2 = padder.pad(image1, image2)
|
||||
|
||||
flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||
flow = padder.unpad(flow_pr[0]).cpu()
|
||||
|
||||
epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
|
||||
epe_list.append(epe.view(-1).numpy())
|
||||
|
||||
epe_all = np.concatenate(epe_list)
|
||||
epe = np.mean(epe_all)
|
||||
px1 = np.mean(epe_all<1)
|
||||
px3 = np.mean(epe_all<3)
|
||||
px5 = np.mean(epe_all<5)
|
||||
|
||||
print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
|
||||
results[dstype] = np.mean(epe_list)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_kitti(model, iters=24):
|
||||
""" Peform validation using the KITTI-2015 (train) split """
|
||||
model.eval()
|
||||
val_dataset = datasets.KITTI(split='training')
|
||||
|
||||
out_list, epe_list = [], []
|
||||
for val_id in range(len(val_dataset)):
|
||||
image1, image2, flow_gt, valid_gt = val_dataset[val_id]
|
||||
image1 = image1[None].cuda()
|
||||
image2 = image2[None].cuda()
|
||||
|
||||
padder = InputPadder(image1.shape, mode='kitti')
|
||||
image1, image2 = padder.pad(image1, image2)
|
||||
|
||||
flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||
flow = padder.unpad(flow_pr[0]).cpu()
|
||||
|
||||
epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
|
||||
mag = torch.sum(flow_gt**2, dim=0).sqrt()
|
||||
|
||||
epe = epe.view(-1)
|
||||
mag = mag.view(-1)
|
||||
val = valid_gt.view(-1) >= 0.5
|
||||
|
||||
out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
|
||||
epe_list.append(epe[val].mean().item())
|
||||
out_list.append(out[val].cpu().numpy())
|
||||
|
||||
epe_list = np.array(epe_list)
|
||||
out_list = np.concatenate(out_list)
|
||||
|
||||
epe = np.mean(epe_list)
|
||||
f1 = 100 * np.mean(out_list)
|
||||
|
||||
print("Validation KITTI: %f, %f" % (epe, f1))
|
||||
return {'kitti-epe': epe, 'kitti-f1': f1}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', help="restore checkpoint")
|
||||
parser.add_argument('--dataset', help="dataset for evaluation")
|
||||
parser.add_argument('--small', action='store_true', help='use small model')
|
||||
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
|
||||
args = parser.parse_args()
|
||||
|
||||
model = torch.nn.DataParallel(RAFT(args))
|
||||
model.load_state_dict(torch.load(args.model))
|
||||
|
||||
model.cuda()
|
||||
model.eval()
|
||||
|
||||
# create_sintel_submission(model.module, warm_start=True)
|
||||
# create_kitti_submission(model.module)
|
||||
|
||||
with torch.no_grad():
|
||||
if args.dataset == 'chairs':
|
||||
validate_chairs(model.module)
|
||||
|
||||
elif args.dataset == 'sintel':
|
||||
validate_sintel(model.module)
|
||||
|
||||
elif args.dataset == 'kitti':
|
||||
validate_kitti(model.module)
|
||||
|
||||
|
247
RAFT/train.py
Executable file
247
RAFT/train.py
Executable file
@ -0,0 +1,247 @@
|
||||
from __future__ import print_function, division
|
||||
import sys
|
||||
sys.path.append('core')
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from raft import RAFT
|
||||
import evaluate
|
||||
import datasets
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
try:
|
||||
from torch.cuda.amp import GradScaler
|
||||
except:
|
||||
# dummy GradScaler for PyTorch < 1.6
|
||||
class GradScaler:
|
||||
def __init__(self):
|
||||
pass
|
||||
def scale(self, loss):
|
||||
return loss
|
||||
def unscale_(self, optimizer):
|
||||
pass
|
||||
def step(self, optimizer):
|
||||
optimizer.step()
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
|
||||
# exclude extremly large displacements
|
||||
MAX_FLOW = 400
|
||||
SUM_FREQ = 100
|
||||
VAL_FREQ = 5000
|
||||
|
||||
|
||||
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
|
||||
""" Loss function defined over sequence of flow predictions """
|
||||
|
||||
n_predictions = len(flow_preds)
|
||||
flow_loss = 0.0
|
||||
|
||||
# exlude invalid pixels and extremely large diplacements
|
||||
mag = torch.sum(flow_gt**2, dim=1).sqrt()
|
||||
valid = (valid >= 0.5) & (mag < max_flow)
|
||||
|
||||
for i in range(n_predictions):
|
||||
i_weight = gamma**(n_predictions - i - 1)
|
||||
i_loss = (flow_preds[i] - flow_gt).abs()
|
||||
flow_loss += i_weight * (valid[:, None] * i_loss).mean()
|
||||
|
||||
epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
|
||||
epe = epe.view(-1)[valid.view(-1)]
|
||||
|
||||
metrics = {
|
||||
'epe': epe.mean().item(),
|
||||
'1px': (epe < 1).float().mean().item(),
|
||||
'3px': (epe < 3).float().mean().item(),
|
||||
'5px': (epe < 5).float().mean().item(),
|
||||
}
|
||||
|
||||
return flow_loss, metrics
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
def fetch_optimizer(args, model):
|
||||
""" Create the optimizer and learning rate scheduler """
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
|
||||
|
||||
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
|
||||
pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')
|
||||
|
||||
return optimizer, scheduler
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self, model, scheduler):
|
||||
self.model = model
|
||||
self.scheduler = scheduler
|
||||
self.total_steps = 0
|
||||
self.running_loss = {}
|
||||
self.writer = None
|
||||
|
||||
def _print_training_status(self):
|
||||
metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())]
|
||||
training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
|
||||
metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
|
||||
|
||||
# print the training status
|
||||
print(training_str + metrics_str)
|
||||
|
||||
if self.writer is None:
|
||||
self.writer = SummaryWriter()
|
||||
|
||||
for k in self.running_loss:
|
||||
self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps)
|
||||
self.running_loss[k] = 0.0
|
||||
|
||||
def push(self, metrics):
|
||||
self.total_steps += 1
|
||||
|
||||
for key in metrics:
|
||||
if key not in self.running_loss:
|
||||
self.running_loss[key] = 0.0
|
||||
|
||||
self.running_loss[key] += metrics[key]
|
||||
|
||||
if self.total_steps % SUM_FREQ == SUM_FREQ-1:
|
||||
self._print_training_status()
|
||||
self.running_loss = {}
|
||||
|
||||
def write_dict(self, results):
|
||||
if self.writer is None:
|
||||
self.writer = SummaryWriter()
|
||||
|
||||
for key in results:
|
||||
self.writer.add_scalar(key, results[key], self.total_steps)
|
||||
|
||||
def close(self):
|
||||
self.writer.close()
|
||||
|
||||
|
||||
def train(args):
|
||||
|
||||
model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
|
||||
print("Parameter Count: %d" % count_parameters(model))
|
||||
|
||||
if args.restore_ckpt is not None:
|
||||
model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
|
||||
|
||||
model.cuda()
|
||||
model.train()
|
||||
|
||||
if args.stage != 'chairs':
|
||||
model.module.freeze_bn()
|
||||
|
||||
train_loader = datasets.fetch_dataloader(args)
|
||||
optimizer, scheduler = fetch_optimizer(args, model)
|
||||
|
||||
total_steps = 0
|
||||
scaler = GradScaler(enabled=args.mixed_precision)
|
||||
logger = Logger(model, scheduler)
|
||||
|
||||
VAL_FREQ = 5000
|
||||
add_noise = True
|
||||
|
||||
should_keep_training = True
|
||||
while should_keep_training:
|
||||
|
||||
for i_batch, data_blob in enumerate(train_loader):
|
||||
optimizer.zero_grad()
|
||||
image1, image2, flow, valid = [x.cuda() for x in data_blob]
|
||||
|
||||
if args.add_noise:
|
||||
stdv = np.random.uniform(0.0, 5.0)
|
||||
image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
|
||||
image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)
|
||||
|
||||
flow_predictions = model(image1, image2, iters=args.iters)
|
||||
|
||||
loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scheduler.step()
|
||||
scaler.update()
|
||||
|
||||
logger.push(metrics)
|
||||
|
||||
if total_steps % VAL_FREQ == VAL_FREQ - 1:
|
||||
PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name)
|
||||
torch.save(model.state_dict(), PATH)
|
||||
|
||||
results = {}
|
||||
for val_dataset in args.validation:
|
||||
if val_dataset == 'chairs':
|
||||
results.update(evaluate.validate_chairs(model.module))
|
||||
elif val_dataset == 'sintel':
|
||||
results.update(evaluate.validate_sintel(model.module))
|
||||
elif val_dataset == 'kitti':
|
||||
results.update(evaluate.validate_kitti(model.module))
|
||||
|
||||
logger.write_dict(results)
|
||||
|
||||
model.train()
|
||||
if args.stage != 'chairs':
|
||||
model.module.freeze_bn()
|
||||
|
||||
total_steps += 1
|
||||
|
||||
if total_steps > args.num_steps:
|
||||
should_keep_training = False
|
||||
break
|
||||
|
||||
logger.close()
|
||||
PATH = 'checkpoints/%s.pth' % args.name
|
||||
torch.save(model.state_dict(), PATH)
|
||||
|
||||
return PATH
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--name', default='raft', help="name your experiment")
|
||||
parser.add_argument('--stage', help="determines which dataset to use for training")
|
||||
parser.add_argument('--restore_ckpt', help="restore checkpoint")
|
||||
parser.add_argument('--small', action='store_true', help='use small model')
|
||||
parser.add_argument('--validation', type=str, nargs='+')
|
||||
|
||||
parser.add_argument('--lr', type=float, default=0.00002)
|
||||
parser.add_argument('--num_steps', type=int, default=100000)
|
||||
parser.add_argument('--batch_size', type=int, default=6)
|
||||
parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
|
||||
parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
|
||||
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||
|
||||
parser.add_argument('--iters', type=int, default=12)
|
||||
parser.add_argument('--wdecay', type=float, default=.00005)
|
||||
parser.add_argument('--epsilon', type=float, default=1e-8)
|
||||
parser.add_argument('--clip', type=float, default=1.0)
|
||||
parser.add_argument('--dropout', type=float, default=0.0)
|
||||
parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
|
||||
parser.add_argument('--add_noise', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.manual_seed(1234)
|
||||
np.random.seed(1234)
|
||||
|
||||
if not os.path.isdir('checkpoints'):
|
||||
os.mkdir('checkpoints')
|
||||
|
||||
train(args)
|
6
RAFT/train_mixed.sh
Executable file
6
RAFT/train_mixed.sh
Executable file
@ -0,0 +1,6 @@
|
||||
#!/bin/bash
|
||||
mkdir -p checkpoints
|
||||
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision
|
||||
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision
|
||||
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision
|
||||
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision
|
6
RAFT/train_standard.sh
Executable file
6
RAFT/train_standard.sh
Executable file
@ -0,0 +1,6 @@
|
||||
#!/bin/bash
|
||||
mkdir -p checkpoints
|
||||
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 10 --lr 0.0004 --image_size 368 496 --wdecay 0.0001
|
||||
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001
|
||||
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85
|
||||
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85
|
BIN
cirtorch/.DS_Store
vendored
Normal file
BIN
cirtorch/.DS_Store
vendored
Normal file
Binary file not shown.
BIN
cirtorch/IamgeRetrieval_dataset/train.pkl
Normal file
BIN
cirtorch/IamgeRetrieval_dataset/train.pkl
Normal file
Binary file not shown.
6
cirtorch/__init__.py
Executable file
6
cirtorch/__init__.py
Executable file
@ -0,0 +1,6 @@
|
||||
from . import datasets, examples, layers, networks, utils
|
||||
|
||||
from .datasets import datahelpers, genericdataset, testdataset, traindataset
|
||||
from .layers import functional, loss, normalization, pooling
|
||||
from .networks import imageretrievalnet
|
||||
from .utils import general, download, evaluate, whiten
|
0
cirtorch/datasets/__init__.py
Executable file
0
cirtorch/datasets/__init__.py
Executable file
56
cirtorch/datasets/datahelpers.py
Executable file
56
cirtorch/datasets/datahelpers.py
Executable file
@ -0,0 +1,56 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
|
||||
def cid2filename(cid, prefix):
|
||||
"""
|
||||
Creates a training image path out of its CID name
|
||||
|
||||
Arguments
|
||||
---------
|
||||
cid : name of the image
|
||||
prefix : root directory where images are saved
|
||||
|
||||
Returns
|
||||
-------
|
||||
filename : full image filename
|
||||
"""
|
||||
return os.path.join(prefix, cid[-2:], cid[-4:-2], cid[-6:-4], cid)
|
||||
|
||||
def pil_loader(path):
|
||||
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
||||
with open(path, 'rb') as f:
|
||||
img = Image.open(f)
|
||||
return img.convert('RGB')
|
||||
|
||||
def accimage_loader(path):
|
||||
import accimage
|
||||
try:
|
||||
return accimage.Image(path)
|
||||
except IOError:
|
||||
# Potentially a decoding problem, fall back to PIL.Image
|
||||
return pil_loader(path)
|
||||
|
||||
def default_loader(path):
|
||||
from torchvision import get_image_backend
|
||||
if get_image_backend() == 'accimage':
|
||||
return accimage_loader(path)
|
||||
else:
|
||||
return pil_loader(path)
|
||||
|
||||
def imresize(img, imsize):
|
||||
img.thumbnail((imsize, imsize), Image.ANTIALIAS)
|
||||
return img
|
||||
|
||||
def flip(x, dim):
|
||||
xsize = x.size()
|
||||
dim = x.dim() + dim if dim < 0 else dim
|
||||
x = x.view(-1, *xsize[dim:])
|
||||
x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1, -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :]
|
||||
return x.view(xsize)
|
||||
|
||||
def collate_tuples(batch):
|
||||
if len(batch) == 1:
|
||||
return [batch[0][0]], [batch[0][1]]
|
||||
return [batch[i][0] for i in range(len(batch))], [batch[i][1] for i in range(len(batch))]
|
121
cirtorch/datasets/genericdataset.py
Executable file
121
cirtorch/datasets/genericdataset.py
Executable file
@ -0,0 +1,121 @@
|
||||
import os
|
||||
import pdb
|
||||
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
|
||||
from cirtorch.datasets.datahelpers import default_loader, imresize
|
||||
|
||||
|
||||
class ImagesFromList(data.Dataset):
|
||||
"""A generic data loader that loads images from a list
|
||||
(Based on ImageFolder from pytorch)
|
||||
Args:
|
||||
root (string): Root directory path.
|
||||
images (list): Relative image paths as strings.
|
||||
imsize (int, Default: None): Defines the maximum size of longer image side
|
||||
bbxs (list): List of (x1,y1,x2,y2) tuples to crop the query images
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
loader (callable, optional): A function to load an image given its path.
|
||||
Attributes:
|
||||
images_fn (list): List of full image filename
|
||||
"""
|
||||
|
||||
def __init__(self, root, images, imsize=None, bbxs=None, transform=None, loader=default_loader):
|
||||
|
||||
images_fn = [os.path.join(root,images[i]) for i in range(len(images))]
|
||||
|
||||
if len(images_fn) == 0:
|
||||
raise(RuntimeError("Dataset contains 0 images!"))
|
||||
|
||||
self.root = root
|
||||
self.images = images
|
||||
self.imsize = imsize
|
||||
self.images_fn = images_fn
|
||||
self.bbxs = bbxs
|
||||
self.transform = transform
|
||||
self.loader = loader
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
Returns:
|
||||
image (PIL): Loaded image
|
||||
"""
|
||||
path = self.images_fn[index]
|
||||
img = self.loader(path)
|
||||
imfullsize = max(img.size)
|
||||
|
||||
if self.bbxs is not None:
|
||||
print('self.bbxs>>>ok')
|
||||
img = img.crop(self.bbxs[index])
|
||||
|
||||
if self.imsize is not None:
|
||||
if self.bbxs is not None:
|
||||
print('self.bbxs and self.imsize>>>ok')
|
||||
img = imresize(img, self.imsize * max(img.size) / imfullsize)
|
||||
else:
|
||||
print('not self.bbxs and self.imsize>>>ok')
|
||||
img = imresize(img, self.imsize)
|
||||
|
||||
if self.transform is not None:
|
||||
print('self.transform>>>>>ok')
|
||||
img = self.transform(img)
|
||||
|
||||
return img, path
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images_fn)
|
||||
|
||||
def __repr__(self):
|
||||
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
||||
fmt_str += ' Number of images: {}\n'.format(self.__len__())
|
||||
fmt_str += ' Root Location: {}\n'.format(self.root)
|
||||
tmp = ' Transforms (if any): '
|
||||
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||
return fmt_str
|
||||
|
||||
class ImagesFromDataList(data.Dataset):
|
||||
"""A generic data loader that loads images given as an array of pytorch tensors
|
||||
(Based on ImageFolder from pytorch)
|
||||
Args:
|
||||
images (list): Images as tensors.
|
||||
transform (callable, optional): A function/transform that image as a tensors
|
||||
and returns a transformed version. E.g, ``normalize`` with mean and std
|
||||
"""
|
||||
|
||||
def __init__(self, images, transform=None):
|
||||
|
||||
if len(images) == 0:
|
||||
raise(RuntimeError("Dataset contains 0 images!"))
|
||||
|
||||
self.images = images
|
||||
self.transform = transform
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
Returns:
|
||||
image (Tensor): Loaded image
|
||||
"""
|
||||
img = self.images[index]
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if len(img.size()):
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
return img
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __repr__(self):
|
||||
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
||||
fmt_str += ' Number of images: {}\n'.format(self.__len__())
|
||||
tmp = ' Transforms (if any): '
|
||||
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||
return fmt_str
|
38
cirtorch/datasets/testdataset.py
Executable file
38
cirtorch/datasets/testdataset.py
Executable file
@ -0,0 +1,38 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
DATASETS = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k']
|
||||
|
||||
def configdataset(dataset, dir_main):
|
||||
|
||||
dataset = dataset.lower()
|
||||
|
||||
if dataset not in DATASETS:
|
||||
raise ValueError('Unknown dataset: {}!'.format(dataset))
|
||||
|
||||
# loading imlist, qimlist, and gnd, in cfg as a dict
|
||||
gnd_fname = os.path.join(dir_main, dataset, 'gnd_{}.pkl'.format(dataset))
|
||||
with open(gnd_fname, 'rb') as f:
|
||||
cfg = pickle.load(f)
|
||||
cfg['gnd_fname'] = gnd_fname
|
||||
|
||||
cfg['ext'] = '.jpg'
|
||||
cfg['qext'] = '.jpg'
|
||||
cfg['dir_data'] = os.path.join(dir_main, dataset)
|
||||
cfg['dir_images'] = os.path.join(cfg['dir_data'], 'jpg')
|
||||
|
||||
cfg['n'] = len(cfg['imlist'])
|
||||
cfg['nq'] = len(cfg['qimlist'])
|
||||
|
||||
cfg['im_fname'] = config_imname
|
||||
cfg['qim_fname'] = config_qimname
|
||||
|
||||
cfg['dataset'] = dataset
|
||||
|
||||
return cfg
|
||||
|
||||
def config_imname(cfg, i):
|
||||
return os.path.join(cfg['dir_images'], cfg['imlist'][i] + cfg['ext'])
|
||||
|
||||
def config_qimname(cfg, i):
|
||||
return os.path.join(cfg['dir_images'], cfg['qimlist'][i] + cfg['qext'])
|
247
cirtorch/datasets/traindataset.py
Executable file
247
cirtorch/datasets/traindataset.py
Executable file
@ -0,0 +1,247 @@
|
||||
import os
|
||||
import pickle
|
||||
import pdb
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
|
||||
from cirtorch.datasets.datahelpers import default_loader, imresize, cid2filename
|
||||
from cirtorch.datasets.genericdataset import ImagesFromList
|
||||
from cirtorch.utils.general import get_data_root
|
||||
|
||||
class TuplesDataset(data.Dataset):
|
||||
"""Data loader that loads training and validation tuples of
|
||||
Radenovic etal ECCV16: CNN image retrieval learns from BoW
|
||||
|
||||
Args:
|
||||
name (string): dataset name: 'retrieval-sfm-120k'
|
||||
mode (string): 'train' or 'val' for training and validation parts of dataset
|
||||
imsize (int, Default: None): Defines the maximum size of longer image side
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
loader (callable, optional): A function to load an image given its path.
|
||||
nnum (int, Default:5): Number of negatives for a query image in a training tuple
|
||||
qsize (int, Default:1000): Number of query images, ie number of (q,p,n1,...nN) tuples, to be processed in one epoch
|
||||
poolsize (int, Default:10000): Pool size for negative images re-mining
|
||||
|
||||
Attributes:
|
||||
images (list): List of full filenames for each image
|
||||
clusters (list): List of clusterID per image
|
||||
qpool (list): List of all query image indexes
|
||||
ppool (list): List of positive image indexes, each corresponding to query at the same position in qpool
|
||||
|
||||
qidxs (list): List of qsize query image indexes to be processed in an epoch
|
||||
pidxs (list): List of qsize positive image indexes, each corresponding to query at the same position in qidxs
|
||||
nidxs (list): List of qsize tuples of negative images
|
||||
Each nidxs tuple contains nnum images corresponding to query image at the same position in qidxs
|
||||
|
||||
Lists qidxs, pidxs, nidxs are refreshed by calling the ``create_epoch_tuples()`` method,
|
||||
ie new q-p pairs are picked and negative images are remined
|
||||
"""
|
||||
|
||||
def __init__(self, name, mode, imsize=None, nnum=5, qsize=2000, poolsize=20000, transform=None, loader=default_loader):
|
||||
|
||||
if not (mode == 'train' or mode == 'val'):
|
||||
raise(RuntimeError("MODE should be either train or val, passed as string"))
|
||||
|
||||
if name.startswith('retrieval-SfM'):
|
||||
# setting up paths
|
||||
#data_root = get_data_root()
|
||||
#db_root = os.path.join(data_root, 'train', name)
|
||||
#ims_root = os.path.join(db_root, 'ims')
|
||||
db_root = '/home/lc/project/Search_By_Image_Upgrade/cirtorch/IamgeRetrieval_dataset'
|
||||
ims_root = '/home/lc/project/Search_By_Image_Upgrade/cirtorch/IamgeRetrieval_dataset/train'
|
||||
# loading db
|
||||
db_fn = os.path.join(db_root, '{}.pkl'.format('train'))
|
||||
with open(db_fn, 'rb') as f:
|
||||
db = pickle.load(f)[mode]
|
||||
|
||||
# setting fullpath for images
|
||||
self.images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))]
|
||||
|
||||
#elif name.startswith('gl'):
|
||||
## TODO: NOT IMPLEMENTED YET PROPOERLY (WITH AUTOMATIC DOWNLOAD)
|
||||
|
||||
# setting up paths
|
||||
#db_root = '/mnt/fry2/users/datasets/landmarkscvprw18/recognition/'
|
||||
#ims_root = os.path.join(db_root, 'images', 'train')
|
||||
|
||||
# loading db
|
||||
#db_fn = os.path.join(db_root, '{}.pkl'.format('train'))
|
||||
#with open(db_fn, 'rb') as f:
|
||||
# db = pickle.load(f)[mode]
|
||||
|
||||
# setting fullpath for images
|
||||
self.images = [os.path.join(ims_root, db['cids'][i]) for i in range(len(db['cids']))]
|
||||
else:
|
||||
raise(RuntimeError("Unknown dataset name!"))
|
||||
|
||||
# initializing tuples dataset
|
||||
self.name = name
|
||||
self.mode = mode
|
||||
self.imsize = imsize
|
||||
self.clusters = db['cluster']
|
||||
self.qpool = db['qidxs']
|
||||
self.ppool = db['pidxs']
|
||||
|
||||
## If we want to keep only unique q-p pairs
|
||||
## However, ordering of pairs will change, although that is not important
|
||||
# qpidxs = list(set([(self.qidxs[i], self.pidxs[i]) for i in range(len(self.qidxs))]))
|
||||
# self.qidxs = [qpidxs[i][0] for i in range(len(qpidxs))]
|
||||
# self.pidxs = [qpidxs[i][1] for i in range(len(qpidxs))]
|
||||
|
||||
# size of training subset for an epoch
|
||||
self.nnum = nnum
|
||||
self.qsize = min(qsize, len(self.qpool))
|
||||
self.poolsize = min(poolsize, len(self.images))
|
||||
self.qidxs = None
|
||||
self.pidxs = None
|
||||
self.nidxs = None
|
||||
|
||||
self.transform = transform
|
||||
self.loader = loader
|
||||
|
||||
self.print_freq = 10
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
|
||||
Returns:
|
||||
images tuple (q,p,n1,...,nN): Loaded train/val tuple at index of self.qidxs
|
||||
"""
|
||||
if self.__len__() == 0:
|
||||
raise(RuntimeError("List qidxs is empty. Run ``dataset.create_epoch_tuples(net)`` method to create subset for train/val!"))
|
||||
|
||||
output = []
|
||||
# query image
|
||||
output.append(self.loader(self.images[self.qidxs[index]]))
|
||||
# positive image
|
||||
output.append(self.loader(self.images[self.pidxs[index]]))
|
||||
# negative images
|
||||
for i in range(len(self.nidxs[index])):
|
||||
output.append(self.loader(self.images[self.nidxs[index][i]]))
|
||||
|
||||
if self.imsize is not None:
|
||||
output = [imresize(img, self.imsize) for img in output]
|
||||
|
||||
if self.transform is not None:
|
||||
output = [self.transform(output[i]).unsqueeze_(0) for i in range(len(output))]
|
||||
|
||||
target = torch.Tensor([-1, 1] + [0]*len(self.nidxs[index]))
|
||||
|
||||
return output, target
|
||||
|
||||
def __len__(self):
|
||||
# if not self.qidxs:
|
||||
# return 0
|
||||
# return len(self.qidxs)
|
||||
return self.qsize
|
||||
|
||||
def __repr__(self):
|
||||
fmt_str = self.__class__.__name__ + '\n'
|
||||
fmt_str += ' Name and mode: {} {}\n'.format(self.name, self.mode)
|
||||
fmt_str += ' Number of images: {}\n'.format(len(self.images))
|
||||
fmt_str += ' Number of training tuples: {}\n'.format(len(self.qpool))
|
||||
fmt_str += ' Number of negatives per tuple: {}\n'.format(self.nnum)
|
||||
fmt_str += ' Number of tuples processed in an epoch: {}\n'.format(self.qsize)
|
||||
fmt_str += ' Pool size for negative remining: {}\n'.format(self.poolsize)
|
||||
tmp = ' Transforms (if any): '
|
||||
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||
return fmt_str
|
||||
|
||||
def create_epoch_tuples(self, net):
|
||||
|
||||
print('>> Creating tuples for an epoch of {}-{}...'.format(self.name, self.mode))
|
||||
print(">>>> used network: ")
|
||||
print(net.meta_repr())
|
||||
|
||||
## ------------------------
|
||||
## SELECTING POSITIVE PAIRS
|
||||
## ------------------------
|
||||
|
||||
# draw qsize random queries for tuples
|
||||
idxs2qpool = torch.randperm(len(self.qpool))[:self.qsize]
|
||||
self.qidxs = [self.qpool[i] for i in idxs2qpool]
|
||||
self.pidxs = [self.ppool[i] for i in idxs2qpool]
|
||||
|
||||
## ------------------------
|
||||
## SELECTING NEGATIVE PAIRS
|
||||
## ------------------------
|
||||
|
||||
# if nnum = 0 create dummy nidxs
|
||||
# useful when only positives used for training
|
||||
if self.nnum == 0:
|
||||
self.nidxs = [[] for _ in range(len(self.qidxs))]
|
||||
return 0
|
||||
|
||||
# draw poolsize random images for pool of negatives images
|
||||
idxs2images = torch.randperm(len(self.images))[:self.poolsize]
|
||||
|
||||
# prepare network
|
||||
net.cuda()
|
||||
net.eval()
|
||||
|
||||
# no gradients computed, to reduce memory and increase speed
|
||||
with torch.no_grad():
|
||||
|
||||
print('>> Extracting descriptors for query images...')
|
||||
# prepare query loader
|
||||
loader = torch.utils.data.DataLoader(
|
||||
ImagesFromList(root='', images=[self.images[i] for i in self.qidxs], imsize=self.imsize, transform=self.transform),
|
||||
batch_size=1, shuffle=False, num_workers=8, pin_memory=True
|
||||
)
|
||||
# extract query vectors
|
||||
qvecs = torch.zeros(net.meta['outputdim'], len(self.qidxs)).cuda()
|
||||
for i, input in enumerate(loader):
|
||||
#print('*********************',input,type(input))
|
||||
#print('#######################',type(input))
|
||||
qvecs[:, i] = net(input[0].cuda()).data.squeeze()
|
||||
if (i+1) % self.print_freq == 0 or (i+1) == len(self.qidxs):
|
||||
print('\r>>>> {}/{} done...'.format(i+1, len(self.qidxs)), end='')
|
||||
print('')
|
||||
|
||||
print('>> Extracting descriptors for negative pool...')
|
||||
# prepare negative pool data loader
|
||||
loader = torch.utils.data.DataLoader(
|
||||
ImagesFromList(root='', images=[self.images[i] for i in idxs2images], imsize=self.imsize, transform=self.transform),
|
||||
batch_size=1, shuffle=False, num_workers=8, pin_memory=True
|
||||
)
|
||||
# extract negative pool vectors
|
||||
poolvecs = torch.zeros(net.meta['outputdim'], len(idxs2images)).cuda()
|
||||
for i, input in enumerate(loader):
|
||||
poolvecs[:, i] = net(input[0].cuda()).data.squeeze()
|
||||
if (i+1) % self.print_freq == 0 or (i+1) == len(idxs2images):
|
||||
print('\r>>>> {}/{} done...'.format(i+1, len(idxs2images)), end='')
|
||||
print('')
|
||||
|
||||
print('>> Searching for hard negatives...')
|
||||
# compute dot product scores and ranks on GPU
|
||||
scores = torch.mm(poolvecs.t(), qvecs)
|
||||
scores, ranks = torch.sort(scores, dim=0, descending=True)
|
||||
avg_ndist = torch.tensor(0).float().cuda() # for statistics
|
||||
n_ndist = torch.tensor(0).float().cuda() # for statistics
|
||||
# selection of negative examples
|
||||
self.nidxs = []
|
||||
for q in range(len(self.qidxs)):
|
||||
# do not use query cluster,
|
||||
# those images are potentially positive
|
||||
qcluster = self.clusters[self.qidxs[q]]
|
||||
clusters = [qcluster]
|
||||
nidxs = []
|
||||
r = 0
|
||||
while len(nidxs) < self.nnum:
|
||||
potential = idxs2images[ranks[r, q]]
|
||||
# take at most one image from the same cluster
|
||||
if not self.clusters[potential] in clusters:
|
||||
nidxs.append(potential)
|
||||
clusters.append(self.clusters[potential])
|
||||
avg_ndist += torch.pow(qvecs[:,q]-poolvecs[:,ranks[r, q]]+1e-6, 2).sum(dim=0).sqrt()
|
||||
n_ndist += 1
|
||||
r += 1
|
||||
self.nidxs.append(nidxs)
|
||||
print('>>>> Average negative l2-distance: {:.2f}'.format(avg_ndist/n_ndist))
|
||||
print('>>>> Done')
|
||||
|
||||
return (avg_ndist/n_ndist).item() # return average negative l2-distance
|
0
cirtorch/examples/__init__.py
Executable file
0
cirtorch/examples/__init__.py
Executable file
266
cirtorch/examples/test.py
Executable file
266
cirtorch/examples/test.py
Executable file
@ -0,0 +1,266 @@
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import pickle
|
||||
import pdb
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils.model_zoo import load_url
|
||||
from torchvision import transforms
|
||||
|
||||
from cirtorch.networks.imageretrievalnet import init_network, extract_vectors
|
||||
from cirtorch.datasets.datahelpers import cid2filename
|
||||
from cirtorch.datasets.testdataset import configdataset
|
||||
from cirtorch.utils.download import download_train, download_test
|
||||
from cirtorch.utils.whiten import whitenlearn, whitenapply
|
||||
from cirtorch.utils.evaluate import compute_map_and_print
|
||||
from cirtorch.utils.general import get_data_root, htime
|
||||
|
||||
PRETRAINED = {
|
||||
'retrievalSfM120k-vgg16-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/retrievalSfM120k-vgg16-gem-b4dcdc6.pth',
|
||||
'retrievalSfM120k-resnet101-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/retrievalSfM120k-resnet101-gem-b80fb85.pth',
|
||||
# new networks with whitening learned end-to-end
|
||||
'rSfM120k-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet50-gem-w-97bf910.pth',
|
||||
'rSfM120k-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet101-gem-w-a155e54.pth',
|
||||
'rSfM120k-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet152-gem-w-f39cada.pth',
|
||||
'gl18-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet50-gem-w-83fdc30.pth',
|
||||
'gl18-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet101-gem-w-a4d43db.pth',
|
||||
'gl18-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet152-gem-w-21278d5.pth',
|
||||
}
|
||||
|
||||
datasets_names = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k']
|
||||
whitening_names = ['retrieval-SfM-30k', 'retrieval-SfM-120k']
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch CNN Image Retrieval Testing')
|
||||
|
||||
# network
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument('--network-path', '-npath', metavar='NETWORK',
|
||||
help="pretrained network or network path (destination where network is saved)")
|
||||
group.add_argument('--network-offtheshelf', '-noff', metavar='NETWORK',
|
||||
help="off-the-shelf network, in the format 'ARCHITECTURE-POOLING' or 'ARCHITECTURE-POOLING-{reg-lwhiten-whiten}'," +
|
||||
" examples: 'resnet101-gem' | 'resnet101-gem-reg' | 'resnet101-gem-whiten' | 'resnet101-gem-lwhiten' | 'resnet101-gem-reg-whiten'")
|
||||
|
||||
# test options
|
||||
parser.add_argument('--datasets', '-d', metavar='DATASETS', default='oxford5k,paris6k',
|
||||
help="comma separated list of test datasets: " +
|
||||
" | ".join(datasets_names) +
|
||||
" (default: 'oxford5k,paris6k')")
|
||||
parser.add_argument('--image-size', '-imsize', default=1024, type=int, metavar='N',
|
||||
help="maximum size of longer image side used for testing (default: 1024)")
|
||||
parser.add_argument('--multiscale', '-ms', metavar='MULTISCALE', default='[1]',
|
||||
help="use multiscale vectors for testing, " +
|
||||
" examples: '[1]' | '[1, 1/2**(1/2), 1/2]' | '[1, 2**(1/2), 1/2**(1/2)]' (default: '[1]')")
|
||||
parser.add_argument('--whitening', '-w', metavar='WHITENING', default=None, choices=whitening_names,
|
||||
help="dataset used to learn whitening for testing: " +
|
||||
" | ".join(whitening_names) +
|
||||
" (default: None)")
|
||||
|
||||
# GPU ID
|
||||
parser.add_argument('--gpu-id', '-g', default='0', metavar='N',
|
||||
help="gpu id used for testing (default: '0')")
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
# check if there are unknown datasets
|
||||
for dataset in args.datasets.split(','):
|
||||
if dataset not in datasets_names:
|
||||
raise ValueError('Unsupported or unknown dataset: {}!'.format(dataset))
|
||||
|
||||
# check if test dataset are downloaded
|
||||
# and download if they are not
|
||||
download_train(get_data_root())
|
||||
download_test(get_data_root())
|
||||
|
||||
# setting up the visible GPU
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
|
||||
|
||||
# loading network from path
|
||||
if args.network_path is not None:
|
||||
|
||||
print(">> Loading network:\n>>>> '{}'".format(args.network_path))
|
||||
if args.network_path in PRETRAINED:
|
||||
# pretrained networks (downloaded automatically)
|
||||
state = load_url(PRETRAINED[args.network_path], model_dir=os.path.join(get_data_root(), 'networks'))
|
||||
else:
|
||||
# fine-tuned network from path
|
||||
state = torch.load(args.network_path)
|
||||
|
||||
# parsing net params from meta
|
||||
# architecture, pooling, mean, std required
|
||||
# the rest has default values, in case that is doesnt exist
|
||||
net_params = {}
|
||||
net_params['architecture'] = state['meta']['architecture']
|
||||
net_params['pooling'] = state['meta']['pooling']
|
||||
net_params['local_whitening'] = state['meta'].get('local_whitening', False)
|
||||
net_params['regional'] = state['meta'].get('regional', False)
|
||||
net_params['whitening'] = state['meta'].get('whitening', False)
|
||||
net_params['mean'] = state['meta']['mean']
|
||||
net_params['std'] = state['meta']['std']
|
||||
net_params['pretrained'] = False
|
||||
|
||||
# load network
|
||||
net = init_network(net_params)
|
||||
net.load_state_dict(state['state_dict'])
|
||||
|
||||
# if whitening is precomputed
|
||||
if 'Lw' in state['meta']:
|
||||
net.meta['Lw'] = state['meta']['Lw']
|
||||
|
||||
print(">>>> loaded network: ")
|
||||
print(net.meta_repr())
|
||||
|
||||
# loading offtheshelf network
|
||||
elif args.network_offtheshelf is not None:
|
||||
|
||||
# parse off-the-shelf parameters
|
||||
offtheshelf = args.network_offtheshelf.split('-')
|
||||
net_params = {}
|
||||
net_params['architecture'] = offtheshelf[0]
|
||||
net_params['pooling'] = offtheshelf[1]
|
||||
net_params['local_whitening'] = 'lwhiten' in offtheshelf[2:]
|
||||
net_params['regional'] = 'reg' in offtheshelf[2:]
|
||||
net_params['whitening'] = 'whiten' in offtheshelf[2:]
|
||||
net_params['pretrained'] = True
|
||||
|
||||
# load off-the-shelf network
|
||||
print(">> Loading off-the-shelf network:\n>>>> '{}'".format(args.network_offtheshelf))
|
||||
net = init_network(net_params)
|
||||
print(">>>> loaded network: ")
|
||||
print(net.meta_repr())
|
||||
|
||||
# setting up the multi-scale parameters
|
||||
ms = list(eval(args.multiscale))
|
||||
if len(ms)>1 and net.meta['pooling'] == 'gem' and not net.meta['regional'] and not net.meta['whitening']:
|
||||
msp = net.pool.p.item()
|
||||
print(">> Set-up multiscale:")
|
||||
print(">>>> ms: {}".format(ms))
|
||||
print(">>>> msp: {}".format(msp))
|
||||
else:
|
||||
msp = 1
|
||||
|
||||
# moving network to gpu and eval mode
|
||||
net.cuda()
|
||||
net.eval()
|
||||
|
||||
# set up the transform
|
||||
normalize = transforms.Normalize(
|
||||
mean=net.meta['mean'],
|
||||
std=net.meta['std']
|
||||
)
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
normalize
|
||||
])
|
||||
|
||||
# compute whitening
|
||||
if args.whitening is not None:
|
||||
start = time.time()
|
||||
|
||||
if 'Lw' in net.meta and args.whitening in net.meta['Lw']:
|
||||
|
||||
print('>> {}: Whitening is precomputed, loading it...'.format(args.whitening))
|
||||
|
||||
if len(ms)>1:
|
||||
Lw = net.meta['Lw'][args.whitening]['ms']
|
||||
else:
|
||||
Lw = net.meta['Lw'][args.whitening]['ss']
|
||||
|
||||
else:
|
||||
|
||||
# if we evaluate networks from path we should save/load whitening
|
||||
# not to compute it every time
|
||||
if args.network_path is not None:
|
||||
whiten_fn = args.network_path + '_{}_whiten'.format(args.whitening)
|
||||
if len(ms) > 1:
|
||||
whiten_fn += '_ms'
|
||||
whiten_fn += '.pth'
|
||||
else:
|
||||
whiten_fn = None
|
||||
|
||||
if whiten_fn is not None and os.path.isfile(whiten_fn):
|
||||
print('>> {}: Whitening is precomputed, loading it...'.format(args.whitening))
|
||||
Lw = torch.load(whiten_fn)
|
||||
|
||||
else:
|
||||
print('>> {}: Learning whitening...'.format(args.whitening))
|
||||
|
||||
# loading db
|
||||
db_root = os.path.join(get_data_root(), 'train', args.whitening)
|
||||
ims_root = os.path.join(db_root, 'ims')
|
||||
db_fn = os.path.join(db_root, '{}-whiten.pkl'.format(args.whitening))
|
||||
with open(db_fn, 'rb') as f:
|
||||
db = pickle.load(f)
|
||||
images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))]
|
||||
|
||||
# extract whitening vectors
|
||||
print('>> {}: Extracting...'.format(args.whitening))
|
||||
wvecs = extract_vectors(net, images, args.image_size, transform, ms=ms, msp=msp)
|
||||
|
||||
# learning whitening
|
||||
print('>> {}: Learning...'.format(args.whitening))
|
||||
wvecs = wvecs.numpy()
|
||||
m, P = whitenlearn(wvecs, db['qidxs'], db['pidxs'])
|
||||
Lw = {'m': m, 'P': P}
|
||||
|
||||
# saving whitening if whiten_fn exists
|
||||
if whiten_fn is not None:
|
||||
print('>> {}: Saving to {}...'.format(args.whitening, whiten_fn))
|
||||
torch.save(Lw, whiten_fn)
|
||||
|
||||
print('>> {}: elapsed time: {}'.format(args.whitening, htime(time.time()-start)))
|
||||
|
||||
else:
|
||||
Lw = None
|
||||
|
||||
# evaluate on test datasets
|
||||
datasets = args.datasets.split(',')
|
||||
for dataset in datasets:
|
||||
start = time.time()
|
||||
|
||||
print('>> {}: Extracting...'.format(dataset))
|
||||
|
||||
# prepare config structure for the test dataset
|
||||
cfg = configdataset(dataset, os.path.join(get_data_root(), 'test'))
|
||||
images = [cfg['im_fname'](cfg,i) for i in range(cfg['n'])]
|
||||
qimages = [cfg['qim_fname'](cfg,i) for i in range(cfg['nq'])]
|
||||
try:
|
||||
bbxs = [tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq'])]
|
||||
except:
|
||||
bbxs = None # for holidaysmanrot and copydays
|
||||
|
||||
# extract database and query vectors
|
||||
print('>> {}: database images...'.format(dataset))
|
||||
vecs = extract_vectors(net, images, args.image_size, transform, ms=ms, msp=msp)
|
||||
print('>> {}: query images...'.format(dataset))
|
||||
qvecs = extract_vectors(net, qimages, args.image_size, transform, bbxs=bbxs, ms=ms, msp=msp)
|
||||
|
||||
print('>> {}: Evaluating...'.format(dataset))
|
||||
|
||||
# convert to numpy
|
||||
vecs = vecs.numpy()
|
||||
qvecs = qvecs.numpy()
|
||||
|
||||
# search, rank, and print
|
||||
scores = np.dot(vecs.T, qvecs)
|
||||
ranks = np.argsort(-scores, axis=0)
|
||||
compute_map_and_print(dataset, ranks, cfg['gnd'])
|
||||
|
||||
if Lw is not None:
|
||||
# whiten the vectors
|
||||
vecs_lw = whitenapply(vecs, Lw['m'], Lw['P'])
|
||||
qvecs_lw = whitenapply(qvecs, Lw['m'], Lw['P'])
|
||||
|
||||
# search, rank, and print
|
||||
scores = np.dot(vecs_lw.T, qvecs_lw)
|
||||
ranks = np.argsort(-scores, axis=0)
|
||||
compute_map_and_print(dataset + ' + whiten', ranks, cfg['gnd'])
|
||||
|
||||
print('>> {}: elapsed time: {}'.format(dataset, htime(time.time()-start)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
145
cirtorch/examples/test_e2e.py
Executable file
145
cirtorch/examples/test_e2e.py
Executable file
@ -0,0 +1,145 @@
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import pickle
|
||||
import pdb
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils.model_zoo import load_url
|
||||
from torchvision import transforms
|
||||
|
||||
from cirtorch.networks.imageretrievalnet import init_network, extract_vectors
|
||||
from cirtorch.datasets.testdataset import configdataset
|
||||
from cirtorch.utils.download import download_train, download_test
|
||||
from cirtorch.utils.evaluate import compute_map_and_print
|
||||
from cirtorch.utils.general import get_data_root, htime
|
||||
|
||||
PRETRAINED = {
|
||||
'rSfM120k-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet50-gem-w-97bf910.pth',
|
||||
'rSfM120k-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet101-gem-w-a155e54.pth',
|
||||
'rSfM120k-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet152-gem-w-f39cada.pth',
|
||||
'gl18-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet50-gem-w-83fdc30.pth',
|
||||
'gl18-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet101-gem-w-a4d43db.pth',
|
||||
'gl18-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet152-gem-w-21278d5.pth',
|
||||
}
|
||||
|
||||
datasets_names = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k']
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch CNN Image Retrieval Testing End-to-End')
|
||||
|
||||
# test options
|
||||
parser.add_argument('--network', '-n', metavar='NETWORK',
|
||||
help="network to be evaluated: " +
|
||||
" | ".join(PRETRAINED.keys()))
|
||||
parser.add_argument('--datasets', '-d', metavar='DATASETS', default='roxford5k,rparis6k',
|
||||
help="comma separated list of test datasets: " +
|
||||
" | ".join(datasets_names) +
|
||||
" (default: 'roxford5k,rparis6k')")
|
||||
parser.add_argument('--image-size', '-imsize', default=1024, type=int, metavar='N',
|
||||
help="maximum size of longer image side used for testing (default: 1024)")
|
||||
parser.add_argument('--multiscale', '-ms', metavar='MULTISCALE', default='[1]',
|
||||
help="use multiscale vectors for testing, " +
|
||||
" examples: '[1]' | '[1, 1/2**(1/2), 1/2]' | '[1, 2**(1/2), 1/2**(1/2)]' (default: '[1]')")
|
||||
|
||||
# GPU ID
|
||||
parser.add_argument('--gpu-id', '-g', default='0', metavar='N',
|
||||
help="gpu id used for testing (default: '0')")
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
# check if there are unknown datasets
|
||||
for dataset in args.datasets.split(','):
|
||||
if dataset not in datasets_names:
|
||||
raise ValueError('Unsupported or unknown dataset: {}!'.format(dataset))
|
||||
|
||||
# check if test dataset are downloaded
|
||||
# and download if they are not
|
||||
download_train(get_data_root())
|
||||
download_test(get_data_root())
|
||||
|
||||
# setting up the visible GPU
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
|
||||
|
||||
# loading network
|
||||
# pretrained networks (downloaded automatically)
|
||||
print(">> Loading network:\n>>>> '{}'".format(args.network))
|
||||
state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks'))
|
||||
# state = torch.load(args.network)
|
||||
# parsing net params from meta
|
||||
# architecture, pooling, mean, std required
|
||||
# the rest has default values, in case that is doesnt exist
|
||||
net_params = {}
|
||||
net_params['architecture'] = state['meta']['architecture']
|
||||
net_params['pooling'] = state['meta']['pooling']
|
||||
net_params['local_whitening'] = state['meta'].get('local_whitening', False)
|
||||
net_params['regional'] = state['meta'].get('regional', False)
|
||||
net_params['whitening'] = state['meta'].get('whitening', False)
|
||||
net_params['mean'] = state['meta']['mean']
|
||||
net_params['std'] = state['meta']['std']
|
||||
net_params['pretrained'] = False
|
||||
# network initialization
|
||||
net = init_network(net_params)
|
||||
net.load_state_dict(state['state_dict'])
|
||||
|
||||
print(">>>> loaded network: ")
|
||||
print(net.meta_repr())
|
||||
|
||||
# setting up the multi-scale parameters
|
||||
ms = list(eval(args.multiscale))
|
||||
print(">>>> Evaluating scales: {}".format(ms))
|
||||
|
||||
# moving network to gpu and eval mode
|
||||
net.cuda()
|
||||
net.eval()
|
||||
|
||||
# set up the transform
|
||||
normalize = transforms.Normalize(
|
||||
mean=net.meta['mean'],
|
||||
std=net.meta['std']
|
||||
)
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
normalize
|
||||
])
|
||||
|
||||
# evaluate on test datasets
|
||||
datasets = args.datasets.split(',')
|
||||
for dataset in datasets:
|
||||
start = time.time()
|
||||
|
||||
print('>> {}: Extracting...'.format(dataset))
|
||||
|
||||
# prepare config structure for the test dataset
|
||||
cfg = configdataset(dataset, os.path.join(get_data_root(), 'test'))
|
||||
images = [cfg['im_fname'](cfg,i) for i in range(cfg['n'])]
|
||||
qimages = [cfg['qim_fname'](cfg,i) for i in range(cfg['nq'])]
|
||||
try:
|
||||
bbxs = [tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq'])]
|
||||
except:
|
||||
bbxs = None # for holidaysmanrot and copydays
|
||||
|
||||
# extract database and query vectors
|
||||
print('>> {}: database images...'.format(dataset))
|
||||
vecs = extract_vectors(net, images, args.image_size, transform, ms=ms)
|
||||
print('>> {}: query images...'.format(dataset))
|
||||
qvecs = extract_vectors(net, qimages, args.image_size, transform, bbxs=bbxs, ms=ms)
|
||||
|
||||
print('>> {}: Evaluating...'.format(dataset))
|
||||
|
||||
# convert to numpy
|
||||
vecs = vecs.numpy()
|
||||
qvecs = qvecs.numpy()
|
||||
|
||||
# search, rank, and print
|
||||
scores = np.dot(vecs.T, qvecs)
|
||||
ranks = np.argsort(-scores, axis=0)
|
||||
compute_map_and_print(dataset, ranks, cfg['gnd'])
|
||||
|
||||
print('>> {}: elapsed time: {}'.format(dataset, htime(time.time()-start)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
580
cirtorch/examples/train.py
Executable file
580
cirtorch/examples/train.py
Executable file
@ -0,0 +1,580 @@
|
||||
import sys
|
||||
sys.path.append('/home/lc/project/Search_By_Image_Upgrade')
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
import math
|
||||
import pickle
|
||||
import pdb
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim
|
||||
import torch.utils.data
|
||||
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.models as models
|
||||
|
||||
from cirtorch.networks.imageretrievalnet import init_network, extract_vectors
|
||||
from cirtorch.layers.loss import ContrastiveLoss, TripletLoss
|
||||
from cirtorch.datasets.datahelpers import collate_tuples, cid2filename
|
||||
from cirtorch.datasets.traindataset import TuplesDataset
|
||||
from cirtorch.datasets.testdataset import configdataset
|
||||
from cirtorch.utils.download import download_train, download_test
|
||||
from cirtorch.utils.whiten import whitenlearn, whitenapply
|
||||
from cirtorch.utils.evaluate import compute_map_and_print
|
||||
from cirtorch.utils.general import get_data_root, htime
|
||||
|
||||
training_dataset_names = ['retrieval-SfM-120k']
|
||||
test_datasets_names = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k']
|
||||
test_whiten_names = ['retrieval-SfM-30k', 'retrieval-SfM-120k']
|
||||
|
||||
model_names = sorted(name for name in models.__dict__
|
||||
if name.islower() and not name.startswith("__")
|
||||
and callable(models.__dict__[name]))
|
||||
pool_names = ['mac', 'spoc', 'gem', 'gemmp']
|
||||
loss_names = ['contrastive', 'triplet']
|
||||
optimizer_names = ['sgd', 'adam']
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch CNN Image Retrieval Training')
|
||||
|
||||
# export directory, training and val datasets, test datasets
|
||||
parser.add_argument('directory', metavar='EXPORT_DIR',default='models',
|
||||
help='destination where trained network should be saved')
|
||||
parser.add_argument('--training-dataset', '-d', metavar='DATASET', default='retrieval-SfM-120k', choices=training_dataset_names,
|
||||
help='training dataset: ' +
|
||||
' | '.join(training_dataset_names) +
|
||||
' (default: retrieval-SfM-120k)')
|
||||
parser.add_argument('--no-val', dest='val', action='store_false',default = False,
|
||||
help='do not run validation')
|
||||
parser.add_argument('--test-datasets', '-td', metavar='DATASETS', default='roxford5k,rparis6k',
|
||||
help='comma separated list of test datasets: ' +
|
||||
' | '.join(test_datasets_names) +
|
||||
' (default: roxford5k,rparis6k)')
|
||||
parser.add_argument('--test-whiten', metavar='DATASET', default='', choices=test_whiten_names,
|
||||
help='dataset used to learn whitening for testing: ' +
|
||||
' | '.join(test_whiten_names) +
|
||||
' (default: None)')
|
||||
parser.add_argument('--test-freq', default=1, type=int, metavar='N',
|
||||
help='run test evaluation every N epochs (default: 1)')
|
||||
|
||||
# network architecture and initialization options
|
||||
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50', choices=model_names,
|
||||
help='model architecture: ' +
|
||||
' | '.join(model_names) +
|
||||
' (default: resnet101)')
|
||||
parser.add_argument('--pool', '-p', metavar='POOL', default='gem', choices=pool_names,
|
||||
help='pooling options: ' +
|
||||
' | '.join(pool_names) +
|
||||
' (default: gem)')
|
||||
parser.add_argument('--local-whitening', '-lw', dest='local_whitening', action='store_true',
|
||||
help='train model with learnable local whitening (linear layer) before the pooling')
|
||||
parser.add_argument('--regional', '-r', dest='regional', action='store_true',
|
||||
help='train model with regional pooling using fixed grid')
|
||||
parser.add_argument('--whitening', '-w', dest='whitening', action='store_true',
|
||||
help='train model with learnable whitening (linear layer) after the pooling')
|
||||
parser.add_argument('--not-pretrained', dest='pretrained', action='store_false',
|
||||
help='initialize model with random weights (default: pretrained on imagenet)')
|
||||
parser.add_argument('--loss', '-l', metavar='LOSS', default='contrastive',
|
||||
choices=loss_names,
|
||||
help='training loss options: ' +
|
||||
' | '.join(loss_names) +
|
||||
' (default: contrastive)')
|
||||
parser.add_argument('--loss-margin', '-lm', metavar='LM', default=0.7, type=float,
|
||||
help='loss margin: (default: 0.7)')
|
||||
|
||||
# train/val options specific for image retrieval learning
|
||||
parser.add_argument('--image-size', default=648, type=int, metavar='N', # 1024
|
||||
help='maximum size of longer image side used for training (default: 1024)')
|
||||
parser.add_argument('--neg-num', '-nn', default=5, type=int, metavar='N',
|
||||
help='number of negative image per train/val tuple (default: 5)')
|
||||
parser.add_argument('--query-size', '-qs', default=2000, type=int, metavar='N',
|
||||
help='number of queries randomly drawn per one train epoch (default: 2000)')
|
||||
parser.add_argument('--pool-size', '-ps', default=20000, type=int, metavar='N',
|
||||
help='size of the pool for hard negative mining (default: 20000)')
|
||||
|
||||
# standard train/val options
|
||||
parser.add_argument('--gpu-id', '-g', default='0,1', metavar='N',
|
||||
help='gpu id used for training (default: 0)')
|
||||
parser.add_argument('--workers', '-j', default=8, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 8)')
|
||||
parser.add_argument('--epochs', default=100, type=int, metavar='N',
|
||||
help='number of total epochs to run (default: 100)')
|
||||
parser.add_argument('--batch-size', '-b', default=32, type=int, metavar='N',
|
||||
help='number of (q,p,n1,...,nN) tuples in a mini-batch (default: 5)')
|
||||
parser.add_argument('--update-every', '-u', default=1, type=int, metavar='N',
|
||||
help='update model weights every N batches, used to handle really large batches, ' +
|
||||
'batch_size effectively becomes update_every x batch_size (default: 1)')
|
||||
parser.add_argument('--optimizer', '-o', metavar='OPTIMIZER', default='adam',
|
||||
choices=optimizer_names,
|
||||
help='optimizer options: ' +
|
||||
' | '.join(optimizer_names) +
|
||||
' (default: adam)')
|
||||
parser.add_argument('--lr', '--learning-rate', default=1e-6, type=float,
|
||||
metavar='LR', help='initial learning rate (default: 1e-6)')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--weight-decay', '--wd', default=1e-6, type=float,
|
||||
metavar='W', help='weight decay (default: 1e-6)')
|
||||
parser.add_argument('--print-freq', default=10, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
parser.add_argument('--resume', default='', type=str, metavar='FILENAME',
|
||||
help='name of the latest checkpoint (default: None)')
|
||||
|
||||
min_loss = float('inf')
|
||||
|
||||
def main():
|
||||
global args, min_loss
|
||||
args = parser.parse_args()
|
||||
|
||||
# manually check if there are unknown test datasets
|
||||
for dataset in args.test_datasets.split(','):
|
||||
if dataset not in test_datasets_names:
|
||||
raise ValueError('Unsupported or unknown test dataset: {}!'.format(dataset))
|
||||
|
||||
# check if test dataset are downloaded
|
||||
# and download if they are not
|
||||
download_train(get_data_root())
|
||||
download_test(get_data_root())
|
||||
|
||||
# create export dir if it doesnt exist
|
||||
directory = "{}".format(args.training_dataset)
|
||||
directory += "_{}".format(args.arch)
|
||||
directory += "_{}".format(args.pool)
|
||||
if args.local_whitening:
|
||||
directory += "_lwhiten"
|
||||
if args.regional:
|
||||
directory += "_r"
|
||||
if args.whitening:
|
||||
directory += "_whiten"
|
||||
if not args.pretrained:
|
||||
directory += "_notpretrained"
|
||||
directory += "_{}_m{:.2f}".format(args.loss, args.loss_margin)
|
||||
directory += "_{}_lr{:.1e}_wd{:.1e}".format(args.optimizer, args.lr, args.weight_decay)
|
||||
directory += "_nnum{}_qsize{}_psize{}".format(args.neg_num, args.query_size, args.pool_size)
|
||||
directory += "_bsize{}_uevery{}_imsize{}".format(args.batch_size, args.update_every, args.image_size)
|
||||
|
||||
args.directory = os.path.join(args.directory, directory)
|
||||
print(">> Creating directory if it does not exist:\n>> '{}'".format(args.directory))
|
||||
if not os.path.exists(args.directory):
|
||||
os.makedirs(args.directory)
|
||||
|
||||
# set cuda visible device
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
|
||||
|
||||
# set random seeds
|
||||
# TODO: maybe pass as argument in future implementation?
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed_all(0)
|
||||
np.random.seed(0)
|
||||
|
||||
# initialize model
|
||||
if args.pretrained:
|
||||
print(">> Using pre-trained model '{}'".format(args.arch))
|
||||
else:
|
||||
print(">> Using model from scratch (random weights) '{}'".format(args.arch))
|
||||
model_params = {}
|
||||
model_params['architecture'] = args.arch
|
||||
model_params['pooling'] = args.pool
|
||||
model_params['local_whitening'] = args.local_whitening
|
||||
model_params['regional'] = args.regional
|
||||
model_params['whitening'] = args.whitening
|
||||
# model_params['mean'] = ... # will use default
|
||||
# model_params['std'] = ... # will use default
|
||||
model_params['pretrained'] = args.pretrained
|
||||
model = init_network(model_params)
|
||||
|
||||
# move network to gpu
|
||||
model.cuda()
|
||||
|
||||
# define loss function (criterion) and optimizer
|
||||
if args.loss == 'contrastive':
|
||||
criterion = ContrastiveLoss(margin=args.loss_margin).cuda()
|
||||
elif args.loss == 'triplet':
|
||||
criterion = TripletLoss(margin=args.loss_margin).cuda()
|
||||
else:
|
||||
raise(RuntimeError("Loss {} not available!".format(args.loss)))
|
||||
|
||||
# parameters split into features, pool, whitening
|
||||
# IMPORTANT: no weight decay for pooling parameter p in GeM or regional-GeM
|
||||
parameters = []
|
||||
# add feature parameters
|
||||
parameters.append({'params': model.features.parameters()})
|
||||
# add local whitening if exists
|
||||
if model.lwhiten is not None:
|
||||
parameters.append({'params': model.lwhiten.parameters()})
|
||||
# add pooling parameters (or regional whitening which is part of the pooling layer!)
|
||||
if not args.regional:
|
||||
# global, only pooling parameter p weight decay should be 0
|
||||
if args.pool == 'gem':
|
||||
parameters.append({'params': model.pool.parameters(), 'lr': args.lr*10, 'weight_decay': 0})
|
||||
elif args.pool == 'gemmp':
|
||||
parameters.append({'params': model.pool.parameters(), 'lr': args.lr*100, 'weight_decay': 0})
|
||||
else:
|
||||
# regional, pooling parameter p weight decay should be 0,
|
||||
# and we want to add regional whitening if it is there
|
||||
if args.pool == 'gem':
|
||||
parameters.append({'params': model.pool.rpool.parameters(), 'lr': args.lr*10, 'weight_decay': 0})
|
||||
elif args.pool == 'gemmp':
|
||||
parameters.append({'params': model.pool.rpool.parameters(), 'lr': args.lr*100, 'weight_decay': 0})
|
||||
if model.pool.whiten is not None:
|
||||
parameters.append({'params': model.pool.whiten.parameters()})
|
||||
# add final whitening if exists
|
||||
if model.whiten is not None:
|
||||
parameters.append({'params': model.whiten.parameters()})
|
||||
|
||||
# define optimizer
|
||||
if args.optimizer == 'sgd':
|
||||
optimizer = torch.optim.SGD(parameters, args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
|
||||
elif args.optimizer == 'adam':
|
||||
optimizer = torch.optim.Adam(parameters, args.lr, weight_decay=args.weight_decay)
|
||||
|
||||
# define learning rate decay schedule
|
||||
# TODO: maybe pass as argument in future implementation?
|
||||
exp_decay = math.exp(-0.01)
|
||||
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=exp_decay)
|
||||
|
||||
# optionally resume from a checkpoint
|
||||
start_epoch = 0
|
||||
if args.resume:
|
||||
args.resume = os.path.join(args.directory, args.resume)
|
||||
if os.path.isfile(args.resume):
|
||||
# load checkpoint weights and update model and optimizer
|
||||
print(">> Loading checkpoint:\n>> '{}'".format(args.resume))
|
||||
checkpoint = torch.load(args.resume)
|
||||
start_epoch = checkpoint['epoch']
|
||||
min_loss = checkpoint['min_loss']
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
print(">>>> loaded checkpoint:\n>>>> '{}' (epoch {})"
|
||||
.format(args.resume, checkpoint['epoch']))
|
||||
# important not to forget scheduler updating
|
||||
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=exp_decay, last_epoch=checkpoint['epoch']-1)
|
||||
else:
|
||||
print(">> No checkpoint found at '{}'".format(args.resume))
|
||||
|
||||
# Data loading code
|
||||
normalize = transforms.Normalize(mean=model.meta['mean'], std=model.meta['std'])
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
train_dataset = TuplesDataset(
|
||||
name=args.training_dataset,
|
||||
mode='train',
|
||||
imsize=args.image_size,
|
||||
nnum=args.neg_num,
|
||||
qsize=args.query_size,
|
||||
poolsize=args.pool_size,
|
||||
transform=transform
|
||||
)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=args.batch_size, shuffle=True,
|
||||
num_workers=args.workers, pin_memory=True, sampler=None,
|
||||
drop_last=True, collate_fn=collate_tuples
|
||||
)
|
||||
if args.val:
|
||||
val_dataset = TuplesDataset(
|
||||
name=args.training_dataset,
|
||||
mode='val',
|
||||
imsize=args.image_size,
|
||||
nnum=args.neg_num,
|
||||
qsize=float('Inf'),
|
||||
poolsize=float('Inf'),
|
||||
transform=transform
|
||||
)
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
val_dataset, batch_size=args.batch_size, shuffle=False,
|
||||
num_workers=args.workers, pin_memory=True,
|
||||
drop_last=True, collate_fn=collate_tuples
|
||||
)
|
||||
|
||||
# evaluate the network before starting
|
||||
# this might not be necessary?
|
||||
#test(args.test_datasets, model)
|
||||
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
|
||||
# set manual seeds per epoch
|
||||
np.random.seed(epoch)
|
||||
torch.manual_seed(epoch)
|
||||
torch.cuda.manual_seed_all(epoch)
|
||||
|
||||
# adjust learning rate for each epoch
|
||||
scheduler.step()
|
||||
# # debug printing to check if everything ok
|
||||
# lr_feat = optimizer.param_groups[0]['lr']
|
||||
# lr_pool = optimizer.param_groups[1]['lr']
|
||||
# print('>> Features lr: {:.2e}; Pooling lr: {:.2e}'.format(lr_feat, lr_pool))
|
||||
|
||||
# train for one epoch on train set
|
||||
loss = train(train_loader, model, criterion, optimizer, epoch)
|
||||
|
||||
# evaluate on validation set
|
||||
if args.val:
|
||||
with torch.no_grad():
|
||||
loss = validate(val_loader, model, criterion, epoch)
|
||||
|
||||
# evaluate on test datasets every test_freq epochs
|
||||
#if (epoch + 1) % args.test_freq == 0:
|
||||
# with torch.no_grad():
|
||||
# test(args.test_datasets, model)
|
||||
|
||||
# remember best loss and save checkpoint
|
||||
is_best = loss < min_loss
|
||||
min_loss = min(loss, min_loss)
|
||||
if (epoch+1)%10 == 0:
|
||||
save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'meta': model.meta,
|
||||
'state_dict': model.state_dict(),
|
||||
'min_loss': min_loss,
|
||||
'optimizer' : optimizer.state_dict(),
|
||||
}, is_best, args.directory)
|
||||
|
||||
def train(train_loader, model, criterion, optimizer, epoch):
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
|
||||
# create tuples for training
|
||||
avg_neg_distance = train_loader.dataset.create_epoch_tuples(model)
|
||||
|
||||
# switch to train mode
|
||||
model.train()
|
||||
model.apply(set_batchnorm_eval)
|
||||
|
||||
# zero out gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
end = time.time()
|
||||
for i, (input, target) in enumerate(train_loader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
|
||||
nq = len(input) # number of training tuples
|
||||
ni = len(input[0]) # number of images per tuple
|
||||
|
||||
for q in range(nq):
|
||||
output = torch.zeros(model.meta['outputdim'], ni).cuda()
|
||||
for imi in range(ni):
|
||||
|
||||
# compute output vector for image imi
|
||||
output[:, imi] = model(input[q][imi].cuda()).squeeze()
|
||||
|
||||
# reducing memory consumption:
|
||||
# compute loss for this query tuple only
|
||||
# then, do backward pass for one tuple only
|
||||
# each backward pass gradients will be accumulated
|
||||
# the optimization step is performed for the full batch later
|
||||
loss = criterion(output, target[q].cuda())
|
||||
losses.update(loss.item())
|
||||
loss.backward()
|
||||
|
||||
if (i + 1) % args.update_every == 0:
|
||||
# do one step for multiple batches
|
||||
# accumulated gradients are used
|
||||
optimizer.step()
|
||||
# zero out gradients so we can
|
||||
# accumulate new ones over batches
|
||||
optimizer.zero_grad()
|
||||
# print('>> Train: [{0}][{1}/{2}]\t'
|
||||
# 'Weight update performed'.format(
|
||||
# epoch+1, i+1, len(train_loader)))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if (i+1) % args.print_freq == 0 or i == 0 or (i+1) == len(train_loader):
|
||||
print('>> Train: [{0}][{1}/{2}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
|
||||
'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
|
||||
epoch+1, i+1, len(train_loader), batch_time=batch_time,
|
||||
data_time=data_time, loss=losses))
|
||||
|
||||
return losses.avg
|
||||
|
||||
|
||||
def validate(val_loader, model, criterion, epoch):
|
||||
batch_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
|
||||
# create tuples for validation
|
||||
avg_neg_distance = val_loader.dataset.create_epoch_tuples(model)
|
||||
|
||||
# switch to evaluate mode
|
||||
model.eval()
|
||||
|
||||
end = time.time()
|
||||
for i, (input, target) in enumerate(val_loader):
|
||||
|
||||
nq = len(input) # number of training tuples
|
||||
ni = len(input[0]) # number of images per tuple
|
||||
output = torch.zeros(model.meta['outputdim'], nq*ni).cuda()
|
||||
|
||||
for q in range(nq):
|
||||
for imi in range(ni):
|
||||
|
||||
# compute output vector for image imi of query q
|
||||
output[:, q*ni + imi] = model(input[q][imi].cuda()).squeeze()
|
||||
|
||||
# no need to reduce memory consumption (no backward pass):
|
||||
# compute loss for the full batch
|
||||
loss = criterion(output, torch.cat(target).cuda())
|
||||
|
||||
# record loss
|
||||
losses.update(loss.item()/nq, nq)
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if (i+1) % args.print_freq == 0 or i == 0 or (i+1) == len(val_loader):
|
||||
print('>> Val: [{0}][{1}/{2}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||
'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
|
||||
epoch+1, i+1, len(val_loader), batch_time=batch_time, loss=losses))
|
||||
|
||||
return losses.avg
|
||||
|
||||
def test(datasets, net):
|
||||
|
||||
print('>> Evaluating network on test datasets...')
|
||||
|
||||
# for testing we use image size of max 1024
|
||||
image_size = 1024
|
||||
|
||||
# moving network to gpu and eval mode
|
||||
net.cuda()
|
||||
net.eval()
|
||||
# set up the transform
|
||||
normalize = transforms.Normalize(
|
||||
mean=net.meta['mean'],
|
||||
std=net.meta['std']
|
||||
)
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
normalize
|
||||
])
|
||||
|
||||
# compute whitening
|
||||
if args.test_whiten:
|
||||
start = time.time()
|
||||
|
||||
print('>> {}: Learning whitening...'.format(args.test_whiten))
|
||||
|
||||
# loading db
|
||||
db_root = os.path.join(get_data_root(), 'train', args.test_whiten)
|
||||
ims_root = os.path.join(db_root, 'ims')
|
||||
db_fn = os.path.join(db_root, '{}-whiten.pkl'.format(args.test_whiten))
|
||||
with open(db_fn, 'rb') as f:
|
||||
db = pickle.load(f)
|
||||
images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))]
|
||||
|
||||
# extract whitening vectors
|
||||
print('>> {}: Extracting...'.format(args.test_whiten))
|
||||
wvecs = extract_vectors(net, images, image_size, transform) # implemented with torch.no_grad
|
||||
|
||||
# learning whitening
|
||||
print('>> {}: Learning...'.format(args.test_whiten))
|
||||
wvecs = wvecs.numpy()
|
||||
m, P = whitenlearn(wvecs, db['qidxs'], db['pidxs'])
|
||||
Lw = {'m': m, 'P': P}
|
||||
|
||||
print('>> {}: elapsed time: {}'.format(args.test_whiten, htime(time.time()-start)))
|
||||
else:
|
||||
Lw = None
|
||||
|
||||
# evaluate on test datasets
|
||||
datasets = args.test_datasets.split(',')
|
||||
for dataset in datasets:
|
||||
start = time.time()
|
||||
|
||||
print('>> {}: Extracting...'.format(dataset))
|
||||
|
||||
# prepare config structure for the test dataset
|
||||
cfg = configdataset(dataset, os.path.join(get_data_root(), 'test'))
|
||||
images = [cfg['im_fname'](cfg,i) for i in range(cfg['n'])]
|
||||
qimages = [cfg['qim_fname'](cfg,i) for i in range(cfg['nq'])]
|
||||
bbxs = [tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq'])]
|
||||
|
||||
# extract database and query vectors
|
||||
print('>> {}: database images...'.format(dataset))
|
||||
vecs = extract_vectors(net, images, image_size, transform) # implemented with torch.no_grad
|
||||
print('>> {}: query images...'.format(dataset))
|
||||
qvecs = extract_vectors(net, qimages, image_size, transform, bbxs) # implemented with torch.no_grad
|
||||
|
||||
print('>> {}: Evaluating...'.format(dataset))
|
||||
|
||||
# convert to numpy
|
||||
vecs = vecs.numpy()
|
||||
qvecs = qvecs.numpy()
|
||||
|
||||
# search, rank, and print
|
||||
scores = np.dot(vecs.T, qvecs)
|
||||
ranks = np.argsort(-scores, axis=0)
|
||||
compute_map_and_print(dataset, ranks, cfg['gnd'])
|
||||
|
||||
if Lw is not None:
|
||||
# whiten the vectors
|
||||
vecs_lw = whitenapply(vecs, Lw['m'], Lw['P'])
|
||||
qvecs_lw = whitenapply(qvecs, Lw['m'], Lw['P'])
|
||||
|
||||
# search, rank, and print
|
||||
scores = np.dot(vecs_lw.T, qvecs_lw)
|
||||
ranks = np.argsort(-scores, axis=0)
|
||||
compute_map_and_print(dataset + ' + whiten', ranks, cfg['gnd'])
|
||||
|
||||
print('>> {}: elapsed time: {}'.format(dataset, htime(time.time()-start)))
|
||||
|
||||
|
||||
def save_checkpoint(state, is_best, directory):
|
||||
filename = os.path.join(directory, 'model_epoch%d.pth.tar' % state['epoch'])
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
filename_best = os.path.join(directory, 'model_best.pth.tar')
|
||||
shutil.copyfile(filename, filename_best)
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def set_batchnorm_eval(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('BatchNorm') != -1:
|
||||
# freeze running mean and std:
|
||||
# we do training one image at a time
|
||||
# so the statistics would not be per batch
|
||||
# hence we choose freezing (ie using imagenet statistics)
|
||||
m.eval()
|
||||
# # freeze parameters:
|
||||
# # in fact no need to freeze scale and bias
|
||||
# # they can be learned
|
||||
# # that is why next two lines are commented
|
||||
# for p in m.parameters():
|
||||
# p.requires_grad = False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
0
cirtorch/layers/__init__.py
Executable file
0
cirtorch/layers/__init__.py
Executable file
172
cirtorch/layers/functional.py
Executable file
172
cirtorch/layers/functional.py
Executable file
@ -0,0 +1,172 @@
|
||||
import math
|
||||
import pdb
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# --------------------------------------
|
||||
# pooling
|
||||
# --------------------------------------
|
||||
|
||||
def mac(x):
|
||||
return F.max_pool2d(x, (x.size(-2), x.size(-1)))
|
||||
# return F.adaptive_max_pool2d(x, (1,1)) # alternative
|
||||
|
||||
|
||||
def spoc(x):
|
||||
return F.avg_pool2d(x, (x.size(-2), x.size(-1)))
|
||||
# return F.adaptive_avg_pool2d(x, (1,1)) # alternative
|
||||
|
||||
|
||||
def gem(x, p=3, eps=1e-6):
|
||||
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
|
||||
# return F.lp_pool2d(F.threshold(x, eps, eps), p, (x.size(-2), x.size(-1))) # alternative
|
||||
|
||||
|
||||
def rmac(x, L=3, eps=1e-6):
|
||||
ovr = 0.4 # desired overlap of neighboring regions
|
||||
steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension
|
||||
|
||||
W = x.size(3)
|
||||
H = x.size(2)
|
||||
|
||||
w = min(W, H)
|
||||
w2 = math.floor(w/2.0 - 1)
|
||||
|
||||
b = (max(H, W)-w)/(steps-1)
|
||||
(tmp, idx) = torch.min(torch.abs(((w**2 - w*b)/w**2)-ovr), 0) # steps(idx) regions for long dimension
|
||||
|
||||
# region overplus per dimension
|
||||
Wd = 0;
|
||||
Hd = 0;
|
||||
if H < W:
|
||||
Wd = idx.item() + 1
|
||||
elif H > W:
|
||||
Hd = idx.item() + 1
|
||||
|
||||
v = F.max_pool2d(x, (x.size(-2), x.size(-1)))
|
||||
v = v / (torch.norm(v, p=2, dim=1, keepdim=True) + eps).expand_as(v)
|
||||
|
||||
for l in range(1, L+1):
|
||||
wl = math.floor(2*w/(l+1))
|
||||
wl2 = math.floor(wl/2 - 1)
|
||||
|
||||
if l+Wd == 1:
|
||||
b = 0
|
||||
else:
|
||||
b = (W-wl)/(l+Wd-1)
|
||||
cenW = torch.floor(wl2 + torch.Tensor(range(l-1+Wd+1))*b) - wl2 # center coordinates
|
||||
if l+Hd == 1:
|
||||
b = 0
|
||||
else:
|
||||
b = (H-wl)/(l+Hd-1)
|
||||
cenH = torch.floor(wl2 + torch.Tensor(range(l-1+Hd+1))*b) - wl2 # center coordinates
|
||||
|
||||
for i_ in cenH.tolist():
|
||||
for j_ in cenW.tolist():
|
||||
if wl == 0:
|
||||
continue
|
||||
R = x[:,:,(int(i_)+torch.Tensor(range(wl)).long()).tolist(),:]
|
||||
R = R[:,:,:,(int(j_)+torch.Tensor(range(wl)).long()).tolist()]
|
||||
vt = F.max_pool2d(R, (R.size(-2), R.size(-1)))
|
||||
vt = vt / (torch.norm(vt, p=2, dim=1, keepdim=True) + eps).expand_as(vt)
|
||||
v += vt
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def roipool(x, rpool, L=3, eps=1e-6):
|
||||
ovr = 0.4 # desired overlap of neighboring regions
|
||||
steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension
|
||||
|
||||
W = x.size(3)
|
||||
H = x.size(2)
|
||||
|
||||
w = min(W, H)
|
||||
w2 = math.floor(w/2.0 - 1)
|
||||
|
||||
b = (max(H, W)-w)/(steps-1)
|
||||
_, idx = torch.min(torch.abs(((w**2 - w*b)/w**2)-ovr), 0) # steps(idx) regions for long dimension
|
||||
|
||||
# region overplus per dimension
|
||||
Wd = 0;
|
||||
Hd = 0;
|
||||
if H < W:
|
||||
Wd = idx.item() + 1
|
||||
elif H > W:
|
||||
Hd = idx.item() + 1
|
||||
|
||||
vecs = []
|
||||
vecs.append(rpool(x).unsqueeze(1))
|
||||
|
||||
for l in range(1, L+1):
|
||||
wl = math.floor(2*w/(l+1))
|
||||
wl2 = math.floor(wl/2 - 1)
|
||||
|
||||
if l+Wd == 1:
|
||||
b = 0
|
||||
else:
|
||||
b = (W-wl)/(l+Wd-1)
|
||||
cenW = torch.floor(wl2 + torch.Tensor(range(l-1+Wd+1))*b).int() - wl2 # center coordinates
|
||||
if l+Hd == 1:
|
||||
b = 0
|
||||
else:
|
||||
b = (H-wl)/(l+Hd-1)
|
||||
cenH = torch.floor(wl2 + torch.Tensor(range(l-1+Hd+1))*b).int() - wl2 # center coordinates
|
||||
|
||||
for i_ in cenH.tolist():
|
||||
for j_ in cenW.tolist():
|
||||
if wl == 0:
|
||||
continue
|
||||
vecs.append(rpool(x.narrow(2,i_,wl).narrow(3,j_,wl)).unsqueeze(1))
|
||||
|
||||
return torch.cat(vecs, dim=1)
|
||||
|
||||
|
||||
# --------------------------------------
|
||||
# normalization
|
||||
# --------------------------------------
|
||||
|
||||
def l2n(x, eps=1e-6):
|
||||
return x / (torch.norm(x, p=2, dim=1, keepdim=True) + eps).expand_as(x)
|
||||
|
||||
def powerlaw(x, eps=1e-6):
|
||||
x = x + self.eps
|
||||
return x.abs().sqrt().mul(x.sign())
|
||||
|
||||
# --------------------------------------
|
||||
# loss
|
||||
# --------------------------------------
|
||||
|
||||
def contrastive_loss(x, label, margin=0.7, eps=1e-6):
|
||||
# x is D x N
|
||||
dim = x.size(0) # D
|
||||
nq = torch.sum(label.data==-1) # number of tuples
|
||||
S = x.size(1) // nq # number of images per tuple including query: 1+1+n
|
||||
|
||||
x1 = x[:, ::S].permute(1,0).repeat(1,S-1).view((S-1)*nq,dim).permute(1,0)
|
||||
idx = [i for i in range(len(label)) if label.data[i] != -1]
|
||||
x2 = x[:, idx]
|
||||
lbl = label[label!=-1]
|
||||
|
||||
dif = x1 - x2
|
||||
D = torch.pow(dif+eps, 2).sum(dim=0).sqrt()
|
||||
|
||||
y = 0.5*lbl*torch.pow(D,2) + 0.5*(1-lbl)*torch.pow(torch.clamp(margin-D, min=0),2)
|
||||
y = torch.sum(y)
|
||||
return y
|
||||
|
||||
def triplet_loss(x, label, margin=0.1):
|
||||
# x is D x N
|
||||
dim = x.size(0) # D
|
||||
nq = torch.sum(label.data==-1).item() # number of tuples
|
||||
S = x.size(1) // nq # number of images per tuple including query: 1+1+n
|
||||
|
||||
xa = x[:, label.data==-1].permute(1,0).repeat(1,S-2).view((S-2)*nq,dim).permute(1,0)
|
||||
xp = x[:, label.data==1].permute(1,0).repeat(1,S-2).view((S-2)*nq,dim).permute(1,0)
|
||||
xn = x[:, label.data==0]
|
||||
|
||||
dist_pos = torch.sum(torch.pow(xa - xp, 2), dim=0)
|
||||
dist_neg = torch.sum(torch.pow(xa - xn, 2), dim=0)
|
||||
|
||||
return torch.sum(torch.clamp(dist_pos - dist_neg + margin, min=0))
|
48
cirtorch/layers/loss.py
Executable file
48
cirtorch/layers/loss.py
Executable file
@ -0,0 +1,48 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import cirtorch.layers.functional as LF
|
||||
|
||||
# --------------------------------------
|
||||
# Loss/Error layers
|
||||
# --------------------------------------
|
||||
|
||||
class ContrastiveLoss(nn.Module):
|
||||
r"""CONTRASTIVELOSS layer that computes contrastive loss for a batch of images:
|
||||
Q query tuples, each packed in the form of (q,p,n1,..nN)
|
||||
|
||||
Args:
|
||||
x: tuples arranges in columns as [q,p,n1,nN, ... ]
|
||||
label: -1 for query, 1 for corresponding positive, 0 for corresponding negative
|
||||
margin: contrastive loss margin. Default: 0.7
|
||||
|
||||
>>> contrastive_loss = ContrastiveLoss(margin=0.7)
|
||||
>>> input = torch.randn(128, 35, requires_grad=True)
|
||||
>>> label = torch.Tensor([-1, 1, 0, 0, 0, 0, 0] * 5)
|
||||
>>> output = contrastive_loss(input, label)
|
||||
>>> output.backward()
|
||||
"""
|
||||
|
||||
def __init__(self, margin=0.7, eps=1e-6):
|
||||
super(ContrastiveLoss, self).__init__()
|
||||
self.margin = margin
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x, label):
|
||||
return LF.contrastive_loss(x, label, margin=self.margin, eps=self.eps)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' + 'margin=' + '{:.4f}'.format(self.margin) + ')'
|
||||
|
||||
|
||||
class TripletLoss(nn.Module):
|
||||
|
||||
def __init__(self, margin=0.1):
|
||||
super(TripletLoss, self).__init__()
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, x, label):
|
||||
return LF.triplet_loss(x, label, margin=self.margin)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' + 'margin=' + '{:.4f}'.format(self.margin) + ')'
|
33
cirtorch/layers/normalization.py
Executable file
33
cirtorch/layers/normalization.py
Executable file
@ -0,0 +1,33 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import cirtorch.layers.functional as LF
|
||||
|
||||
# --------------------------------------
|
||||
# Normalization layers
|
||||
# --------------------------------------
|
||||
|
||||
class L2N(nn.Module):
|
||||
|
||||
def __init__(self, eps=1e-6):
|
||||
super(L2N,self).__init__()
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return LF.l2n(x, eps=self.eps)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' + 'eps=' + str(self.eps) + ')'
|
||||
|
||||
|
||||
class PowerLaw(nn.Module):
|
||||
|
||||
def __init__(self, eps=1e-6):
|
||||
super(PowerLaw, self).__init__()
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return LF.powerlaw(x, eps=self.eps)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' + 'eps=' + str(self.eps) + ')'
|
113
cirtorch/layers/pooling.py
Executable file
113
cirtorch/layers/pooling.py
Executable file
@ -0,0 +1,113 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import cirtorch.layers.functional as LF
|
||||
from cirtorch.layers.normalization import L2N
|
||||
|
||||
# --------------------------------------
|
||||
# Pooling layers
|
||||
# --------------------------------------
|
||||
|
||||
class MAC(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MAC,self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return LF.mac(x)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '()'
|
||||
|
||||
|
||||
class SPoC(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(SPoC,self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return LF.spoc(x)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '()'
|
||||
|
||||
|
||||
class GeM(nn.Module):
|
||||
|
||||
def __init__(self, p=3, eps=1e-6):
|
||||
super(GeM,self).__init__()
|
||||
self.p = Parameter(torch.ones(1)*p)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return LF.gem(x, p=self.p, eps=self.eps)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'
|
||||
|
||||
class GeMmp(nn.Module):
|
||||
|
||||
def __init__(self, p=3, mp=1, eps=1e-6):
|
||||
super(GeMmp,self).__init__()
|
||||
self.p = Parameter(torch.ones(mp)*p)
|
||||
self.mp = mp
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return LF.gem(x, p=self.p.unsqueeze(-1).unsqueeze(-1), eps=self.eps)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' + 'p=' + '[{}]'.format(self.mp) + ', ' + 'eps=' + str(self.eps) + ')'
|
||||
|
||||
class RMAC(nn.Module):
|
||||
|
||||
def __init__(self, L=3, eps=1e-6):
|
||||
super(RMAC,self).__init__()
|
||||
self.L = L
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return LF.rmac(x, L=self.L, eps=self.eps)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' + 'L=' + '{}'.format(self.L) + ')'
|
||||
|
||||
|
||||
class Rpool(nn.Module):
|
||||
|
||||
def __init__(self, rpool, whiten=None, L=3, eps=1e-6):
|
||||
super(Rpool,self).__init__()
|
||||
self.rpool = rpool
|
||||
self.L = L
|
||||
self.whiten = whiten
|
||||
self.norm = L2N()
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x, aggregate=True):
|
||||
# features -> roipool
|
||||
o = LF.roipool(x, self.rpool, self.L, self.eps) # size: #im, #reg, D, 1, 1
|
||||
|
||||
# concatenate regions from all images in the batch
|
||||
s = o.size()
|
||||
o = o.view(s[0]*s[1], s[2], s[3], s[4]) # size: #im x #reg, D, 1, 1
|
||||
|
||||
# rvecs -> norm
|
||||
o = self.norm(o)
|
||||
|
||||
# rvecs -> whiten -> norm
|
||||
if self.whiten is not None:
|
||||
o = self.norm(self.whiten(o.squeeze(-1).squeeze(-1)))
|
||||
|
||||
# reshape back to regions per image
|
||||
o = o.view(s[0], s[1], s[2], s[3], s[4]) # size: #im, #reg, D, 1, 1
|
||||
|
||||
# aggregate regions into a single global vector per image
|
||||
if aggregate:
|
||||
# rvecs -> sumpool -> norm
|
||||
o = self.norm(o.sum(1, keepdim=False)) # size: #im, D, 1, 1
|
||||
|
||||
return o
|
||||
|
||||
def __repr__(self):
|
||||
return super(Rpool, self).__repr__() + '(' + 'L=' + '{}'.format(self.L) + ')'
|
0
cirtorch/networks/__init__.py
Executable file
0
cirtorch/networks/__init__.py
Executable file
427
cirtorch/networks/imageretrievalnet.py
Executable file
427
cirtorch/networks/imageretrievalnet.py
Executable file
@ -0,0 +1,427 @@
|
||||
import os
|
||||
import pdb
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
import torchvision
|
||||
|
||||
from cirtorch.layers.pooling import MAC, SPoC, GeM, GeMmp, RMAC, Rpool
|
||||
from cirtorch.layers.normalization import L2N, PowerLaw
|
||||
from cirtorch.datasets.genericdataset import ImagesFromList
|
||||
from cirtorch.utils.general import get_data_root
|
||||
from cirtorch.datasets.datahelpers import default_loader, imresize
|
||||
from PIL import Image
|
||||
#from ModelHelper.Common.CommonUtils.ImageAugmentation import Padding
|
||||
import cv2
|
||||
|
||||
# for some models, we have imported features (convolutions) from caffe because the image retrieval performance is higher for them
|
||||
FEATURES = {
|
||||
'vgg16': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-vgg16-features-d369c8e.pth',
|
||||
'resnet50': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet50-features-ac468af.pth',
|
||||
'resnet101': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet101-features-10a101d.pth',
|
||||
'resnet152': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet152-features-1011020.pth',
|
||||
}
|
||||
|
||||
# TODO: pre-compute for more architectures and properly test variations (pre l2norm, post l2norm)
|
||||
# pre-computed local pca whitening that can be applied before the pooling layer
|
||||
L_WHITENING = {
|
||||
'resnet101': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-lwhiten-9f830ef.pth',
|
||||
# no pre l2 norm
|
||||
# 'resnet101' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-lwhiten-da5c935.pth', # with pre l2 norm
|
||||
}
|
||||
|
||||
# possible global pooling layers, each on of these can be made regional
|
||||
POOLING = {
|
||||
'mac': MAC,
|
||||
'spoc': SPoC,
|
||||
'gem': GeM,
|
||||
'gemmp': GeMmp,
|
||||
'rmac': RMAC,
|
||||
}
|
||||
|
||||
# TODO: pre-compute for: resnet50-gem-r, resnet50-mac-r, vgg16-mac-r, alexnet-mac-r
|
||||
# pre-computed regional whitening, for most commonly used architectures and pooling methods
|
||||
R_WHITENING = {
|
||||
'alexnet-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-r-rwhiten-c8cf7e2.pth',
|
||||
'vgg16-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-rwhiten-19b204e.pth',
|
||||
'resnet101-mac-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-mac-r-rwhiten-7f1ed8c.pth',
|
||||
'resnet101-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-r-rwhiten-adace84.pth',
|
||||
}
|
||||
|
||||
# TODO: pre-compute for more architectures
|
||||
# pre-computed final (global) whitening, for most commonly used architectures and pooling methods
|
||||
WHITENING = {
|
||||
'alexnet-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-whiten-454ad53.pth',
|
||||
'alexnet-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-r-whiten-4c9126b.pth',
|
||||
'vgg16-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-whiten-eaa6695.pth',
|
||||
'vgg16-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-whiten-83582df.pth',
|
||||
'resnet50-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet50-gem-whiten-f15da7b.pth',
|
||||
'resnet101-mac-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-mac-r-whiten-9df41d3.pth',
|
||||
'resnet101-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-whiten-22ab0c1.pth',
|
||||
'resnet101-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-r-whiten-b379c0a.pth',
|
||||
'resnet101-gemmp': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gemmp-whiten-770f53c.pth',
|
||||
'resnet152-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet152-gem-whiten-abe7b93.pth',
|
||||
'densenet121-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet121-gem-whiten-79e3eea.pth',
|
||||
'densenet169-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet169-gem-whiten-6b2a76a.pth',
|
||||
'densenet201-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet201-gem-whiten-22ea45c.pth',
|
||||
}
|
||||
|
||||
# output dimensionality for supported architectures
|
||||
OUTPUT_DIM = {
|
||||
'alexnet': 256,
|
||||
'vgg11': 512,
|
||||
'vgg13': 512,
|
||||
'vgg16': 512,
|
||||
'vgg19': 512,
|
||||
'resnet18': 512,
|
||||
'resnet34': 512,
|
||||
'resnet50': 2048,
|
||||
'resnet101': 2048,
|
||||
'resnet152': 2048,
|
||||
'densenet121': 1024,
|
||||
'densenet169': 1664,
|
||||
'densenet201': 1920,
|
||||
'densenet161': 2208, # largest densenet
|
||||
'squeezenet1_0': 512,
|
||||
'squeezenet1_1': 512,
|
||||
}
|
||||
|
||||
|
||||
class ImageRetrievalNet(nn.Module):
|
||||
|
||||
def __init__(self, features, lwhiten, pool, whiten, meta):
|
||||
super(ImageRetrievalNet, self).__init__()
|
||||
self.features = nn.Sequential(*features)
|
||||
self.lwhiten = lwhiten
|
||||
self.pool = pool
|
||||
self.whiten = whiten
|
||||
self.norm = L2N()
|
||||
self.meta = meta
|
||||
|
||||
def forward(self, x):
|
||||
# x -> features
|
||||
o = self.features(x)
|
||||
|
||||
# TODO: properly test (with pre-l2norm and/or post-l2norm)
|
||||
# if lwhiten exist: features -> local whiten
|
||||
if self.lwhiten is not None:
|
||||
# o = self.norm(o)
|
||||
s = o.size()
|
||||
o = o.permute(0, 2, 3, 1).contiguous().view(-1, s[1])
|
||||
o = self.lwhiten(o)
|
||||
o = o.view(s[0], s[2], s[3], self.lwhiten.out_features).permute(0, 3, 1, 2)
|
||||
# o = self.norm(o)
|
||||
|
||||
# features -> pool -> norm
|
||||
o = self.norm(self.pool(o)).squeeze(-1).squeeze(-1)
|
||||
|
||||
# if whiten exist: pooled features -> whiten -> norm
|
||||
if self.whiten is not None:
|
||||
o = self.norm(self.whiten(o))
|
||||
|
||||
# permute so that it is Dx1 column vector per image (DxN if many images)
|
||||
return o.permute(1, 0)
|
||||
|
||||
def __repr__(self):
|
||||
tmpstr = super(ImageRetrievalNet, self).__repr__()[:-1]
|
||||
tmpstr += self.meta_repr()
|
||||
tmpstr = tmpstr + ')'
|
||||
return tmpstr
|
||||
|
||||
def meta_repr(self):
|
||||
tmpstr = ' (' + 'meta' + '): dict( \n' # + self.meta.__repr__() + '\n'
|
||||
tmpstr += ' architecture: {}\n'.format(self.meta['architecture'])
|
||||
tmpstr += ' local_whitening: {}\n'.format(self.meta['local_whitening'])
|
||||
tmpstr += ' pooling: {}\n'.format(self.meta['pooling'])
|
||||
tmpstr += ' regional: {}\n'.format(self.meta['regional'])
|
||||
tmpstr += ' whitening: {}\n'.format(self.meta['whitening'])
|
||||
tmpstr += ' outputdim: {}\n'.format(self.meta['outputdim'])
|
||||
tmpstr += ' mean: {}\n'.format(self.meta['mean'])
|
||||
tmpstr += ' std: {}\n'.format(self.meta['std'])
|
||||
tmpstr = tmpstr + ' )\n'
|
||||
return tmpstr
|
||||
|
||||
|
||||
def init_network(params):
|
||||
# parse params with default values
|
||||
architecture = params.get('architecture', 'resnet101')
|
||||
local_whitening = params.get('local_whitening', False)
|
||||
pooling = params.get('pooling', 'gem')
|
||||
regional = params.get('regional', False)
|
||||
whitening = params.get('whitening', False)
|
||||
mean = params.get('mean', [0.485, 0.456, 0.406])
|
||||
std = params.get('std', [0.229, 0.224, 0.225])
|
||||
pretrained = params.get('pretrained', True)
|
||||
|
||||
# get output dimensionality size
|
||||
dim = OUTPUT_DIM[architecture]
|
||||
|
||||
# loading network from torchvision
|
||||
if pretrained:
|
||||
if architecture not in FEATURES:
|
||||
# initialize with network pretrained on imagenet in pytorch
|
||||
net_in = getattr(torchvision.models, architecture)(pretrained=True)
|
||||
else:
|
||||
# initialize with random weights, later on we will fill features with custom pretrained network
|
||||
net_in = getattr(torchvision.models, architecture)(pretrained=False)
|
||||
else:
|
||||
# initialize with random weights
|
||||
net_in = getattr(torchvision.models, architecture)(pretrained=False)
|
||||
|
||||
# initialize features
|
||||
# take only convolutions for features,
|
||||
# always ends with ReLU to make last activations non-negative
|
||||
if architecture.startswith('alexnet'):
|
||||
features = list(net_in.features.children())[:-1]
|
||||
elif architecture.startswith('vgg'):
|
||||
features = list(net_in.features.children())[:-1]
|
||||
elif architecture.startswith('resnet'):
|
||||
features = list(net_in.children())[:-2]
|
||||
elif architecture.startswith('densenet'):
|
||||
features = list(net_in.features.children())
|
||||
features.append(nn.ReLU(inplace=True))
|
||||
elif architecture.startswith('squeezenet'):
|
||||
features = list(net_in.features.children())
|
||||
else:
|
||||
raise ValueError('Unsupported or unknown architecture: {}!'.format(architecture))
|
||||
|
||||
# initialize local whitening
|
||||
if local_whitening:
|
||||
lwhiten = nn.Linear(dim, dim, bias=True)
|
||||
# TODO: lwhiten with possible dimensionality reduce
|
||||
|
||||
if pretrained:
|
||||
lw = architecture
|
||||
if lw in L_WHITENING:
|
||||
print(">> {}: for '{}' custom computed local whitening '{}' is used"
|
||||
.format(os.path.basename(__file__), lw, os.path.basename(L_WHITENING[lw])))
|
||||
whiten_dir = os.path.join(get_data_root(), 'whiten')
|
||||
lwhiten.load_state_dict(model_zoo.load_url(L_WHITENING[lw], model_dir=whiten_dir))
|
||||
else:
|
||||
print(">> {}: for '{}' there is no local whitening computed, random weights are used"
|
||||
.format(os.path.basename(__file__), lw))
|
||||
|
||||
else:
|
||||
lwhiten = None
|
||||
|
||||
# initialize pooling
|
||||
if pooling == 'gemmp':
|
||||
pool = POOLING[pooling](mp=dim)
|
||||
else:
|
||||
pool = POOLING[pooling]()
|
||||
|
||||
# initialize regional pooling
|
||||
if regional:
|
||||
rpool = pool
|
||||
rwhiten = nn.Linear(dim, dim, bias=True)
|
||||
# TODO: rwhiten with possible dimensionality reduce
|
||||
|
||||
if pretrained:
|
||||
rw = '{}-{}-r'.format(architecture, pooling)
|
||||
if rw in R_WHITENING:
|
||||
print(">> {}: for '{}' custom computed regional whitening '{}' is used"
|
||||
.format(os.path.basename(__file__), rw, os.path.basename(R_WHITENING[rw])))
|
||||
whiten_dir = os.path.join(get_data_root(), 'whiten')
|
||||
rwhiten.load_state_dict(model_zoo.load_url(R_WHITENING[rw], model_dir=whiten_dir))
|
||||
else:
|
||||
print(">> {}: for '{}' there is no regional whitening computed, random weights are used"
|
||||
.format(os.path.basename(__file__), rw))
|
||||
|
||||
pool = Rpool(rpool, rwhiten)
|
||||
|
||||
# initialize whitening
|
||||
if whitening:
|
||||
whiten = nn.Linear(dim, dim, bias=True)
|
||||
# TODO: whiten with possible dimensionality reduce
|
||||
|
||||
if pretrained:
|
||||
w = architecture
|
||||
if local_whitening:
|
||||
w += '-lw'
|
||||
w += '-' + pooling
|
||||
if regional:
|
||||
w += '-r'
|
||||
if w in WHITENING:
|
||||
print(">> {}: for '{}' custom computed whitening '{}' is used"
|
||||
.format(os.path.basename(__file__), w, os.path.basename(WHITENING[w])))
|
||||
whiten_dir = os.path.join(get_data_root(), 'whiten')
|
||||
whiten.load_state_dict(model_zoo.load_url(WHITENING[w], model_dir=whiten_dir))
|
||||
else:
|
||||
print(">> {}: for '{}' there is no whitening computed, random weights are used"
|
||||
.format(os.path.basename(__file__), w))
|
||||
else:
|
||||
whiten = None
|
||||
|
||||
# create meta information to be stored in the network
|
||||
meta = {
|
||||
'architecture': architecture,
|
||||
'local_whitening': local_whitening,
|
||||
'pooling': pooling,
|
||||
'regional': regional,
|
||||
'whitening': whitening,
|
||||
'mean': mean,
|
||||
'std': std,
|
||||
'outputdim': dim,
|
||||
}
|
||||
|
||||
# create a generic image retrieval network
|
||||
net = ImageRetrievalNet(features, lwhiten, pool, whiten, meta)
|
||||
|
||||
# initialize features with custom pretrained network if needed
|
||||
if pretrained and architecture in FEATURES:
|
||||
print(">> {}: for '{}' custom pretrained features '{}' are used"
|
||||
.format(os.path.basename(__file__), architecture, os.path.basename(FEATURES[architecture])))
|
||||
model_dir = os.path.join(get_data_root(), 'networks')
|
||||
net.features.load_state_dict(model_zoo.load_url(FEATURES[architecture], model_dir=model_dir))
|
||||
|
||||
return net
|
||||
|
||||
def extract_vectors(net, images, image_size, transform, bbxs=None, ms=[1], msp=1, print_freq=10):
|
||||
# moving network to gpu and eval mode
|
||||
if torch.cuda.is_available():
|
||||
net.cuda()
|
||||
net.eval()
|
||||
|
||||
# creating dataset loader
|
||||
loader = torch.utils.data.DataLoader(
|
||||
ImagesFromList(root='', images=images, imsize=image_size, bbxs=bbxs, transform=transform),
|
||||
batch_size=1, shuffle=False, num_workers=1, pin_memory=True
|
||||
)
|
||||
|
||||
# extracting vectors
|
||||
with torch.no_grad():
|
||||
vecs = torch.zeros(net.meta['outputdim'], len(images))
|
||||
img_paths = list()
|
||||
for i, (input, path) in enumerate(loader):
|
||||
#print(i)
|
||||
if torch.cuda.is_available():
|
||||
input = input.cuda()
|
||||
|
||||
if len(ms) == 1 and ms[0] == 1:
|
||||
vecs[:, i] = extract_ss(net, input)
|
||||
else:
|
||||
vecs[:, i] = extract_ms(net, input, ms, msp)
|
||||
img_paths.append(path)
|
||||
|
||||
if (i + 1) % print_freq == 0 or (i + 1) == len(images):
|
||||
print('\r>>>> {}/{} done...'.format((i + 1), len(images)), end='')
|
||||
imgs = list()
|
||||
for one in img_paths:
|
||||
imgs += one
|
||||
return vecs, imgs
|
||||
|
||||
def extract_vectors_o(net, image, size, tranform, bbxs = None, ms=[1], msp = 1, print_freq=10):
|
||||
if torch.cuda.is_available():
|
||||
net.cuda()
|
||||
net.eval()
|
||||
#image = cv2.resize(image, (size, size))
|
||||
if type(image) == np.ndarray:
|
||||
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
image = imresize(image, size)
|
||||
image = tranform(image)
|
||||
image = image.unsqueeze(0)
|
||||
#print('image>>>>>>>', image)
|
||||
#print('image>>>>>>>', image.shape)
|
||||
with torch.no_grad():
|
||||
#vecs = torch.zeros(net.meta['outputdim'], len(image))
|
||||
if torch.cuda.is_available():
|
||||
image = image.cuda()
|
||||
if len(ms) == 1 and ms[0] == 1:
|
||||
vecs = extract_ss(net, image)
|
||||
else:
|
||||
vecs = extract_ms(net, image, ms, msp)
|
||||
return vecs
|
||||
|
||||
def extract_ss(net, input):
|
||||
#return net(input).cpu().data.squeeze()
|
||||
return net(input).cuda().data.squeeze()
|
||||
|
||||
|
||||
def extract_ms(net, input, ms, msp):
|
||||
v = torch.zeros(net.meta['outputdim'])
|
||||
|
||||
for s in ms:
|
||||
if s == 1:
|
||||
input_t = input.clone()
|
||||
else:
|
||||
input_t = nn.functional.interpolate(input, scale_factor=s, mode='bilinear', align_corners=False)
|
||||
v += net(input_t).pow(msp).cpu().data.squeeze()
|
||||
|
||||
v /= len(ms)
|
||||
v = v.pow(1. / msp)
|
||||
v /= v.norm()
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def extract_regional_vectors(net, images, image_size, transform, bbxs=None, ms=[1], msp=1, print_freq=10):
|
||||
# moving network to gpu and eval mode
|
||||
net.cuda()
|
||||
net.eval()
|
||||
|
||||
# creating dataset loader
|
||||
loader = torch.utils.data.DataLoader(
|
||||
ImagesFromList(root='', images=images, imsize=image_size, bbxs=bbxs, transform=transform),
|
||||
batch_size=1, shuffle=False, num_workers=8, pin_memory=True
|
||||
)
|
||||
|
||||
# extracting vectors
|
||||
with torch.no_grad():
|
||||
vecs = []
|
||||
for i, input in enumerate(loader):
|
||||
input = input.cuda()
|
||||
|
||||
if len(ms) == 1:
|
||||
vecs.append(extract_ssr(net, input))
|
||||
else:
|
||||
# TODO: not implemented yet
|
||||
# vecs.append(extract_msr(net, input, ms, msp))
|
||||
raise NotImplementedError
|
||||
|
||||
if (i + 1) % print_freq == 0 or (i + 1) == len(images):
|
||||
print('\r>>>> {}/{} done...'.format((i + 1), len(images)), end='')
|
||||
print('')
|
||||
|
||||
return vecs
|
||||
|
||||
|
||||
def extract_ssr(net, input):
|
||||
return net.pool(net.features(input), aggregate=False).squeeze(0).squeeze(-1).squeeze(-1).permute(1, 0).cpu().data
|
||||
|
||||
|
||||
def extract_local_vectors(net, images, image_size, transform, bbxs=None, ms=[1], msp=1, print_freq=10):
|
||||
# moving network to gpu and eval mode
|
||||
net.cuda()
|
||||
net.eval()
|
||||
|
||||
# creating dataset loader
|
||||
loader = torch.utils.data.DataLoader(
|
||||
ImagesFromList(root='', images=images, imsize=image_size, bbxs=bbxs, transform=transform),
|
||||
batch_size=1, shuffle=False, num_workers=8, pin_memory=True
|
||||
)
|
||||
|
||||
# extracting vectors
|
||||
with torch.no_grad():
|
||||
vecs = []
|
||||
for i, input in enumerate(loader):
|
||||
input = input.cuda()
|
||||
|
||||
if len(ms) == 1:
|
||||
vecs.append(extract_ssl(net, input))
|
||||
else:
|
||||
# TODO: not implemented yet
|
||||
# vecs.append(extract_msl(net, input, ms, msp))
|
||||
raise NotImplementedError
|
||||
|
||||
if (i + 1) % print_freq == 0 or (i + 1) == len(images):
|
||||
print('\r>>>> {}/{} done...'.format((i + 1), len(images)), end='')
|
||||
print('')
|
||||
|
||||
return vecs
|
||||
|
||||
|
||||
def extract_ssl(net, input):
|
||||
return net.norm(net.features(input)).squeeze(0).view(net.meta['outputdim'], -1).cpu().data
|
392
cirtorch/networks/imageretrievalnet_cpu.py
Normal file
392
cirtorch/networks/imageretrievalnet_cpu.py
Normal file
@ -0,0 +1,392 @@
|
||||
import os
|
||||
import pdb
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
import torchvision
|
||||
|
||||
from cirtorch.layers.pooling import MAC, SPoC, GeM, GeMmp, RMAC, Rpool
|
||||
from cirtorch.layers.normalization import L2N, PowerLaw
|
||||
from cirtorch.datasets.genericdataset import ImagesFromList
|
||||
from cirtorch.utils.general import get_data_root
|
||||
|
||||
# for some models, we have imported features (convolutions) from caffe because the image retrieval performance is higher for them
|
||||
FEATURES = {
|
||||
'vgg16' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-vgg16-features-d369c8e.pth',
|
||||
'resnet50' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet50-features-ac468af.pth',
|
||||
'resnet101' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet101-features-10a101d.pth',
|
||||
'resnet152' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet152-features-1011020.pth',
|
||||
}
|
||||
|
||||
# TODO: pre-compute for more architectures and properly test variations (pre l2norm, post l2norm)
|
||||
# pre-computed local pca whitening that can be applied before the pooling layer
|
||||
L_WHITENING = {
|
||||
'resnet101' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-lwhiten-9f830ef.pth', # no pre l2 norm
|
||||
# 'resnet101' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-lwhiten-da5c935.pth', # with pre l2 norm
|
||||
}
|
||||
|
||||
# possible global pooling layers, each on of these can be made regional
|
||||
POOLING = {
|
||||
'mac' : MAC,
|
||||
'spoc' : SPoC,
|
||||
'gem' : GeM,
|
||||
'gemmp' : GeMmp,
|
||||
'rmac' : RMAC,
|
||||
}
|
||||
|
||||
# TODO: pre-compute for: resnet50-gem-r, resnet50-mac-r, vgg16-mac-r, alexnet-mac-r
|
||||
# pre-computed regional whitening, for most commonly used architectures and pooling methods
|
||||
R_WHITENING = {
|
||||
'alexnet-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-r-rwhiten-c8cf7e2.pth',
|
||||
'vgg16-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-rwhiten-19b204e.pth',
|
||||
'resnet101-mac-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-mac-r-rwhiten-7f1ed8c.pth',
|
||||
'resnet101-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-r-rwhiten-adace84.pth',
|
||||
}
|
||||
|
||||
# TODO: pre-compute for more architectures
|
||||
# pre-computed final (global) whitening, for most commonly used architectures and pooling methods
|
||||
WHITENING = {
|
||||
'alexnet-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-whiten-454ad53.pth',
|
||||
'alexnet-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-r-whiten-4c9126b.pth',
|
||||
'vgg16-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-whiten-eaa6695.pth',
|
||||
'vgg16-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-whiten-83582df.pth',
|
||||
'resnet50-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet50-gem-whiten-f15da7b.pth',
|
||||
'resnet101-mac-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-mac-r-whiten-9df41d3.pth',
|
||||
'resnet101-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-whiten-22ab0c1.pth',
|
||||
'resnet101-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-r-whiten-b379c0a.pth',
|
||||
'resnet101-gemmp' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gemmp-whiten-770f53c.pth',
|
||||
'resnet152-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet152-gem-whiten-abe7b93.pth',
|
||||
'densenet121-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet121-gem-whiten-79e3eea.pth',
|
||||
'densenet169-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet169-gem-whiten-6b2a76a.pth',
|
||||
'densenet201-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet201-gem-whiten-22ea45c.pth',
|
||||
}
|
||||
|
||||
# output dimensionality for supported architectures
|
||||
OUTPUT_DIM = {
|
||||
'alexnet' : 256,
|
||||
'vgg11' : 512,
|
||||
'vgg13' : 512,
|
||||
'vgg16' : 512,
|
||||
'vgg19' : 512,
|
||||
'resnet18' : 512,
|
||||
'resnet34' : 512,
|
||||
'resnet50' : 2048,
|
||||
'resnet101' : 2048,
|
||||
'resnet152' : 2048,
|
||||
'densenet121' : 1024,
|
||||
'densenet169' : 1664,
|
||||
'densenet201' : 1920,
|
||||
'densenet161' : 2208, # largest densenet
|
||||
'squeezenet1_0' : 512,
|
||||
'squeezenet1_1' : 512,
|
||||
}
|
||||
|
||||
|
||||
class ImageRetrievalNet(nn.Module):
|
||||
|
||||
def __init__(self, features, lwhiten, pool, whiten, meta):
|
||||
super(ImageRetrievalNet, self).__init__()
|
||||
self.features = nn.Sequential(*features)
|
||||
self.lwhiten = lwhiten
|
||||
self.pool = pool
|
||||
self.whiten = whiten
|
||||
self.norm = L2N()
|
||||
self.meta = meta
|
||||
|
||||
def forward(self, x):
|
||||
# x -> features
|
||||
o = self.features(x)
|
||||
|
||||
# TODO: properly test (with pre-l2norm and/or post-l2norm)
|
||||
# if lwhiten exist: features -> local whiten
|
||||
if self.lwhiten is not None:
|
||||
# o = self.norm(o)
|
||||
s = o.size()
|
||||
o = o.permute(0,2,3,1).contiguous().view(-1, s[1])
|
||||
o = self.lwhiten(o)
|
||||
o = o.view(s[0],s[2],s[3],self.lwhiten.out_features).permute(0,3,1,2)
|
||||
# o = self.norm(o)
|
||||
|
||||
# features -> pool -> norm
|
||||
o = self.norm(self.pool(o)).squeeze(-1).squeeze(-1)
|
||||
|
||||
# if whiten exist: pooled features -> whiten -> norm
|
||||
if self.whiten is not None:
|
||||
o = self.norm(self.whiten(o))
|
||||
|
||||
# permute so that it is Dx1 column vector per image (DxN if many images)
|
||||
return o.permute(1,0)
|
||||
|
||||
def __repr__(self):
|
||||
tmpstr = super(ImageRetrievalNet, self).__repr__()[:-1]
|
||||
tmpstr += self.meta_repr()
|
||||
tmpstr = tmpstr + ')'
|
||||
return tmpstr
|
||||
|
||||
def meta_repr(self):
|
||||
tmpstr = ' (' + 'meta' + '): dict( \n' # + self.meta.__repr__() + '\n'
|
||||
tmpstr += ' architecture: {}\n'.format(self.meta['architecture'])
|
||||
tmpstr += ' local_whitening: {}\n'.format(self.meta['local_whitening'])
|
||||
tmpstr += ' pooling: {}\n'.format(self.meta['pooling'])
|
||||
tmpstr += ' regional: {}\n'.format(self.meta['regional'])
|
||||
tmpstr += ' whitening: {}\n'.format(self.meta['whitening'])
|
||||
tmpstr += ' outputdim: {}\n'.format(self.meta['outputdim'])
|
||||
tmpstr += ' mean: {}\n'.format(self.meta['mean'])
|
||||
tmpstr += ' std: {}\n'.format(self.meta['std'])
|
||||
tmpstr = tmpstr + ' )\n'
|
||||
return tmpstr
|
||||
|
||||
|
||||
def init_network(params):
|
||||
|
||||
# parse params with default values
|
||||
architecture = params.get('architecture', 'resnet101')
|
||||
local_whitening = params.get('local_whitening', False)
|
||||
pooling = params.get('pooling', 'gem')
|
||||
regional = params.get('regional', False)
|
||||
whitening = params.get('whitening', False)
|
||||
mean = params.get('mean', [0.485, 0.456, 0.406])
|
||||
std = params.get('std', [0.229, 0.224, 0.225])
|
||||
pretrained = params.get('pretrained', True)
|
||||
|
||||
# get output dimensionality size
|
||||
dim = OUTPUT_DIM[architecture]
|
||||
|
||||
# loading network from torchvision
|
||||
if pretrained:
|
||||
if architecture not in FEATURES:
|
||||
# initialize with network pretrained on imagenet in pytorch
|
||||
net_in = getattr(torchvision.models, architecture)(pretrained=True)
|
||||
else:
|
||||
# initialize with random weights, later on we will fill features with custom pretrained network
|
||||
net_in = getattr(torchvision.models, architecture)(pretrained=False)
|
||||
else:
|
||||
# initialize with random weights
|
||||
net_in = getattr(torchvision.models, architecture)(pretrained=False)
|
||||
|
||||
# initialize features
|
||||
# take only convolutions for features,
|
||||
# always ends with ReLU to make last activations non-negative
|
||||
if architecture.startswith('alexnet'):
|
||||
features = list(net_in.features.children())[:-1]
|
||||
elif architecture.startswith('vgg'):
|
||||
features = list(net_in.features.children())[:-1]
|
||||
elif architecture.startswith('resnet'):
|
||||
features = list(net_in.children())[:-2]
|
||||
elif architecture.startswith('densenet'):
|
||||
features = list(net_in.features.children())
|
||||
features.append(nn.ReLU(inplace=True))
|
||||
elif architecture.startswith('squeezenet'):
|
||||
features = list(net_in.features.children())
|
||||
else:
|
||||
raise ValueError('Unsupported or unknown architecture: {}!'.format(architecture))
|
||||
|
||||
# initialize local whitening
|
||||
if local_whitening:
|
||||
lwhiten = nn.Linear(dim, dim, bias=True)
|
||||
# TODO: lwhiten with possible dimensionality reduce
|
||||
|
||||
if pretrained:
|
||||
lw = architecture
|
||||
if lw in L_WHITENING:
|
||||
print(">> {}: for '{}' custom computed local whitening '{}' is used"
|
||||
.format(os.path.basename(__file__), lw, os.path.basename(L_WHITENING[lw])))
|
||||
whiten_dir = os.path.join(get_data_root(), 'whiten')
|
||||
lwhiten.load_state_dict(model_zoo.load_url(L_WHITENING[lw], model_dir=whiten_dir))
|
||||
else:
|
||||
print(">> {}: for '{}' there is no local whitening computed, random weights are used"
|
||||
.format(os.path.basename(__file__), lw))
|
||||
|
||||
else:
|
||||
lwhiten = None
|
||||
|
||||
# initialize pooling
|
||||
if pooling == 'gemmp':
|
||||
pool = POOLING[pooling](mp=dim)
|
||||
else:
|
||||
pool = POOLING[pooling]()
|
||||
|
||||
# initialize regional pooling
|
||||
if regional:
|
||||
rpool = pool
|
||||
rwhiten = nn.Linear(dim, dim, bias=True)
|
||||
# TODO: rwhiten with possible dimensionality reduce
|
||||
|
||||
if pretrained:
|
||||
rw = '{}-{}-r'.format(architecture, pooling)
|
||||
if rw in R_WHITENING:
|
||||
print(">> {}: for '{}' custom computed regional whitening '{}' is used"
|
||||
.format(os.path.basename(__file__), rw, os.path.basename(R_WHITENING[rw])))
|
||||
whiten_dir = os.path.join(get_data_root(), 'whiten')
|
||||
rwhiten.load_state_dict(model_zoo.load_url(R_WHITENING[rw], model_dir=whiten_dir))
|
||||
else:
|
||||
print(">> {}: for '{}' there is no regional whitening computed, random weights are used"
|
||||
.format(os.path.basename(__file__), rw))
|
||||
|
||||
pool = Rpool(rpool, rwhiten)
|
||||
|
||||
# initialize whitening
|
||||
if whitening:
|
||||
whiten = nn.Linear(dim, dim, bias=True)
|
||||
# TODO: whiten with possible dimensionality reduce
|
||||
|
||||
if pretrained:
|
||||
w = architecture
|
||||
if local_whitening:
|
||||
w += '-lw'
|
||||
w += '-' + pooling
|
||||
if regional:
|
||||
w += '-r'
|
||||
if w in WHITENING:
|
||||
print(">> {}: for '{}' custom computed whitening '{}' is used"
|
||||
.format(os.path.basename(__file__), w, os.path.basename(WHITENING[w])))
|
||||
whiten_dir = os.path.join(get_data_root(), 'whiten')
|
||||
whiten.load_state_dict(model_zoo.load_url(WHITENING[w], model_dir=whiten_dir))
|
||||
else:
|
||||
print(">> {}: for '{}' there is no whitening computed, random weights are used"
|
||||
.format(os.path.basename(__file__), w))
|
||||
else:
|
||||
whiten = None
|
||||
|
||||
# create meta information to be stored in the network
|
||||
meta = {
|
||||
'architecture' : architecture,
|
||||
'local_whitening' : local_whitening,
|
||||
'pooling' : pooling,
|
||||
'regional' : regional,
|
||||
'whitening' : whitening,
|
||||
'mean' : mean,
|
||||
'std' : std,
|
||||
'outputdim' : dim,
|
||||
}
|
||||
|
||||
# create a generic image retrieval network
|
||||
net = ImageRetrievalNet(features, lwhiten, pool, whiten, meta)
|
||||
|
||||
# initialize features with custom pretrained network if needed
|
||||
if pretrained and architecture in FEATURES:
|
||||
print(">> {}: for '{}' custom pretrained features '{}' are used"
|
||||
.format(os.path.basename(__file__), architecture, os.path.basename(FEATURES[architecture])))
|
||||
model_dir = os.path.join(get_data_root(), 'networks')
|
||||
net.features.load_state_dict(model_zoo.load_url(FEATURES[architecture], model_dir=model_dir))
|
||||
|
||||
return net
|
||||
|
||||
|
||||
def extract_vectors(net, images, image_size, transform, bbxs=None, ms=[1], msp=1, print_freq=10):
|
||||
# moving network to gpu and eval mode
|
||||
net.cuda()
|
||||
net.eval()
|
||||
|
||||
# creating dataset loader
|
||||
loader = torch.utils.data.DataLoader(
|
||||
ImagesFromList(root='', images=images, imsize=image_size, bbxs=bbxs, transform=transform),
|
||||
batch_size=1, shuffle=False, num_workers=8, pin_memory=True
|
||||
)
|
||||
|
||||
# extracting vectors
|
||||
with torch.no_grad():
|
||||
vecs = torch.zeros(net.meta['outputdim'], len(images))
|
||||
for i, input in enumerate(loader):
|
||||
input = input.cuda()
|
||||
|
||||
if len(ms) == 1 and ms[0] == 1:
|
||||
vecs[:, i] = extract_ss(net, input)
|
||||
else:
|
||||
vecs[:, i] = extract_ms(net, input, ms, msp)
|
||||
|
||||
if (i+1) % print_freq == 0 or (i+1) == len(images):
|
||||
print('\r>>>> {}/{} done...'.format((i+1), len(images)), end='')
|
||||
print('')
|
||||
|
||||
return vecs
|
||||
|
||||
def extract_ss(net, input):
|
||||
return net(input).cpu().data.squeeze()
|
||||
|
||||
def extract_ms(net, input, ms, msp):
|
||||
|
||||
v = torch.zeros(net.meta['outputdim'])
|
||||
|
||||
for s in ms:
|
||||
if s == 1:
|
||||
input_t = input.clone()
|
||||
else:
|
||||
input_t = nn.functional.interpolate(input, scale_factor=s, mode='bilinear', align_corners=False)
|
||||
v += net(input_t).pow(msp).cpu().data.squeeze()
|
||||
|
||||
v /= len(ms)
|
||||
v = v.pow(1./msp)
|
||||
v /= v.norm()
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def extract_regional_vectors(net, images, image_size, transform, bbxs=None, ms=[1], msp=1, print_freq=10):
|
||||
# moving network to gpu and eval mode
|
||||
net.cuda()
|
||||
net.eval()
|
||||
|
||||
# creating dataset loader
|
||||
loader = torch.utils.data.DataLoader(
|
||||
ImagesFromList(root='', images=images, imsize=image_size, bbxs=bbxs, transform=transform),
|
||||
batch_size=1, shuffle=False, num_workers=8, pin_memory=True
|
||||
)
|
||||
|
||||
# extracting vectors
|
||||
with torch.no_grad():
|
||||
vecs = []
|
||||
for i, input in enumerate(loader):
|
||||
input = input.cuda()
|
||||
|
||||
if len(ms) == 1:
|
||||
vecs.append(extract_ssr(net, input))
|
||||
else:
|
||||
# TODO: not implemented yet
|
||||
# vecs.append(extract_msr(net, input, ms, msp))
|
||||
raise NotImplementedError
|
||||
|
||||
if (i+1) % print_freq == 0 or (i+1) == len(images):
|
||||
print('\r>>>> {}/{} done...'.format((i+1), len(images)), end='')
|
||||
print('')
|
||||
|
||||
return vecs
|
||||
|
||||
def extract_ssr(net, input):
|
||||
return net.pool(net.features(input), aggregate=False).squeeze(0).squeeze(-1).squeeze(-1).permute(1,0).cpu().data
|
||||
|
||||
|
||||
def extract_local_vectors(net, images, image_size, transform, bbxs=None, ms=[1], msp=1, print_freq=10):
|
||||
# moving network to gpu and eval mode
|
||||
net.cuda()
|
||||
net.eval()
|
||||
|
||||
# creating dataset loader
|
||||
loader = torch.utils.data.DataLoader(
|
||||
ImagesFromList(root='', images=images, imsize=image_size, bbxs=bbxs, transform=transform),
|
||||
batch_size=1, shuffle=False, num_workers=8, pin_memory=True
|
||||
)
|
||||
|
||||
# extracting vectors
|
||||
with torch.no_grad():
|
||||
vecs = []
|
||||
for i, input in enumerate(loader):
|
||||
input = input.cuda()
|
||||
|
||||
if len(ms) == 1:
|
||||
vecs.append(extract_ssl(net, input))
|
||||
else:
|
||||
# TODO: not implemented yet
|
||||
# vecs.append(extract_msl(net, input, ms, msp))
|
||||
raise NotImplementedError
|
||||
|
||||
if (i+1) % print_freq == 0 or (i+1) == len(images):
|
||||
print('\r>>>> {}/{} done...'.format((i+1), len(images)), end='')
|
||||
print('')
|
||||
|
||||
return vecs
|
||||
|
||||
def extract_ssl(net, input):
|
||||
return net.norm(net.features(input)).squeeze(0).view(net.meta['outputdim'], -1).cpu().data
|
0
cirtorch/utils/__init__.py
Executable file
0
cirtorch/utils/__init__.py
Executable file
154
cirtorch/utils/download.py
Executable file
154
cirtorch/utils/download.py
Executable file
@ -0,0 +1,154 @@
|
||||
import os
|
||||
|
||||
def download_test(data_dir):
|
||||
"""
|
||||
DOWNLOAD_TEST Checks, and, if required, downloads the necessary datasets for the testing.
|
||||
|
||||
download_test(DATA_ROOT) checks if the data necessary for running the example script exist.
|
||||
If not it downloads it in the folder structure:
|
||||
DATA_ROOT/test/oxford5k/ : folder with Oxford images and ground truth file
|
||||
DATA_ROOT/test/paris6k/ : folder with Paris images and ground truth file
|
||||
DATA_ROOT/test/roxford5k/ : folder with Oxford images and revisited ground truth file
|
||||
DATA_ROOT/test/rparis6k/ : folder with Paris images and revisited ground truth file
|
||||
"""
|
||||
|
||||
# Create data folder if it does not exist
|
||||
if not os.path.isdir(data_dir):
|
||||
os.mkdir(data_dir)
|
||||
|
||||
# Create datasets folder if it does not exist
|
||||
datasets_dir = os.path.join(data_dir, 'test')
|
||||
print('***************', os.path.exists(datasets_dir))
|
||||
#print(not os.path.isdir(datasets_dir))
|
||||
if not os.path.exists(datasets_dir):
|
||||
os.mkdir(datasets_dir)
|
||||
|
||||
# Download datasets folders test/DATASETNAME/
|
||||
datasets = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k']
|
||||
for di in range(len(datasets)):
|
||||
dataset = datasets[di]
|
||||
|
||||
if dataset == 'oxford5k':
|
||||
src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings'
|
||||
dl_files = ['oxbuild_images.tgz']
|
||||
elif dataset == 'paris6k':
|
||||
src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings'
|
||||
dl_files = ['paris_1.tgz', 'paris_2.tgz']
|
||||
elif dataset == 'roxford5k':
|
||||
src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings'
|
||||
dl_files = ['oxbuild_images.tgz']
|
||||
elif dataset == 'rparis6k':
|
||||
src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings'
|
||||
dl_files = ['paris_1.tgz', 'paris_2.tgz']
|
||||
else:
|
||||
raise ValueError('Unknown dataset: {}!'.format(dataset))
|
||||
|
||||
dst_dir = os.path.join(datasets_dir, dataset, 'jpg')
|
||||
print('%%%%%%%%%%%%%%%%',dst_dir, dataset)
|
||||
if not os.path.exists(dst_dir):
|
||||
# for oxford and paris download images
|
||||
if dataset == 'oxford5k' or dataset == 'paris6k':
|
||||
print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir))
|
||||
os.makedirs(dst_dir)
|
||||
for dli in range(len(dl_files)):
|
||||
dl_file = dl_files[dli]
|
||||
src_file = os.path.join(src_dir, dl_file)
|
||||
dst_file = os.path.join(dst_dir, dl_file)
|
||||
print('>> Downloading dataset {} archive {}...'.format(dataset, dl_file))
|
||||
os.system('wget {} -O {}'.format(src_file, dst_file))
|
||||
print('>> Extracting dataset {} archive {}...'.format(dataset, dl_file))
|
||||
# create tmp folder
|
||||
dst_dir_tmp = os.path.join(dst_dir, 'tmp')
|
||||
os.system('mkdir {}'.format(dst_dir_tmp))
|
||||
# extract in tmp folder
|
||||
os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir_tmp))
|
||||
# remove all (possible) subfolders by moving only files in dst_dir
|
||||
os.system('find {} -type f -exec mv -i {{}} {} \\;'.format(dst_dir_tmp, dst_dir))
|
||||
# remove tmp folder
|
||||
os.system('rm -rf {}'.format(dst_dir_tmp))
|
||||
print('>> Extracted, deleting dataset {} archive {}...'.format(dataset, dl_file))
|
||||
os.system('rm {}'.format(dst_file))
|
||||
|
||||
# for roxford and rparis just make sym links
|
||||
elif dataset == 'roxford5k' or dataset == 'rparis6k':
|
||||
print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir))
|
||||
dataset_old = dataset[1:]
|
||||
dst_dir_old = os.path.join(datasets_dir, dataset_old, 'jpg')
|
||||
os.mkdir(os.path.join(datasets_dir, dataset))
|
||||
os.system('ln -s {} {}'.format(dst_dir_old, dst_dir))
|
||||
print('>> Created symbolic link from {} jpg to {} jpg'.format(dataset_old, dataset))
|
||||
|
||||
|
||||
gnd_src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'test', dataset)
|
||||
gnd_dst_dir = os.path.join(datasets_dir, dataset)
|
||||
gnd_dl_file = 'gnd_{}.pkl'.format(dataset)
|
||||
gnd_src_file = os.path.join(gnd_src_dir, gnd_dl_file)
|
||||
gnd_dst_file = os.path.join(gnd_dst_dir, gnd_dl_file)
|
||||
if not os.path.exists(gnd_dst_file):
|
||||
print('>> Downloading dataset {} ground truth file...'.format(dataset))
|
||||
os.system('wget {} -O {}'.format(gnd_src_file, gnd_dst_file))
|
||||
|
||||
|
||||
def download_train(data_dir):
|
||||
"""
|
||||
DOWNLOAD_TRAIN Checks, and, if required, downloads the necessary datasets for the training.
|
||||
|
||||
download_train(DATA_ROOT) checks if the data necessary for running the example script exist.
|
||||
If not it downloads it in the folder structure:
|
||||
DATA_ROOT/train/retrieval-SfM-120k/ : folder with rsfm120k images and db files
|
||||
DATA_ROOT/train/retrieval-SfM-30k/ : folder with rsfm30k images and db files
|
||||
"""
|
||||
|
||||
# Create data folder if it does not exist
|
||||
if not os.path.isdir(data_dir):
|
||||
os.mkdir(data_dir)
|
||||
|
||||
# Create datasets folder if it does not exist
|
||||
datasets_dir = os.path.join(data_dir, 'train')
|
||||
if not os.path.isdir(datasets_dir):
|
||||
os.mkdir(datasets_dir)
|
||||
|
||||
# Download folder train/retrieval-SfM-120k/
|
||||
src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'ims')
|
||||
dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims')
|
||||
dl_file = 'ims.tar.gz'
|
||||
if not os.path.isdir(dst_dir):
|
||||
src_file = os.path.join(src_dir, dl_file)
|
||||
dst_file = os.path.join(dst_dir, dl_file)
|
||||
print('>> Image directory does not exist. Creating: {}'.format(dst_dir))
|
||||
os.makedirs(dst_dir)
|
||||
print('>> Downloading ims.tar.gz...')
|
||||
os.system('wget {} -O {}'.format(src_file, dst_file))
|
||||
print('>> Extracting {}...'.format(dst_file))
|
||||
os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir))
|
||||
print('>> Extracted, deleting {}...'.format(dst_file))
|
||||
os.system('rm {}'.format(dst_file))
|
||||
|
||||
# Create symlink for train/retrieval-SfM-30k/
|
||||
dst_dir_old = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims')
|
||||
dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-30k', 'ims')
|
||||
if not os.path.exists(dst_dir):
|
||||
os.makedirs(os.path.join(datasets_dir, 'retrieval-SfM-30k'))
|
||||
os.system('ln -s {} {}'.format(dst_dir_old, dst_dir))
|
||||
print('>> Created symbolic link from retrieval-SfM-120k/ims to retrieval-SfM-30k/ims')
|
||||
|
||||
# Download db files
|
||||
src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'dbs')
|
||||
datasets = ['retrieval-SfM-120k', 'retrieval-SfM-30k']
|
||||
for dataset in datasets:
|
||||
dst_dir = os.path.join(datasets_dir, dataset)
|
||||
if dataset == 'retrieval-SfM-120k':
|
||||
dl_files = ['{}.pkl'.format(dataset), '{}-whiten.pkl'.format(dataset)]
|
||||
elif dataset == 'retrieval-SfM-30k':
|
||||
dl_files = ['{}-whiten.pkl'.format(dataset)]
|
||||
|
||||
if not os.path.isdir(dst_dir):
|
||||
print('>> Dataset directory does not exist. Creating: {}'.format(dst_dir))
|
||||
os.mkdir(dst_dir)
|
||||
|
||||
for i in range(len(dl_files)):
|
||||
src_file = os.path.join(src_dir, dl_files[i])
|
||||
dst_file = os.path.join(dst_dir, dl_files[i])
|
||||
if not os.path.isfile(dst_file):
|
||||
print('>> DB file {} does not exist. Downloading...'.format(dl_files[i]))
|
||||
os.system('wget {} -O {}'.format(src_file, dst_file))
|
152
cirtorch/utils/download_win.py
Executable file
152
cirtorch/utils/download_win.py
Executable file
@ -0,0 +1,152 @@
|
||||
import os
|
||||
|
||||
def download_test(data_dir):
|
||||
"""
|
||||
DOWNLOAD_TEST Checks, and, if required, downloads the necessary datasets for the testing.
|
||||
|
||||
download_test(DATA_ROOT) checks if the data necessary for running the example script exist.
|
||||
If not it downloads it in the folder structure:
|
||||
DATA_ROOT/test/oxford5k/ : folder with Oxford images and ground truth file
|
||||
DATA_ROOT/test/paris6k/ : folder with Paris images and ground truth file
|
||||
DATA_ROOT/test/roxford5k/ : folder with Oxford images and revisited ground truth file
|
||||
DATA_ROOT/test/rparis6k/ : folder with Paris images and revisited ground truth file
|
||||
"""
|
||||
|
||||
# Create data folder if it does not exist
|
||||
if not os.path.isdir(data_dir):
|
||||
os.mkdir(data_dir)
|
||||
|
||||
# Create datasets folder if it does not exist
|
||||
datasets_dir = os.path.join(data_dir, 'test')
|
||||
if not os.path.isdir(datasets_dir):
|
||||
os.mkdir(datasets_dir)
|
||||
|
||||
# Download datasets folders test/DATASETNAME/
|
||||
datasets = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k']
|
||||
for di in range(len(datasets)):
|
||||
dataset = datasets[di]
|
||||
|
||||
if dataset == 'oxford5k':
|
||||
src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings'
|
||||
dl_files = ['oxbuild_images.tgz']
|
||||
elif dataset == 'paris6k':
|
||||
src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings'
|
||||
dl_files = ['paris_1.tgz', 'paris_2.tgz']
|
||||
elif dataset == 'roxford5k':
|
||||
src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings'
|
||||
dl_files = ['oxbuild_images.tgz']
|
||||
elif dataset == 'rparis6k':
|
||||
src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings'
|
||||
dl_files = ['paris_1.tgz', 'paris_2.tgz']
|
||||
else:
|
||||
raise ValueError('Unknown dataset: {}!'.format(dataset))
|
||||
|
||||
dst_dir = os.path.join(datasets_dir, dataset, 'jpg')
|
||||
if not os.path.isdir(dst_dir):
|
||||
|
||||
# for oxford and paris download images
|
||||
if dataset == 'oxford5k' or dataset == 'paris6k':
|
||||
print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir))
|
||||
os.makedirs(dst_dir)
|
||||
for dli in range(len(dl_files)):
|
||||
dl_file = dl_files[dli]
|
||||
src_file = os.path.join(src_dir, dl_file)
|
||||
dst_file = os.path.join(dst_dir, dl_file)
|
||||
print('>> Downloading dataset {} archive {}...'.format(dataset, dl_file))
|
||||
os.system('wget {} -O {}'.format(src_file, dst_file))
|
||||
print('>> Extracting dataset {} archive {}...'.format(dataset, dl_file))
|
||||
# create tmp folder
|
||||
dst_dir_tmp = os.path.join(dst_dir, 'tmp')
|
||||
os.system('mkdir {}'.format(dst_dir_tmp))
|
||||
# extract in tmp folder
|
||||
os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir_tmp))
|
||||
# remove all (possible) subfolders by moving only files in dst_dir
|
||||
os.system('find {} -type f -exec mv -i {{}} {} \\;'.format(dst_dir_tmp, dst_dir))
|
||||
# remove tmp folder
|
||||
os.system('rd {}'.format(dst_dir_tmp))
|
||||
print('>> Extracted, deleting dataset {} archive {}...'.format(dataset, dl_file))
|
||||
os.system('del {}'.format(dst_file))
|
||||
|
||||
# for roxford and rparis just make sym links
|
||||
elif dataset == 'roxford5k' or dataset == 'rparis6k':
|
||||
print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir))
|
||||
dataset_old = dataset[1:]
|
||||
dst_dir_old = os.path.join(datasets_dir, dataset_old, 'jpg')
|
||||
os.mkdir(os.path.join(datasets_dir, dataset))
|
||||
os.system('cmd /c mklink /d {} {}'.format(dst_dir_old, dst_dir))
|
||||
print('>> Created symbolic link from {} jpg to {} jpg'.format(dataset_old, dataset))
|
||||
|
||||
|
||||
gnd_src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'test', dataset)
|
||||
gnd_dst_dir = os.path.join(datasets_dir, dataset)
|
||||
gnd_dl_file = 'gnd_{}.pkl'.format(dataset)
|
||||
gnd_src_file = os.path.join(gnd_src_dir, gnd_dl_file)
|
||||
gnd_dst_file = os.path.join(gnd_dst_dir, gnd_dl_file)
|
||||
if not os.path.exists(gnd_dst_file):
|
||||
print('>> Downloading dataset {} ground truth file...'.format(dataset))
|
||||
os.system('wget {} -O {}'.format(gnd_src_file, gnd_dst_file))
|
||||
|
||||
|
||||
def download_train(data_dir):
|
||||
"""
|
||||
DOWNLOAD_TRAIN Checks, and, if required, downloads the necessary datasets for the training.
|
||||
|
||||
download_train(DATA_ROOT) checks if the data necessary for running the example script exist.
|
||||
If not it downloads it in the folder structure:
|
||||
DATA_ROOT/train/retrieval-SfM-120k/ : folder with rsfm120k images and db files
|
||||
DATA_ROOT/train/retrieval-SfM-30k/ : folder with rsfm30k images and db files
|
||||
"""
|
||||
|
||||
# Create data folder if it does not exist
|
||||
if not os.path.isdir(data_dir):
|
||||
os.mkdir(data_dir)
|
||||
print(data_dir)
|
||||
# Create datasets folder if it does not exist
|
||||
datasets_dir = os.path.join(data_dir, 'train')
|
||||
if not os.path.isdir(datasets_dir):
|
||||
os.mkdir(datasets_dir)
|
||||
|
||||
# Download folder train/retrieval-SfM-120k/
|
||||
src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'ims')
|
||||
dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims')
|
||||
dl_file = 'ims.tar.gz'
|
||||
if not os.path.isdir(dst_dir):
|
||||
src_file = os.path.join(src_dir, dl_file)
|
||||
dst_file = os.path.join(dst_dir, dl_file)
|
||||
print('>> Image directory does not exist. Creating: {}'.format(dst_dir))
|
||||
os.makedirs(dst_dir)
|
||||
print('>> Downloading ims.tar.gz...')
|
||||
# os.system('wget {} -O {}'.format(src_file, dst_file))
|
||||
print('>> Extracting {}...'.format(dst_file))
|
||||
os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir))
|
||||
print('>> Extracted, deleting {}...'.format(dst_file))
|
||||
os.system('del {}'.format(dst_file))
|
||||
|
||||
# Create symlink for train/retrieval-SfM-30k/
|
||||
dst_dir_old = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims')
|
||||
dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-30k', 'ims')
|
||||
if not os.path.isdir(dst_dir):
|
||||
os.makedirs(os.path.join(datasets_dir, 'retrieval-SfM-30k','ims'))
|
||||
os.system('mklink {} {}'.format(dst_dir_old, dst_dir))
|
||||
print('>> Created symbolic link from retrieval-SfM-120k/ims to retrieval-SfM-30k/ims')
|
||||
|
||||
# Download db files
|
||||
src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'dbs')
|
||||
datasets = ['retrieval-SfM-120k', 'retrieval-SfM-30k']
|
||||
for dataset in datasets:
|
||||
dst_dir = os.path.join(datasets_dir, dataset)
|
||||
if dataset == 'retrieval-SfM-120k':
|
||||
dl_files = ['{}.pkl'.format(dataset), '{}-whiten.pkl'.format(dataset)]
|
||||
elif dataset == 'retrieval-SfM-30k':
|
||||
dl_files = ['{}-whiten.pkl'.format(dataset)]
|
||||
|
||||
if not os.path.isdir(dst_dir):
|
||||
print('>> Dataset directory does not exist. Creating: {}'.format(dst_dir))
|
||||
os.mkdir(dst_dir)
|
||||
|
||||
for i in range(len(dl_files)):
|
||||
src_file = os.path.join(src_dir, dl_files[i])
|
||||
dst_file = os.path.join(dst_dir, dl_files[i])
|
||||
if not os.path.isfile(dst_file):
|
||||
print('>> DB file {} does not exist. Downloading...'.format(dl_files[i]))
|
||||
os.system('wget {} -O {}'.format(src_file, dst_file))
|
149
cirtorch/utils/evaluate.py
Executable file
149
cirtorch/utils/evaluate.py
Executable file
@ -0,0 +1,149 @@
|
||||
import numpy as np
|
||||
|
||||
def compute_ap(ranks, nres):
|
||||
"""
|
||||
Computes average precision for given ranked indexes.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
ranks : zerro-based ranks of positive images
|
||||
nres : number of positive images
|
||||
|
||||
Returns
|
||||
-------
|
||||
ap : average precision
|
||||
"""
|
||||
|
||||
# number of images ranked by the system
|
||||
nimgranks = len(ranks)
|
||||
|
||||
# accumulate trapezoids in PR-plot
|
||||
ap = 0
|
||||
|
||||
recall_step = 1. / nres
|
||||
|
||||
for j in np.arange(nimgranks):
|
||||
rank = ranks[j]
|
||||
|
||||
if rank == 0:
|
||||
precision_0 = 1.
|
||||
else:
|
||||
precision_0 = float(j) / rank
|
||||
|
||||
precision_1 = float(j + 1) / (rank + 1)
|
||||
|
||||
ap += (precision_0 + precision_1) * recall_step / 2.
|
||||
|
||||
return ap
|
||||
|
||||
def compute_map(ranks, gnd, kappas=[]):
|
||||
"""
|
||||
Computes the mAP for a given set of returned results.
|
||||
|
||||
Usage:
|
||||
map = compute_map (ranks, gnd)
|
||||
computes mean average precsion (map) only
|
||||
|
||||
map, aps, pr, prs = compute_map (ranks, gnd, kappas)
|
||||
computes mean average precision (map), average precision (aps) for each query
|
||||
computes mean precision at kappas (pr), precision at kappas (prs) for each query
|
||||
|
||||
Notes:
|
||||
1) ranks starts from 0, ranks.shape = db_size X #queries
|
||||
2) The junk results (e.g., the query itself) should be declared in the gnd stuct array
|
||||
3) If there are no positive images for some query, that query is excluded from the evaluation
|
||||
"""
|
||||
|
||||
map = 0.
|
||||
nq = len(gnd) # number of queries
|
||||
aps = np.zeros(nq)
|
||||
pr = np.zeros(len(kappas))
|
||||
prs = np.zeros((nq, len(kappas)))
|
||||
nempty = 0
|
||||
|
||||
for i in np.arange(nq):
|
||||
qgnd = np.array(gnd[i]['ok'])
|
||||
|
||||
# no positive images, skip from the average
|
||||
if qgnd.shape[0] == 0:
|
||||
aps[i] = float('nan')
|
||||
prs[i, :] = float('nan')
|
||||
nempty += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
qgndj = np.array(gnd[i]['junk'])
|
||||
except:
|
||||
qgndj = np.empty(0)
|
||||
|
||||
# sorted positions of positive and junk images (0 based)
|
||||
pos = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgnd)]
|
||||
junk = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgndj)]
|
||||
|
||||
k = 0;
|
||||
ij = 0;
|
||||
if len(junk):
|
||||
# decrease positions of positives based on the number of
|
||||
# junk images appearing before them
|
||||
ip = 0
|
||||
while (ip < len(pos)):
|
||||
while (ij < len(junk) and pos[ip] > junk[ij]):
|
||||
k += 1
|
||||
ij += 1
|
||||
pos[ip] = pos[ip] - k
|
||||
ip += 1
|
||||
|
||||
# compute ap
|
||||
ap = compute_ap(pos, len(qgnd))
|
||||
map = map + ap
|
||||
aps[i] = ap
|
||||
|
||||
# compute precision @ k
|
||||
pos += 1 # get it to 1-based
|
||||
for j in np.arange(len(kappas)):
|
||||
kq = min(max(pos), kappas[j]);
|
||||
prs[i, j] = (pos <= kq).sum() / kq
|
||||
pr = pr + prs[i, :]
|
||||
|
||||
map = map / (nq - nempty)
|
||||
pr = pr / (nq - nempty)
|
||||
|
||||
return map, aps, pr, prs
|
||||
|
||||
|
||||
def compute_map_and_print(dataset, ranks, gnd, kappas=[1, 5, 10]):
|
||||
|
||||
# old evaluation protocol
|
||||
if dataset.startswith('oxford5k') or dataset.startswith('paris6k'):
|
||||
map, aps, _, _ = compute_map(ranks, gnd)
|
||||
print('>> {}: mAP {:.2f}'.format(dataset, np.around(map*100, decimals=2)))
|
||||
|
||||
# new evaluation protocol
|
||||
elif dataset.startswith('roxford5k') or dataset.startswith('rparis6k'):
|
||||
|
||||
gnd_t = []
|
||||
for i in range(len(gnd)):
|
||||
g = {}
|
||||
g['ok'] = np.concatenate([gnd[i]['easy']])
|
||||
g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['hard']])
|
||||
gnd_t.append(g)
|
||||
mapE, apsE, mprE, prsE = compute_map(ranks, gnd_t, kappas)
|
||||
|
||||
gnd_t = []
|
||||
for i in range(len(gnd)):
|
||||
g = {}
|
||||
g['ok'] = np.concatenate([gnd[i]['easy'], gnd[i]['hard']])
|
||||
g['junk'] = np.concatenate([gnd[i]['junk']])
|
||||
gnd_t.append(g)
|
||||
mapM, apsM, mprM, prsM = compute_map(ranks, gnd_t, kappas)
|
||||
|
||||
gnd_t = []
|
||||
for i in range(len(gnd)):
|
||||
g = {}
|
||||
g['ok'] = np.concatenate([gnd[i]['hard']])
|
||||
g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['easy']])
|
||||
gnd_t.append(g)
|
||||
mapH, apsH, mprH, prsH = compute_map(ranks, gnd_t, kappas)
|
||||
|
||||
print('>> {}: mAP E: {}, M: {}, H: {}'.format(dataset, np.around(mapE*100, decimals=2), np.around(mapM*100, decimals=2), np.around(mapH*100, decimals=2)))
|
||||
print('>> {}: mP@k{} E: {}, M: {}, H: {}'.format(dataset, kappas, np.around(mprE*100, decimals=2), np.around(mprM*100, decimals=2), np.around(mprH*100, decimals=2)))
|
34
cirtorch/utils/general.py
Executable file
34
cirtorch/utils/general.py
Executable file
@ -0,0 +1,34 @@
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
def get_root():
|
||||
return os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))
|
||||
|
||||
|
||||
def get_data_root():
|
||||
return os.path.join(get_root(), 'data')
|
||||
|
||||
|
||||
def htime(c):
|
||||
c = round(c)
|
||||
|
||||
days = c // 86400
|
||||
hours = c // 3600 % 24
|
||||
minutes = c // 60 % 60
|
||||
seconds = c % 60
|
||||
|
||||
if days > 0:
|
||||
return '{:d}d {:d}h {:d}m {:d}s'.format(days, hours, minutes, seconds)
|
||||
if hours > 0:
|
||||
return '{:d}h {:d}m {:d}s'.format(hours, minutes, seconds)
|
||||
if minutes > 0:
|
||||
return '{:d}m {:d}s'.format(minutes, seconds)
|
||||
return '{:d}s'.format(seconds)
|
||||
|
||||
|
||||
def sha256_hash(filename, block_size=65536, length=8):
|
||||
sha256 = hashlib.sha256()
|
||||
with open(filename, 'rb') as f:
|
||||
for block in iter(lambda: f.read(block_size), b''):
|
||||
sha256.update(block)
|
||||
return sha256.hexdigest()[:length-1]
|
65
cirtorch/utils/whiten.py
Executable file
65
cirtorch/utils/whiten.py
Executable file
@ -0,0 +1,65 @@
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
def whitenapply(X, m, P, dimensions=None):
|
||||
|
||||
if not dimensions:
|
||||
dimensions = P.shape[0]
|
||||
|
||||
X = np.dot(P[:dimensions, :], X-m)
|
||||
X = X / (np.linalg.norm(X, ord=2, axis=0, keepdims=True) + 1e-6)
|
||||
|
||||
return X
|
||||
|
||||
def pcawhitenlearn(X):
|
||||
|
||||
N = X.shape[1]
|
||||
|
||||
# Learning PCA w/o annotations
|
||||
m = X.mean(axis=1, keepdims=True)
|
||||
Xc = X - m
|
||||
Xcov = np.dot(Xc, Xc.T)
|
||||
Xcov = (Xcov + Xcov.T) / (2*N)
|
||||
eigval, eigvec = np.linalg.eig(Xcov)
|
||||
order = eigval.argsort()[::-1]
|
||||
eigval = eigval[order]
|
||||
eigvec = eigvec[:, order]
|
||||
|
||||
P = np.dot(np.linalg.inv(np.sqrt(np.diag(eigval))), eigvec.T)
|
||||
|
||||
return m, P
|
||||
|
||||
def whitenlearn(X, qidxs, pidxs):
|
||||
|
||||
# Learning Lw w annotations
|
||||
m = X[:, qidxs].mean(axis=1, keepdims=True)
|
||||
df = X[:, qidxs] - X[:, pidxs]
|
||||
S = np.dot(df, df.T) / df.shape[1]
|
||||
P = np.linalg.inv(cholesky(S))
|
||||
df = np.dot(P, X-m)
|
||||
D = np.dot(df, df.T)
|
||||
eigval, eigvec = np.linalg.eig(D)
|
||||
order = eigval.argsort()[::-1]
|
||||
eigval = eigval[order]
|
||||
eigvec = eigvec[:, order]
|
||||
|
||||
P = np.dot(eigvec.T, P)
|
||||
|
||||
return m, P
|
||||
|
||||
def cholesky(S):
|
||||
# Cholesky decomposition
|
||||
# with adding a small value on the diagonal
|
||||
# until matrix is positive definite
|
||||
alpha = 0
|
||||
while 1:
|
||||
try:
|
||||
L = np.linalg.cholesky(S + alpha*np.eye(*S.shape))
|
||||
return L
|
||||
except:
|
||||
if alpha == 0:
|
||||
alpha = 1e-10
|
||||
else:
|
||||
alpha *= 10
|
||||
print(">>>> {}::cholesky: Matrix is not positive definite, adding {:.0e} on the diagonal"
|
||||
.format(os.path.basename(__file__), alpha))
|
96
ieemoo-ai-search.py
Executable file
96
ieemoo-ai-search.py
Executable file
@ -0,0 +1,96 @@
|
||||
import sys
|
||||
import argparse
|
||||
#from utils.retrieval_index import EvaluteMap
|
||||
from utils.tools import EvaluteMap
|
||||
from utils.retrieval_feature import AntiFraudFeatureDataset
|
||||
from utils.monitor import Moniting
|
||||
from utils.updateObs import *
|
||||
from utils.config import cfg
|
||||
from utils.tools import createNet
|
||||
from flask import request,Flask
|
||||
from utils.forsegmentation import analysis
|
||||
from gevent.pywsgi import WSGIServer
|
||||
import os, base64, stat, shutil, json, time
|
||||
sys.path.append('RAFT')
|
||||
sys.path.append('RAFT/core')
|
||||
sys.path.append('RAFT/core/utils')
|
||||
from RAFT.analysis_video import *
|
||||
import logging.config
|
||||
from skywalking import agent, config
|
||||
from threading import Thread
|
||||
|
||||
SW_SERVER = os.environ.get('SW_AGENT_COLLECTOR_BACKEND_SERVICES')
|
||||
SW_SERVICE_NAME = os.environ.get('SW_AGENT_NAME')
|
||||
if SW_SERVER and SW_SERVICE_NAME:
|
||||
config.init() #采集服务的地址,给自己的服务起个名称
|
||||
#config.init(collector="123.60.56.51:11800", service='ieemoo-ai-search') #采集服务的地址,给自己的服务起个名称
|
||||
agent.start()
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
#parser.add_argument('--model', default='../module/ieemoo-ai-search/model/now/raft-things.pth',help="restore checkpoint")
|
||||
parser.add_argument('--model', default='../module/ieemoo-ai-search/model/now/raft-small.pth',help="restore checkpoint")
|
||||
#parser.add_argument('--small', action='store_true', help='use small model')
|
||||
parser.add_argument('--small', type=bool, default=True, help='use small model')
|
||||
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
|
||||
opt, unknown = parser.parse_known_args()
|
||||
|
||||
'''
|
||||
status 状态码
|
||||
00: 视频未解析成功(视频截取错误)
|
||||
01: 未纳入监查列表
|
||||
02: 未检测出商品
|
||||
03: 异常输出
|
||||
04: 正确识别
|
||||
'''
|
||||
status = ['00', '01', '02', '03', '04']
|
||||
net, transform, ms = createNet()
|
||||
raft_model = raft_init_model(opt)
|
||||
def setup_logging(path):
|
||||
if os.path.exists(path):
|
||||
with open(path, 'r') as f:
|
||||
config = json.load(f)
|
||||
logging.config.dictConfig(config)
|
||||
logger = logging.getLogger("root")
|
||||
return logger
|
||||
|
||||
logger = setup_logging('utils/logging.json')
|
||||
@app.route('/search', methods=['POST'])
|
||||
def search():
|
||||
pre_status = False
|
||||
try:
|
||||
video_name = request.form.get('video_name')
|
||||
logger.info('get video '+video_name)
|
||||
ocr_file_path = os.sep.join([cfg.Ocrtxt, video_name.split('.')[0]+'.txt'])
|
||||
video_extra_info = request.form.get('video_extra_info')
|
||||
if not video_extra_info is None:
|
||||
with open(ocr_file_path, 'w') as f:
|
||||
f.write(video_extra_info)
|
||||
video_data = request.files['video']
|
||||
videoPath = os.sep.join([cfg.VIDEOPATH, video_name])
|
||||
video_data.save(videoPath)
|
||||
uuid_barcode = video_name.split('.')[0]
|
||||
barcode_name = uuid_barcode.split('_')[-1]
|
||||
if Moniting(barcode_name).search() == 'nomatch':
|
||||
state = status[1]
|
||||
analysis_video(raft_model, videoPath, '',uuid_barcode,None,net=net, transform=transform,ms=ms, match=False)
|
||||
else:
|
||||
state = analysis_video(raft_model, videoPath, '',uuid_barcode,None,net=net, transform=transform,ms=ms, match=True)
|
||||
result = uuid_barcode+'_'+state #参数修改返回结果
|
||||
except Exception as e:
|
||||
logger.warning(e) #异常返回00
|
||||
thread = Thread(target=AddObs, kwargs={'file_path':videoPath, 'status':status[3]})
|
||||
thread.start()
|
||||
return uuid_barcode+'_'+status[3] #参数修改返回00
|
||||
thread = Thread(target=AddObs, kwargs={'file_path':videoPath, 'status':state})
|
||||
thread.start()
|
||||
logger.info(result)
|
||||
print('result >>>>> {}'.format(result))
|
||||
return result
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host='0.0.0.0', port=8085)
|
||||
|
4
init.sh
Normal file
4
init.sh
Normal file
@ -0,0 +1,4 @@
|
||||
/opt/miniconda3/bin/conda activate ieemoo
|
||||
|
||||
/opt/miniconda3/envs/ieemoo/bin/pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
246
interface.py
Normal file
246
interface.py
Normal file
@ -0,0 +1,246 @@
|
||||
# coding=utf-8
|
||||
# /usr/bin/env pythpn
|
||||
|
||||
import torch
|
||||
from torch.utils.model_zoo import load_url
|
||||
from torchvision import transforms
|
||||
from cirtorch.datasets.testdataset import configdataset
|
||||
from cirtorch.utils.download import download_train, download_test
|
||||
from cirtorch.utils.evaluate import compute_map_and_print
|
||||
from cirtorch.utils.general import get_data_root, htime
|
||||
from cirtorch.networks.imageretrievalnet_cpu import init_network, extract_vectors
|
||||
from cirtorch.datasets.datahelpers import imresize
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from flask import Flask, request
|
||||
import json, io, sys, time, traceback, argparse, logging, subprocess, pickle, os, yaml,shutil
|
||||
import cv2
|
||||
import pdb
|
||||
from werkzeug.utils import cached_property
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from multiprocessing import Pool
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return ""
|
||||
|
||||
@app.route("/images/*", methods=['GET','POST'])
|
||||
def accInsurance():
|
||||
"""
|
||||
flask request process handle
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
if request.method == 'GET':
|
||||
return json.dumps({'err': 1, 'msg': 'POST only'})
|
||||
else:
|
||||
app.logger.debug("print headers------")
|
||||
headers = request.headers
|
||||
headers_info = ""
|
||||
for k, v in headers.items():
|
||||
headers_info += "{}: {}\n".format(k, v)
|
||||
app.logger.debug(headers_info)
|
||||
|
||||
app.logger.debug("print forms------")
|
||||
forms_info = ""
|
||||
for k, v in request.form.items():
|
||||
forms_info += "{}: {}\n".format(k, v)
|
||||
app.logger.debug(forms_info)
|
||||
|
||||
if 'query' not in request.files:
|
||||
return json.dumps({'err': 2, 'msg': 'query image is empty'})
|
||||
|
||||
if 'sig' not in request.form:
|
||||
return json.dumps({'err': 3, 'msg': 'sig is empty'})
|
||||
|
||||
if 'q_no' not in request.form:
|
||||
return json.dumps({'err': 4, 'msg': 'no is empty'})
|
||||
|
||||
if 'q_did' not in request.form:
|
||||
return json.dumps({'err': 5, 'msg': 'did is empty'})
|
||||
|
||||
if 'q_id' not in request.form:
|
||||
return json.dumps({'err': 6, 'msg': 'id is empty'})
|
||||
|
||||
if 'type' not in request.form:
|
||||
return json.dumps({'err': 7, 'msg': 'type is empty'})
|
||||
|
||||
img_name = request.files['query'].filename
|
||||
img_bytes = request.files['query'].read()
|
||||
img = request.files['query']
|
||||
sig = request.form['sig']
|
||||
q_no = request.form['q_no']
|
||||
q_did = request.form['q_did']
|
||||
q_id = request.form['q_id']
|
||||
type = request.form['type']
|
||||
|
||||
if str(type) not in types:
|
||||
return json.dumps({'err': 8, 'msg': 'type is not exist'})
|
||||
|
||||
if img_bytes is None:
|
||||
return json.dumps({'err': 10, 'msg': 'img is none'})
|
||||
|
||||
results = imageRetrieval().retrieval_online_v0(img, q_no, q_did, q_id, type)
|
||||
|
||||
data = dict()
|
||||
data['query'] = img_name
|
||||
data['sig'] = sig
|
||||
data['type'] = type
|
||||
data['q_no'] = q_no
|
||||
data['q_did'] = q_did
|
||||
data['q_id'] = q_id
|
||||
data['results'] = results
|
||||
|
||||
return json.dumps({'err': 0, 'msg': 'success', 'data': data})
|
||||
|
||||
except:
|
||||
app.logger.exception(sys.exc_info())
|
||||
return json.dumps({'err': 9, 'msg': 'unknow error'})
|
||||
|
||||
|
||||
class imageRetrieval():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def cosine_dist(self, x, y):
|
||||
return 100 * float(np.dot(x, y))/(np.dot(x,x)*np.dot(y,y)) ** 0.5
|
||||
|
||||
def inference(self, img):
|
||||
try:
|
||||
input = Image.open(img).convert("RGB")
|
||||
input = imresize(input, 224)
|
||||
input = transforms(input).unsqueeze()
|
||||
with torch.no_grad():
|
||||
vect = net(input)
|
||||
return vect
|
||||
except:
|
||||
print('cannot indentify error')
|
||||
|
||||
def retrieval_online_v0(self, img, q_no, q_did, q_id, type):
|
||||
# load model
|
||||
query_vect = self.inference(img)
|
||||
query_vect = list(query_vect.detach().numpy().T[0])
|
||||
|
||||
lsh = lsh_dict[str(type)]
|
||||
response = lsh.query(query_vect, num_results=1, distance_func = "cosine")
|
||||
|
||||
try:
|
||||
similar_path = response[0][0][1]
|
||||
score = np.rint(self.cosine_dist(list(query_vect), list(response[0][0][0])))
|
||||
rank_list = similar_path.split("/")
|
||||
s_id, s_did, s_no = rank_list[-1].split("_")[-1].split(".")[0], rank_list[-1].split("_")[0], rank_list[-2]
|
||||
results = [{"s_no": s_no, "r_did": s_did, "s_id": s_id, "score": score}]
|
||||
except:
|
||||
results = []
|
||||
|
||||
img_path = "/{}/{}_{}".format(q_no, q_did, q_id)
|
||||
lsh.index(query_vect, extra_data=img_path)
|
||||
lsh_dict[str(type)] = lsh
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
||||
class initModel():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def init_model(self, network, model_dir, types):
|
||||
print(">> Loading network:\n>>>> '{}'".format(network))
|
||||
# state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks'))
|
||||
state = torch.load(network)
|
||||
# parsing net params from meta
|
||||
# architecture, pooling, mean, std required
|
||||
# the rest has default values, in case that is doesnt exist
|
||||
net_params = {}
|
||||
net_params['architecture'] = state['meta']['architecture']
|
||||
net_params['pooling'] = state['meta']['pooling']
|
||||
net_params['local_whitening'] = state['meta'].get('local_whitening', False)
|
||||
net_params['regional'] = state['meta'].get('regional', False)
|
||||
net_params['whitening'] = state['meta'].get('whitening', False)
|
||||
net_params['mean'] = state['meta']['mean']
|
||||
net_params['std'] = state['meta']['std']
|
||||
net_params['pretrained'] = False
|
||||
# network initialization
|
||||
net = init_network(net_params)
|
||||
net.load_state_dict(state['state_dict'])
|
||||
print(">>>> loaded network: ")
|
||||
print(net.meta_repr())
|
||||
# moving network to gpu and eval mode
|
||||
# net.cuda()
|
||||
net.eval()
|
||||
|
||||
# set up the transform
|
||||
normalize = transforms.Normalize(
|
||||
mean=net.meta['mean'],
|
||||
std=net.meta['std']
|
||||
)
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
normalize
|
||||
])
|
||||
|
||||
lsh_dict = dict()
|
||||
for type in types:
|
||||
with open(os.path.join(model_dir, "dataset_index_{}.pkl".format(str(type))), "rb") as f:
|
||||
lsh = pickle.load(f)
|
||||
|
||||
lsh_dict[str(type)] = lsh
|
||||
|
||||
return net, lsh_dict, transforms
|
||||
|
||||
def init(self):
|
||||
with open('config.yaml', 'r') as f:
|
||||
conf = yaml.load(f)
|
||||
|
||||
app.logger.info(conf)
|
||||
host = conf['website']['host']
|
||||
port = conf['website']['port']
|
||||
network = conf['model']['network']
|
||||
model_dir = conf['model']['model_dir']
|
||||
types = conf['model']['type']
|
||||
|
||||
net, lsh_dict, transforms = self.init_model(network, model_dir, types)
|
||||
|
||||
return host, port, net, lsh_dict, transforms, model_dir, types
|
||||
|
||||
|
||||
def job():
|
||||
for type in types:
|
||||
with open(os.path.join(model_dir, "dataset_index_{}_v0.pkl".format(str(type))), "wb") as f:
|
||||
pickle.dump(lsh_dict[str(type)], f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
start app from ssh
|
||||
"""
|
||||
scheduler = BackgroundScheduler()
|
||||
host, port, net, lsh_dict, transforms, model_dir, types = initModel().init()
|
||||
app.run(host=host, port=port, debug=True)
|
||||
print("start server {}:{}".format(host, port))
|
||||
|
||||
scheduler.add_job(job, 'interval', seconds= 30)
|
||||
scheduler.start()
|
||||
|
||||
else:
|
||||
"""
|
||||
start app from gunicorn
|
||||
"""
|
||||
scheduler = BackgroundScheduler()
|
||||
gunicorn_logger = logging.getLogger("gunicorn.error")
|
||||
app.logger.handlers = gunicorn_logger.handlers
|
||||
app.logger.setLevel(gunicorn_logger.level)
|
||||
|
||||
host, port, net, lsh_dict, transforms, model_dir, types = initModel().init()
|
||||
app.logger.info("started from gunicorn...")
|
||||
|
||||
scheduler.add_job(job, 'interval', seconds=30)
|
||||
scheduler.start()
|
||||
|
||||
|
||||
|
16
main.py
Normal file
16
main.py
Normal file
@ -0,0 +1,16 @@
|
||||
# This is a sample Python script.
|
||||
|
||||
# Press Shift+F10 to execute it or replace it with your code.
|
||||
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
|
||||
|
||||
|
||||
def print_hi(name):
|
||||
# Use a breakpoint in the code line below to debug your script.
|
||||
print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint.
|
||||
|
||||
|
||||
# Press the green button in the gutter to run the script.
|
||||
if __name__ == '__main__':
|
||||
print_hi('PyCharm')
|
||||
|
||||
# See PyCharm help at https://www.jetbrains.com/help/pycharm/
|
33
nts/README.md
Normal file
33
nts/README.md
Normal file
@ -0,0 +1,33 @@
|
||||
# NTS-Net
|
||||
|
||||
This is a PyTorch implementation of the ECCV2018 paper "Learning to Navigate for Fine-grained Classification" (Ze Yang, Tiange Luo, Dong Wang, Zhiqiang Hu, Jun Gao, Liwei Wang).
|
||||
|
||||
## Requirements
|
||||
- python 3+
|
||||
- pytorch 0.4+
|
||||
- numpy
|
||||
- datetime
|
||||
|
||||
## Datasets
|
||||
Download the [CUB-200-2011](http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz) datasets and put it in the root directory named **CUB_200_2011**, You can also try other fine-grained datasets.
|
||||
|
||||
## Train the model
|
||||
If you want to train the NTS-Net, just run ``python train.py``. You may need to change the configurations in ``config.py``. The parameter ``PROPOSAL_NUM`` is ``M`` in the original paper and the parameter ``CAT_NUM`` is ``K`` in the original paper. During training, the log file and checkpoint file will be saved in ``save_dir`` directory. You can change the parameter ``resume`` to choose the checkpoint model to resume.
|
||||
|
||||
## Test the model
|
||||
If you want to test the NTS-Net, just run ``python test.py``. You need to specify the ``test_model`` in ``config.py`` to choose the checkpoint model for testing.
|
||||
|
||||
## Model
|
||||
We also provide the checkpoint model trained by ourselves, you can download it from [here](https://drive.google.com/file/d/1F-eKqPRjlya5GH2HwTlLKNSPEUaxCu9H/view?usp=sharing). If you test on our provided model, you will get a 87.6% test accuracy.
|
||||
|
||||
## Reference
|
||||
If you are interested in our work and want to cite it, please acknowledge the following paper:
|
||||
|
||||
```
|
||||
@inproceedings{Yang2018Learning,
|
||||
author = {Yang, Ze and Luo, Tiange and Wang, Dong and Hu, Zhiqiang and Gao, Jun and Wang, Liwei},
|
||||
title = {Learning to Navigate for Fine-grained Classification},
|
||||
booktitle = {ECCV},
|
||||
year = {2018}
|
||||
}
|
||||
```
|
10
nts/config.py
Normal file
10
nts/config.py
Normal file
@ -0,0 +1,10 @@
|
||||
BATCH_SIZE = 16
|
||||
PROPOSAL_NUM = 6
|
||||
CAT_NUM = 4
|
||||
INPUT_SIZE = (448, 448) # (w, h)
|
||||
LR = 0.001
|
||||
WD = 1e-4
|
||||
SAVE_FREQ = 1
|
||||
resume = ''
|
||||
test_model = 'model.ckpt'
|
||||
save_dir = '/data_4t/yangz/models/'
|
100
nts/core/anchors.py
Normal file
100
nts/core/anchors.py
Normal file
@ -0,0 +1,100 @@
|
||||
import numpy as np
|
||||
from config import INPUT_SIZE
|
||||
|
||||
_default_anchors_setting = (
|
||||
dict(layer='p3', stride=32, size=48, scale=[2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]),
|
||||
dict(layer='p4', stride=64, size=96, scale=[2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]),
|
||||
dict(layer='p5', stride=128, size=192, scale=[1, 2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]),
|
||||
)
|
||||
|
||||
|
||||
def generate_default_anchor_maps(anchors_setting=None, input_shape=INPUT_SIZE):
|
||||
"""
|
||||
generate default anchor
|
||||
|
||||
:param anchors_setting: all informations of anchors
|
||||
:param input_shape: shape of input images, e.g. (h, w)
|
||||
:return: center_anchors: # anchors * 4 (oy, ox, h, w)
|
||||
edge_anchors: # anchors * 4 (y0, x0, y1, x1)
|
||||
anchor_area: # anchors * 1 (area)
|
||||
"""
|
||||
if anchors_setting is None:
|
||||
anchors_setting = _default_anchors_setting
|
||||
|
||||
center_anchors = np.zeros((0, 4), dtype=np.float32)
|
||||
edge_anchors = np.zeros((0, 4), dtype=np.float32)
|
||||
anchor_areas = np.zeros((0,), dtype=np.float32)
|
||||
input_shape = np.array(input_shape, dtype=int)
|
||||
|
||||
for anchor_info in anchors_setting:
|
||||
|
||||
stride = anchor_info['stride']
|
||||
size = anchor_info['size']
|
||||
scales = anchor_info['scale']
|
||||
aspect_ratios = anchor_info['aspect_ratio']
|
||||
|
||||
output_map_shape = np.ceil(input_shape.astype(np.float32) / stride)
|
||||
output_map_shape = output_map_shape.astype(np.int)
|
||||
output_shape = tuple(output_map_shape) + (4,)
|
||||
ostart = stride / 2.
|
||||
oy = np.arange(ostart, ostart + stride * output_shape[0], stride)
|
||||
oy = oy.reshape(output_shape[0], 1)
|
||||
ox = np.arange(ostart, ostart + stride * output_shape[1], stride)
|
||||
ox = ox.reshape(1, output_shape[1])
|
||||
center_anchor_map_template = np.zeros(output_shape, dtype=np.float32)
|
||||
center_anchor_map_template[:, :, 0] = oy
|
||||
center_anchor_map_template[:, :, 1] = ox
|
||||
for scale in scales:
|
||||
for aspect_ratio in aspect_ratios:
|
||||
center_anchor_map = center_anchor_map_template.copy()
|
||||
center_anchor_map[:, :, 2] = size * scale / float(aspect_ratio) ** 0.5
|
||||
center_anchor_map[:, :, 3] = size * scale * float(aspect_ratio) ** 0.5
|
||||
|
||||
edge_anchor_map = np.concatenate((center_anchor_map[..., :2] - center_anchor_map[..., 2:4] / 2.,
|
||||
center_anchor_map[..., :2] + center_anchor_map[..., 2:4] / 2.),
|
||||
axis=-1)
|
||||
anchor_area_map = center_anchor_map[..., 2] * center_anchor_map[..., 3]
|
||||
center_anchors = np.concatenate((center_anchors, center_anchor_map.reshape(-1, 4)))
|
||||
edge_anchors = np.concatenate((edge_anchors, edge_anchor_map.reshape(-1, 4)))
|
||||
anchor_areas = np.concatenate((anchor_areas, anchor_area_map.reshape(-1)))
|
||||
|
||||
return center_anchors, edge_anchors, anchor_areas
|
||||
|
||||
|
||||
def hard_nms(cdds, topn=10, iou_thresh=0.25):
|
||||
if not (type(cdds).__module__ == 'numpy' and len(cdds.shape) == 2 and cdds.shape[1] >= 5):
|
||||
raise TypeError('edge_box_map should be N * 5+ ndarray')
|
||||
|
||||
cdds = cdds.copy()
|
||||
indices = np.argsort(cdds[:, 0])
|
||||
cdds = cdds[indices]
|
||||
cdd_results = []
|
||||
|
||||
res = cdds
|
||||
|
||||
while res.any():
|
||||
cdd = res[-1]
|
||||
cdd_results.append(cdd)
|
||||
if len(cdd_results) == topn:
|
||||
return np.array(cdd_results)
|
||||
res = res[:-1]
|
||||
|
||||
start_max = np.maximum(res[:, 1:3], cdd[1:3])
|
||||
end_min = np.minimum(res[:, 3:5], cdd[3:5])
|
||||
lengths = end_min - start_max
|
||||
intersec_map = lengths[:, 0] * lengths[:, 1]
|
||||
intersec_map[np.logical_or(lengths[:, 0] < 0, lengths[:, 1] < 0)] = 0
|
||||
iou_map_cur = intersec_map / ((res[:, 3] - res[:, 1]) * (res[:, 4] - res[:, 2]) + (cdd[3] - cdd[1]) * (
|
||||
cdd[4] - cdd[2]) - intersec_map)
|
||||
res = res[iou_map_cur < iou_thresh]
|
||||
|
||||
return np.array(cdd_results)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
a = hard_nms(np.array([
|
||||
[0.4, 1, 10, 12, 20],
|
||||
[0.5, 1, 11, 11, 20],
|
||||
[0.55, 20, 30, 40, 50]
|
||||
]), topn=100, iou_thresh=0.4)
|
||||
print(a)
|
77
nts/core/dataset.py
Normal file
77
nts/core/dataset.py
Normal file
@ -0,0 +1,77 @@
|
||||
import numpy as np
|
||||
import scipy.misc
|
||||
import os
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from config import INPUT_SIZE
|
||||
|
||||
|
||||
class CUB():
|
||||
def __init__(self, root, is_train=True, data_len=None):
|
||||
self.root = root
|
||||
self.is_train = is_train
|
||||
img_txt_file = open(os.path.join(self.root, 'images.txt'))
|
||||
label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))
|
||||
train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))
|
||||
img_name_list = []
|
||||
for line in img_txt_file:
|
||||
img_name_list.append(line[:-1].split(' ')[-1])
|
||||
label_list = []
|
||||
for line in label_txt_file:
|
||||
label_list.append(int(line[:-1].split(' ')[-1]) - 1)
|
||||
train_test_list = []
|
||||
for line in train_val_file:
|
||||
train_test_list.append(int(line[:-1].split(' ')[-1]))
|
||||
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
||||
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
||||
if self.is_train:
|
||||
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||
train_file_list[:data_len]]
|
||||
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
||||
if not self.is_train:
|
||||
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||
test_file_list[:data_len]]
|
||||
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.is_train:
|
||||
img, target = self.train_img[index], self.train_label[index]
|
||||
if len(img.shape) == 2:
|
||||
img = np.stack([img] * 3, 2)
|
||||
img = Image.fromarray(img, mode='RGB')
|
||||
img = transforms.Resize((600, 600), Image.BILINEAR)(img)
|
||||
img = transforms.RandomCrop(INPUT_SIZE)(img)
|
||||
img = transforms.RandomHorizontalFlip()(img)
|
||||
img = transforms.ToTensor()(img)
|
||||
img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img)
|
||||
|
||||
else:
|
||||
img, target = self.test_img[index], self.test_label[index]
|
||||
if len(img.shape) == 2:
|
||||
img = np.stack([img] * 3, 2)
|
||||
img = Image.fromarray(img, mode='RGB')
|
||||
img = transforms.Resize((600, 600), Image.BILINEAR)(img)
|
||||
img = transforms.CenterCrop(INPUT_SIZE)(img)
|
||||
img = transforms.ToTensor()(img)
|
||||
img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
if self.is_train:
|
||||
return len(self.train_label)
|
||||
else:
|
||||
return len(self.test_label)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset = CUB(root='./CUB_200_2011')
|
||||
print(len(dataset.train_img))
|
||||
print(len(dataset.train_label))
|
||||
for data in dataset:
|
||||
print(data[0].size(), data[1])
|
||||
dataset = CUB(root='./CUB_200_2011', is_train=False)
|
||||
print(len(dataset.test_img))
|
||||
print(len(dataset.test_label))
|
||||
for data in dataset:
|
||||
print(data[0].size(), data[1])
|
96
nts/core/model.py
Normal file
96
nts/core/model.py
Normal file
@ -0,0 +1,96 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
from core import resnet
|
||||
import numpy as np
|
||||
from core.anchors import generate_default_anchor_maps, hard_nms
|
||||
from config import CAT_NUM, PROPOSAL_NUM
|
||||
|
||||
|
||||
class ProposalNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(ProposalNet, self).__init__()
|
||||
self.down1 = nn.Conv2d(2048, 128, 3, 1, 1)
|
||||
self.down2 = nn.Conv2d(128, 128, 3, 2, 1)
|
||||
self.down3 = nn.Conv2d(128, 128, 3, 2, 1)
|
||||
self.ReLU = nn.ReLU()
|
||||
self.tidy1 = nn.Conv2d(128, 6, 1, 1, 0)
|
||||
self.tidy2 = nn.Conv2d(128, 6, 1, 1, 0)
|
||||
self.tidy3 = nn.Conv2d(128, 9, 1, 1, 0)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.size(0)
|
||||
d1 = self.ReLU(self.down1(x))
|
||||
d2 = self.ReLU(self.down2(d1))
|
||||
d3 = self.ReLU(self.down3(d2))
|
||||
t1 = self.tidy1(d1).view(batch_size, -1)
|
||||
t2 = self.tidy2(d2).view(batch_size, -1)
|
||||
t3 = self.tidy3(d3).view(batch_size, -1)
|
||||
return torch.cat((t1, t2, t3), dim=1)
|
||||
|
||||
|
||||
class attention_net(nn.Module):
|
||||
def __init__(self, topN=4):
|
||||
super(attention_net, self).__init__()
|
||||
self.pretrained_model = resnet.resnet50(pretrained=True)
|
||||
self.pretrained_model.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.pretrained_model.fc = nn.Linear(512 * 4, 200)
|
||||
self.proposal_net = ProposalNet()
|
||||
self.topN = topN
|
||||
self.concat_net = nn.Linear(2048 * (CAT_NUM + 1), 200)
|
||||
self.partcls_net = nn.Linear(512 * 4, 200)
|
||||
_, edge_anchors, _ = generate_default_anchor_maps()
|
||||
self.pad_side = 224
|
||||
self.edge_anchors = (edge_anchors + 224).astype(np.int)
|
||||
|
||||
def forward(self, x):
|
||||
resnet_out, rpn_feature, feature = self.pretrained_model(x)
|
||||
x_pad = F.pad(x, (self.pad_side, self.pad_side, self.pad_side, self.pad_side), mode='constant', value=0)
|
||||
batch = x.size(0)
|
||||
# we will reshape rpn to shape: batch * nb_anchor
|
||||
rpn_score = self.proposal_net(rpn_feature.detach())
|
||||
all_cdds = [
|
||||
np.concatenate((x.reshape(-1, 1), self.edge_anchors.copy(), np.arange(0, len(x)).reshape(-1, 1)), axis=1)
|
||||
for x in rpn_score.data.cpu().numpy()]
|
||||
top_n_cdds = [hard_nms(x, topn=self.topN, iou_thresh=0.25) for x in all_cdds]
|
||||
top_n_cdds = np.array(top_n_cdds)
|
||||
top_n_index = top_n_cdds[:, :, -1].astype(np.int)
|
||||
top_n_index = torch.from_numpy(top_n_index).cuda()
|
||||
top_n_prob = torch.gather(rpn_score, dim=1, index=top_n_index)
|
||||
part_imgs = torch.zeros([batch, self.topN, 3, 224, 224]).cuda()
|
||||
for i in range(batch):
|
||||
for j in range(self.topN):
|
||||
[y0, x0, y1, x1] = top_n_cdds[i][j, 1:5].astype(np.int)
|
||||
part_imgs[i:i + 1, j] = F.interpolate(x_pad[i:i + 1, :, y0:y1, x0:x1], size=(224, 224), mode='bilinear',
|
||||
align_corners=True)
|
||||
part_imgs = part_imgs.view(batch * self.topN, 3, 224, 224)
|
||||
_, _, part_features = self.pretrained_model(part_imgs.detach())
|
||||
part_feature = part_features.view(batch, self.topN, -1)
|
||||
part_feature = part_feature[:, :CAT_NUM, ...].contiguous()
|
||||
part_feature = part_feature.view(batch, -1)
|
||||
# concat_logits have the shape: B*200
|
||||
concat_out = torch.cat([part_feature, feature], dim=1)
|
||||
concat_logits = self.concat_net(concat_out)
|
||||
raw_logits = resnet_out
|
||||
# part_logits have the shape: B*N*200
|
||||
part_logits = self.partcls_net(part_features).view(batch, self.topN, -1)
|
||||
return [raw_logits, concat_logits, part_logits, top_n_index, top_n_prob]
|
||||
|
||||
|
||||
def list_loss(logits, targets):
|
||||
temp = F.log_softmax(logits, -1)
|
||||
loss = [-temp[i][targets[i].item()] for i in range(logits.size(0))]
|
||||
return torch.stack(loss)
|
||||
|
||||
|
||||
def ranking_loss(score, targets, proposal_num=PROPOSAL_NUM):
|
||||
loss = Variable(torch.zeros(1).cuda())
|
||||
batch_size = score.size(0)
|
||||
for i in range(proposal_num):
|
||||
targets_p = (targets > targets[:, i].unsqueeze(1)).type(torch.cuda.FloatTensor)
|
||||
pivot = score[:, i].unsqueeze(1)
|
||||
loss_p = (1 - pivot + score) * targets_p
|
||||
loss_p = torch.sum(F.relu(loss_p))
|
||||
loss += loss_p
|
||||
return loss / batch_size
|
212
nts/core/resnet.py
Normal file
212
nts/core/resnet.py
Normal file
@ -0,0 +1,212 @@
|
||||
import torch.nn as nn
|
||||
import math
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
'resnet152']
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"3x3 convolution with padding"
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, layers, num_classes=1000):
|
||||
self.inplanes = 64
|
||||
super(ResNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
self.avgpool = nn.AvgPool2d(7)
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
feature1 = x
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = nn.Dropout(p=0.5)(x)
|
||||
feature2 = x
|
||||
x = self.fc(x)
|
||||
|
||||
return x, feature1, feature2
|
||||
|
||||
|
||||
def resnet18(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-18 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
|
||||
return model
|
||||
|
||||
|
||||
def resnet34(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-34 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
|
||||
return model
|
||||
|
||||
|
||||
def resnet50(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-50 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
|
||||
return model
|
||||
|
||||
|
||||
def resnet101(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-101 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
|
||||
return model
|
||||
|
||||
|
||||
def resnet152(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-152 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
|
||||
return model
|
104
nts/core/utils.py
Normal file
104
nts/core/utils.py
Normal file
@ -0,0 +1,104 @@
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
|
||||
_, term_width = os.popen('stty size', 'r').read().split()
|
||||
term_width = int(term_width)
|
||||
|
||||
TOTAL_BAR_LENGTH = 40.
|
||||
last_time = time.time()
|
||||
begin_time = last_time
|
||||
|
||||
|
||||
def progress_bar(current, total, msg=None):
|
||||
global last_time, begin_time
|
||||
if current == 0:
|
||||
begin_time = time.time() # Reset for new bar.
|
||||
|
||||
cur_len = int(TOTAL_BAR_LENGTH * current / total)
|
||||
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
|
||||
|
||||
sys.stdout.write(' [')
|
||||
for i in range(cur_len):
|
||||
sys.stdout.write('=')
|
||||
sys.stdout.write('>')
|
||||
for i in range(rest_len):
|
||||
sys.stdout.write('.')
|
||||
sys.stdout.write(']')
|
||||
|
||||
cur_time = time.time()
|
||||
step_time = cur_time - last_time
|
||||
last_time = cur_time
|
||||
tot_time = cur_time - begin_time
|
||||
|
||||
L = []
|
||||
L.append(' Step: %s' % format_time(step_time))
|
||||
L.append(' | Tot: %s' % format_time(tot_time))
|
||||
if msg:
|
||||
L.append(' | ' + msg)
|
||||
|
||||
msg = ''.join(L)
|
||||
sys.stdout.write(msg)
|
||||
for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3):
|
||||
sys.stdout.write(' ')
|
||||
|
||||
# Go back to the center of the bar.
|
||||
for i in range(term_width - int(TOTAL_BAR_LENGTH / 2)):
|
||||
sys.stdout.write('\b')
|
||||
sys.stdout.write(' %d/%d ' % (current + 1, total))
|
||||
|
||||
if current < total - 1:
|
||||
sys.stdout.write('\r')
|
||||
else:
|
||||
sys.stdout.write('\n')
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def format_time(seconds):
|
||||
days = int(seconds / 3600 / 24)
|
||||
seconds = seconds - days * 3600 * 24
|
||||
hours = int(seconds / 3600)
|
||||
seconds = seconds - hours * 3600
|
||||
minutes = int(seconds / 60)
|
||||
seconds = seconds - minutes * 60
|
||||
secondsf = int(seconds)
|
||||
seconds = seconds - secondsf
|
||||
millis = int(seconds * 1000)
|
||||
|
||||
f = ''
|
||||
i = 1
|
||||
if days > 0:
|
||||
f += str(days) + 'D'
|
||||
i += 1
|
||||
if hours > 0 and i <= 2:
|
||||
f += str(hours) + 'h'
|
||||
i += 1
|
||||
if minutes > 0 and i <= 2:
|
||||
f += str(minutes) + 'm'
|
||||
i += 1
|
||||
if secondsf > 0 and i <= 2:
|
||||
f += str(secondsf) + 's'
|
||||
i += 1
|
||||
if millis > 0 and i <= 2:
|
||||
f += str(millis) + 'ms'
|
||||
i += 1
|
||||
if f == '':
|
||||
f = '0ms'
|
||||
return f
|
||||
|
||||
|
||||
def init_log(output_dir):
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(message)s',
|
||||
datefmt='%Y%m%d-%H:%M:%S',
|
||||
filename=os.path.join(output_dir, 'log.log'),
|
||||
filemode='w')
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(logging.INFO)
|
||||
logging.getLogger('').addHandler(console)
|
||||
return logging
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
74
nts/test.py
Normal file
74
nts/test.py
Normal file
@ -0,0 +1,74 @@
|
||||
import os
|
||||
from torch.autograd import Variable
|
||||
import torch.utils.data
|
||||
from torch.nn import DataParallel
|
||||
from config import BATCH_SIZE, PROPOSAL_NUM, test_model
|
||||
from core import model, dataset
|
||||
from core.utils import progress_bar
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
|
||||
if not test_model:
|
||||
raise NameError('please set the test_model file to choose the checkpoint!')
|
||||
# read dataset
|
||||
trainset = dataset.CUB(root='./CUB_200_2011', is_train=True, data_len=None)
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
|
||||
shuffle=True, num_workers=8, drop_last=False)
|
||||
testset = dataset.CUB(root='./CUB_200_2011', is_train=False, data_len=None)
|
||||
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
|
||||
shuffle=False, num_workers=8, drop_last=False)
|
||||
# define model
|
||||
net = model.attention_net(topN=PROPOSAL_NUM)
|
||||
ckpt = torch.load(test_model)
|
||||
net.load_state_dict(ckpt['net_state_dict'])
|
||||
net = net.cuda()
|
||||
net = DataParallel(net)
|
||||
creterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
# evaluate on train set
|
||||
train_loss = 0
|
||||
train_correct = 0
|
||||
total = 0
|
||||
net.eval()
|
||||
|
||||
for i, data in enumerate(trainloader):
|
||||
with torch.no_grad():
|
||||
img, label = data[0].cuda(), data[1].cuda()
|
||||
batch_size = img.size(0)
|
||||
_, concat_logits, _, _, _ = net(img)
|
||||
# calculate loss
|
||||
concat_loss = creterion(concat_logits, label)
|
||||
# calculate accuracy
|
||||
_, concat_predict = torch.max(concat_logits, 1)
|
||||
total += batch_size
|
||||
train_correct += torch.sum(concat_predict.data == label.data)
|
||||
train_loss += concat_loss.item() * batch_size
|
||||
progress_bar(i, len(trainloader), 'eval on train set')
|
||||
|
||||
train_acc = float(train_correct) / total
|
||||
train_loss = train_loss / total
|
||||
print('train set loss: {:.3f} and train set acc: {:.3f} total sample: {}'.format(train_loss, train_acc, total))
|
||||
|
||||
|
||||
# evaluate on test set
|
||||
test_loss = 0
|
||||
test_correct = 0
|
||||
total = 0
|
||||
for i, data in enumerate(testloader):
|
||||
with torch.no_grad():
|
||||
img, label = data[0].cuda(), data[1].cuda()
|
||||
batch_size = img.size(0)
|
||||
_, concat_logits, _, _, _ = net(img)
|
||||
# calculate loss
|
||||
concat_loss = creterion(concat_logits, label)
|
||||
# calculate accuracy
|
||||
_, concat_predict = torch.max(concat_logits, 1)
|
||||
total += batch_size
|
||||
test_correct += torch.sum(concat_predict.data == label.data)
|
||||
test_loss += concat_loss.item() * batch_size
|
||||
progress_bar(i, len(testloader), 'eval on test set')
|
||||
|
||||
test_acc = float(test_correct) / total
|
||||
test_loss = test_loss / total
|
||||
print('test set loss: {:.3f} and test set acc: {:.3f} total sample: {}'.format(test_loss, test_acc, total))
|
||||
|
||||
print('finishing testing')
|
152
nts/train.py
Normal file
152
nts/train.py
Normal file
@ -0,0 +1,152 @@
|
||||
import os
|
||||
import torch.utils.data
|
||||
from torch.nn import DataParallel
|
||||
from datetime import datetime
|
||||
from torch.optim.lr_scheduler import MultiStepLR
|
||||
from config import BATCH_SIZE, PROPOSAL_NUM, SAVE_FREQ, LR, WD, resume, save_dir
|
||||
from core import model, dataset
|
||||
from core.utils import init_log, progress_bar
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
|
||||
start_epoch = 1
|
||||
save_dir = os.path.join(save_dir, datetime.now().strftime('%Y%m%d_%H%M%S'))
|
||||
if os.path.exists(save_dir):
|
||||
raise NameError('model dir exists!')
|
||||
os.makedirs(save_dir)
|
||||
logging = init_log(save_dir)
|
||||
_print = logging.info
|
||||
|
||||
# read dataset
|
||||
trainset = dataset.CUB(root='./CUB_200_2011', is_train=True, data_len=None)
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
|
||||
shuffle=True, num_workers=8, drop_last=False)
|
||||
testset = dataset.CUB(root='./CUB_200_2011', is_train=False, data_len=None)
|
||||
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
|
||||
shuffle=False, num_workers=8, drop_last=False)
|
||||
# define model
|
||||
net = model.attention_net(topN=PROPOSAL_NUM)
|
||||
if resume:
|
||||
ckpt = torch.load(resume)
|
||||
net.load_state_dict(ckpt['net_state_dict'])
|
||||
start_epoch = ckpt['epoch'] + 1
|
||||
creterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
# define optimizers
|
||||
raw_parameters = list(net.pretrained_model.parameters())
|
||||
part_parameters = list(net.proposal_net.parameters())
|
||||
concat_parameters = list(net.concat_net.parameters())
|
||||
partcls_parameters = list(net.partcls_net.parameters())
|
||||
|
||||
raw_optimizer = torch.optim.SGD(raw_parameters, lr=LR, momentum=0.9, weight_decay=WD)
|
||||
concat_optimizer = torch.optim.SGD(concat_parameters, lr=LR, momentum=0.9, weight_decay=WD)
|
||||
part_optimizer = torch.optim.SGD(part_parameters, lr=LR, momentum=0.9, weight_decay=WD)
|
||||
partcls_optimizer = torch.optim.SGD(partcls_parameters, lr=LR, momentum=0.9, weight_decay=WD)
|
||||
schedulers = [MultiStepLR(raw_optimizer, milestones=[60, 100], gamma=0.1),
|
||||
MultiStepLR(concat_optimizer, milestones=[60, 100], gamma=0.1),
|
||||
MultiStepLR(part_optimizer, milestones=[60, 100], gamma=0.1),
|
||||
MultiStepLR(partcls_optimizer, milestones=[60, 100], gamma=0.1)]
|
||||
net = net.cuda()
|
||||
net = DataParallel(net)
|
||||
|
||||
for epoch in range(start_epoch, 500):
|
||||
for scheduler in schedulers:
|
||||
scheduler.step()
|
||||
|
||||
# begin training
|
||||
_print('--' * 50)
|
||||
net.train()
|
||||
for i, data in enumerate(trainloader):
|
||||
img, label = data[0].cuda(), data[1].cuda()
|
||||
batch_size = img.size(0)
|
||||
raw_optimizer.zero_grad()
|
||||
part_optimizer.zero_grad()
|
||||
concat_optimizer.zero_grad()
|
||||
partcls_optimizer.zero_grad()
|
||||
|
||||
raw_logits, concat_logits, part_logits, _, top_n_prob = net(img)
|
||||
part_loss = model.list_loss(part_logits.view(batch_size * PROPOSAL_NUM, -1),
|
||||
label.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1)).view(batch_size, PROPOSAL_NUM)
|
||||
raw_loss = creterion(raw_logits, label)
|
||||
concat_loss = creterion(concat_logits, label)
|
||||
rank_loss = model.ranking_loss(top_n_prob, part_loss)
|
||||
partcls_loss = creterion(part_logits.view(batch_size * PROPOSAL_NUM, -1),
|
||||
label.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1))
|
||||
|
||||
total_loss = raw_loss + rank_loss + concat_loss + partcls_loss
|
||||
total_loss.backward()
|
||||
raw_optimizer.step()
|
||||
part_optimizer.step()
|
||||
concat_optimizer.step()
|
||||
partcls_optimizer.step()
|
||||
progress_bar(i, len(trainloader), 'train')
|
||||
|
||||
if epoch % SAVE_FREQ == 0:
|
||||
train_loss = 0
|
||||
train_correct = 0
|
||||
total = 0
|
||||
net.eval()
|
||||
for i, data in enumerate(trainloader):
|
||||
with torch.no_grad():
|
||||
img, label = data[0].cuda(), data[1].cuda()
|
||||
batch_size = img.size(0)
|
||||
_, concat_logits, _, _, _ = net(img)
|
||||
# calculate loss
|
||||
concat_loss = creterion(concat_logits, label)
|
||||
# calculate accuracy
|
||||
_, concat_predict = torch.max(concat_logits, 1)
|
||||
total += batch_size
|
||||
train_correct += torch.sum(concat_predict.data == label.data)
|
||||
train_loss += concat_loss.item() * batch_size
|
||||
progress_bar(i, len(trainloader), 'eval train set')
|
||||
|
||||
train_acc = float(train_correct) / total
|
||||
train_loss = train_loss / total
|
||||
|
||||
_print(
|
||||
'epoch:{} - train loss: {:.3f} and train acc: {:.3f} total sample: {}'.format(
|
||||
epoch,
|
||||
train_loss,
|
||||
train_acc,
|
||||
total))
|
||||
|
||||
# evaluate on test set
|
||||
test_loss = 0
|
||||
test_correct = 0
|
||||
total = 0
|
||||
for i, data in enumerate(testloader):
|
||||
with torch.no_grad():
|
||||
img, label = data[0].cuda(), data[1].cuda()
|
||||
batch_size = img.size(0)
|
||||
_, concat_logits, _, _, _ = net(img)
|
||||
# calculate loss
|
||||
concat_loss = creterion(concat_logits, label)
|
||||
# calculate accuracy
|
||||
_, concat_predict = torch.max(concat_logits, 1)
|
||||
total += batch_size
|
||||
test_correct += torch.sum(concat_predict.data == label.data)
|
||||
test_loss += concat_loss.item() * batch_size
|
||||
progress_bar(i, len(testloader), 'eval test set')
|
||||
|
||||
test_acc = float(test_correct) / total
|
||||
test_loss = test_loss / total
|
||||
_print(
|
||||
'epoch:{} - test loss: {:.3f} and test acc: {:.3f} total sample: {}'.format(
|
||||
epoch,
|
||||
test_loss,
|
||||
test_acc,
|
||||
total))
|
||||
|
||||
# save model
|
||||
net_state_dict = net.module.state_dict()
|
||||
if not os.path.exists(save_dir):
|
||||
os.mkdir(save_dir)
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'train_loss': train_loss,
|
||||
'train_acc': train_acc,
|
||||
'test_loss': test_loss,
|
||||
'test_acc': test_acc,
|
||||
'net_state_dict': net_state_dict},
|
||||
os.path.join(save_dir, '%03d.ckpt' % epoch))
|
||||
|
||||
print('finishing training')
|
22
requirements.txt
Normal file
22
requirements.txt
Normal file
@ -0,0 +1,22 @@
|
||||
esdk_obs_python==3.21.8
|
||||
Flask==2.0.0
|
||||
gevent==21.1.2
|
||||
matplotlib==3.4.1
|
||||
numpy==1.20.2
|
||||
esdk-obs-python --trusted-host pypi.org
|
||||
opencv_python==4.5.5.64
|
||||
opencv-contrib-python==4.5.5.64
|
||||
Pillow==9.1.0
|
||||
scipy==1.6.2
|
||||
setuptools==49.6.0
|
||||
coremltools==5.2.0
|
||||
onnx==1.7.0
|
||||
pandas==1.2.4
|
||||
pycocotools==2.0.2
|
||||
PyYAML==6.0
|
||||
requests==2.25.1
|
||||
seaborn==0.11.1
|
||||
thop==0.0.31.post2005241907
|
||||
tqdm==4.60.0
|
||||
ml-collections==0.1.1
|
||||
apache-skywalking
|
112
sample_server.py
Executable file
112
sample_server.py
Executable file
@ -0,0 +1,112 @@
|
||||
import sys
|
||||
import argparse
|
||||
#from utils.retrieval_index import EvaluteMap
|
||||
from utils.tools import EvaluteMap
|
||||
from utils.retrieval_feature import AntiFraudFeatureDataset
|
||||
from utils.monitor import Moniting
|
||||
from utils.updateObs import *
|
||||
#from utils.decide import Decide
|
||||
from utils.config import cfg
|
||||
from utils.tools import createNet
|
||||
from flask import request,Flask, jsonify
|
||||
from utils.forsegmentation import analysis
|
||||
from gevent.pywsgi import WSGIServer
|
||||
import os, base64, stat, shutil
|
||||
import pdb
|
||||
sys.path.append('RAFT')
|
||||
sys.path.append('RAFT/core')
|
||||
sys.path.append('RAFT/core/utils')
|
||||
from RAFT.analysis_video import *
|
||||
import time
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'
|
||||
app = Flask(__name__)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', default='RAFT/models/raft-things.pth',help="restore checkpoint")
|
||||
parser.add_argument('--small', action='store_true', help='use small model')
|
||||
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
|
||||
opt, unknown = parser.parse_known_args()
|
||||
'''
|
||||
status 状态码
|
||||
00: 视频未解析成功(视频截取错误)
|
||||
01: 未纳入监查列表
|
||||
02: 未检测出商品
|
||||
03: 异常输出
|
||||
04: 正确识别
|
||||
'''
|
||||
status = ['00', '01', '02', '03', '04']
|
||||
net, transform, ms = createNet()
|
||||
raft_model = raft_init_model(opt)
|
||||
def get_video():
|
||||
url = "https://api.ieemoo.com/emoo-train/collection/getVideoCollectByTime.do"
|
||||
data = {"startTime":"2022-01-25", "endTime":"2022-01-26"}
|
||||
r = requests.post(url=url, data=data)
|
||||
videonames = []
|
||||
filename = cfg.SAVIDEOPATH
|
||||
for dictdata in r.json()['data']:
|
||||
urlpath = dictdata["videoPath"]
|
||||
videonames.append(urlpath)
|
||||
for urlname in videonames:
|
||||
videoname = os.path.basename(urlname)
|
||||
savepath = os.sep.join([filename, videoname])
|
||||
filepath, _ = urllib.request.urlretrieve(urlname, savepath, _progress)
|
||||
|
||||
def search(video_name):
|
||||
#get_video()
|
||||
T1 = time.time()
|
||||
pre_status = False
|
||||
try:
|
||||
video_path = os.sep.join([cfg.SAVIDEOPATH, video_name])
|
||||
uuid_barcode = video_name.split('.')[0]
|
||||
barcode_name = uuid_barcode.split('_')[-1]
|
||||
#pdb.set_trace()
|
||||
photo_nu = analysis_video(raft_model, video_path, cfg.SAMPLEIMGS, uuid_barcode)
|
||||
if not Moniting(barcode_name).search() == 'nomatch':
|
||||
if photo_nu == 0:
|
||||
deleteimg(uuid_barcode)
|
||||
return uuid_barcode+'_0.90_!'+status[0]+'_'+video_name
|
||||
#Addimg(uuid_barcode)
|
||||
feature_dict = AntiFraudFeatureDataset(uuid_barcode, cfg.SAMPLEIMGS, 'sample').extractFeature(net, transform, ms)
|
||||
res = EvaluteMap().match_images(feature_dict, barcode_name)
|
||||
if res<cfg.THRESHOLD: pre_status = status[2]
|
||||
else: pre_status = status[4]
|
||||
else:
|
||||
pre_status = status[1]
|
||||
res = '0.90'
|
||||
except:
|
||||
return uuid_barcode+'_0.90_!'+'_'+status[3]+'_'+video_name
|
||||
data = uuid_barcode+'_'+str(res)+'_!'
|
||||
print(data)
|
||||
if pre_status == '04':#去除异常与识别正确
|
||||
deleteimg(uuid_barcode)
|
||||
result = data+'_'+pre_status+'_'+video_name
|
||||
T2 = time.time()
|
||||
print('程序运行总时间:%s秒' % ((T2 - T1) ))
|
||||
print(result)
|
||||
return result
|
||||
|
||||
def match():
|
||||
n = 0
|
||||
total = len(os.listdir(cfg.SAVIDEOPATH))
|
||||
f = open('tmp.txt', 'a')
|
||||
for video_name in os.listdir(cfg.SAVIDEOPATH):
|
||||
result = search(video_name)
|
||||
score = result.split('!')[0].split('_')[-2]
|
||||
if float(score) >cfg.THRESHOLD:
|
||||
if not float(score) == 0.90:
|
||||
#print('video_name',video_name)
|
||||
f.write(result+'\n')
|
||||
n += 1
|
||||
else:
|
||||
total -= 1
|
||||
if not n == 0:
|
||||
print(n/total)
|
||||
f.close()
|
||||
|
||||
def deleteimg(uuid_barcode):
|
||||
for img_name in os.listdir(cfg.SAMPLEIMGS):
|
||||
if uuid_barcode in img_name:
|
||||
os.remove(os.sep.join([cfg.SAMPLEIMGS, img_name]))
|
||||
|
||||
if __name__ == '__main__':
|
||||
match()
|
85
server.py
Executable file
85
server.py
Executable file
@ -0,0 +1,85 @@
|
||||
import sys
|
||||
import argparse
|
||||
#from utils.retrieval_index import EvaluteMap
|
||||
from utils.tools import EvaluteMap
|
||||
from utils.retrieval_feature import AntiFraudFeatureDataset
|
||||
from utils.monitor import Moniting
|
||||
from utils.updateObs import *
|
||||
from utils.config import cfg
|
||||
from utils.tools import createNet
|
||||
from flask import request,Flask, jsonify
|
||||
from utils.forsegmentation import analysis
|
||||
from gevent.pywsgi import WSGIServer
|
||||
import os, base64, stat, shutil
|
||||
sys.path.append('RAFT')
|
||||
sys.path.append('RAFT/core')
|
||||
sys.path.append('RAFT/core/utils')
|
||||
from RAFT.analysis_video import *
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'
|
||||
app = Flask(__name__)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', default='RAFT/models/raft-things.pth',help="restore checkpoint")
|
||||
parser.add_argument('--small', action='store_true', help='use small model')
|
||||
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
|
||||
opt, unknown = parser.parse_known_args()
|
||||
'''
|
||||
status 状态码
|
||||
00: 视频未解析成功(视频截取错误)
|
||||
01: 未纳入监查列表
|
||||
02: 未检测出商品
|
||||
03: 异常输出
|
||||
04: 正确识别
|
||||
'''
|
||||
status = ['00', '01', '02', '03', '04']
|
||||
net, transform, ms = createNet()
|
||||
raft_model = raft_init_model(opt)
|
||||
@app.route('/search', methods=['POST'])
|
||||
def search():
|
||||
pre_status = False
|
||||
try:
|
||||
video_name = request.form.get('video_name')
|
||||
video_data = request.files['video']
|
||||
video_path = os.sep.join([cfg.VIDEOPATH, video_name])
|
||||
video_data.save(video_path)
|
||||
uuid_barcode = video_name.split('.')[0]
|
||||
barcode_name = uuid_barcode.split('_')[-1]
|
||||
photo_nu = analysis_video(raft_model, video_path, cfg.TEST_IMG_DIR, uuid_barcode)
|
||||
Addimg(uuid_barcode)
|
||||
if not Moniting(barcode_name).search() == 'nomatch':
|
||||
if photo_nu == 0:
|
||||
deleteimg(uuid_barcode)
|
||||
AddObs(video_path, status[0])
|
||||
return uuid_barcode+'_0.90_!_'+status[0]+'_'+video_name
|
||||
#Addimg(uuid_barcode)
|
||||
feature_dict = AntiFraudFeatureDataset(uuid_barcode).extractFeature(net, transform, ms)
|
||||
res = EvaluteMap().match_images(feature_dict, barcode_name)
|
||||
if res == 'nan':
|
||||
res = '0.90'
|
||||
pre_status = status[1]
|
||||
if res<cfg.THRESHOLD: pre_status = status[2]
|
||||
else: pre_status = status[4]
|
||||
else:
|
||||
pre_status = status[1]
|
||||
res = '0.90'
|
||||
except:
|
||||
AddObs(video_path, status[3])
|
||||
deleteimg(uuid_barcode)
|
||||
return uuid_barcode+'_0.90_!_'+status[3]+'_'+video_name
|
||||
data = uuid_barcode+'_'+str(res)+'_!'
|
||||
if pre_status == '04':
|
||||
deleteimg(uuid_barcode)
|
||||
AddObs(video_path, pre_status)
|
||||
print('result:',data)
|
||||
return data+'_'+pre_status+'_'+video_name
|
||||
|
||||
def deleteimg(uuid_barcode):
|
||||
for img_name in os.listdir(cfg.TEST_IMG_DIR):
|
||||
if uuid_barcode in img_name:
|
||||
os.remove(os.sep.join([cfg.TEST_IMG_DIR, img_name]))
|
||||
|
||||
if __name__ == '__main__':
|
||||
# http_server = WSGIServer(('192.168.1.142', 6001), app)
|
||||
# http_server.serve_forever()
|
||||
app.run()
|
73
updatefile.py
Normal file
73
updatefile.py
Normal file
@ -0,0 +1,73 @@
|
||||
import sys
|
||||
import argparse
|
||||
import requests
|
||||
from utils.tools import EvaluteMap
|
||||
from utils.retrieval_feature import AntiFraudFeatureDataset
|
||||
from utils.tools import createNet, rotate_bound
|
||||
from utils.monitor import Moniting
|
||||
from cirtorch.networks.imageretrievalnet import init_network, extract_vectors
|
||||
from utils.config import cfg
|
||||
from utils.updateObs import *
|
||||
import numpy as np
|
||||
import cv2
|
||||
import json
|
||||
import os, base64, stat, shutil
|
||||
import pdb
|
||||
import socket
|
||||
|
||||
sys.path.append('RAFT')
|
||||
sys.path.append('RAFT/core')
|
||||
sys.path.append('RAFT/core/utils')
|
||||
from RAFT.analysis_video import *
|
||||
|
||||
class update:
|
||||
def __init__(self):
|
||||
self.net, self.transform, self.ms = createNet()
|
||||
self.updateMonitor = Moniting()
|
||||
|
||||
def ImageProcess(self):
|
||||
dicts = {}
|
||||
for root, dirs, files in os.walk(self.img_dir):
|
||||
if len(dirs) == 0:
|
||||
for file in files:
|
||||
name = file.split('.')[0].split('_')[-1]
|
||||
if not name in dicts:
|
||||
dicts[name] = [os.sep.join([root, file])]
|
||||
else:
|
||||
dicts[name] = dicts[name] + [os.sep.join([root, file])]
|
||||
return dicts
|
||||
|
||||
def updateDataCenter(self):
|
||||
for name in os.listdir(cfg.IMG_DIR_TOTAL):
|
||||
img = cv2.imread(os.sep.join([cfg.IMG_DIR_TOTAL, name]))
|
||||
for an in cfg.ANGLES:
|
||||
image = rotate_bound(img, an)
|
||||
cv2.imwrite(os.sep.join([cfg.IMG_DIR_TOTAL, str(an)+'_'+name]), image)
|
||||
self.updateMonitor.update(self.net, self.transform, self.ms)
|
||||
for name in os.listdir(cfg.IMG_DIR_TOTAL):
|
||||
os.remove(os.sep.join([cfg.IMG_DIR_TOTAL, name]))
|
||||
|
||||
def disposaldata(self):
|
||||
with open(cfg.SAMPLE, 'r', encoding='utf-8') as jsonpath:
|
||||
loadDict = json.load(jsonpath)
|
||||
allsample = set(loadDict['monitor'])
|
||||
for imgname in os.listdir(cfg.TEST_IMG_DIR):
|
||||
imgbarcode = imgname.split('.')[0].split('_')[-1]
|
||||
if not os.path.isdir(cfg.DATA_POOLING):
|
||||
os.mkdir(cfg.DATA_POOLING)
|
||||
if not imgbarcode in allsample:
|
||||
# shutil.move(
|
||||
# os.sep.join([cfg.TEST_IMG_DIR, imgname]),
|
||||
# os.sep.join([cfg.DATA_POOLING, imgname]))
|
||||
os.remove(os.sep.join([cfg.TEST_IMG_DIR, imgname]))
|
||||
else:
|
||||
shutil.move(
|
||||
os.sep.join([cfg.TEST_IMG_DIR, imgname]),
|
||||
os.sep.join([cfg.SAMPLEDIR, imgname]))
|
||||
for vname in os.listdir(cfg.VIDEOPATH):
|
||||
os.remove(os.sep.join([cfg.VIDEOPATH, vname]))
|
||||
|
||||
if __name__ == '__main__':
|
||||
Update = update()
|
||||
Update.disposaldata()
|
||||
#Update.updateDataCenter()
|
19
upfile.py
Normal file
19
upfile.py
Normal file
@ -0,0 +1,19 @@
|
||||
from utils.config import cfg
|
||||
import os
|
||||
|
||||
obsClient = ObsClient(
|
||||
access_key_id='LHXJC7GIC2NNUUHHTNVI',
|
||||
secret_access_key='sVWvEItrFKWPp5DxeMvX8jLFU69iXPpzkjuMX3iM',
|
||||
server='https://obs.cn-east-3.myhuaweicloud.com'
|
||||
)
|
||||
bucketName = 'ieemoo-ai'
|
||||
|
||||
def addobs():
|
||||
for name in os.listdir(cfg.temp):
|
||||
status = name.split('_')[0]
|
||||
if objectkey.split('.')[-1] in ['avi','mp4']:
|
||||
objectkey = 'videos/'+time+'/'+status+'/'+name
|
||||
resp = obsClient.putFile(bucketName, objectkey, file_path)
|
||||
if __name__ == '__main__':
|
||||
addobs()
|
||||
|
BIN
utils/.DS_Store
vendored
Normal file
BIN
utils/.DS_Store
vendored
Normal file
Binary file not shown.
4
utils/celeryconfig.py
Normal file
4
utils/celeryconfig.py
Normal file
@ -0,0 +1,4 @@
|
||||
broker_url = 'redis://127.0.0.1:6379/0'
|
||||
result_backend = 'redis://127.0.0.1:6379/1'
|
||||
includ = ['ieemoo-ai-search.ieemoo-ai-search']
|
||||
main = 'ieemoo-ai-search.ieemoo-ai-search'
|
582
utils/classify.py
Normal file
582
utils/classify.py
Normal file
@ -0,0 +1,582 @@
|
||||
# coding=utf-8
|
||||
# /usr/bin/env pythpn
|
||||
|
||||
'''
|
||||
Author: yinhao
|
||||
Email: yinhao_x@163.com
|
||||
Wechat: xss_yinhao
|
||||
Github: http://github.com/yinhaoxs
|
||||
data: 2019-11-23 18:29
|
||||
desc:
|
||||
'''
|
||||
|
||||
import torch.nn as nn
|
||||
import math
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from torch import nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
import cv2
|
||||
import shutil
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
# config.py
|
||||
BATCH_SIZE = 16
|
||||
PROPOSAL_NUM = 6
|
||||
CAT_NUM = 4
|
||||
INPUT_SIZE = (448, 448) # (w, h)
|
||||
DROP_OUT = 0.5
|
||||
CLASS_NUM = 37
|
||||
|
||||
|
||||
# resnet.py
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"3x3 convolution with padding"
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, layers, num_classes=1000):
|
||||
self.inplanes = 64
|
||||
super(ResNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
self.avgpool = nn.AvgPool2d(7)
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
feature1 = x
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = nn.Dropout(p=0.5)(x)
|
||||
feature2 = x
|
||||
x = self.fc(x)
|
||||
|
||||
return x, feature1, feature2
|
||||
|
||||
|
||||
# model.py
|
||||
class ProposalNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(ProposalNet, self).__init__()
|
||||
self.down1 = nn.Conv2d(2048, 128, 3, 1, 1)
|
||||
self.down2 = nn.Conv2d(128, 128, 3, 2, 1)
|
||||
self.down3 = nn.Conv2d(128, 128, 3, 2, 1)
|
||||
self.ReLU = nn.ReLU()
|
||||
self.tidy1 = nn.Conv2d(128, 6, 1, 1, 0)
|
||||
self.tidy2 = nn.Conv2d(128, 6, 1, 1, 0)
|
||||
self.tidy3 = nn.Conv2d(128, 9, 1, 1, 0)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.size(0)
|
||||
d1 = self.ReLU(self.down1(x))
|
||||
d2 = self.ReLU(self.down2(d1))
|
||||
d3 = self.ReLU(self.down3(d2))
|
||||
t1 = self.tidy1(d1).view(batch_size, -1)
|
||||
t2 = self.tidy2(d2).view(batch_size, -1)
|
||||
t3 = self.tidy3(d3).view(batch_size, -1)
|
||||
return torch.cat((t1, t2, t3), dim=1)
|
||||
|
||||
|
||||
class AttentionNet(nn.Module):
|
||||
def __init__(self, topN=4):
|
||||
super(attention_net, self).__init__()
|
||||
self.pretrained_model = ResNet(Bottleneck, [3, 4, 6, 3])
|
||||
self.pretrained_model.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.pretrained_model.fc = nn.Linear(512 * 4, 200)
|
||||
self.proposal_net = ProposalNet()
|
||||
self.topN = topN
|
||||
self.concat_net = nn.Linear(2048 * (CAT_NUM + 1), 200)
|
||||
self.partcls_net = nn.Linear(512 * 4, 200)
|
||||
_, edge_anchors, _ = generate_default_anchor_maps()
|
||||
self.pad_side = 224
|
||||
self.edge_anchors = (edge_anchors + 224).astype(np.int)
|
||||
|
||||
def forward(self, x):
|
||||
resnet_out, rpn_feature, feature = self.pretrained_model(x)
|
||||
x_pad = F.pad(x, (self.pad_side, self.pad_side, self.pad_side, self.pad_side), mode='constant', value=0)
|
||||
batch = x.size(0)
|
||||
# we will reshape rpn to shape: batch * nb_anchor
|
||||
rpn_score = self.proposal_net(rpn_feature.detach())
|
||||
all_cdds = [
|
||||
np.concatenate((x.reshape(-1, 1), self.edge_anchors.copy(), np.arange(0, len(x)).reshape(-1, 1)), axis=1)
|
||||
for x in rpn_score.data.cpu().numpy()]
|
||||
top_n_cdds = [hard_nms(x, topn=self.topN, iou_thresh=0.25) for x in all_cdds]
|
||||
top_n_cdds = np.array(top_n_cdds)
|
||||
top_n_index = top_n_cdds[:, :, -1].astype(np.int)
|
||||
top_n_index = torch.from_numpy(top_n_index).cuda()
|
||||
top_n_prob = torch.gather(rpn_score, dim=1, index=top_n_index)
|
||||
part_imgs = torch.zeros([batch, self.topN, 3, 224, 224]).cuda()
|
||||
for i in range(batch):
|
||||
for j in range(self.topN):
|
||||
[y0, x0, y1, x1] = top_n_cdds[i][j, 1:5].astype(np.int)
|
||||
part_imgs[i:i + 1, j] = F.interpolate(x_pad[i:i + 1, :, y0:y1, x0:x1], size=(224, 224), mode='bilinear',
|
||||
align_corners=True)
|
||||
part_imgs = part_imgs.view(batch * self.topN, 3, 224, 224)
|
||||
_, _, part_features = self.pretrained_model(part_imgs.detach())
|
||||
part_feature = part_features.view(batch, self.topN, -1)
|
||||
part_feature = part_feature[:, :CAT_NUM, ...].contiguous()
|
||||
part_feature = part_feature.view(batch, -1)
|
||||
# concat_logits have the shape: B*200
|
||||
concat_out = torch.cat([part_feature, feature], dim=1)
|
||||
concat_logits = self.concat_net(concat_out)
|
||||
raw_logits = resnet_out
|
||||
# part_logits have the shape: B*N*200
|
||||
part_logits = self.partcls_net(part_features).view(batch, self.topN, -1)
|
||||
return [raw_logits, concat_logits, part_logits, top_n_index, top_n_prob]
|
||||
|
||||
|
||||
def list_loss(logits, targets):
|
||||
temp = F.log_softmax(logits, -1)
|
||||
loss = [-temp[i][targets[i].item()] for i in range(logits.size(0))]
|
||||
return torch.stack(loss)
|
||||
|
||||
|
||||
def ranking_loss(score, targets, proposal_num=PROPOSAL_NUM):
|
||||
loss = Variable(torch.zeros(1).cuda())
|
||||
batch_size = score.size(0)
|
||||
for i in range(proposal_num):
|
||||
targets_p = (targets > targets[:, i].unsqueeze(1)).type(torch.cuda.FloatTensor)
|
||||
pivot = score[:, i].unsqueeze(1)
|
||||
loss_p = (1 - pivot + score) * targets_p
|
||||
loss_p = torch.sum(F.relu(loss_p))
|
||||
loss += loss_p
|
||||
return loss / batch_size
|
||||
|
||||
|
||||
# anchors.py
|
||||
_default_anchors_setting = (
|
||||
dict(layer='p3', stride=32, size=48, scale=[2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]),
|
||||
dict(layer='p4', stride=64, size=96, scale=[2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]),
|
||||
dict(layer='p5', stride=128, size=192, scale=[1, 2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]),
|
||||
)
|
||||
|
||||
|
||||
def generate_default_anchor_maps(anchors_setting=None, input_shape=INPUT_SIZE):
|
||||
"""
|
||||
generate default anchor
|
||||
:param anchors_setting: all informations of anchors
|
||||
:param input_shape: shape of input images, e.g. (h, w)
|
||||
:return: center_anchors: # anchors * 4 (oy, ox, h, w)
|
||||
edge_anchors: # anchors * 4 (y0, x0, y1, x1)
|
||||
anchor_area: # anchors * 1 (area)
|
||||
"""
|
||||
if anchors_setting is None:
|
||||
anchors_setting = _default_anchors_setting
|
||||
|
||||
center_anchors = np.zeros((0, 4), dtype=np.float32)
|
||||
edge_anchors = np.zeros((0, 4), dtype=np.float32)
|
||||
anchor_areas = np.zeros((0,), dtype=np.float32)
|
||||
input_shape = np.array(input_shape, dtype=int)
|
||||
|
||||
for anchor_info in anchors_setting:
|
||||
|
||||
stride = anchor_info['stride']
|
||||
size = anchor_info['size']
|
||||
scales = anchor_info['scale']
|
||||
aspect_ratios = anchor_info['aspect_ratio']
|
||||
|
||||
output_map_shape = np.ceil(input_shape.astype(np.float32) / stride)
|
||||
output_map_shape = output_map_shape.astype(np.int)
|
||||
output_shape = tuple(output_map_shape) + (4,)
|
||||
ostart = stride / 2.
|
||||
oy = np.arange(ostart, ostart + stride * output_shape[0], stride)
|
||||
oy = oy.reshape(output_shape[0], 1)
|
||||
ox = np.arange(ostart, ostart + stride * output_shape[1], stride)
|
||||
ox = ox.reshape(1, output_shape[1])
|
||||
center_anchor_map_template = np.zeros(output_shape, dtype=np.float32)
|
||||
center_anchor_map_template[:, :, 0] = oy
|
||||
center_anchor_map_template[:, :, 1] = ox
|
||||
for scale in scales:
|
||||
for aspect_ratio in aspect_ratios:
|
||||
center_anchor_map = center_anchor_map_template.copy()
|
||||
center_anchor_map[:, :, 2] = size * scale / float(aspect_ratio) ** 0.5
|
||||
center_anchor_map[:, :, 3] = size * scale * float(aspect_ratio) ** 0.5
|
||||
|
||||
edge_anchor_map = np.concatenate((center_anchor_map[..., :2] - center_anchor_map[..., 2:4] / 2.,
|
||||
center_anchor_map[..., :2] + center_anchor_map[..., 2:4] / 2.),
|
||||
axis=-1)
|
||||
anchor_area_map = center_anchor_map[..., 2] * center_anchor_map[..., 3]
|
||||
center_anchors = np.concatenate((center_anchors, center_anchor_map.reshape(-1, 4)))
|
||||
edge_anchors = np.concatenate((edge_anchors, edge_anchor_map.reshape(-1, 4)))
|
||||
anchor_areas = np.concatenate((anchor_areas, anchor_area_map.reshape(-1)))
|
||||
|
||||
return center_anchors, edge_anchors, anchor_areas
|
||||
|
||||
|
||||
def hard_nms(cdds, topn=10, iou_thresh=0.25):
|
||||
if not (type(cdds).__module__ == 'numpy' and len(cdds.shape) == 2 and cdds.shape[1] >= 5):
|
||||
raise TypeError('edge_box_map should be N * 5+ ndarray')
|
||||
|
||||
cdds = cdds.copy()
|
||||
indices = np.argsort(cdds[:, 0])
|
||||
cdds = cdds[indices]
|
||||
cdd_results = []
|
||||
|
||||
res = cdds
|
||||
|
||||
while res.any():
|
||||
cdd = res[-1]
|
||||
cdd_results.append(cdd)
|
||||
if len(cdd_results) == topn:
|
||||
return np.array(cdd_results)
|
||||
res = res[:-1]
|
||||
|
||||
start_max = np.maximum(res[:, 1:3], cdd[1:3])
|
||||
end_min = np.minimum(res[:, 3:5], cdd[3:5])
|
||||
lengths = end_min - start_max
|
||||
intersec_map = lengths[:, 0] * lengths[:, 1]
|
||||
intersec_map[np.logical_or(lengths[:, 0] < 0, lengths[:, 1] < 0)] = 0
|
||||
iou_map_cur = intersec_map / ((res[:, 3] - res[:, 1]) * (res[:, 4] - res[:, 2]) + (cdd[3] - cdd[1]) * (
|
||||
cdd[4] - cdd[2]) - intersec_map)
|
||||
res = res[iou_map_cur < iou_thresh]
|
||||
|
||||
return np.array(cdd_results)
|
||||
|
||||
|
||||
#### -------------------------------如何定义batch的读写方式-------------------------------
|
||||
# 默认读写方式
|
||||
def default_loader(path):
|
||||
try:
|
||||
img = Image.open(path).convert("RGB")
|
||||
if img is not None:
|
||||
return img
|
||||
except:
|
||||
print("error image:{}".format(path))
|
||||
|
||||
|
||||
def opencv_isvalid(img_path):
|
||||
img_bgr = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), -1)
|
||||
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||
return img_bgr
|
||||
|
||||
|
||||
# 判断图片是否为无效
|
||||
def IsValidImage(img_path):
|
||||
vaild = True
|
||||
if img_path.endswith(".tif") or img_path.endswith(".tiff"):
|
||||
vaild = False
|
||||
return vaild
|
||||
try:
|
||||
img = opencv_isvalid(img_path)
|
||||
if img is None:
|
||||
vaild = False
|
||||
return vaild
|
||||
except:
|
||||
vaild = False
|
||||
return vaild
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
def __init__(self, dir_path, transform=None, loader=default_loader):
|
||||
fh, imgs = list(), list()
|
||||
num = 0
|
||||
for root, dirs, files in os.walk(dir_path):
|
||||
for file in files:
|
||||
try:
|
||||
img_path = os.path.join(root + os.sep, file)
|
||||
num += 1
|
||||
if IsValidImage(img_path):
|
||||
fh.append(img_path)
|
||||
else:
|
||||
os.remove(img_path)
|
||||
|
||||
except:
|
||||
print("image is broken")
|
||||
print("total images is:{}".format(num))
|
||||
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
imgs.append(line)
|
||||
|
||||
self.imgs = imgs
|
||||
self.transform = transform
|
||||
self.loader = loader
|
||||
|
||||
def __getitem__(self, item):
|
||||
fh = self.imgs[item]
|
||||
img = self.loader(fh)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
return fh, img
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
|
||||
|
||||
#### -------------------------------如何定义batch的读写方式-------------------------------
|
||||
|
||||
|
||||
#### -------------------------------图像模糊的定义-------------------------------
|
||||
def variance_of_laplacian(image):
|
||||
return cv2.Laplacian(image, cv2.CV_64f).var()
|
||||
|
||||
|
||||
## 如何定义接口函数
|
||||
def imgQualJudge(img, QA_THRESHOLD):
|
||||
'''
|
||||
:param img:
|
||||
:param QA_THRESHOLD: 越高越清晰
|
||||
:return: 是否模糊,0为模糊,1为清晰
|
||||
'''
|
||||
|
||||
norheight = 1707
|
||||
norwidth = 1280
|
||||
flag = 0
|
||||
# 筛选尺寸
|
||||
if max(img.shape[0], img.shape[1]) < 320:
|
||||
flag = '10002'
|
||||
return flag
|
||||
|
||||
# 模糊筛选部分
|
||||
if img.shape[0] <= img.shape[1]:
|
||||
size1 = (norheight, norwidth)
|
||||
timage = cv2.resize(img, size1)
|
||||
else:
|
||||
size2 = (norwidth, norheight)
|
||||
timage = cv2.resize(img, size2)
|
||||
|
||||
tgray = cv2.cvtColor(timage, cv2.COLOR_BGR2GRAY)
|
||||
halfgray = tgray[0:int(tgray.shape[0] / 2), 0:tgray.shape[1]]
|
||||
norgrayImg = np.zeros(halfgray.shape, np.int8)
|
||||
cv2.normalize(halfgray, norgrayImg, 0, 255, cv2.NORM_MINMAX)
|
||||
fm = variance_of_laplacian(norgrayImg) # 模糊值
|
||||
if fm < QA_THRESHOLD:
|
||||
flag = '10001'
|
||||
return flag
|
||||
return flag
|
||||
|
||||
|
||||
def process(img_path):
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
valid = True
|
||||
low_quality = "10001"
|
||||
size_error = "10002"
|
||||
|
||||
flag = imgQualJudge(np.array(img), 5)
|
||||
if flag == low_quality or flag == size_error or not img or 0 in np.asarray(img).shape[:2]:
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
|
||||
#### -------------------------------图像模糊的定义-------------------------------
|
||||
|
||||
def build_dict():
|
||||
dict_club = dict()
|
||||
dict_club[0] = ["身份证", 0.999999]
|
||||
dict_club[1] = ["校园卡", 0.890876]
|
||||
return dict_club
|
||||
|
||||
|
||||
class Classifier():
|
||||
def __init__(self):
|
||||
self.device = torch.device('cuda')
|
||||
self.class_id_name_dict = build_dict()
|
||||
self.mean = [0.485, 0.456, 0.406]
|
||||
self.std = [0.229, 0.224, 0.225]
|
||||
self.input_size = 448
|
||||
self.use_cuda = torch.cuda.is_available()
|
||||
self.model = AttentionNet(topN=4)
|
||||
self.model.eval()
|
||||
|
||||
checkpoint = torch.load("./.ckpt")
|
||||
newweights = checkpoint['net_state_dict']
|
||||
|
||||
# 多卡测试转为单卡
|
||||
new_state_dic = OrderedDict()
|
||||
for k, v in newweights.items():
|
||||
name = k[7:] if k.startwith("module.") else k
|
||||
new_state_dic[name] = v
|
||||
|
||||
self.model.load_state_dict(new_state_dic)
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
def evalute(self, dir_path):
|
||||
data = MyDataset(dir_path, transform=self.preprocess)
|
||||
dataloader = DataLoader(dataset=data, batch_size=32, num_workers=8)
|
||||
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
num = 0
|
||||
for i, (data, path) in enumerate(dataloader, 1):
|
||||
data = data.to(self.device)
|
||||
output = self.model(data)
|
||||
for j in range(len(data)):
|
||||
img_path = path[j]
|
||||
img_output = output[1][j]
|
||||
score, label, type = self.postprocess(img_output)
|
||||
out_dict, score = self.process(score, label, type)
|
||||
class_id = out_dict["results"]["class2"]["code"]
|
||||
num += 1
|
||||
if class_id != '00038':
|
||||
os.remove(img_path)
|
||||
else:
|
||||
continue
|
||||
|
||||
def preprocess(self, img):
|
||||
img = transforms.Resize((600, 600), Image.BILINEAR)(img)
|
||||
img = transforms.CenterCrop(self.input_size)(img)
|
||||
img = transforms.ToTensor()(img)
|
||||
img = transforms.Normalize(self.mean, self.std)
|
||||
|
||||
def postprocess(self, output):
|
||||
pred_logits = F.softmax(output, dim=0)
|
||||
score, label = pred_logits.max(0)
|
||||
score = score.item()
|
||||
label = label.item()
|
||||
type = self.class_id_name_dict[label][0]
|
||||
return score, label, type
|
||||
|
||||
def process(self, score, label, type):
|
||||
success_code = "200"
|
||||
lower_conf_code = "10008"
|
||||
|
||||
threshold = float(self.class_id_name_dict[label][1])
|
||||
if threshold > 0.99:
|
||||
threshold = 0.99
|
||||
if threshold < 0.9:
|
||||
threshold = 0.9
|
||||
## 设置查勘图片较低的阈值
|
||||
if label == 38:
|
||||
threshold = 0.5
|
||||
|
||||
if score > threshold:
|
||||
status_code = success_code
|
||||
pred_label = str(label).zfill(5)
|
||||
print("pred_label:", pred_label)
|
||||
return {"code:": status_code, "message": '图像分类成功',
|
||||
"results": {"class2": {'code': pred_label, 'name': type}}}, score
|
||||
else:
|
||||
status_code = lower_conf_code
|
||||
return {"code:": status_code, "message": '图像分类置信度低,不返回结果',
|
||||
"results": {"class2": {'code': '', 'name': ''}}}, score
|
||||
|
||||
|
||||
def class_results(img_dir):
|
||||
Classifier().evalute(img_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
59
utils/config.py
Normal file
59
utils/config.py
Normal file
@ -0,0 +1,59 @@
|
||||
from yacs.config import CfgNode as CfgNode
|
||||
_C = CfgNode()
|
||||
cfg = _C
|
||||
|
||||
_C.RESIZE = 648
|
||||
|
||||
#Monitoring table of commodity identification System
|
||||
_C.MONITORPATH = '../module/ieemoo-ai-search/model/now/monitor.json'
|
||||
|
||||
_C.VIDEOPATH = '../module/ieemoo-ai-search/videos'
|
||||
|
||||
#The number of results retrieved
|
||||
_C.NUM_RESULT = 5
|
||||
|
||||
#RAFT img numbers
|
||||
_C.NUM_RAFT = 10#6 #real nu: nu-1
|
||||
|
||||
#Path for storing video images
|
||||
_C.TEST_IMG_DIR = '../module/ieemoo-ai-search/imgs'
|
||||
_C.ORIIMG = '../module/ieemoo-ai-search/imgs'
|
||||
|
||||
_C.Tempimg = '../module/ieemoo-ai-search/data'#扫码商品展示
|
||||
_C.Tempvideos = '../module/ieemoo-ai-search/tempvideos'#扫码商品展示
|
||||
|
||||
_C.DIM = 2048
|
||||
_C.flag = False
|
||||
|
||||
#angle of rotation
|
||||
_C.ANGLES = [45,90,270,315] #45,90,135,180,225,270,315
|
||||
|
||||
#Weight of feature extraction
|
||||
#_C.NETWORK = '../../module/ieemoo-ai-search/model/now/model_best.pth' #retrieval_feature 测试
|
||||
_C.NETWORK = '../module/ieemoo-ai-search/model/now/model_best.pth'
|
||||
|
||||
#Weight of RAFT
|
||||
_C.RAFTMODEL= '../module/ieemoo-ai-search/model/now/raft-things.pth'
|
||||
|
||||
_C.DEVICE = 0
|
||||
|
||||
#Similarity threshold
|
||||
_C.THRESHOLD = 0.9
|
||||
|
||||
#mask img
|
||||
_C.MASKIMG = '../module/ieemoo-ai-search/model/now/masking.jpg'
|
||||
_C.MASKIMG_old = '../module/ieemoo-ai-search/model/now/masking_old.jpg'
|
||||
|
||||
#fgbg mask img
|
||||
_C.fgbgmask = '../module/ieemoo-ai-search/model/now/ori.jpg'
|
||||
_C.fgbgmask_old = '../module/ieemoo-ai-search/model/now/ori_old.jpg'
|
||||
|
||||
_C.URL = 'https://api.ieemoo.com/emoo-api/intelligence' #online
|
||||
#_C.URL = 'http://api.test.ieemoo.com/emoo-api/intelligence'
|
||||
|
||||
_C.Vre = 'http://api.test.ieemoo.com/emoo-api/intelligence/queryVideoCompareResult.do'
|
||||
#_C.Vre = 'http://192.168.1.98:8088/emoo-api/intelligence/queryVideoCompareResult.do'
|
||||
|
||||
_C.Ocrimg = '../module/ieemoo-ai-ocr/imgs'#post ocr img
|
||||
_C.Ocrtxt = '../module/ieemoo-ai-ocr/document'#post ocr txts
|
||||
_C.Ocrvideo = '../module/ieemoo-ai-ocr/videos'#post ocr video
|
121
utils/forsegmentation.py
Normal file
121
utils/forsegmentation.py
Normal file
@ -0,0 +1,121 @@
|
||||
import requests
|
||||
from base64 import b64encode
|
||||
from json import dumps
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pdb
|
||||
max_nu_area= 0
|
||||
|
||||
def get_object_mask(frame, ori_mask, mask_path, all_nu, result_path):
|
||||
global max_nu_area
|
||||
maskimg = cv2.imread(mask_path, 0)
|
||||
kernel = np.ones((5, 5), np.uint8)
|
||||
dst = ori_mask
|
||||
dst = cv2.erode(dst, kernel)
|
||||
dst = cv2.dilate(dst, kernel)
|
||||
dst = cv2.medianBlur(dst,3)
|
||||
if (cv2.__version__).split('.')[0] == '3':
|
||||
_, contours, _ = cv2.findContours(dst, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
||||
else:
|
||||
contours, _ = cv2.findContours(dst, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
|
||||
area,mask_area = 0,0
|
||||
if len(contours) == 1 or len(contours) == 0:
|
||||
return None
|
||||
for contour in contours:
|
||||
area_now = cv2.contourArea(contour)
|
||||
dst = cv2.fillPoly(dst, [contour], (0))
|
||||
if area_now > area:
|
||||
area = area_now
|
||||
Matrix = contour
|
||||
max_area = area_now
|
||||
if max_area > max_nu_area:
|
||||
max_nu_area = max_area #维护最大的五个
|
||||
flag = False
|
||||
else:flag = True
|
||||
(x, y, w, h) = cv2.boundingRect(Matrix)
|
||||
dst = cv2.fillPoly(dst, [Matrix], (255))
|
||||
coordination = [x, y, x + w, y + h]
|
||||
if max_area/(w*h)<0.3:
|
||||
return None
|
||||
#print('masking', maskimg.shape)
|
||||
#print('dst', dst.shape)
|
||||
#pdb.set_trace()
|
||||
mask_dst = cv2.bitwise_and(maskimg, dst)
|
||||
if w < 350 or h <350:
|
||||
return None
|
||||
if (cv2.__version__).split('.')[0] == '3':
|
||||
_, contours, _ = cv2.findContours(mask_dst, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
||||
else:
|
||||
contours, _ = cv2.findContours(mask_dst, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
|
||||
for contour in contours:
|
||||
mask_area_now = cv2.contourArea(contour)
|
||||
if mask_area_now > mask_area:
|
||||
mask_area = mask_area_now
|
||||
proportion = float(mask_area/max_area)
|
||||
if proportion<1.05 and proportion>0.5:
|
||||
#print(coordination)
|
||||
A,B,C = mask_dst, mask_dst, mask_dst
|
||||
mask_dst = cv2.merge([A,B,C])
|
||||
img = cv2.bitwise_and(mask_dst, frame)
|
||||
img = img[coordination[1]:coordination[3],coordination[0]:coordination[2]]
|
||||
frame = frame[coordination[1]:coordination[3],coordination[0]:coordination[2]]
|
||||
#cv2.imshow('dst',img)
|
||||
#cv2.waitKey(1000)
|
||||
#print(all_nu)
|
||||
#if all_nu>4: return 'True'
|
||||
ratio = (w/h if h>w else h/w)
|
||||
if ratio<0.5:
|
||||
return None
|
||||
if all_nu<5:
|
||||
savenu = all_nu
|
||||
elif all_nu>=5 and not(flag):
|
||||
savenu = all_nu%5
|
||||
print(savenu)
|
||||
else:
|
||||
return 'True'
|
||||
cv2.imwrite('images/' + str(savenu)+'.jpg', img)
|
||||
cv2.imwrite('images/' + 'ori'+ str(savenu)+'.jpg', frame)
|
||||
#cv2.imwrite(os.sep.join([result_path, str(all_nu)+'.jpg']), img)
|
||||
return 'con'
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_object_location(file_test, mask_path,result_path):
|
||||
cap = cv2.VideoCapture(file_test)
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
||||
fgbg = cv2.createBackgroundSubtractorMOG2(detectShadows = False)#高斯混合模型为基础背景
|
||||
nu= 10
|
||||
working_nu = 0
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
frame = cv2.medianBlur(frame, ksize=3)
|
||||
frame_motion = frame.copy()
|
||||
fgmask = fgbg.apply(frame_motion)
|
||||
mask= cv2.threshold(fgmask, 25, 255, cv2.THRESH_BINARY)[1] # 二值化
|
||||
mask = cv2.dilate(mask, kernel, iterations=1)
|
||||
if nu<=30:
|
||||
res = get_object_mask(frame, mask, mask_path, working_nu, result_path)
|
||||
#print(res)
|
||||
if res=='True':
|
||||
break
|
||||
elif res == 'con':
|
||||
working_nu+=1
|
||||
else: continue
|
||||
else:break
|
||||
nu+=1
|
||||
|
||||
def analysis(file_test, mask_path, result_path):
|
||||
#mask_path = 'mask.jpg'
|
||||
#result_path = 'result'
|
||||
if not (os.path.exists(result_path)):
|
||||
os.mkdir(result_path)
|
||||
get_object_location(file_test, mask_path, result_path)
|
||||
if __name__ == '__main__':
|
||||
mask_path = 'mask2.jpg'
|
||||
result_path = 'images'
|
||||
file_test = "ftp/anonymous/20210908-150524_e8e24395-fc7b-42f2-a50a-068c4ac73ee9_6921168509256.mp4"
|
||||
analysis(file_test, mask_path, result_path)
|
45
utils/logging.json
Normal file
45
utils/logging.json
Normal file
@ -0,0 +1,45 @@
|
||||
{
|
||||
"version":1,
|
||||
"disable_existing_loggers":false,
|
||||
"formatters":{
|
||||
"simple":{
|
||||
"format":"%(asctime)s - %(module)s - %(thread)d - %(levelname)s : %(message)s"
|
||||
}
|
||||
},
|
||||
"handlers":{
|
||||
"console":{
|
||||
"class":"logging.StreamHandler",
|
||||
"level":"DEBUG",
|
||||
"formatter":"simple",
|
||||
"stream":"ext://sys.stdout"
|
||||
},
|
||||
"info_file_handler":{
|
||||
"class":"logging.handlers.RotatingFileHandler",
|
||||
"level":"INFO",
|
||||
"formatter":"simple",
|
||||
"filename":"../log/ieemoo-ai-search-biz.log",
|
||||
"maxBytes":10485760,
|
||||
"backupCount":20,
|
||||
"encoding":"utf8"
|
||||
},
|
||||
"error_file_handler":{
|
||||
"class":"logging.handlers.RotatingFileHandler",
|
||||
"level":"ERROR",
|
||||
"formatter":"simple",
|
||||
"filename":"../log/ieemoo-ai-search-biz.log",
|
||||
"maxBytes":10485760,
|
||||
"backupCount":20,
|
||||
"encoding":"utf8"
|
||||
}
|
||||
},
|
||||
"loggers":{
|
||||
"my_module":{
|
||||
"level":"ERROR",
|
||||
"handlers":["info_file_handler"],
|
||||
"propagate":"no"}
|
||||
},
|
||||
"root":{
|
||||
"level":"INFO",
|
||||
"handlers":["console","info_file_handler","error_file_handler"]
|
||||
}
|
||||
}
|
61
utils/monitor.py
Normal file
61
utils/monitor.py
Normal file
@ -0,0 +1,61 @@
|
||||
import os
|
||||
import json
|
||||
import h5py
|
||||
import numpy as np
|
||||
from utils.config import cfg
|
||||
from cirtorch.networks.imageretrievalnet import init_network, extract_vectors
|
||||
from utils.tools import createNet,ManagingFeature
|
||||
|
||||
class Moniting:
|
||||
def __init__(self, barcode = ''):
|
||||
self.barcode = barcode
|
||||
self.jsonpath = cfg.MONITORPATH
|
||||
self.MF = ManagingFeature()
|
||||
if not os.path.exists(self.jsonpath):
|
||||
jsontext = {"monitor":[]}
|
||||
jsondata = json.dumps(jsontext)
|
||||
f = open(self.jsonpath, 'w+', encoding='utf-8')
|
||||
f.write(jsondata)
|
||||
f.close()
|
||||
|
||||
def add(self):
|
||||
with open(self.jsonpath, 'r', encoding='utf-8') as add_f:
|
||||
jsonfile = json.load(add_f)
|
||||
add_f.close()
|
||||
data = list(set(jsonfile['monitor']+self.barcode))
|
||||
jsonfile['monitor'] = data
|
||||
with open(self.jsonpath, 'w') as add_f:
|
||||
json.dump(jsonfile, add_f)
|
||||
add_f.close()
|
||||
|
||||
def search(self):
|
||||
with open(self.jsonpath, 'r', encoding='utf-8') as f:
|
||||
jsonfile = json.load(f)
|
||||
f.close()
|
||||
data = set(jsonfile['monitor'])
|
||||
if self.barcode in data:
|
||||
return 'success'
|
||||
else:
|
||||
return 'nomatch'
|
||||
|
||||
def update(self, net, transform, ms):#update monitor.json
|
||||
Dict = {}
|
||||
path = []
|
||||
collectbarcode = set()
|
||||
for name in os.listdir(cfg.IMG_DIR_TOTAL):
|
||||
barcode = name.split('_')[-1].split('.')[0]
|
||||
path.append(os.sep.join([cfg.IMG_DIR_TOTAL, name]))
|
||||
collectbarcode.add(barcode)
|
||||
vecs, img_paths = extract_vectors(net, path, cfg.RESIZE, transform, ms=ms)
|
||||
data = list(vecs.detach().cpu().numpy().T)
|
||||
for code, feature in zip(img_paths, data):
|
||||
barcode = code.split('_')[-1].split('.')[0]
|
||||
feature = feature.tolist()
|
||||
self.MF.addfeature(barcode, feature)
|
||||
Moniting(list(collectbarcode)).add()
|
||||
|
||||
if __name__ == '__main__':
|
||||
barcode = '1'
|
||||
mo = Moniting(barcode)
|
||||
print(mo.search())
|
||||
mo.add(barcode)
|
71
utils/retrieval_feature.py
Normal file
71
utils/retrieval_feature.py
Normal file
@ -0,0 +1,71 @@
|
||||
# coding=utf-8
|
||||
# /usr/bin/env pythpn
|
||||
|
||||
import sys
|
||||
sys.path.append('..')
|
||||
import os
|
||||
from PIL import Image
|
||||
from cirtorch.networks.imageretrievalnet import extract_vectors, extract_vectors_o
|
||||
from utils.config import cfg
|
||||
from utils.monitor import Moniting
|
||||
import cv2 as cv
|
||||
# setting up the visible GPU
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
|
||||
|
||||
class ImageProcess():
|
||||
def __init__(self, img_dir):
|
||||
self.img_dir = img_dir
|
||||
|
||||
def process(self, uuid_barcode):
|
||||
imgs = list()
|
||||
nu = 0
|
||||
for root, dirs, files in os.walk(self.img_dir):
|
||||
for file in files:
|
||||
img_path = os.path.join(root + os.sep, file)
|
||||
try:
|
||||
image = Image.open(img_path)
|
||||
if max(image.size) / min(image.size) < 5:
|
||||
if uuid_barcode == None:
|
||||
imgs.append(img_path)
|
||||
print('\r>>>> {}/{} Train done...'.format((nu + 1),
|
||||
len(os.listdir(self.img_dir))),
|
||||
end='')
|
||||
nu+=1
|
||||
else:
|
||||
if uuid_barcode in img_path:
|
||||
imgs.append(img_path)
|
||||
except:
|
||||
print("image height/width ratio is small")
|
||||
return imgs
|
||||
|
||||
class AntiFraudFeatureDataset():
|
||||
def __init__(self, uuid_barcode=None, test_img_dir=cfg.TEST_IMG_DIR):#, model='work'):
|
||||
self.uuid_barcode = uuid_barcode
|
||||
self.TestImgDir = test_img_dir
|
||||
#self.model = model
|
||||
|
||||
def extractFeature_o(self, net, image, transform, ms):
|
||||
size = cfg.RESIZE
|
||||
#image = cv.resize(image, (size, size))
|
||||
vecs = extract_vectors_o(net, image, size,transform, ms=ms)
|
||||
feature_dict = list(vecs.detach().cpu().numpy().T)
|
||||
return feature_dict
|
||||
|
||||
def extractFeature(self, net, transform, ms):
|
||||
# extract database and query vectors
|
||||
print('>> database images...')
|
||||
images = ImageProcess(self.TestImgDir).process(self.uuid_barcode)
|
||||
#print('ori', images)
|
||||
vecs, img_paths = extract_vectors( net, images, cfg.RESIZE, transform, ms=ms)
|
||||
feature_dict = list(vecs.detach().cpu().numpy().T)
|
||||
return feature_dict
|
||||
|
||||
if __name__ == '__main__':
|
||||
from utils.tools import createNet
|
||||
net, transform, ms = createNet()
|
||||
path = '../data/imgs/1.jpg'
|
||||
image = cv.imread(path)
|
||||
affd = AntiFraudFeatureDataset()
|
||||
feature = affd.extractFeature_o(net, image, transform, ms)
|
||||
print(len(feature))
|
||||
|
124
utils/tools.py
Normal file
124
utils/tools.py
Normal file
@ -0,0 +1,124 @@
|
||||
#from config import cfg
|
||||
from utils.config import cfg
|
||||
from cirtorch.networks.imageretrievalnet import init_network, extract_vectors
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
import os, scipy, math
|
||||
import http.client
|
||||
#http.client.HTTPConnection._http_vsn = 10
|
||||
#http.client.HTTPConnection._http_vsn_str = 'HTTP/1.0'
|
||||
|
||||
def rotate_bound(image, angle): #ratio
|
||||
(h, w) = image.shape[:2]
|
||||
(cX, cY) = (w // 2, h // 2)
|
||||
M = cv2.getRotationMatrix2D((cX, cY), -angle, 1.0)
|
||||
cos = np.abs(M[0, 0])
|
||||
sin = np.abs(M[0, 1])
|
||||
nW = int((h * sin) + (w * cos))
|
||||
nH = int((h * cos) + (w * sin))
|
||||
M[0, 2] += (nW / 2) - cX
|
||||
M[1, 2] += (nH / 2) - cY
|
||||
return cv2.warpAffine(image, M, (nW, nH))
|
||||
|
||||
def createNet(): #load model
|
||||
multiscale = '[1]'
|
||||
print(">> Loading network:\n>>>> '{}'".format(cfg.NETWORK))
|
||||
state = torch.load(cfg.NETWORK)
|
||||
net_params = {}
|
||||
net_params['architecture'] = state['meta']['architecture']
|
||||
net_params['pooling'] = state['meta']['pooling']
|
||||
net_params['local_whitening'] = state['meta'].get('local_whitening', False)
|
||||
net_params['regional'] = state['meta'].get('regional', False)
|
||||
net_params['whitening'] = state['meta'].get('whitening', False)
|
||||
net_params['mean'] = state['meta']['mean']
|
||||
net_params['std'] = state['meta']['std']
|
||||
net_params['pretrained'] = False
|
||||
net = init_network(net_params)
|
||||
net.load_state_dict(state['state_dict'])
|
||||
print(">>>> loaded network: ")
|
||||
print(net.meta_repr())
|
||||
ms = list(eval(multiscale))
|
||||
print(">>>> Evaluating scales: {}".format(ms))
|
||||
if torch.cuda.is_available():
|
||||
net.cuda()
|
||||
net.eval()
|
||||
normalize = transforms.Normalize(
|
||||
mean=net.meta['mean'],
|
||||
std=net.meta['std']
|
||||
)
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
normalize
|
||||
])
|
||||
return net, transform, ms
|
||||
|
||||
class ManagingFeature: #特征增删改查
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def addfeature(self, code, feature):
|
||||
url = os.sep.join([cfg.URL, 'addImageFeatureInfo.do'])
|
||||
json = {'code':code,
|
||||
'featureVal':feature}
|
||||
r = requests.post(url=url, data=json)
|
||||
return r.text
|
||||
|
||||
def deletefeature(self, code, timeStamp): #eg: "timeStamp":"2022/02/10 17:59:59"
|
||||
url = os.sep.join([cfg.URL, 'deletImageFeatureInfo.do'])
|
||||
json = {'code':code,
|
||||
'timeStamp':timeStamp}
|
||||
r = requests.get(url=url, params=json)
|
||||
return r.json
|
||||
|
||||
def getfeature(self, code):
|
||||
try:
|
||||
url = os.sep.join([cfg.URL, 'getImageFeatureInfo.do'])
|
||||
json = {'code': code}
|
||||
r = requests.get(url=url, params=json)
|
||||
data = r.json()['data']
|
||||
return data
|
||||
except Exception as e:
|
||||
print('>>>>>get feature error<<<<<<')
|
||||
|
||||
class EvaluteMap():
|
||||
def __init__(self):
|
||||
self.MF = ManagingFeature()
|
||||
|
||||
def match_feature(self, features, search_r):
|
||||
alldict = []
|
||||
#search_r = self.MF.getfeature(barcode)
|
||||
for feature in features:
|
||||
for r in search_r:
|
||||
dist = np.linalg.norm(feature - r)
|
||||
#alldict.append(math.pow(dist, 2))
|
||||
alldict.append(dist)
|
||||
meandist = scipy.mean(sorted(alldict)[0:5])
|
||||
return meandist
|
||||
|
||||
def match_feature_single(self, feature, search_r):
|
||||
alldict = []
|
||||
for r in search_r:
|
||||
r = np.array(r)
|
||||
feature = np.array(feature)
|
||||
dist = np.linalg.norm(feature-r)
|
||||
alldict.append(dist)
|
||||
meandist = scipy.mean(sorted(alldict)[0:2])
|
||||
return meandist
|
||||
|
||||
def match_images(self, feature_dict, barcode, choose = False, mod = 'batch'):
|
||||
if mod == 'batch':
|
||||
result = self.match_feature(feature_dict, barcode)
|
||||
return result
|
||||
else:
|
||||
result = self.match_feature_single(feature_dict, barcode)
|
||||
return result
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
# MF = ManagingFeature()
|
||||
# result = MF.getfeature('7613035075443')
|
||||
# print(result)
|
||||
|
71
utils/updateObs.py
Normal file
71
utils/updateObs.py
Normal file
@ -0,0 +1,71 @@
|
||||
from obs import ObsClient
|
||||
from datetime import datetime
|
||||
#from config import cfg
|
||||
from utils.config import cfg
|
||||
import os,threading
|
||||
import time as ti
|
||||
import base64,requests,cv2,shutil
|
||||
lock = threading.Lock()
|
||||
obsClient = ObsClient(
|
||||
access_key_id='LHXJC7GIC2NNUUHHTNVI',
|
||||
secret_access_key='sVWvEItrFKWPp5DxeMvX8jLFU69iXPpzkjuMX3iM',
|
||||
server='https://obs.cn-east-3.myhuaweicloud.com'
|
||||
)
|
||||
bucketName = 'ieemoo-ai'
|
||||
|
||||
def AddObs(file_path, status):
|
||||
with lock:
|
||||
if not cfg.flag:
|
||||
addobs(file_path, status)
|
||||
else:
|
||||
if status == '02':
|
||||
addobs(file_path, status)
|
||||
else:
|
||||
objectkey = os.path.basename(file_path)
|
||||
f_dist = os.sep.join([cfg.tempvideos, status+'_'+objectkey])
|
||||
shutil.move(file_path, f_dst)
|
||||
|
||||
def addobs(file_path, status): #save videos
|
||||
ti.sleep(5)
|
||||
videoUuid = os.path.basename(file_path).split('_')[1]
|
||||
json_data = {'videoUuid': videoUuid}
|
||||
resp = requests.post(url=cfg.Vre,
|
||||
data=json_data)
|
||||
status = resp.json()
|
||||
try:
|
||||
status = status['data']
|
||||
except Exception as e:
|
||||
ti.sleep(5)
|
||||
resp = requests.post(url=cfg.Vre,data=json_data)
|
||||
status = resp.json()
|
||||
status = status['data']
|
||||
|
||||
objectkey = os.path.basename(file_path)
|
||||
time = os.path.basename(file_path).split('-')[0]
|
||||
if objectkey.split('.')[-1] in ['avi','mp4']:
|
||||
objectkey = 'videos/'+time+'/'+status+'/'+status+'_'+objectkey
|
||||
resp = obsClient.putFile(bucketName, objectkey, file_path)
|
||||
os.remove(file_path)
|
||||
|
||||
def Addimg(uuid_barcode):
|
||||
time = uuid_barcode.split('-')[0].split('_')[-1]
|
||||
objectkey = 'imgs/'+time+'/'+uuid_barcode+'.jpg'
|
||||
file_path = os.sep.join([cfg.Tempimg, '5_'+uuid_barcode+'.jpg'])
|
||||
if not os.path.exists(file_path):
|
||||
file_path = os.sep.join([cfg.Tempimg, '3_'+uuid_barcode+'.jpg'])
|
||||
if not os.path.exists(file_path):
|
||||
file_path = os.sep.join([cfg.Tempimg, 'ex_'+uuid_barcode+'.jpg'])
|
||||
resp = obsClient.putFile(bucketName, objectkey, file_path)
|
||||
|
||||
def Addimg_content(uuid_barcode, context):
|
||||
success, encoded_image = cv2.imencode(".jpg",context)
|
||||
context = encoded_image.tobytes()
|
||||
time = uuid_barcode.split('-')[0]
|
||||
objectkey = 'imgs/'+time+'/'+uuid_barcode+'.jpg'
|
||||
resp = obsClient.putContent(bucketName, objectkey, context)
|
||||
|
||||
if __name__ == '__main__':
|
||||
import cv2
|
||||
context = cv2.imread('/home/lc/project/ieemoo-ai-search/data/imgs/20230625-094651_37dd99b0-520d-457b-8615-efdb7f53b5b4_6907992825762.jpg')
|
||||
uuid_barcode = '20230625-094651_37dd99b0-520d-457b-8615-efdb7f53b5b4'
|
||||
Addimg_content(uuid_barcode, context)
|
Reference in New Issue
Block a user