diff --git a/.gitignore b/.gitignore index a81c8ee..a6d4548 100644 --- a/.gitignore +++ b/.gitignore @@ -136,3 +136,8 @@ dmypy.json # Cython debug symbols cython_debug/ + +*.pth +*.jpg +*.mp4 +*.h5 diff --git a/floder/config.py b/floder/config.py new file mode 100644 index 0000000..b1caa13 --- /dev/null +++ b/floder/config.py @@ -0,0 +1,17 @@ +from yacs.config import CfgNode as CfgNode +_C = CfgNode() +cfg = _C + +_C.RESIZE = 648 + +#Monitoring table of commodity identification System +_C.model_path = './checkpoint/best_model.pth' +_C.maskImg = './imgs/mask.jpg' +_C.maskImg1 = './imgs/mask1.jpg' + +_C.maskAreaImg = './imgs/maskAreaImg.jpg' +_C.maskAreaImg1 = './imgs/maskAreaImg1.jpg' + +_C.streamModel = './checkpoint/raft-things.pth' +_C.hFile = './floder/tempdata' +_C.videoPath = './floder/tempvideos' diff --git a/ieemoo-ai-conpurchase.py b/ieemoo-ai-conpurchase.py new file mode 100644 index 0000000..62af9d5 --- /dev/null +++ b/ieemoo-ai-conpurchase.py @@ -0,0 +1,39 @@ +from flask import request, Flask +import os +from network.vanalysis_video import vanalysis, raft_init_model +import argparse +from floder.config import cfg +from utils.detect import opvideo +from utils.embedding import DataProcessing as dp + +app = Flask(__name__) + +parser = argparse.ArgumentParser() +parser.add_argument('--model', default='./checkpoint/raft-small.pth',help="restore checkpoint") +parser.add_argument('--checkpoint', default='mobilevit',help="get embedding ") +parser.add_argument('--device', default='cuda',help="device") +parser.add_argument('--small', type=bool, default=True, help='use small model') +parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') +parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') +opt, unknown = parser.parse_known_args() + +flowmodel = raft_init_model(opt) +dps = dp(opt.checkpoint, cfg.model_path, opt.device) +opv = opvideo(flowmodel, dps) +@app.route('/conpurchase', methods=['POST', 'GET']) +def conpurchase(): + flag = request.form.get('flag') + num_id = request.form.get('num') + video_name = request.form.get('uuid') + video_data = request.files['video'] + videoPath = os.sep.join([cfg.videoPath, video_name]) + video_data.save(videoPath) + #opv.addFreature(uuid, num_id, videoPath) + #opv.opFreature(uuid, finalnum, videoPath) + if not flag: + opv.addFreature(uuid, num_id, videoPath) + else: + result = opv.opFreature(uuid, num_id, videoPath) + return result +if __name__ == '__main__': + app.run('0.0.0.0', 8898) diff --git a/network/BaseNet.py b/network/BaseNet.py new file mode 100644 index 0000000..56484b9 --- /dev/null +++ b/network/BaseNet.py @@ -0,0 +1,181 @@ +import torch +from torch import nn +import torch.nn.init as init +import torch.nn.functional as F +import torchvision.models as models +from PIL import Image +import torchvision.transforms as transforms +#from network import GeM as gem +import torch.nn.functional as F +class channelAttention(nn.Module): + def __init__(self, channel, reduction=16): + super(channelAttention, self).__init__() + self.Maxpooling = nn.AdaptiveMaxPool2d(1) + self.Avepooling = nn.AdaptiveAvgPool2d(1) + self.ca = nn.Sequential() + self.ca.add_module('conv1',nn.Conv2d(channel, channel//reduction, 1, bias=False)) + self.ca.add_module('Relu', nn.ReLU()) + self.ca.add_module('conv2',nn.Conv2d(channel//reduction, channel, 1, bias=False)) + self.sigmod = nn.Sigmoid() + + def forward(self, x): + M_out = self.Maxpooling(x) + A_out = self.Avepooling(x) + M_out = self.ca(M_out) + A_out = self.ca(A_out) + out = self.sigmod(M_out+A_out) + return out + +class SpatialAttention(nn.Module): + def __init__(self, kernel_size=7): + super().__init__() + self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, padding=kernel_size // 2) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + max_result, _ = torch.max(x, dim=1, keepdim=True) + avg_result = torch.mean(x, dim=1, keepdim=True) + result = torch.cat([max_result, avg_result], dim=1) + output = self.conv(result) + output = self.sigmoid(output) + return output +class CBAM(nn.Module): + def __init__(self, channel=512, reduction=16, kernel_size=7): + super().__init__() + self.ca = channelAttention(channel, reduction) + self.sa = SpatialAttention(kernel_size) + + def init_weights(self): + for m in self.modules():#权重初始化 + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.001) + if m.bias is not None: + init.constant_(m.bias, 0) + + def forward(self, x): + # b,c_,_ = x.size() + # residual = x + out = x*self.ca(x) + out = out*self.sa(out) + return out + +class GeM(nn.Module): + def __init__(self, p=3, eps=1e-6): + super(GeM, self).__init__() + self.p = nn.Parameter(torch.ones(1) * p) + self.eps = eps + + def forward(self, x): + return self.gem(x, p=self.p, eps=self.eps) + + def gem(self, x, p=3, eps=1e-6): + #return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p) + return F.avg_pool2d(x.clamp(min=eps).pow(p), (7, 7)).pow(1. / p) + + def __repr__(self): + return self.__class__.__name__ + \ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \ + ', ' + 'eps=' + str(self.eps) + ')' + +class ResnetFpn(nn.Module): + def __init__(self): + super(ResnetFpn, self).__init__() + self.model = models.resnet50() + self.conv1 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=1, stride=1, padding=0) + self.conv2 = nn.Conv2d(1024, 256, 1, 1, 0) + self.conv3 = nn.Conv2d(512, 256, 1, 1, 0) + self.conv4 = nn.Conv2d(256, 256, 1, 1, 0) + self.fpn_convs = nn.Conv2d(256, 256, 3, 1, 1) + self.pool = nn.AvgPool2d(7, 7, padding=2) + #self.gem = GeM() + #self.in_channel = 64 + self.cbam_layer1 = CBAM(256) + self.cbam_layer2 = CBAM(512) + self.cbam_layer3 = CBAM(1024) + self.cbam_layer4 = CBAM(2048) + self.fc = nn.Linear(in_features=20736, out_features=2048) + self.fc1 = nn.Linear(2048, 1024) + self.fc2 = nn.Linear(1024, 512) + self.fc3 = nn.Linear(512, 128) + + def forward(self, x): + x = self.model.conv1(x) + x = self.model.bn1(x) + x = self.model.relu(x) + x = self.model.maxpool(x) + + layer1 = self.model.layer1(x) + layer1 = self.cbam_layer1(layer1) + #print('layer1 >>> {}'.format(layer1.shape)) + + layer2 = self.model.layer2(layer1) + layer2 = self.cbam_layer2(layer2) + #print('layer2 >>> {}'.format(layer2.shape)) + + layer3 = self.model.layer3(layer2) + layer3 = self.cbam_layer3(layer3) + #print('layer3 >>> {}'.format(layer3.shape)) + + layer4 = self.model.layer4(layer3) # channel 256 512 1024 2048 + layer4 = self.cbam_layer4(layer4) + #print('layer4 >>> {}'.format(layer4.shape)) + + P5 = self.conv1(layer4) + P4_ = self.conv2(layer3) + P3_ = self.conv3(layer2) + P2_ = self.conv4(layer1) + + size4 = P4_.shape[2:] + size3 = P3_.shape[2:] + size2 = P2_.shape[2:] + + P4 = P4_ + F.interpolate(P5, size=size4, mode='nearest') + P3 = P3_ + F.interpolate(P4, size=size3, mode='nearest') + P2 = P2_ + F.interpolate(P3, size=size2, mode='nearest') + + P5 = self.fpn_convs(P5) + P4 = self.fpn_convs(P4) + P3 = self.fpn_convs(P3) + P2 = self.fpn_convs(P2) + + output = self.pool(P2) + #output = self.gem(P2) + + #input_dim = len(output.view(-1)) + + #output = output.view(output.size(0), -1) + output = output.contiguous().view(output.size(0), -1) + + output = self.fc(output) + output = self.fc1(output) + output = self.fc2(output) + output = self.fc3(output) + return output + +if __name__ == '__main__': + img_path = '600.jpg' + img = Image.open('600.jpg') + # if img.mode != 'L': + # img = img.convert('L') + #img = img.resize((256, 256)) + transform = transforms.Compose([transforms.Resize((256,256)), + transforms.ToTensor()]) + img = transform(img) + img = img.cuda() + + + # from torchsummary import summary + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = ResnetFpn().to(device) + model.eval() + img = torch.unsqueeze(img, dim=0).float() + # images, targets = model.transform(images, targets=None) + result = model(img) + #print('result >>> {} >>{}'.format(result, result.size())) diff --git a/network/LICENSE b/network/LICENSE new file mode 100755 index 0000000..ed13d84 --- /dev/null +++ b/network/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2020, princeton-vl +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/network/__init__.py b/network/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/network/alt_cuda_corr/correlation.cpp b/network/alt_cuda_corr/correlation.cpp new file mode 100755 index 0000000..b01584d --- /dev/null +++ b/network/alt_cuda_corr/correlation.cpp @@ -0,0 +1,54 @@ +#include +#include + +// CUDA forward declarations +std::vector corr_cuda_forward( + torch::Tensor fmap1, + torch::Tensor fmap2, + torch::Tensor coords, + int radius); + +std::vector corr_cuda_backward( + torch::Tensor fmap1, + torch::Tensor fmap2, + torch::Tensor coords, + torch::Tensor corr_grad, + int radius); + +// C++ interface +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::vector corr_forward( + torch::Tensor fmap1, + torch::Tensor fmap2, + torch::Tensor coords, + int radius) { + CHECK_INPUT(fmap1); + CHECK_INPUT(fmap2); + CHECK_INPUT(coords); + + return corr_cuda_forward(fmap1, fmap2, coords, radius); +} + + +std::vector corr_backward( + torch::Tensor fmap1, + torch::Tensor fmap2, + torch::Tensor coords, + torch::Tensor corr_grad, + int radius) { + CHECK_INPUT(fmap1); + CHECK_INPUT(fmap2); + CHECK_INPUT(coords); + CHECK_INPUT(corr_grad); + + return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &corr_forward, "CORR forward"); + m.def("backward", &corr_backward, "CORR backward"); +} \ No newline at end of file diff --git a/network/alt_cuda_corr/correlation_kernel.cu b/network/alt_cuda_corr/correlation_kernel.cu new file mode 100755 index 0000000..145e580 --- /dev/null +++ b/network/alt_cuda_corr/correlation_kernel.cu @@ -0,0 +1,324 @@ +#include +#include +#include +#include + + +#define BLOCK_H 4 +#define BLOCK_W 8 +#define BLOCK_HW BLOCK_H * BLOCK_W +#define CHANNEL_STRIDE 32 + + +__forceinline__ __device__ +bool within_bounds(int h, int w, int H, int W) { + return h >= 0 && h < H && w >= 0 && w < W; +} + +template +__global__ void corr_forward_kernel( + const torch::PackedTensorAccessor32 fmap1, + const torch::PackedTensorAccessor32 fmap2, + const torch::PackedTensorAccessor32 coords, + torch::PackedTensorAccessor32 corr, + int r) +{ + const int b = blockIdx.x; + const int h0 = blockIdx.y * blockDim.x; + const int w0 = blockIdx.z * blockDim.y; + const int tid = threadIdx.x * blockDim.y + threadIdx.y; + + const int H1 = fmap1.size(1); + const int W1 = fmap1.size(2); + const int H2 = fmap2.size(1); + const int W2 = fmap2.size(2); + const int N = coords.size(1); + const int C = fmap1.size(3); + + __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; + __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; + __shared__ scalar_t x2s[BLOCK_HW]; + __shared__ scalar_t y2s[BLOCK_HW]; + + for (int c=0; c(floor(y2s[k1]))-r+iy; + int w2 = static_cast(floor(x2s[k1]))-r+ix; + int c2 = tid % CHANNEL_STRIDE; + + auto fptr = fmap2[b][h2][w2]; + if (within_bounds(h2, w2, H2, W2)) + f2[c2][k1] = fptr[c+c2]; + else + f2[c2][k1] = 0.0; + } + + __syncthreads(); + + scalar_t s = 0.0; + for (int k=0; k 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) + *(corr_ptr + ix_nw) += nw; + + if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) + *(corr_ptr + ix_ne) += ne; + + if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) + *(corr_ptr + ix_sw) += sw; + + if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) + *(corr_ptr + ix_se) += se; + } + } + } + } +} + + +template +__global__ void corr_backward_kernel( + const torch::PackedTensorAccessor32 fmap1, + const torch::PackedTensorAccessor32 fmap2, + const torch::PackedTensorAccessor32 coords, + const torch::PackedTensorAccessor32 corr_grad, + torch::PackedTensorAccessor32 fmap1_grad, + torch::PackedTensorAccessor32 fmap2_grad, + torch::PackedTensorAccessor32 coords_grad, + int r) +{ + + const int b = blockIdx.x; + const int h0 = blockIdx.y * blockDim.x; + const int w0 = blockIdx.z * blockDim.y; + const int tid = threadIdx.x * blockDim.y + threadIdx.y; + + const int H1 = fmap1.size(1); + const int W1 = fmap1.size(2); + const int H2 = fmap2.size(1); + const int W2 = fmap2.size(2); + const int N = coords.size(1); + const int C = fmap1.size(3); + + __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; + __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; + + __shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1]; + __shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1]; + + __shared__ scalar_t x2s[BLOCK_HW]; + __shared__ scalar_t y2s[BLOCK_HW]; + + for (int c=0; c(floor(y2s[k1]))-r+iy; + int w2 = static_cast(floor(x2s[k1]))-r+ix; + int c2 = tid % CHANNEL_STRIDE; + + auto fptr = fmap2[b][h2][w2]; + if (within_bounds(h2, w2, H2, W2)) + f2[c2][k1] = fptr[c+c2]; + else + f2[c2][k1] = 0.0; + + f2_grad[c2][k1] = 0.0; + } + + __syncthreads(); + + const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1]; + scalar_t g = 0.0; + + int ix_nw = H1*W1*((iy-1) + rd*(ix-1)); + int ix_ne = H1*W1*((iy-1) + rd*ix); + int ix_sw = H1*W1*(iy + rd*(ix-1)); + int ix_se = H1*W1*(iy + rd*ix); + + if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) + g += *(grad_ptr + ix_nw) * dy * dx; + + if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) + g += *(grad_ptr + ix_ne) * dy * (1-dx); + + if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) + g += *(grad_ptr + ix_sw) * (1-dy) * dx; + + if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) + g += *(grad_ptr + ix_se) * (1-dy) * (1-dx); + + for (int k=0; k(floor(y2s[k1]))-r+iy; + int w2 = static_cast(floor(x2s[k1]))-r+ix; + int c2 = tid % CHANNEL_STRIDE; + + scalar_t* fptr = &fmap2_grad[b][h2][w2][0]; + if (within_bounds(h2, w2, H2, W2)) + atomicAdd(fptr+c+c2, f2_grad[c2][k1]); + } + } + } + } + __syncthreads(); + + + for (int k=0; k corr_cuda_forward( + torch::Tensor fmap1, + torch::Tensor fmap2, + torch::Tensor coords, + int radius) +{ + const auto B = coords.size(0); + const auto N = coords.size(1); + const auto H = coords.size(2); + const auto W = coords.size(3); + + const auto rd = 2 * radius + 1; + auto opts = fmap1.options(); + auto corr = torch::zeros({B, N, rd*rd, H, W}, opts); + + const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W); + const dim3 threads(BLOCK_H, BLOCK_W); + + corr_forward_kernel<<>>( + fmap1.packed_accessor32(), + fmap2.packed_accessor32(), + coords.packed_accessor32(), + corr.packed_accessor32(), + radius); + + return {corr}; +} + +std::vector corr_cuda_backward( + torch::Tensor fmap1, + torch::Tensor fmap2, + torch::Tensor coords, + torch::Tensor corr_grad, + int radius) +{ + const auto B = coords.size(0); + const auto N = coords.size(1); + + const auto H1 = fmap1.size(1); + const auto W1 = fmap1.size(2); + const auto H2 = fmap2.size(1); + const auto W2 = fmap2.size(2); + const auto C = fmap1.size(3); + + auto opts = fmap1.options(); + auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts); + auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts); + auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts); + + const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W); + const dim3 threads(BLOCK_H, BLOCK_W); + + + corr_backward_kernel<<>>( + fmap1.packed_accessor32(), + fmap2.packed_accessor32(), + coords.packed_accessor32(), + corr_grad.packed_accessor32(), + fmap1_grad.packed_accessor32(), + fmap2_grad.packed_accessor32(), + coords_grad.packed_accessor32(), + radius); + + return {fmap1_grad, fmap2_grad, coords_grad}; +} \ No newline at end of file diff --git a/network/alt_cuda_corr/setup.py b/network/alt_cuda_corr/setup.py new file mode 100755 index 0000000..c0207ff --- /dev/null +++ b/network/alt_cuda_corr/setup.py @@ -0,0 +1,15 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + +setup( + name='correlation', + ext_modules=[ + CUDAExtension('alt_cuda_corr', + sources=['correlation.cpp', 'correlation_kernel.cu'], + extra_compile_args={'cxx': [], 'nvcc': ['-O3']}), + ], + cmdclass={ + 'build_ext': BuildExtension + }) + diff --git a/network/core/__init__.py b/network/core/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/network/core/corr.py b/network/core/corr.py new file mode 100755 index 0000000..c09d1e2 --- /dev/null +++ b/network/core/corr.py @@ -0,0 +1,94 @@ +import torch +import torch.nn.functional as F +import sys +sys.path.append('utils') +#from utils.utils import bilinear_sampler, coords_grid +from network.core.utils.utils import bilinear_sampler, coords_grid + +try: + import alt_cuda_corr +except: + # alt_cuda_corr is not compiled + pass + + +class CorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch*h1*w1, dim, h2, w2) + + self.corr_pyramid.append(corr) + for i in range(self.num_levels-1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2*r+1) + dy = torch.linspace(-r, r, 2*r+1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) + + centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht*wd) + fmap2 = fmap2.view(batch, dim, ht*wd) + + corr = torch.matmul(fmap1.transpose(1,2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class AlternateCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + + self.pyramid = [(fmap1, fmap2)] + for i in range(self.num_levels): + fmap1 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2 = F.avg_pool2d(fmap2, 2, stride=2) + self.pyramid.append((fmap1, fmap2)) + + def __call__(self, coords): + coords = coords.permute(0, 2, 3, 1) + B, H, W, _ = coords.shape + dim = self.pyramid[0][0].shape[1] + + corr_list = [] + for i in range(self.num_levels): + r = self.radius + fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() + fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() + + coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() + corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) + corr_list.append(corr.squeeze(1)) + + corr = torch.stack(corr_list, dim=1) + corr = corr.reshape(B, -1, H, W) + return corr / torch.sqrt(torch.tensor(dim).float()) diff --git a/network/core/datasets.py b/network/core/datasets.py new file mode 100755 index 0000000..3411fda --- /dev/null +++ b/network/core/datasets.py @@ -0,0 +1,235 @@ +# Data loading based on https://github.com/NVIDIA/flownet2-pytorch + +import numpy as np +import torch +import torch.utils.data as data +import torch.nn.functional as F + +import os +import math +import random +from glob import glob +import os.path as osp + +from utils import frame_utils +from utils.augmentor import FlowAugmentor, SparseFlowAugmentor + + +class FlowDataset(data.Dataset): + def __init__(self, aug_params=None, sparse=False): + self.augmentor = None + self.sparse = sparse + if aug_params is not None: + if sparse: + self.augmentor = SparseFlowAugmentor(**aug_params) + else: + self.augmentor = FlowAugmentor(**aug_params) + + self.is_test = False + self.init_seed = False + self.flow_list = [] + self.image_list = [] + self.extra_info = [] + + def __getitem__(self, index): + + if self.is_test: + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + img1 = np.array(img1).astype(np.uint8)[..., :3] + img2 = np.array(img2).astype(np.uint8)[..., :3] + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + return img1, img2, self.extra_info[index] + + if not self.init_seed: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + torch.manual_seed(worker_info.id) + np.random.seed(worker_info.id) + random.seed(worker_info.id) + self.init_seed = True + + index = index % len(self.image_list) + valid = None + if self.sparse: + flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) + else: + flow = frame_utils.read_gen(self.flow_list[index]) + + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + + flow = np.array(flow).astype(np.float32) + img1 = np.array(img1).astype(np.uint8) + img2 = np.array(img2).astype(np.uint8) + + # grayscale images + if len(img1.shape) == 2: + img1 = np.tile(img1[...,None], (1, 1, 3)) + img2 = np.tile(img2[...,None], (1, 1, 3)) + else: + img1 = img1[..., :3] + img2 = img2[..., :3] + + if self.augmentor is not None: + if self.sparse: + img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) + else: + img1, img2, flow = self.augmentor(img1, img2, flow) + + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + flow = torch.from_numpy(flow).permute(2, 0, 1).float() + + if valid is not None: + valid = torch.from_numpy(valid) + else: + valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) + + return img1, img2, flow, valid.float() + + + def __rmul__(self, v): + self.flow_list = v * self.flow_list + self.image_list = v * self.image_list + return self + + def __len__(self): + return len(self.image_list) + + +class MpiSintel(FlowDataset): + def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): + super(MpiSintel, self).__init__(aug_params) + flow_root = osp.join(root, split, 'flow') + image_root = osp.join(root, split, dstype) + + if split == 'test': + self.is_test = True + + for scene in os.listdir(image_root): + image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) + for i in range(len(image_list)-1): + self.image_list += [ [image_list[i], image_list[i+1]] ] + self.extra_info += [ (scene, i) ] # scene and frame_id + + if split != 'test': + self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) + + +class FlyingChairs(FlowDataset): + def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): + super(FlyingChairs, self).__init__(aug_params) + + images = sorted(glob(osp.join(root, '*.ppm'))) + flows = sorted(glob(osp.join(root, '*.flo'))) + assert (len(images)//2 == len(flows)) + + split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) + for i in range(len(flows)): + xid = split_list[i] + if (split=='training' and xid==1) or (split=='validation' and xid==2): + self.flow_list += [ flows[i] ] + self.image_list += [ [images[2*i], images[2*i+1]] ] + + +class FlyingThings3D(FlowDataset): + def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): + super(FlyingThings3D, self).__init__(aug_params) + + for cam in ['left']: + for direction in ['into_future', 'into_past']: + image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) + image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) + + flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) + flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) + + for idir, fdir in zip(image_dirs, flow_dirs): + images = sorted(glob(osp.join(idir, '*.png')) ) + flows = sorted(glob(osp.join(fdir, '*.pfm')) ) + for i in range(len(flows)-1): + if direction == 'into_future': + self.image_list += [ [images[i], images[i+1]] ] + self.flow_list += [ flows[i] ] + elif direction == 'into_past': + self.image_list += [ [images[i+1], images[i]] ] + self.flow_list += [ flows[i+1] ] + + +class KITTI(FlowDataset): + def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): + super(KITTI, self).__init__(aug_params, sparse=True) + if split == 'testing': + self.is_test = True + + root = osp.join(root, split) + images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) + images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) + + for img1, img2 in zip(images1, images2): + frame_id = img1.split('/')[-1] + self.extra_info += [ [frame_id] ] + self.image_list += [ [img1, img2] ] + + if split == 'training': + self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) + + +class HD1K(FlowDataset): + def __init__(self, aug_params=None, root='datasets/HD1k'): + super(HD1K, self).__init__(aug_params, sparse=True) + + seq_ix = 0 + while 1: + flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) + images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) + + if len(flows) == 0: + break + + for i in range(len(flows)-1): + self.flow_list += [flows[i]] + self.image_list += [ [images[i], images[i+1]] ] + + seq_ix += 1 + + +def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): + """ Create the data loader for the corresponding trainign set """ + + if args.stage == 'chairs': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} + train_dataset = FlyingChairs(aug_params, split='training') + + elif args.stage == 'things': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} + clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') + final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') + train_dataset = clean_dataset + final_dataset + + elif args.stage == 'sintel': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} + things = FlyingThings3D(aug_params, dstype='frames_cleanpass') + sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') + sintel_final = MpiSintel(aug_params, split='training', dstype='final') + + if TRAIN_DS == 'C+T+K+S+H': + kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) + hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) + train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things + + elif TRAIN_DS == 'C+T+K/S': + train_dataset = 100*sintel_clean + 100*sintel_final + things + + elif args.stage == 'kitti': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} + train_dataset = KITTI(aug_params, split='training') + + train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, + pin_memory=False, shuffle=True, num_workers=4, drop_last=True) + + print('Training with %d image pairs' % len(train_dataset)) + return train_loader + diff --git a/network/core/extractor.py b/network/core/extractor.py new file mode 100755 index 0000000..9a9c759 --- /dev/null +++ b/network/core/extractor.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/network/core/raft.py b/network/core/raft.py new file mode 100755 index 0000000..56a22a0 --- /dev/null +++ b/network/core/raft.py @@ -0,0 +1,145 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from update import BasicUpdateBlock, SmallUpdateBlock +from extractor import BasicEncoder, SmallEncoder +from corr import CorrBlock, AlternateCorrBlock +from network.core.utils.utils import bilinear_sampler, coords_grid, upflow8 + +try: + autocast = torch.cuda.amp.autocast +except: + # dummy autocast for PyTorch < 1.6 + class autocast: + def __init__(self, enabled): + pass + def __enter__(self): + pass + def __exit__(self, *args): + pass + + +class RAFT(nn.Module): + def __init__(self, args): + super(RAFT, self).__init__() + self.args = args + #args.small = True + + if args.small: + self.hidden_dim = hdim = 96 + self.context_dim = cdim = 64 + args.corr_levels = 4 + args.corr_radius = 3 + + else: + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + args.corr_levels = 4 + args.corr_radius = 4 + + if 'dropout' not in self.args: + self.args.dropout = 0 + + if 'alternate_corr' not in self.args: + self.args.alternate_corr = False + + # feature network, context network, and update block + if args.small: + self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) + self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) + self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) + + else: + self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) + self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) + self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def initialize_flow(self, img): + """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, C, H, W = img.shape + coords0 = coords_grid(N, H//8, W//8).to(img.device) + coords1 = coords_grid(N, H//8, W//8).to(img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def upsample_flow(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3,3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8*H, 8*W) + + + def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): + """ Estimate optical flow between pair of frames """ + + image1 = 2 * (image1 / 255.0) - 1.0 + image2 = 2 * (image2 / 255.0) - 1.0 + + image1 = image1.contiguous() + image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + # run the feature network + with autocast(enabled=self.args.mixed_precision): + fmap1, fmap2 = self.fnet([image1, image2]) + + fmap1 = fmap1.float() + fmap2 = fmap2.float() + if self.args.alternate_corr: + corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + else: + corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + # run the context network + with autocast(enabled=self.args.mixed_precision): + cnet = self.cnet(image1) + net, inp = torch.split(cnet, [hdim, cdim], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + + coords0, coords1 = self.initialize_flow(image1) + + if flow_init is not None: + coords1 = coords1 + flow_init + + flow_predictions = [] + for itr in range(iters): + coords1 = coords1.detach() + corr = corr_fn(coords1) # index correlation volume + + flow = coords1 - coords0 + with autocast(enabled=self.args.mixed_precision): + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # upsample predictions + if up_mask is None: + flow_up = upflow8(coords1 - coords0) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + flow_predictions.append(flow_up) + + if test_mode: + return coords1 - coords0, flow_up + + return flow_predictions diff --git a/network/core/update.py b/network/core/update.py new file mode 100755 index 0000000..f940497 --- /dev/null +++ b/network/core/update.py @@ -0,0 +1,139 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + + h = (1-z) * h + z * q + return h + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + + self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + return h + +class SmallMotionEncoder(nn.Module): + def __init__(self, args): + super(SmallMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) + self.convf1 = nn.Conv2d(2, 64, 7, padding=3) + self.convf2 = nn.Conv2d(64, 32, 3, padding=1) + self.conv = nn.Conv2d(128, 80, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class SmallUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=96): + super(SmallUpdateBlock, self).__init__() + self.encoder = SmallMotionEncoder(args) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) + self.flow_head = FlowHead(hidden_dim, hidden_dim=128) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + return net, None, delta_flow + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow + + + diff --git a/network/core/utils/__init__.py b/network/core/utils/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/network/core/utils/augmentor.py b/network/core/utils/augmentor.py new file mode 100755 index 0000000..e81c4f2 --- /dev/null +++ b/network/core/utils/augmentor.py @@ -0,0 +1,246 @@ +import numpy as np +import random +import math +from PIL import Image + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +import torch +from torchvision.transforms import ColorJitter +import torch.nn.functional as F + + +class FlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): + + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + """ Photometric augmentation """ + + # asymmetric + if np.random.rand() < self.asymmetric_color_aug_prob: + img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) + img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) + + # symmetric + else: + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + + return img1, img2 + + def eraser_transform(self, img1, img2, bounds=[50, 100]): + """ Occlusion augmentation """ + + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(bounds[0], bounds[1]) + dy = np.random.randint(bounds[0], bounds[1]) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def spatial_transform(self, img1, img2, flow): + # randomly sample scale + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 8) / float(ht), + (self.crop_size[1] + 8) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = scale + scale_y = scale + if np.random.rand() < self.stretch_prob: + scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + + scale_x = np.clip(scale_x, min_scale, None) + scale_y = np.clip(scale_y, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = flow * [scale_x, scale_y] + + if self.do_flip: + if np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + + if np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow = flow[::-1, :] * [1.0, -1.0] + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) + x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + + return img1, img2, flow + + def __call__(self, img1, img2, flow): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow = self.spatial_transform(img1, img2, flow) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + + return img1, img2, flow + +class SparseFlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + return img1, img2 + + def eraser_transform(self, img1, img2): + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(50, 100) + dy = np.random.randint(50, 100) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): + ht, wd = flow.shape[:2] + coords = np.meshgrid(np.arange(wd), np.arange(ht)) + coords = np.stack(coords, axis=-1) + + coords = coords.reshape(-1, 2).astype(np.float32) + flow = flow.reshape(-1, 2).astype(np.float32) + valid = valid.reshape(-1).astype(np.float32) + + coords0 = coords[valid>=1] + flow0 = flow[valid>=1] + + ht1 = int(round(ht * fy)) + wd1 = int(round(wd * fx)) + + coords1 = coords0 * [fx, fy] + flow1 = flow0 * [fx, fy] + + xx = np.round(coords1[:,0]).astype(np.int32) + yy = np.round(coords1[:,1]).astype(np.int32) + + v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) + valid_img = np.zeros([ht1, wd1], dtype=np.int32) + + flow_img[yy, xx] = flow1 + valid_img[yy, xx] = 1 + + return flow_img, valid_img + + def spatial_transform(self, img1, img2, flow, valid): + # randomly sample scale + + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 1) / float(ht), + (self.crop_size[1] + 1) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = np.clip(scale, min_scale, None) + scale_y = np.clip(scale, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) + + if self.do_flip: + if np.random.rand() < 0.5: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + valid = valid[:, ::-1] + + margin_y = 20 + margin_x = 50 + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) + x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) + + y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) + x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + return img1, img2, flow, valid + + + def __call__(self, img1, img2, flow, valid): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + valid = np.ascontiguousarray(valid) + + return img1, img2, flow, valid diff --git a/network/core/utils/flow_viz.py b/network/core/utils/flow_viz.py new file mode 100755 index 0000000..7c75024 --- /dev/null +++ b/network/core/utils/flow_viz.py @@ -0,0 +1,132 @@ +# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization + + +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +import numpy as np + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) diff --git a/network/core/utils/frame_utils.py b/network/core/utils/frame_utils.py new file mode 100755 index 0000000..6c49113 --- /dev/null +++ b/network/core/utils/frame_utils.py @@ -0,0 +1,137 @@ +import numpy as np +from PIL import Image +from os.path import * +import re + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +TAG_CHAR = np.array([202021.25], np.float32) + +def readFlow(fn): + """ Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(fn, 'rb') as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print('Magic number incorrect. Invalid .flo file') + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b'PF': + color = True + elif header == b'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + +def writeFlow(filename,uv,v=None): + """ Write optical flow to file. + + If v is None, uv is assumed to contain both u and v channels, + stacked in depth. + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + nBands = 2 + + if v is None: + assert(uv.ndim == 3) + assert(uv.shape[2] == 2) + u = uv[:,:,0] + v = uv[:,:,1] + else: + u = uv + + assert(u.shape == v.shape) + height,width = u.shape + f = open(filename,'wb') + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + # arrange into matrix form + tmp = np.zeros((height, width*nBands)) + tmp[:,np.arange(width)*2] = u + tmp[:,np.arange(width)*2 + 1] = v + tmp.astype(np.float32).tofile(f) + f.close() + + +def readFlowKITTI(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) + flow = flow[:,:,::-1].astype(np.float32) + flow, valid = flow[:, :, :2], flow[:, :, 2] + flow = (flow - 2**15) / 64.0 + return flow, valid + +def readDispKITTI(filename): + disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 + valid = disp > 0.0 + flow = np.stack([-disp, np.zeros_like(disp)], -1) + return flow, valid + + +def writeFlowKITTI(filename, uv): + uv = 64.0 * uv + 2**15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + + +def read_gen(file_name, pil=False): + ext = splitext(file_name)[-1] + if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': + return Image.open(file_name) + elif ext == '.bin' or ext == '.raw': + return np.load(file_name) + elif ext == '.flo': + return readFlow(file_name).astype(np.float32) + elif ext == '.pfm': + flow = readPFM(file_name).astype(np.float32) + if len(flow.shape) == 2: + return flow + else: + return flow[:, :, :-1] + return [] \ No newline at end of file diff --git a/network/core/utils/utils.py b/network/core/utils/utils.py new file mode 100755 index 0000000..5f32d28 --- /dev/null +++ b/network/core/utils/utils.py @@ -0,0 +1,82 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy import interpolate + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + else: + self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self,x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata( + (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) + + flow_y = interpolate.griddata( + (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode='bilinear'): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) diff --git a/network/createNet.py b/network/createNet.py new file mode 100644 index 0000000..1085b8f --- /dev/null +++ b/network/createNet.py @@ -0,0 +1,85 @@ +import torch +from torchvision import models +import torch.nn as nn +from network.BaseNet import * +from network.mobilevit import * + +def initnet(flag='resnet50', mvit = 128): + if flag == 'resnet50': + model_ft = models.resnet50(pretrained=True) + for param in model_ft.parameters(): + param.require_grad = False + num_ftrs = model_ft.fc.in_features + model_ft.fc = nn.Linear(num_ftrs, 2048) + return model_ft + elif flag == 'resnet50_fpn': + model_ft = ResnetFpn() + return model_ft + elif flag == 'mobilevit': + model_ft = mobilevit_s(mvit) + return model_ft + else: + raise ValueError("Please select the correct model .......") + +class L2N(nn.Module): + + def __init__(self, eps=1e-6): + super(L2N,self).__init__() + self.eps = eps + + def forward(self, x): + return x / (torch.norm(x, p=2, dim=1, keepdim=True) + self.eps).expand_as(x) + + def __repr__(self): + return self.__class__.__name__ + '(' + 'eps=' + str(self.eps) + ')' + + +class TripletNet(nn.Module): + def __init__(self, initnet): + super(TripletNet, self).__init__() + self.initnet =initnet + + def forward(self, x1, x2, x3): + output1 = self.initnet(x1) + output2 = self.initnet(x2) + output3 = self.initnet(x3) + return output1, output2, output3 + + def get_ininet(self, x): + return self.initnet(x) + +class extractNet(nn.Module): + def __init__(self, initnet, norm): + super(extractNet, self).__init__() + self.initnet =initnet + self.norm = norm + + def forward(self, x): + output = self.initnet(x) + output = self.norm(output).squeeze(-1).squeeze(-1) + return output + + def get_ininet(self, x): + return self.initnet(x) + +import torch.nn.functional as F +class GeM(nn.Module): + def __init__(self, p=3, eps=1e-6): + super(GeM, self).__init__() + self.p = nn.Parameter(torch.ones(1) * p) + self.eps = eps + + def forward(self, x): + return self.gem(x, p=self.p, eps=self.eps) + + def gem(self, x, p=3, eps=1e-6): + return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p) + + def __repr__(self): + return self.__class__.__name__ + \ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \ + ', ' + 'eps=' + str(self.eps) + ')' + +if __name__ == '__main__': + print(initnet()) + diff --git a/network/demo.py b/network/demo.py new file mode 100755 index 0000000..5abc1da --- /dev/null +++ b/network/demo.py @@ -0,0 +1,75 @@ +import sys +sys.path.append('core') + +import argparse +import os +import cv2 +import glob +import numpy as np +import torch +from PIL import Image + +from raft import RAFT +from utils import flow_viz +from utils.utils import InputPadder + + + +DEVICE = 'cuda' + +def load_image(imfile): + img = np.array(Image.open(imfile)).astype(np.uint8) + img = torch.from_numpy(img).permute(2, 0, 1).float() + return img[None].to(DEVICE) + + +def viz(img, flo): + img = img[0].permute(1,2,0).cpu().numpy() + flo = flo[0].permute(1,2,0).cpu().numpy() + + # map flow to rgb image + flo = flow_viz.flow_to_image(flo) + img_flo = np.concatenate([img, flo], axis=0) + + # import matplotlib.pyplot as plt + # plt.imshow(img_flo / 255.0) + # plt.show() + + cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) + cv2.waitKey() + + +def demo(args): + model = torch.nn.DataParallel(RAFT(args)) + model.load_state_dict(torch.load(args.model)) + + model = model.module + model.to(DEVICE) + model.eval() + + with torch.no_grad(): + images = glob.glob(os.path.join(args.path, '*.png')) + \ + glob.glob(os.path.join(args.path, '*.jpg')) + + images = sorted(images) + for imfile1, imfile2 in zip(images[:-1], images[1:]): + image1 = load_image(imfile1) + image2 = load_image(imfile2) + + padder = InputPadder(image1.shape) + image1, image2 = padder.pad(image1, image2) + + flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) + viz(image1, flow_up) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model', help="restore checkpoint") + parser.add_argument('--path', help="dataset for evaluation") + parser.add_argument('--small', action='store_true', help='use small model') + parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') + parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') + args = parser.parse_args() + + demo(args) diff --git a/network/download_models.sh b/network/download_models.sh new file mode 100755 index 0000000..7b6ed7e --- /dev/null +++ b/network/download_models.sh @@ -0,0 +1,3 @@ +#!/bin/bash +wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip +unzip models.zip diff --git a/network/evaluate.py b/network/evaluate.py new file mode 100755 index 0000000..431a0f5 --- /dev/null +++ b/network/evaluate.py @@ -0,0 +1,197 @@ +import sys +sys.path.append('core') + +from PIL import Image +import argparse +import os +import time +import numpy as np +import torch +import torch.nn.functional as F +import matplotlib.pyplot as plt + +import datasets +from utils import flow_viz +from utils import frame_utils + +from raft import RAFT +from utils.utils import InputPadder, forward_interpolate + + +@torch.no_grad() +def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'): + """ Create submission for the Sintel leaderboard """ + model.eval() + for dstype in ['clean', 'final']: + test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype) + + flow_prev, sequence_prev = None, None + for test_id in range(len(test_dataset)): + image1, image2, (sequence, frame) = test_dataset[test_id] + if sequence != sequence_prev: + flow_prev = None + + padder = InputPadder(image1.shape) + image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) + + flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True) + flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() + + if warm_start: + flow_prev = forward_interpolate(flow_low[0])[None].cuda() + + output_dir = os.path.join(output_path, dstype, sequence) + output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1)) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + frame_utils.writeFlow(output_file, flow) + sequence_prev = sequence + + +@torch.no_grad() +def create_kitti_submission(model, iters=24, output_path='kitti_submission'): + """ Create submission for the Sintel leaderboard """ + model.eval() + test_dataset = datasets.KITTI(split='testing', aug_params=None) + + if not os.path.exists(output_path): + os.makedirs(output_path) + + for test_id in range(len(test_dataset)): + image1, image2, (frame_id, ) = test_dataset[test_id] + padder = InputPadder(image1.shape, mode='kitti') + image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) + + _, flow_pr = model(image1, image2, iters=iters, test_mode=True) + flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() + + output_filename = os.path.join(output_path, frame_id) + frame_utils.writeFlowKITTI(output_filename, flow) + + +@torch.no_grad() +def validate_chairs(model, iters=24): + """ Perform evaluation on the FlyingChairs (test) split """ + model.eval() + epe_list = [] + + val_dataset = datasets.FlyingChairs(split='validation') + for val_id in range(len(val_dataset)): + image1, image2, flow_gt, _ = val_dataset[val_id] + image1 = image1[None].cuda() + image2 = image2[None].cuda() + + _, flow_pr = model(image1, image2, iters=iters, test_mode=True) + epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt() + epe_list.append(epe.view(-1).numpy()) + + epe = np.mean(np.concatenate(epe_list)) + print("Validation Chairs EPE: %f" % epe) + return {'chairs': epe} + + +@torch.no_grad() +def validate_sintel(model, iters=32): + """ Peform validation using the Sintel (train) split """ + model.eval() + results = {} + for dstype in ['clean', 'final']: + val_dataset = datasets.MpiSintel(split='training', dstype=dstype) + epe_list = [] + + for val_id in range(len(val_dataset)): + image1, image2, flow_gt, _ = val_dataset[val_id] + image1 = image1[None].cuda() + image2 = image2[None].cuda() + + padder = InputPadder(image1.shape) + image1, image2 = padder.pad(image1, image2) + + flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True) + flow = padder.unpad(flow_pr[0]).cpu() + + epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() + epe_list.append(epe.view(-1).numpy()) + + epe_all = np.concatenate(epe_list) + epe = np.mean(epe_all) + px1 = np.mean(epe_all<1) + px3 = np.mean(epe_all<3) + px5 = np.mean(epe_all<5) + + print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5)) + results[dstype] = np.mean(epe_list) + + return results + + +@torch.no_grad() +def validate_kitti(model, iters=24): + """ Peform validation using the KITTI-2015 (train) split """ + model.eval() + val_dataset = datasets.KITTI(split='training') + + out_list, epe_list = [], [] + for val_id in range(len(val_dataset)): + image1, image2, flow_gt, valid_gt = val_dataset[val_id] + image1 = image1[None].cuda() + image2 = image2[None].cuda() + + padder = InputPadder(image1.shape, mode='kitti') + image1, image2 = padder.pad(image1, image2) + + flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True) + flow = padder.unpad(flow_pr[0]).cpu() + + epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() + mag = torch.sum(flow_gt**2, dim=0).sqrt() + + epe = epe.view(-1) + mag = mag.view(-1) + val = valid_gt.view(-1) >= 0.5 + + out = ((epe > 3.0) & ((epe/mag) > 0.05)).float() + epe_list.append(epe[val].mean().item()) + out_list.append(out[val].cpu().numpy()) + + epe_list = np.array(epe_list) + out_list = np.concatenate(out_list) + + epe = np.mean(epe_list) + f1 = 100 * np.mean(out_list) + + print("Validation KITTI: %f, %f" % (epe, f1)) + return {'kitti-epe': epe, 'kitti-f1': f1} + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model', help="restore checkpoint") + parser.add_argument('--dataset', help="dataset for evaluation") + parser.add_argument('--small', action='store_true', help='use small model') + parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') + parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') + args = parser.parse_args() + + model = torch.nn.DataParallel(RAFT(args)) + model.load_state_dict(torch.load(args.model)) + + model.cuda() + model.eval() + + # create_sintel_submission(model.module, warm_start=True) + # create_kitti_submission(model.module) + + with torch.no_grad(): + if args.dataset == 'chairs': + validate_chairs(model.module) + + elif args.dataset == 'sintel': + validate_sintel(model.module) + + elif args.dataset == 'kitti': + validate_kitti(model.module) + + diff --git a/network/mobilevit.py b/network/mobilevit.py new file mode 100644 index 0000000..ae00c46 --- /dev/null +++ b/network/mobilevit.py @@ -0,0 +1,259 @@ +import torch +import torch.nn as nn + +from einops import rearrange + + +def conv_1x1_bn(inp, oup): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.SiLU() + ) + + +def conv_nxn_bn(inp, oup, kernal_size=3, stride=1): + return nn.Sequential( + nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.SiLU() + ) + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout=0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + def __init__(self, dim, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + attn = self.attend(dots) + out = torch.matmul(attn, v) + out = rearrange(out, 'b p h n d -> b p n (h d)') + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads, dim_head, dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout)) + ])) + + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + + +class MV2Block(nn.Module): + def __init__(self, inp, oup, stride=1, expansion=4): + super().__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(inp * expansion) + self.use_res_connect = self.stride == 1 and inp == oup + + if expansion == 1: + self.conv = nn.Sequential( + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.SiLU(), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ) + else: + self.conv = nn.Sequential( + # pw + nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.SiLU(), + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.SiLU(), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileViTBlock(nn.Module): + def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.): + super().__init__() + self.ph, self.pw = patch_size + + self.conv1 = conv_nxn_bn(channel, channel, kernel_size) + self.conv2 = conv_1x1_bn(channel, dim) + + self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout) + + self.conv3 = conv_1x1_bn(dim, channel) + self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size) + + def forward(self, x): + y = x.clone() + + # Local representations + x = self.conv1(x) + x = self.conv2(x) + + # Global representations + _, _, h, w = x.shape + x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw) + x = self.transformer(x) + x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw) + + # Fusion + x = self.conv3(x) + x = torch.cat((x, y), 1) + x = self.conv4(x) + return x + + +class MobileViT(nn.Module): + def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)): + super().__init__() + ih, iw = image_size + ph, pw = patch_size + assert ih % ph == 0 and iw % pw == 0 + + L = [2, 4, 3] + + self.conv1 = conv_nxn_bn(3, channels[0], stride=2) + + self.mv2 = nn.ModuleList([]) + self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion)) + self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion)) + self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) + self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) # Repeat + self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion)) + self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion)) + self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion)) + + self.mvit = nn.ModuleList([]) + self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0]*2))) + self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1]*4))) + self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2]*4))) + + self.conv2 = conv_1x1_bn(channels[-2], channels[-1]) + + self.pool = nn.AvgPool2d(ih//32, 1) + self.fc = nn.Linear(channels[-1], num_classes, bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.mv2[0](x) + + x = self.mv2[1](x) + x = self.mv2[2](x) + x = self.mv2[3](x) # Repeat + + x = self.mv2[4](x) + x = self.mvit[0](x) + + x = self.mv2[5](x) + x = self.mvit[1](x) + + x = self.mv2[6](x) + x = self.mvit[2](x) + x = self.conv2(x) + + x = self.pool(x).view(-1, x.shape[1]) + x = self.fc(x) + return x + + +def mobilevit_xxs(num_classes=1000): + dims = [64, 80, 96] + channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320] + return MobileViT((256, 256), dims, channels, num_classes, expansion=2) + + +def mobilevit_xs(num_classes=1000): + dims = [96, 120, 144] + channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384] + return MobileViT((256, 256), dims, channels, num_classes) + + +def mobilevit_s(num_classes=1000): + #print('num_classes >>> {}'.format(num_classes)) + dims = [144, 192, 240] + channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640] + return MobileViT((256, 256), dims, channels, num_classes) + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +if __name__ == '__main__': + img = torch.randn(5, 3, 256, 256) + + vit = mobilevit_xxs() + out = vit(img) + print(out.shape) + print(count_parameters(vit)) + + vit = mobilevit_xs() + out = vit(img) + print(out.shape) + print(count_parameters(vit)) + + vit = mobilevit_s() + out = vit(img) + print(out.shape) + print(count_parameters(vit)) diff --git a/network/network.py b/network/network.py new file mode 100644 index 0000000..ae9a696 --- /dev/null +++ b/network/network.py @@ -0,0 +1,84 @@ +import torch +from torchvision import models +import torch.nn as nn +from Networks.BaseNet import * +from Networks.mobilevit import * + +def initnet(flag='resnet50', mvit = 128): + if flag == 'resnet50': + model_ft = models.resnet50(pretrained=True) + for param in model_ft.parameters(): + param.require_grad = False + num_ftrs = model_ft.fc.in_features + model_ft.fc = nn.Linear(num_ftrs, 2048) + return model_ft + elif flag == 'resnet50_fpn': + model_ft = ResnetFpn() + return model_ft + elif flag == 'mobilevit': + model_ft = mobilevit_s(mvit) + return model_ft + else: + raise ValueError("Please select the correct model .......") + +class L2N(nn.Module): + + def __init__(self, eps=1e-6): + super(L2N,self).__init__() + self.eps = eps + + def forward(self, x): + return x / (torch.norm(x, p=2, dim=1, keepdim=True) + self.eps).expand_as(x) + + def __repr__(self): + return self.__class__.__name__ + '(' + 'eps=' + str(self.eps) + ')' + + +class TripletNet(nn.Module): + def __init__(self, initnet): + super(TripletNet, self).__init__() + self.initnet =initnet + + def forward(self, x1, x2, x3): + output1 = self.initnet(x1) + output2 = self.initnet(x2) + output3 = self.initnet(x3) + return output1, output2, output3 + + def get_ininet(self, x): + return self.initnet(x) + +class extractNet(nn.Module): + def __init__(self, initnet, norm): + super(extractNet, self).__init__() + self.initnet =initnet + self.norm = norm + + def forward(self, x): + output = self.initnet(x) + output = self.norm(output).squeeze(-1).squeeze(-1) + return output + + def get_ininet(self, x): + return self.initnet(x) + +import torch.nn.functional as F +class GeM(nn.Module): + def __init__(self, p=3, eps=1e-6): + super(GeM, self).__init__() + self.p = nn.Parameter(torch.ones(1) * p) + self.eps = eps + + def forward(self, x): + return self.gem(x, p=self.p, eps=self.eps) + + def gem(self, x, p=3, eps=1e-6): + return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p) + + def __repr__(self): + return self.__class__.__name__ + \ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \ + ', ' + 'eps=' + str(self.eps) + ')' + +if __name__ == '__main__': + print(initnet()) diff --git a/network/vanalysis_video.py b/network/vanalysis_video.py new file mode 100755 index 0000000..24c658f --- /dev/null +++ b/network/vanalysis_video.py @@ -0,0 +1,156 @@ +import argparse +import glob, cv2, os, pdb, time, sys +sys.path.append('utils') +sys.path.append('network/core') +import numpy as np +import torch +from network.core.raft import RAFT +from network.core.utils import flow_viz +from network.core.utils.utils import InputPadder +from floder.config import cfg +#from utils.retrieval_feature import AntiFraudFeatureDataset + +DEVICE = 'cuda' +pre_area = 0 +kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) # 定义膨胀结构元素 +kernel1 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2, 2)) # 定义腐蚀结构元素 + +def load_image(imfile): + #img = np.array(Image.open(imfile)).astype(np.uint8) + img = np.array(imfile).astype(np.uint8) + img = torch.from_numpy(img).permute(2, 0, 1).float() + return img[None].to(DEVICE) + +def viz(img, flo): + img = img[0].permute(1,2,0).cpu().numpy() + flo = flo[0].permute(1,2,0).cpu().numpy() + flo = flow_viz.flow_to_image(flo) + return flo + +def raft_init_model(args): + model = torch.nn.DataParallel(RAFT(args)) + model.load_state_dict(torch.load(args.model)) + model = model.module + model.to(DEVICE) + model.eval() + return model + +def vanalysis(model, imgsList): + imfile1, imfile2 = None,None + re = [] + with torch.no_grad(): + for nn, frame in enumerate(imgsList): + #print('frame {}>>>{}'.format(nn, type(frame))) + if frame.shape[0]>> {} coordination >>>{}'.format(result, coordination)) + if not result is None: + re.append(result) + #cv2.imwrite('./imgs/tmpimgs/'+str(nn)+'.jpg', result) + imfile1 = imfile2 + return re[:10] # ——> list + +#def get_target(path, img, ori_img, nu, ori_mask, MASKIMG): +def get_target(img, ori_img, MASKIMG): + global pre_area + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + ret, mask = cv2.threshold(img, 249, 255, cv2.THRESH_BINARY) + mask_max_area, mask_max_contour = 0, 0 + mask = cv2.bitwise_not(mask) + mask_image = np.zeros((ori_img.shape[0], ori_img.shape[1], 1), np.uint8) + if (cv2.__version__).split('.')[0] == '3': + _, contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + else: + contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + if len(contours)>100: + return None, '_' + for contour in contours: + mask_area_now = cv2.contourArea(contour) + if mask_area_now > mask_max_area: + mask_max_area = mask_area_now + mask_max_contour = contour + if mask_max_area == 0 :return None, '_' #mask_max_area 目标位的面积 + (x, y, w, h) = cv2.boundingRect(mask_max_contour) + if (w*h)/(img.shape[0]*img.shape[1])>0.80: + return None, '_' + if min(w,h) <100 or max(w,h)>1000: + return None, '_' + coordination = [x, y, x + w, y + h] + mask_image = cv2.fillPoly(mask_image, [mask_max_contour], (255)) + if pre_area==0: + pre_area = mask_max_area + return None, '_' + else: + if abs(mask_max_area-pre_area)/pre_area > 0.4: + pre_area = mask_max_area + #print('abs:',abs(mask_max_area-pre_area)/pre_area) + return None, '_' + else: + pre_area = mask_max_area + A,B,C = mask_image, mask_image, mask_image + mask_image = cv2.merge([A,B,C]) + + #该方法去除框外干扰 + if not get_iou_ratio(mask_image, MASKIMG): + return None, '_' + + show = cv2.bitwise_and(ori_img, mask_image) + show = ori_img[coordination[1]:coordination[3], coordination[0]:coordination[2]] + return show, coordination + +def get_iou_ratio(oimg, MASKIMG): + mimg = cv2.imread(MASKIMG) + iimg = cv2.bitwise_and(oimg, mimg) + iimgarea = get_area(iimg) + oimgarea = get_area(oimg) + if iimgarea/oimgarea < 0.1: + return False + else: return True + +def get_area(img): + kernel = np.ones((3, 3), dtype=np.uint8) + img = cv2.dilate(img, kernel, 1) + img = cv2.erode(img, kernel, 1) + maxcontour, nu = 0,0 + contours, _ = cv2.findContours(img[:,:,1] ,cv2.RETR_TREE , cv2.CHAIN_APPROX_NONE) + if len(contours) == 0: + return 0 + for i in range(len(contours)): + if maxcontour < len(contours[i]): + maxcontour = len(contours[i]) + nu = i + area = cv2.contourArea(contours[nu]) + return area + +if __name__ == '__main__': + model = raft_init_model() + from utils.tools import createNet + import pdb + #uuid_barcode = '6907992825762' + imgslist = [] + for imgname in os.listdir('test_imgs'): + imgslist.append(cv2.imread(os.sep.join(['test_imgs', imgname]))) + pdb.set_trace() + analysis = vanalysis(model=model, video_path=imgslist) +# analysis_video(model, video_path, result_path) diff --git a/test.py b/test.py new file mode 100644 index 0000000..83420db --- /dev/null +++ b/test.py @@ -0,0 +1,40 @@ +from utils.embedding import DataProcessing as dp +from floder.config import cfg +import cv2 as cv +from utils.filter import filt +from network.vanalysis_video import vanalysis, raft_init_model +import argparse +from utils.detect import opvideo + +parser = argparse.ArgumentParser() +#parser.add_argument('--model', default='../module/ieemoo-ai-search/model/now/raft-things.pth',help="restore checkpoint") +parser.add_argument('--model', default='./checkpoint/raft-small.pth',help="restore checkpoint") +#parser.add_argument('--small', action='store_true', help='use small model') +parser.add_argument('--small', type=bool, default=True, help='use small model') +parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') +parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') +opt, unknown = parser.parse_known_args() + +backbone = 'mobilevit' +modelPath = cfg.model_path +device = 'cuda' +dps = dp(backbone, modelPath, device) +flowmodel = raft_init_model(opt) +opv = opvideo(flowmodel , dps) +if __name__ == '__main__': + video1 = './imgs/1.mp4' + video2 = './imgs/2.mp4' + video3 = './imgs/3.mp4' + video4 = './imgs/4.mp4' + opv.addFeature('test', 1, video1) + opv.addFeature('test', 2, video2) + opv.addFeature('test', 3, video3) + opv.opFeature('test', 4, video4) + #imglist = filt(video) + #model = raft_init_model(opt) + #imgs = vanalysis(model, imglist) + #print('>>>>>>>>>>>>>>>>>> {}'.format(type(imgs))) + # + #re = dps.getFeatures(imgs) + + #print(re) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/detect.py b/utils/detect.py new file mode 100644 index 0000000..d0b80e0 --- /dev/null +++ b/utils/detect.py @@ -0,0 +1,47 @@ +from network.vanalysis_video import vanalysis +from utils.filter import filt +import cv2 as cv +from floder.config import cfg +from utils.embedding import DataProcessing as dp +from utils.opfile import writef, readf +import numpy as np +from random import choice + +class opvideo: + def __init__(self, flowmodel, dps): + self.flowmodel = flowmodel + self.dps = dps + + def addFeature(self, uuid, num_id, video): + imglist = filt(video) + imgs = vanalysis(self.flowmodel, imglist) + feature = self.dps.getFeatures(imgs) + writef(uuid, num_id, feature) + + def opFeature(self, uuid, finalnum, video): + videoFeature = [] + self.addFeature(uuid, finalnum, video) + for num_id in range(0, finalnum): + feature = readf(uuid, num_id) + videoFeature.append(feature) + redic = self.opVideFeature(videoFeature) + #print(redic) + return redic + + def opVideFeature(self, videoFeature): + redic = {} + stalist = list(range(0, len(videoFeature))) + for nu in stalist: + dylist = list(range(0, len(videoFeature))) + dylist.remove(nu) + for nn in dylist: + nn_tmp = [] + cosin_re = self.dps.cal_cosine( + videoFeature[nu], + videoFeature[nn]) + if (sum(i<0.86 for i in cosin_re))>0: + redic[nu] = False + else: + redic[nu] = True + break + return redic diff --git a/utils/embedding.py b/utils/embedding.py new file mode 100644 index 0000000..0f2c323 --- /dev/null +++ b/utils/embedding.py @@ -0,0 +1,63 @@ +from network.createNet import initnet +import cv2, torch +import numpy as np + +class DataProcessing(): + def __init__(self, backbone, model_path, device): + model = initnet(backbone) + model.load_state_dict(torch.load(model_path)) + model.to(torch.device(device)) + model.eval() + self.model = model + self.device = device + + def cosin_metric(self, x1, x2): + if not len(x1)==len(x2): + return 100 + return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2)) + + def load_image(self, image): + #image = cv2.imread(image) + if image is None: + return None + image = cv2.resize(image, (256, 256)) + image = image.transpose((2, 0, 1)) + image = image[np.newaxis, :, :, :] + image = image.astype(np.float32, copy=False) + return image + + def getFeatures(self, imgs): #<< input type is np + images = None + features = [] + assert (type(imgs) is list), 'Err input need list' + for i, img in enumerate(imgs): + #print('imgs >>> {}{}'.format(type(img), type(img))) + image = self.load_image(img) + if image is None: + print('read {} error'.format(img_path)) + else: + data = torch.from_numpy(image) + data = data.to(torch.device(self.device)) + output = self.model(data) + output = output.data.cpu().numpy() + features.append(output) + return features # >>>>>>> return type is list + + def cal_cosine(self, t_features, m_features): # Calculate the cosine angular distance + if not (type(m_features) is list or np.ndarray): + return 'Err m_features need list or ndarray' + elif (type(t_features) is list or np.ndarray): + cosin_re = [] + for tf in t_features: + for mf in m_features: + #print('tf >> {} tf>>{} mf>>{} mf>>{}'.format(tf, type(tf), len(mf), type(mf))) + if type(mf) is list: + cosin_re.append(self.cosin_metric(tf.reshape(-1), mf)) + else: + cosin_re.append(self.cosin_metric(tf.reshape(-1), mf.reshape(-1))) + else: + cosin_re = [] + for mf in m_features: + cosin_re.append(self.cosin_metric(t_features.reshape(-1), mf.reshape(-1))) + return cosin_re + diff --git a/utils/filter.py b/utils/filter.py new file mode 100644 index 0000000..9029d39 --- /dev/null +++ b/utils/filter.py @@ -0,0 +1,55 @@ +import cv2 +from floder.config import cfg +def filt(video_path): + #mask_path = '../../module/ieemoo-ai-search/model/now/ori_old.jpg' + maskimg = cv2.imread(cfg.maskImg) + fgbg = cv2.createBackgroundSubtractorMOG2(history=500, varThreshold=20, detectShadows = False) + capture = cv2.VideoCapture(video_path) + ret,frame = capture.read(0) + if frame.shape[0]>>> {}'.format(frame.shape)) + imglist = [] + re = False + nn = 0 + while True: + ret,frame = capture.read() + nn += 1 + #print('>>>>{}'.format(nn)) + if not ret:break + if not re: + re = img_filter(frame, maskimg, fgbg) + else: + imglist.append(frame) + #cv2.imwrite('./imgs/tmpimgs/'+str(nn)+'.jpg', frame) + if len(imglist) > 30: + break + return imglist #-->list imgs + +def img_filter(frame, maskimg, fgbg): + dic,dics = {},{} + iouArea = 0 + + height, width = frame.shape[:2] + frame = cv2.resize(frame, (int(width/2), int(height/2)), interpolation=cv2.INTER_CUBIC) + + fgmask = fgbg.apply(frame) + draw1 = cv2.threshold(fgmask, 25, 255, cv2.THRESH_BINARY)[1] + + draw1 = cv2.bitwise_and(maskimg[:, :, 0], draw1) + contours_m, hierarchy_m = cv2.findContours(draw1.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours_m: + dics[len(contour)] = contour + if len(dics.keys())>0: + cc = sorted(dics.keys()) + iouArea = cv2.contourArea(dics[cc[-1]]) + if iouArea>3000 and iouArea<50000: + return True + return False + +if __name__ == '__main__': + videoName = 'filterImg.mp4' + filt(videoName) + diff --git a/utils/initialize.py b/utils/initialize.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/opfile.py b/utils/opfile.py new file mode 100644 index 0000000..b7c3bf0 --- /dev/null +++ b/utils/opfile.py @@ -0,0 +1,24 @@ +import h5py, os +from floder.config import cfg + +def writef(uuid, num_id, feature): + fname = os.sep.join([cfg.hFile, uuid+'.h5']) + if not os.path.exists(fname): + f = h5py.File(fname, 'w') + f[str(num_id)] = feature + else: + f = h5py.File(fname, 'a') + f[str(num_id)] = feature + print('>>>>>>>have been write') + f.close() + +def readf(uuid, num_id): + fname = os.sep.join([cfg.hFile, uuid+'.h5']) + f = h5py.File(fname, 'r') + value = f[str(num_id)][:] + f.close() + return value + +def removef(uuid): + fname = os.sep.join([cfg.hFile, uuid+'.h5']) + os.remove(fname) diff --git a/utils/save.py b/utils/save.py new file mode 100644 index 0000000..e69de29