Compare commits
33 Commits
Author | SHA1 | Date | |
---|---|---|---|
e044c85a04 | |||
798c596acc | |||
183299c06b | |||
0ccfd0151f | |||
f14faa323e | |||
9b5b135fa3 | |||
0efe8892f3 | |||
b657be729b | |||
64248b1557 | |||
bfe7bc0fd5 | |||
744fb7b7b2 | |||
a50f777839 | |||
3d13b0d9c5 | |||
661489120b | |||
7e13e0f5b4 | |||
dac3b3f2b6 | |||
39f94c7bd4 | |||
afd033b965 | |||
1e6c5deee4 | |||
8bbee310ba | |||
c47894ddc0 | |||
5ecc1285d4 | |||
dfb2272a15 | |||
390c5d2d94 | |||
09e92d63b3 | |||
7309dec166 | |||
f978d4174f | |||
e00fb46847 | |||
0cc36ba920 | |||
5109400a57 | |||
27d57b21d4 | |||
16543107f3 | |||
e986ec060b |
1
.gitignore
vendored
1
.gitignore
vendored
@ -23,6 +23,7 @@
|
||||
|
||||
*.rar
|
||||
*.pkl
|
||||
*.pickle
|
||||
*.npy
|
||||
*.csv
|
||||
|
||||
|
BIN
__pycache__/event_time_specify.cpython-39.pyc
Normal file
BIN
__pycache__/event_time_specify.cpython-39.pyc
Normal file
Binary file not shown.
BIN
__pycache__/export.cpython-312.pyc
Normal file
BIN
__pycache__/export.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/imgs_inference.cpython-39.pyc
Normal file
BIN
__pycache__/imgs_inference.cpython-39.pyc
Normal file
Binary file not shown.
BIN
__pycache__/move_detect.cpython-312.pyc
Normal file
BIN
__pycache__/move_detect.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/move_detect.cpython-39.pyc
Normal file
BIN
__pycache__/move_detect.cpython-39.pyc
Normal file
Binary file not shown.
BIN
__pycache__/pipeline_01.cpython-312.pyc
Normal file
BIN
__pycache__/pipeline_01.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/pipeline_01.cpython-39.pyc
Normal file
BIN
__pycache__/pipeline_01.cpython-39.pyc
Normal file
Binary file not shown.
BIN
__pycache__/track_reid.cpython-312.pyc
Normal file
BIN
__pycache__/track_reid.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/track_reid.cpython-39.pyc
Normal file
BIN
__pycache__/track_reid.cpython-39.pyc
Normal file
Binary file not shown.
359
bakeup/pipeline.py
Normal file
359
bakeup/pipeline.py
Normal file
@ -0,0 +1,359 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Sun Sep 29 08:59:21 2024
|
||||
|
||||
@author: ym
|
||||
"""
|
||||
import os
|
||||
# import sys
|
||||
import cv2
|
||||
import pickle
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from scipy.spatial.distance import cdist
|
||||
from track_reid import yolo_resnet_tracker, yolov10_resnet_tracker
|
||||
|
||||
from tracking.dotrack.dotracks_back import doBackTracks
|
||||
from tracking.dotrack.dotracks_front import doFrontTracks
|
||||
from tracking.utils.drawtracks import plot_frameID_y2, draw_all_trajectories
|
||||
from utils.getsource import get_image_pairs, get_video_pairs
|
||||
from tracking.utils.read_data import read_similar
|
||||
|
||||
|
||||
def save_subimgs(imgdict, boxes, spath, ctype, featdict = None):
|
||||
'''
|
||||
当前 box 特征和该轨迹前一个 box 特征的相似度,可用于和跟踪序列中的相似度进行比较
|
||||
'''
|
||||
boxes = boxes[np.argsort(boxes[:, 7])]
|
||||
for i in range(len(boxes)):
|
||||
simi = None
|
||||
tid, fid, bid = int(boxes[i, 4]), int(boxes[i, 7]), int(boxes[i, 8])
|
||||
|
||||
if i>0:
|
||||
_, fid0, bid0 = int(boxes[i-1, 4]), int(boxes[i-1, 7]), int(boxes[i-1, 8])
|
||||
if f"{fid0}_{bid0}" in featdict.keys() and f"{fid}_{bid}" in featdict.keys():
|
||||
feat0 = featdict[f"{fid0}_{bid0}"]
|
||||
feat1 = featdict[f"{fid}_{bid}"]
|
||||
simi = 1 - np.maximum(0.0, cdist(feat0[None, :], feat1[None, :], "cosine"))[0][0]
|
||||
|
||||
img = imgdict[f"{fid}_{bid}"]
|
||||
imgpath = spath / f"{ctype}_tid{tid}-{fid}-{bid}.png"
|
||||
if simi is not None:
|
||||
imgpath = spath / f"{ctype}_tid{tid}-{fid}-{bid}_sim{simi:.2f}.png"
|
||||
|
||||
cv2.imwrite(imgpath, img)
|
||||
|
||||
|
||||
def save_subimgs_1(imgdict, boxes, spath, ctype, simidict = None):
|
||||
'''
|
||||
当前 box 特征和该轨迹 smooth_feat 特征的相似度, yolo_resnet_tracker 函数中,
|
||||
采用该方式记录特征相似度
|
||||
'''
|
||||
for i in range(len(boxes)):
|
||||
tid, fid, bid = int(boxes[i, 4]), int(boxes[i, 7]), int(boxes[i, 8])
|
||||
|
||||
key = f"{fid}_{bid}"
|
||||
img = imgdict[key]
|
||||
imgpath = spath / f"{ctype}_tid{tid}-{fid}-{bid}.png"
|
||||
if simidict is not None and key in simidict.keys():
|
||||
imgpath = spath / f"{ctype}_tid{tid}-{fid}-{bid}_sim{simidict[key]:.2f}.png"
|
||||
|
||||
cv2.imwrite(imgpath, img)
|
||||
|
||||
|
||||
def pipeline(
|
||||
eventpath,
|
||||
savepath,
|
||||
SourceType,
|
||||
weights,
|
||||
YoloVersion="V5"
|
||||
):
|
||||
'''
|
||||
eventpath: 单个事件的存储路径
|
||||
|
||||
'''
|
||||
optdict = {}
|
||||
optdict["weights"] = weights
|
||||
|
||||
if SourceType == "video":
|
||||
vpaths = get_video_pairs(eventpath)
|
||||
elif SourceType == "image":
|
||||
vpaths = get_image_pairs(eventpath)
|
||||
event_tracks = []
|
||||
|
||||
## 构造购物事件字典
|
||||
evtname = Path(eventpath).stem
|
||||
barcode = evtname.split('_')[-1] if len(evtname.split('_'))>=2 \
|
||||
and len(evtname.split('_')[-1])>=8 \
|
||||
and evtname.split('_')[-1].isdigit() else ''
|
||||
'''事件结果存储文件夹'''
|
||||
if not savepath:
|
||||
savepath = Path(__file__).resolve().parents[0] / "events_result"
|
||||
|
||||
savepath_pipeline = Path(savepath) / Path("Yolos_Tracking") / evtname
|
||||
|
||||
|
||||
"""ShoppingDict pickle 文件保存地址 """
|
||||
savepath_spdict = Path(savepath) / "ShoppingDict_pkfile"
|
||||
if not savepath_spdict.exists():
|
||||
savepath_spdict.mkdir(parents=True, exist_ok=True)
|
||||
pf_path = Path(savepath_spdict) / Path(str(evtname)+".pickle")
|
||||
|
||||
# if pf_path.exists():
|
||||
# print(f"Pickle file have saved: {evtname}.pickle")
|
||||
# return
|
||||
|
||||
'''====================== 构造 ShoppingDict 模块 ======================='''
|
||||
ShoppingDict = {"eventPath": eventpath,
|
||||
"eventName": evtname,
|
||||
"barcode": barcode,
|
||||
"eventType": '', # "input", "output", "other"
|
||||
"frontCamera": {},
|
||||
"backCamera": {},
|
||||
"one2n": [] #
|
||||
}
|
||||
yrtDict = {}
|
||||
|
||||
|
||||
procpath = Path(eventpath).joinpath('process.data')
|
||||
if procpath.is_file():
|
||||
SimiDict = read_similar(procpath)
|
||||
ShoppingDict["one2n"] = SimiDict['one2n']
|
||||
|
||||
|
||||
for vpath in vpaths:
|
||||
'''================= 1. 构造相机事件字典 ================='''
|
||||
CameraEvent = {"cameraType": '', # "front", "back"
|
||||
"videoPath": '',
|
||||
"imagePaths": [],
|
||||
"yoloResnetTracker": [],
|
||||
"tracking": [],
|
||||
}
|
||||
|
||||
if isinstance(vpath, list):
|
||||
CameraEvent["imagePaths"] = vpath
|
||||
bname = os.path.basename(vpath[0])
|
||||
if not isinstance(vpath, list):
|
||||
CameraEvent["videoPath"] = vpath
|
||||
bname = os.path.basename(vpath).split('.')[0]
|
||||
if bname.split('_')[0] == "0" or bname.find('back')>=0:
|
||||
CameraEvent["cameraType"] = "back"
|
||||
if bname.split('_')[0] == "1" or bname.find('front')>=0:
|
||||
CameraEvent["cameraType"] = "front"
|
||||
|
||||
'''================= 2. 事件结果存储文件夹 ================='''
|
||||
if isinstance(vpath, list):
|
||||
savepath_pipeline_imgs = savepath_pipeline / Path("images")
|
||||
else:
|
||||
savepath_pipeline_imgs = savepath_pipeline / Path(str(Path(vpath).stem))
|
||||
|
||||
if not savepath_pipeline_imgs.exists():
|
||||
savepath_pipeline_imgs.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
savepath_pipeline_subimgs = savepath_pipeline / Path("subimgs")
|
||||
if not savepath_pipeline_subimgs.exists():
|
||||
savepath_pipeline_subimgs.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
'''================= 3. Yolo + Resnet + Tracker ================='''
|
||||
optdict["source"] = vpath
|
||||
optdict["save_dir"] = savepath_pipeline_imgs
|
||||
optdict["is_save_img"] = True
|
||||
optdict["is_save_video"] = True
|
||||
|
||||
|
||||
if YoloVersion == "V5":
|
||||
yrtOut = yolo_resnet_tracker(**optdict)
|
||||
elif YoloVersion == "V10":
|
||||
yrtOut = yolov10_resnet_tracker(**optdict)
|
||||
|
||||
|
||||
yrtOut_save = []
|
||||
for frdict in yrtOut:
|
||||
fr_dict = {}
|
||||
for k, v in frdict.items():
|
||||
if k != "imgs":
|
||||
fr_dict[k]=v
|
||||
yrtOut_save.append(fr_dict)
|
||||
CameraEvent["yoloResnetTracker"] = yrtOut_save
|
||||
|
||||
# CameraEvent["yoloResnetTracker"] = yrtOut
|
||||
|
||||
'''================= 4. tracking ================='''
|
||||
'''(1) 生成用于 tracking 模块的 boxes、feats'''
|
||||
bboxes = np.empty((0, 6), dtype=np.float64)
|
||||
trackerboxes = np.empty((0, 9), dtype=np.float64)
|
||||
trackefeats = {}
|
||||
for frameDict in yrtOut:
|
||||
tboxes = frameDict["tboxes"]
|
||||
ffeats = frameDict["feats"]
|
||||
|
||||
boxes = frameDict["bboxes"]
|
||||
bboxes = np.concatenate((bboxes, np.array(boxes)), axis=0)
|
||||
trackerboxes = np.concatenate((trackerboxes, np.array(tboxes)), axis=0)
|
||||
for i in range(len(tboxes)):
|
||||
fid, bid = int(tboxes[i, 7]), int(tboxes[i, 8])
|
||||
trackefeats.update({f"{fid}_{bid}": ffeats[f"{fid}_{bid}"]})
|
||||
|
||||
|
||||
'''(2) tracking, 后摄'''
|
||||
if CameraEvent["cameraType"] == "back":
|
||||
vts = doBackTracks(trackerboxes, trackefeats)
|
||||
vts.classify()
|
||||
event_tracks.append(("back", vts))
|
||||
|
||||
CameraEvent["tracking"] = vts
|
||||
ShoppingDict["backCamera"] = CameraEvent
|
||||
|
||||
yrtDict["backyrt"] = yrtOut
|
||||
|
||||
'''(2) tracking, 前摄'''
|
||||
if CameraEvent["cameraType"] == "front":
|
||||
vts = doFrontTracks(trackerboxes, trackefeats)
|
||||
vts.classify()
|
||||
event_tracks.append(("front", vts))
|
||||
|
||||
CameraEvent["tracking"] = vts
|
||||
ShoppingDict["frontCamera"] = CameraEvent
|
||||
|
||||
yrtDict["frontyrt"] = yrtOut
|
||||
|
||||
'''========================== 保存模块 ================================='''
|
||||
'''(1) 保存 ShoppingDict 事件'''
|
||||
with open(str(pf_path), 'wb') as f:
|
||||
pickle.dump(ShoppingDict, f)
|
||||
|
||||
'''(2) 保存 Tracking 输出的运动轨迹子图,并记录相似度'''
|
||||
for CamerType, vts in event_tracks:
|
||||
if len(vts.tracks)==0: continue
|
||||
if CamerType == 'front':
|
||||
# yolos = ShoppingDict["frontCamera"]["yoloResnetTracker"]
|
||||
|
||||
yolos = yrtDict["frontyrt"]
|
||||
ctype = 1
|
||||
if CamerType == 'back':
|
||||
# yolos = ShoppingDict["backCamera"]["yoloResnetTracker"]
|
||||
|
||||
yolos = yrtDict["backyrt"]
|
||||
ctype = 0
|
||||
|
||||
imgdict, featdict, simidict = {}, {}, {}
|
||||
for y in yolos:
|
||||
imgdict.update(y["imgs"])
|
||||
featdict.update(y["feats"])
|
||||
simidict.update(y["featsimi"])
|
||||
|
||||
for track in vts.Residual:
|
||||
if isinstance(track, np.ndarray):
|
||||
save_subimgs(imgdict, track, savepath_pipeline_subimgs, ctype, featdict)
|
||||
else:
|
||||
save_subimgs(imgdict, track.slt_boxes, savepath_pipeline_subimgs, ctype, featdict)
|
||||
|
||||
'''(3) 轨迹显示与保存'''
|
||||
illus = [None, None]
|
||||
for CamerType, vts in event_tracks:
|
||||
if len(vts.tracks)==0: continue
|
||||
|
||||
if CamerType == 'front':
|
||||
edgeline = cv2.imread("./tracking/shopcart/cart_tempt/board_ftmp_line.png")
|
||||
|
||||
h, w = edgeline.shape[:2]
|
||||
# nh, nw = h//2, w//2
|
||||
# edgeline = cv2.resize(edgeline, (nw, nh), interpolation=cv2.INTER_AREA)
|
||||
|
||||
img_tracking = draw_all_trajectories(vts, edgeline, savepath_pipeline, CamerType, draw5p=True)
|
||||
illus[0] = img_tracking
|
||||
|
||||
plt = plot_frameID_y2(vts)
|
||||
plt.savefig(os.path.join(savepath_pipeline, "front_y2.png"))
|
||||
|
||||
if CamerType == 'back':
|
||||
edgeline = cv2.imread("./tracking/shopcart/cart_tempt/edgeline.png")
|
||||
|
||||
h, w = edgeline.shape[:2]
|
||||
# nh, nw = h//2, w//2
|
||||
# edgeline = cv2.resize(edgeline, (nw, nh), interpolation=cv2.INTER_AREA)
|
||||
|
||||
img_tracking = draw_all_trajectories(vts, edgeline, savepath_pipeline, CamerType, draw5p=True)
|
||||
illus[1] = img_tracking
|
||||
|
||||
illus = [im for im in illus if im is not None]
|
||||
if len(illus):
|
||||
img_cat = np.concatenate(illus, axis = 1)
|
||||
if len(illus)==2:
|
||||
H, W = img_cat.shape[:2]
|
||||
cv2.line(img_cat, (int(W/2), 0), (int(W/2), int(H)), (128, 128, 255), 3)
|
||||
|
||||
trajpath = os.path.join(savepath_pipeline, "trajectory.png")
|
||||
cv2.imwrite(trajpath, img_cat)
|
||||
|
||||
def execute_pipeline(evtdir = r"D:\datasets\ym\后台数据\unzip",
|
||||
source_type = "video", # video, image,
|
||||
save_path = r"D:\work\result_pipeline",
|
||||
yolo_ver = "V10", # V10, V5
|
||||
|
||||
weight_yolo_v5 = r'./ckpts/best_cls10_0906.pt' ,
|
||||
weight_yolo_v10 = r'./ckpts/best_v10s_width0375_1205.pt',
|
||||
k=0
|
||||
):
|
||||
'''
|
||||
运行函数 pipeline(),遍历事件文件夹,每个文件夹是一个事件
|
||||
'''
|
||||
parmDict = {}
|
||||
parmDict["SourceType"] = source_type
|
||||
parmDict["savepath"] = save_path
|
||||
parmDict["YoloVersion"] = yolo_ver
|
||||
if parmDict["YoloVersion"] == "V5":
|
||||
parmDict["weights"] = weight_yolo_v5
|
||||
elif parmDict["YoloVersion"] == "V10":
|
||||
parmDict["weights"] = weight_yolo_v10
|
||||
|
||||
evtdir = Path(evtdir)
|
||||
errEvents = []
|
||||
for item in evtdir.iterdir():
|
||||
if item.is_dir():
|
||||
item = evtdir/Path("20250310-175352-741")
|
||||
parmDict["eventpath"] = item
|
||||
pipeline(**parmDict)
|
||||
# try:
|
||||
# pipeline(**parmDict)
|
||||
# except Exception as e:
|
||||
# errEvents.append(str(item))
|
||||
k+=1
|
||||
if k==1:
|
||||
break
|
||||
|
||||
errfile = os.path.join(parmDict["savepath"], 'error_events.txt')
|
||||
with open(errfile, 'w', encoding='utf-8') as f:
|
||||
for line in errEvents:
|
||||
f.write(line + '\n')
|
||||
|
||||
if __name__ == "__main__":
|
||||
execute_pipeline()
|
||||
|
||||
# spath_v10 = r"D:\work\result_pipeline_v10"
|
||||
# spath_v5 = r"D:\work\result_pipeline_v5"
|
||||
# execute_pipeline(save_path=spath_v10, yolo_ver="V10")
|
||||
# execute_pipeline(save_path=spath_v5, yolo_ver="V5")
|
||||
|
||||
datapath = r'/home/wqg/dataset/test_dataset/base_dataset/single_event/source/'
|
||||
savepath = r'/home/wqg/dataset/pipeline/contrast/single_event_V5'
|
||||
|
||||
|
||||
|
||||
|
||||
execute_pipeline(evtdir = datapath,
|
||||
DataType = "raw", # raw, pkl
|
||||
kk=1,
|
||||
source_type = "video", # video, image,
|
||||
save_path = savepath,
|
||||
yolo_ver = "V10", # V10, V5
|
||||
weight_yolo_v5 = r'./ckpts/best_cls10_0906.pt' ,
|
||||
weight_yolo_v10 = r'./ckpts/best_v10s_width0375_1205.pt',
|
||||
saveimages = False
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
50
bclass.py
Normal file
50
bclass.py
Normal file
@ -0,0 +1,50 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Fri Nov 15 16:23:03 2024
|
||||
|
||||
@author: ym
|
||||
"""
|
||||
|
||||
class CamEvent:
|
||||
def __init__(self, datapath):
|
||||
self.data_path = datapath
|
||||
self.bboxes = None
|
||||
self.bfeats = None
|
||||
self.tboxes = None
|
||||
self.tfeats = None
|
||||
|
||||
|
||||
|
||||
class ShopEvent:
|
||||
def __init__(self, eventpath, stdpath):
|
||||
self.barcode = ""
|
||||
self.event_path = eventpath
|
||||
self.event_type = self.get_event_type(eventpath)
|
||||
|
||||
self.FrontEvent = ""
|
||||
self.BackEvent = ""
|
||||
self.fusion_boxes = None
|
||||
self.fusion_feats = None
|
||||
self.stdfeats = self.get_stdfeats(stdpath)
|
||||
self.weight = None
|
||||
self.imu = None
|
||||
|
||||
def get_event_type(self, eventpath):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
def get_stdfeats(self, stdpath):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
7
contrast/__init__.py
Normal file
7
contrast/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Thu Sep 26 08:53:58 2024
|
||||
|
||||
@author: ym
|
||||
"""
|
||||
|
BIN
contrast/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
contrast/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
contrast/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/config.cpython-39.pyc
Normal file
BIN
contrast/__pycache__/config.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/event_test.cpython-312.pyc
Normal file
BIN
contrast/__pycache__/event_test.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/event_test.cpython-39.pyc
Normal file
BIN
contrast/__pycache__/event_test.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/feat_inference.cpython-39.pyc
Normal file
BIN
contrast/__pycache__/feat_inference.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/genfeats.cpython-312.pyc
Normal file
BIN
contrast/__pycache__/genfeats.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/genfeats.cpython-39.pyc
Normal file
BIN
contrast/__pycache__/genfeats.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/one2n_contrast.cpython-312.pyc
Normal file
BIN
contrast/__pycache__/one2n_contrast.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/one2n_contrast.cpython-39.pyc
Normal file
BIN
contrast/__pycache__/one2n_contrast.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/__pycache__/test_ori.cpython-39.pyc
Normal file
BIN
contrast/__pycache__/test_ori.cpython-39.pyc
Normal file
Binary file not shown.
374
contrast/event_test.py
Normal file
374
contrast/event_test.py
Normal file
@ -0,0 +1,374 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Mon Dec 16 18:56:18 2024
|
||||
|
||||
@author: ym
|
||||
"""
|
||||
import os
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
|
||||
from matplotlib import rcParams
|
||||
from matplotlib.font_manager import FontProperties
|
||||
from scipy.spatial.distance import cdist
|
||||
from utils.event import ShoppingEvent, save_data
|
||||
from utils.calsimi import calsimi_vs_stdfeat_new, get_topk_percent, cluster
|
||||
from utils.tools import get_evtList
|
||||
import pickle
|
||||
|
||||
rcParams['font.sans-serif'] = ['SimHei'] # 用黑体显示中文
|
||||
rcParams['axes.unicode_minus'] = False # 正确显示负号
|
||||
|
||||
'''*********** USearch ***********'''
|
||||
def read_usearch():
|
||||
stdFeaturePath = r"D:\contrast\stdlib\v11_test.json"
|
||||
stdBarcode = []
|
||||
stdlib = {}
|
||||
with open(stdFeaturePath, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
for dic in data['total']:
|
||||
barcode = dic['key']
|
||||
feature = np.array(dic['value'])
|
||||
stdBarcode.append(barcode)
|
||||
stdlib[barcode] = feature
|
||||
|
||||
return stdlib
|
||||
|
||||
def get_eventlist_errortxt(evtpaths):
|
||||
'''
|
||||
读取一次测试中的错误事件
|
||||
'''
|
||||
text1 = "one_2_Small_n_Error.txt"
|
||||
text2 = "one_2_Big_N_Error.txt"
|
||||
events = []
|
||||
text = (text1, text2)
|
||||
for txt in text:
|
||||
txtfile = os.path.join(evtpaths, txt)
|
||||
with open(txtfile, "r") as f:
|
||||
lines = f.readlines()
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
if line:
|
||||
fpath=os.path.join(evtpaths, line)
|
||||
events.append(fpath)
|
||||
|
||||
|
||||
|
||||
events = list(set(events))
|
||||
|
||||
return events
|
||||
|
||||
def save_eventdata():
|
||||
evtpaths = r"/home/wqg/dataset/test_dataset/performence_dataset/"
|
||||
events = get_eventlist_errortxt(evtpaths)
|
||||
|
||||
'''定义当前事件存储地址及生成相应文件件'''
|
||||
resultPath = r"\\192.168.1.28\share\测试视频数据以及日志\算法全流程测试\202412\result\single_event"
|
||||
for evtpath in events:
|
||||
event = ShoppingEvent(evtpath)
|
||||
save_data(event, resultPath)
|
||||
|
||||
print(event.evtname)
|
||||
|
||||
|
||||
|
||||
# def get_topk_percent(data, k):
|
||||
# """
|
||||
# 获取数据中最大的 k% 的元素
|
||||
# """
|
||||
# # 将数据转换为 NumPy 数组
|
||||
# if isinstance(data, list):
|
||||
# data = np.array(data)
|
||||
|
||||
# percentile = np.percentile(data, 100-k)
|
||||
# top_k_percent = data[data >= percentile]
|
||||
|
||||
# return top_k_percent
|
||||
# def cluster(data, thresh=0.15):
|
||||
# # data = np.array([0.1, 0.13, 0.7, 0.2, 0.8, 0.52, 0.3, 0.7, 0.85, 0.58])
|
||||
# # data = np.array([0.1, 0.13, 0.2, 0.3])
|
||||
# # data = np.array([0.1])
|
||||
|
||||
# if isinstance(data, list):
|
||||
# data = np.array(data)
|
||||
|
||||
# data1 = np.sort(data)
|
||||
# cluter, Cluters, = [data1[0]], []
|
||||
# for i in range(1, len(data1)):
|
||||
# if data1[i] - data1[i-1]< thresh:
|
||||
# cluter.append(data1[i])
|
||||
# else:
|
||||
# Cluters.append(cluter)
|
||||
# cluter = [data1[i]]
|
||||
# Cluters.append(cluter)
|
||||
|
||||
# clt_center = []
|
||||
# for clt in Cluters:
|
||||
# ## 是否应该在此处限制一个聚类中的最小轨迹样本数,应该将该因素放在轨迹分析中
|
||||
# # if len(clt)>=3:
|
||||
# # clt_center.append(np.mean(clt))
|
||||
# clt_center.append(np.mean(clt))
|
||||
|
||||
# # print(clt_center)
|
||||
|
||||
# return clt_center
|
||||
|
||||
# def calsimi_vs_stdfeat_new(event, stdfeat):
|
||||
# '''事件与标准库的对比策略
|
||||
# 该比对策略是否可以拓展到事件与事件的比对?
|
||||
# '''
|
||||
|
||||
|
||||
# def calsiml(feat1, feat2, topkp=75, cluth=0.15):
|
||||
# '''轨迹样本和标准特征集样本相似度的选择策略'''
|
||||
# matrix = 1 - cdist(feat1, feat2, 'cosine')
|
||||
# simi_max = []
|
||||
# for i in range(len(matrix)):
|
||||
# sim = np.mean(get_topk_percent(matrix[i, :], topkp))
|
||||
# simi_max.append(sim)
|
||||
# cltc_max = cluster(simi_max, cluth)
|
||||
# Simi = max(cltc_max)
|
||||
|
||||
# ## cltc_max为空属于编程考虑不周,应予以排查解决
|
||||
# # if len(cltc_max):
|
||||
# # Simi = max(cltc_max)
|
||||
# # else:
|
||||
# # Simi = 0 #不应该走到该处
|
||||
|
||||
|
||||
# return Simi
|
||||
|
||||
|
||||
# front_boxes = np.empty((0, 9), dtype=np.float64) ##和类doTracks兼容
|
||||
# front_feats = np.empty((0, 256), dtype=np.float64) ##和类doTracks兼容
|
||||
# for i in range(len(event.front_boxes)):
|
||||
# front_boxes = np.concatenate((front_boxes, event.front_boxes[i]), axis=0)
|
||||
# front_feats = np.concatenate((front_feats, event.front_feats[i]), axis=0)
|
||||
|
||||
# back_boxes = np.empty((0, 9), dtype=np.float64) ##和类doTracks兼容
|
||||
# back_feats = np.empty((0, 256), dtype=np.float64) ##和类doTracks兼容
|
||||
# for i in range(len(event.back_boxes)):
|
||||
# back_boxes = np.concatenate((back_boxes, event.back_boxes[i]), axis=0)
|
||||
# back_feats = np.concatenate((back_feats, event.back_feats[i]), axis=0)
|
||||
|
||||
# if len(front_feats):
|
||||
# front_simi = calsiml(front_feats, stdfeat)
|
||||
# if len(back_feats):
|
||||
# back_simi = calsiml(back_feats, stdfeat)
|
||||
|
||||
# '''前后摄相似度融合策略'''
|
||||
# if len(front_feats) and len(back_feats):
|
||||
# diff_simi = abs(front_simi - back_simi)
|
||||
# if diff_simi>0.15:
|
||||
# Similar = max([front_simi, back_simi])
|
||||
# else:
|
||||
# Similar = (front_simi+back_simi)/2
|
||||
# elif len(front_feats) and len(back_feats)==0:
|
||||
# Similar = front_simi
|
||||
# elif len(front_feats)==0 and len(back_feats):
|
||||
# Similar = back_simi
|
||||
# else:
|
||||
# Similar = None # 在event.front_feats和event.back_feats同时为空时
|
||||
|
||||
# return Similar
|
||||
|
||||
|
||||
|
||||
|
||||
def simi_matrix():
|
||||
evtpaths = r"/home/wqg/dataset/pipeline/contrast/single_event_V10/evtobjs/"
|
||||
|
||||
stdfeatPath = r"/home/wqg/dataset/test_dataset/total_barcode/features_json/v11_barcode_0304/"
|
||||
resultPath = r"/home/wqg/dataset/performence_dataset/result/"
|
||||
|
||||
evt_paths, bcdSet = get_evtList(evtpaths)
|
||||
|
||||
## read std features
|
||||
stdDict={}
|
||||
evtDict = {}
|
||||
for barcode in bcdSet:
|
||||
stdpath = os.path.join(stdfeatPath, f"{barcode}.json")
|
||||
if not os.path.isfile(stdpath):
|
||||
continue
|
||||
|
||||
with open(stdpath, 'r', encoding='utf-8') as f:
|
||||
stddata = json.load(f)
|
||||
feat = np.array(stddata["value"])
|
||||
stdDict[barcode] = feat
|
||||
|
||||
for evtpath in evt_paths:
|
||||
barcode = Path(evtpath).stem.split("_")[-1]
|
||||
|
||||
if barcode not in stdDict.keys():
|
||||
continue
|
||||
|
||||
# try:
|
||||
# with open(evtpath, 'rb') as f:
|
||||
# evtdata = pickle.load(f)
|
||||
# except Exception as e:
|
||||
# print(evtname)
|
||||
|
||||
with open(evtpath, 'rb') as f:
|
||||
event = pickle.load(f)
|
||||
|
||||
stdfeat = stdDict[barcode]
|
||||
|
||||
Similar = calsimi_vs_stdfeat_new(event, stdfeat)
|
||||
|
||||
# 构造 boxes 子图存储路径
|
||||
subimgpath = os.path.join(resultPath, f"{event.evtname}", "subimg")
|
||||
if not os.path.exists(subimgpath):
|
||||
os.makedirs(subimgpath)
|
||||
histpath = os.path.join(resultPath, "simi_hist")
|
||||
if not os.path.exists(histpath):
|
||||
os.makedirs(histpath)
|
||||
|
||||
mean_values, max_values = [], []
|
||||
cameras = ('front', 'back')
|
||||
fig, ax = plt.subplots(2, 3, figsize=(16, 9), dpi=100)
|
||||
kpercent = 25
|
||||
for camera in cameras:
|
||||
boxes = np.empty((0, 9), dtype=np.float64) ##和类doTracks兼容
|
||||
evtfeat = np.empty((0, 256), dtype=np.float64) ##和类doTracks兼容
|
||||
if camera == 'front':
|
||||
for i in range(len(event.front_boxes)):
|
||||
boxes = np.concatenate((boxes, event.front_boxes[i]), axis=0)
|
||||
evtfeat = np.concatenate((evtfeat, event.front_feats[i]), axis=0)
|
||||
imgpaths = event.front_imgpaths
|
||||
|
||||
else:
|
||||
for i in range(len(event.back_boxes)):
|
||||
boxes = np.concatenate((boxes, event.back_boxes[i]), axis=0)
|
||||
evtfeat = np.concatenate((evtfeat, event.back_feats[i]), axis=0)
|
||||
imgpaths = event.back_imgpaths
|
||||
|
||||
assert len(boxes)==len(evtfeat), f"Please check the Event: {event.evtname}"
|
||||
if len(boxes)==0: continue
|
||||
print(event.evtname)
|
||||
|
||||
matrix = 1 - cdist(evtfeat, stdfeat, 'cosine')
|
||||
simi_1d = matrix.flatten()
|
||||
simi_mean = np.mean(matrix, axis=1)
|
||||
# simi_max = np.max(matrix, axis=1)
|
||||
|
||||
'''以相似度矩阵每一行最大的 k% 的相似度做均值计算'''
|
||||
simi_max = []
|
||||
for i in range(len(matrix)):
|
||||
sim = np.mean(get_topk_percent(matrix[i, :], kpercent))
|
||||
simi_max.append(sim)
|
||||
|
||||
|
||||
mean_values.append(np.mean(matrix))
|
||||
max_values.append(np.mean(simi_max))
|
||||
|
||||
diff_max_mean = np.mean(simi_max) - np.mean(matrix)
|
||||
|
||||
'''相似度统计特性图示'''
|
||||
k =0
|
||||
if camera == 'front': k = 1
|
||||
|
||||
'''********************* 相似度全体数据 *********************'''
|
||||
ax[k, 0].hist(simi_1d, bins=60, range=(-0.2, 1), edgecolor='black')
|
||||
ax[k, 0].set_xlim([-0.2, 1])
|
||||
ax[k, 0].set_title(camera)
|
||||
|
||||
_, y_max = ax[k, 0].get_ylim() # 获取y轴范围
|
||||
'''相似度变动范围'''
|
||||
ax[k, 0].text(-0.1, 0.15*y_max, f"rng:{max(simi_1d)-min(simi_1d):.3f}", fontsize=18, color='b')
|
||||
|
||||
'''********************* 均值********************************'''
|
||||
ax[k, 1].hist(simi_mean, bins=24, range=(-0.2, 1), edgecolor='black')
|
||||
ax[k, 1].set_xlim([-0.2, 1])
|
||||
ax[k, 1].set_title("mean")
|
||||
_, y_max = ax[k, 1].get_ylim() # 获取y轴范围
|
||||
'''相似度变动范围'''
|
||||
ax[k, 1].text(-0.1, 0.15*y_max, f"rng:{max(simi_mean)-min(simi_mean):.3f}", fontsize=18, color='b')
|
||||
|
||||
|
||||
'''********************* 最大值 ******************************'''
|
||||
ax[k, 2].hist(simi_max, bins=24, range=(-0.2, 1), edgecolor='black')
|
||||
ax[k, 2].set_xlim([-0.2, 1])
|
||||
ax[k, 2].set_title("max")
|
||||
_, y_max = ax[k, 2].get_ylim() # 获取y轴范围
|
||||
'''相似度变动范围'''
|
||||
ax[k, 2].text(-0.1, 0.15*y_max, f"rng:{max(simi_max)-min(simi_max):.3f}", fontsize=18, color='b')
|
||||
|
||||
|
||||
'''绘制聚类中心'''
|
||||
cltc_mean = cluster(simi_mean)
|
||||
for value in cltc_mean:
|
||||
ax[k, 1].axvline(x=value, color='m', linestyle='--', linewidth=3)
|
||||
|
||||
cltc_max = cluster(simi_max)
|
||||
for value in cltc_max:
|
||||
ax[k, 2].axvline(x=value, color='m', linestyle='--', linewidth=3)
|
||||
|
||||
'''绘制相似度均值与最大值均值'''
|
||||
ax[k, 1].axvline(x=np.mean(matrix), color='r', linestyle='-', linewidth=3)
|
||||
ax[k, 2].axvline(x=np.mean(simi_max), color='g', linestyle='-', linewidth=3)
|
||||
|
||||
'''绘制相似度最大值均值 - 均值'''
|
||||
_, y_max = ax[k, 2].get_ylim() # 获取y轴范围
|
||||
ax[k, 2].text(-0.1, 0.05*y_max, f"g-r={diff_max_mean:.3f}", fontsize=18, color='m')
|
||||
|
||||
plt.show()
|
||||
|
||||
# for i, box in enumerate(boxes):
|
||||
# x1, y1, x2, y2, tid, score, cls, fid, bid = box
|
||||
# imgpath = imgpaths[int(fid-1)]
|
||||
# image = cv2.imread(imgpath)
|
||||
# subimg = image[int(y1/2):int(y2/2), int(x1/2):int(x2/2), :]
|
||||
# camerType, timeTamp, _, frameID = os.path.basename(imgpath).split('.')[0].split('_')
|
||||
# subimgName = f"cam{camerType}_{i}_tid{int(tid)}_fid({int(fid)}, {frameID})_{simi_mean[i]:.3f}.png"
|
||||
# imgpairs.append((subimgName, subimg))
|
||||
# spath = os.path.join(subimgpath, subimgName)
|
||||
# cv2.imwrite(spath, subimg)
|
||||
|
||||
# oldname = f"cam{camerType}_{i}_tid{int(tid)}_fid({int(fid)}, {frameID}).png"
|
||||
# oldpath = os.path.join(subimgpath, oldname)
|
||||
# if os.path.exists(oldpath):
|
||||
# os.remove(oldpath)
|
||||
|
||||
|
||||
if len(mean_values)==2:
|
||||
mean_diff = abs(mean_values[1]-mean_values[0])
|
||||
ax[0, 1].set_title(f"mean diff: {mean_diff:.3f}")
|
||||
if len(max_values)==2:
|
||||
max_diff = abs(max_values[1]-max_values[0])
|
||||
ax[0, 2].set_title(f"max diff: {max_diff:.3f}")
|
||||
try:
|
||||
fig.suptitle(f"Similar: {Similar:.3f}", fontsize=16)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"Similar: {Similar}")
|
||||
pltpath = os.path.join(subimgpath, f"hist_max_{kpercent}%_.png")
|
||||
plt.savefig(pltpath)
|
||||
|
||||
pltpath1 = os.path.join(histpath, f"{event.evtname}_.png")
|
||||
plt.savefig(pltpath1)
|
||||
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def main():
|
||||
simi_matrix()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# cluster()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
BIN
contrast/feat_extract/__pycache__/config.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/__pycache__/config.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/__pycache__/config.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/__pycache__/config.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/__pycache__/inference.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/__pycache__/inference.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/__pycache__/inference.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/__pycache__/inference.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/checkpoints/resnet18_0515/best.rknn
Normal file
BIN
contrast/feat_extract/checkpoints/resnet18_0515/best.rknn
Normal file
Binary file not shown.
88
contrast/feat_extract/config.py
Normal file
88
contrast/feat_extract/config.py
Normal file
@ -0,0 +1,88 @@
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
|
||||
|
||||
class Config:
|
||||
# network settings
|
||||
backbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5]
|
||||
metric = 'arcface' # [cosface, arcface]
|
||||
cbam = True
|
||||
embedding_size = 256
|
||||
drop_ratio = 0.5
|
||||
img_size = 224
|
||||
|
||||
batch_size = 8
|
||||
|
||||
# data preprocess
|
||||
# input_shape = [1, 128, 128]
|
||||
"""transforms.RandomCrop(size),
|
||||
transforms.RandomVerticalFlip(p=0.5),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
RandomRotate(15, 0.3),
|
||||
# RandomGaussianBlur()"""
|
||||
|
||||
train_transform = T.Compose([
|
||||
T.ToTensor(),
|
||||
T.Resize((img_size, img_size)),
|
||||
# T.RandomCrop(img_size),
|
||||
# T.RandomHorizontalFlip(p=0.5),
|
||||
T.RandomRotation(180),
|
||||
T.ColorJitter(brightness=0.5),
|
||||
T.ConvertImageDtype(torch.float32),
|
||||
T.Normalize(mean=[0.5], std=[0.5]),
|
||||
])
|
||||
test_transform = T.Compose([
|
||||
T.ToTensor(),
|
||||
T.Resize((img_size, img_size)),
|
||||
T.ConvertImageDtype(torch.float32),
|
||||
T.Normalize(mean=[0.5], std=[0.5]),
|
||||
])
|
||||
|
||||
# dataset
|
||||
train_root = './data/2250_train/train' # 初始筛选过一次的数据集
|
||||
# train_root = './data/0612_train/train'
|
||||
test_root = "./data/2250_train/val/"
|
||||
# test_root = "./data/0612_train/val"
|
||||
test_list = "./data/2250_train/val_pair.txt"
|
||||
|
||||
test_group_json = "./2250_train/cross_same_0508.json"
|
||||
|
||||
|
||||
# test_list = "./data/test_data_100/val_pair.txt"
|
||||
|
||||
# training settings
|
||||
checkpoints = "checkpoints/resnet18_0613/" # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3]
|
||||
restore = False
|
||||
# restore_model = "checkpoints/renet18_2250_0315/best_resnet18_2250_0315.pth" # best_resnet18_1491_0306.pth
|
||||
restore_model = "checkpoints/resnet18_0515/best.pth" # best_resnet18_1491_0306.pth
|
||||
|
||||
# test_model = "checkpoints/renet18_2250_0314/best_resnet18_2250_0314.pth"
|
||||
testbackbone = 'resnet18' # [resnet18, mobilevit_s, mobilenet_v2, mobilenetv3_small, mobilenetv3_large, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5]
|
||||
test_val = "D:/比对/cl"
|
||||
# test_val = "./data/test_data_100"
|
||||
|
||||
test_model = "checkpoints/best_20250228.pth"
|
||||
# test_model = "checkpoints/zhanting_res_801.pth"
|
||||
# test_model = "checkpoints/zhanting_res_abroad_8021.pth"
|
||||
|
||||
|
||||
|
||||
train_batch_size = 512 # 256
|
||||
test_batch_size = 256 # 256
|
||||
|
||||
epoch = 300
|
||||
optimizer = 'sgd' # ['sgd', 'adam']
|
||||
lr = 1.5e-2 # 1e-2
|
||||
lr_step = 5 # 10
|
||||
lr_decay = 0.95 # 0.98
|
||||
weight_decay = 5e-4
|
||||
loss = 'cross_entropy' # ['focal_loss', 'cross_entropy']
|
||||
# device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
pin_memory = True # if memory is large, set it True to speed up a bit
|
||||
num_workers = 4 # dataloader
|
||||
|
||||
group_test = True
|
||||
|
||||
config = Config()
|
605
contrast/feat_extract/inference.py
Normal file
605
contrast/feat_extract/inference.py
Normal file
@ -0,0 +1,605 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
|
||||
@author: LiChen
|
||||
"""
|
||||
# import pdb
|
||||
# import shutil
|
||||
import torch.nn as nn
|
||||
# import statistics
|
||||
import os
|
||||
import numpy as np
|
||||
from scipy.spatial.distance import cdist
|
||||
import torch
|
||||
import os.path as osp
|
||||
from PIL import Image
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
# import sys
|
||||
# sys.path.append(r"D:\DetectTracking")
|
||||
# from contrast.config import config as conf
|
||||
# from contrast.model import resnet18
|
||||
|
||||
from .config import config as conf
|
||||
from .model import resnet18
|
||||
|
||||
# from model import (mobilevit_s, resnet14, resnet18, resnet34, resnet50, mobilenet_v2,
|
||||
# MobileNetV3_Small, mobilenet_v1, PPLCNET_x1_0, PPLCNET_x0_5, PPLCNET_x2_5)
|
||||
|
||||
curpath = Path(__file__).resolve().parents[0]
|
||||
|
||||
class FeatsInterface:
|
||||
def __init__(self, conf):
|
||||
self.device = conf.device
|
||||
|
||||
# if conf.backbone == 'resnet18':
|
||||
# model = resnet18().to(conf.device)
|
||||
|
||||
model = resnet18().to(conf.device)
|
||||
self.transform = conf.test_transform
|
||||
self.batch_size = conf.batch_size
|
||||
self.embedding_size = conf.embedding_size
|
||||
|
||||
if conf.test_model.find("zhanting") == -1:
|
||||
model = nn.DataParallel(model).to(conf.device)
|
||||
self.model = model
|
||||
|
||||
modpath = os.path.join(curpath, conf.test_model)
|
||||
self.model.load_state_dict(torch.load(modpath, map_location=conf.device))
|
||||
self.model.eval()
|
||||
# print('load model {} '.format(conf.testbackbone))
|
||||
|
||||
def inference(self, images, detections=None):
|
||||
'''
|
||||
如果是BGR,需要转变为RGB格式
|
||||
'''
|
||||
if isinstance(images, np.ndarray):
|
||||
imgs, features = self.inference_image(images, detections)
|
||||
return imgs, features
|
||||
|
||||
batch_patches = []
|
||||
patches = []
|
||||
for i, img in enumerate(images):
|
||||
img = img.copy()
|
||||
|
||||
## 对 img 进行补黑边,生成新的图像new_img
|
||||
width, height = img.size
|
||||
new_size = max(width, height)
|
||||
new_img = Image.new("RGB", (new_size, new_size), (0, 0, 0))
|
||||
paste_x = (new_size - width) // 2
|
||||
paste_y = (new_size - height) // 2
|
||||
new_img.paste(img, (paste_x, paste_y))
|
||||
|
||||
patch = self.transform(new_img)
|
||||
patch = patch.to(device=self.device)
|
||||
# if str(self.device) != "cpu":
|
||||
# patch = patch.to(device=self.device).half()
|
||||
# else:
|
||||
# patch = patch.to(device=self.device)
|
||||
|
||||
patches.append(patch)
|
||||
if (i + 1) % self.batch_size == 0:
|
||||
patches = torch.stack(patches, dim=0)
|
||||
batch_patches.append(patches)
|
||||
patches = []
|
||||
|
||||
if len(patches):
|
||||
patches = torch.stack(patches, dim=0)
|
||||
batch_patches.append(patches)
|
||||
|
||||
features = np.zeros((0, self.embedding_size))
|
||||
for patches in batch_patches:
|
||||
pred=self.model(patches)
|
||||
pred[torch.isinf(pred)] = 1.0
|
||||
feat = pred.cpu().data.numpy()
|
||||
features = np.vstack((features, feat))
|
||||
return features
|
||||
|
||||
def inference_image(self, image, detections):
|
||||
H, W, _ = np.shape(image)
|
||||
|
||||
batch_patches = []
|
||||
patches = []
|
||||
imgs = []
|
||||
for d in range(np.size(detections, 0)):
|
||||
tlbr = detections[d, :4].astype(np.int_)
|
||||
tlbr[0] = max(0, tlbr[0])
|
||||
tlbr[1] = max(0, tlbr[1])
|
||||
tlbr[2] = min(W - 1, tlbr[2])
|
||||
tlbr[3] = min(H - 1, tlbr[3])
|
||||
img = image[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2], :]
|
||||
|
||||
imgs.append(img)
|
||||
|
||||
|
||||
img1 = img[:, :, ::-1].copy() # the model expects RGB inputs
|
||||
patch = self.transform(img1)
|
||||
|
||||
# patch = patch.to(device=self.device).half()
|
||||
# if str(self.device) != "cpu":
|
||||
# patch = patch.to(device=self.device).half()
|
||||
# patch = patch.to(device=self.device)
|
||||
# else:
|
||||
# patch = patch.to(device=self.device)
|
||||
patch = patch.to(device=self.device)
|
||||
|
||||
patches.append(patch)
|
||||
if (d + 1) % self.batch_size == 0:
|
||||
patches = torch.stack(patches, dim=0)
|
||||
batch_patches.append(patches)
|
||||
patches = []
|
||||
|
||||
if len(patches):
|
||||
patches = torch.stack(patches, dim=0)
|
||||
batch_patches.append(patches)
|
||||
|
||||
features = np.zeros((0, self.embedding_size))
|
||||
for patches in batch_patches:
|
||||
pred = self.model(patches)
|
||||
pred[torch.isinf(pred)] = 1.0
|
||||
feat = pred.cpu().data.numpy()
|
||||
features = np.vstack((features, feat))
|
||||
|
||||
return imgs, features
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def unique_image(pair_list) -> set:
|
||||
"""Return unique image path in pair_list.txt"""
|
||||
with open(pair_list, 'r') as fd:
|
||||
pairs = fd.readlines()
|
||||
unique = set()
|
||||
for pair in pairs:
|
||||
id1, id2, _ = pair.split()
|
||||
unique.add(id1)
|
||||
unique.add(id2)
|
||||
return unique
|
||||
|
||||
|
||||
def group_image(images: set, batch) -> list:
|
||||
"""Group image paths by batch size"""
|
||||
images = list(images)
|
||||
size = len(images)
|
||||
res = []
|
||||
for i in range(0, size, batch):
|
||||
end = min(batch + i, size)
|
||||
res.append(images[i: end])
|
||||
return res
|
||||
|
||||
|
||||
def _preprocess(images: list, transform) -> torch.Tensor:
|
||||
res = []
|
||||
for img in images:
|
||||
im = Image.open(img)
|
||||
im = transform(im)
|
||||
res.append(im)
|
||||
# data = torch.cat(res, dim=0) # shape: (batch, 128, 128)
|
||||
# data = data[:, None, :, :] # shape: (batch, 1, 128, 128)
|
||||
data = torch.stack(res)
|
||||
return data
|
||||
|
||||
|
||||
def test_preprocess(images: list, transform) -> torch.Tensor:
|
||||
res = []
|
||||
for img in images:
|
||||
im = Image.open(img)
|
||||
im = transform(im)
|
||||
res.append(im)
|
||||
# data = torch.cat(res, dim=0) # shape: (batch, 128, 128)
|
||||
# data = data[:, None, :, :] # shape: (batch, 1, 128, 128)
|
||||
data = torch.stack(res)
|
||||
return data
|
||||
|
||||
|
||||
def featurize(images: list, transform, net, device, train=False) -> dict:
|
||||
"""featurize each image and save into a dictionary
|
||||
Args:
|
||||
images: image paths
|
||||
transform: test transform
|
||||
net: pretrained model
|
||||
device: cpu or cuda
|
||||
Returns:
|
||||
Dict (key: imagePath, value: feature)
|
||||
"""
|
||||
if train:
|
||||
data = _preprocess(images, transform)
|
||||
data = data.to(device)
|
||||
net = net.to(device)
|
||||
with torch.no_grad():
|
||||
features = net(data)
|
||||
res = {img: feature for (img, feature) in zip(images, features)}
|
||||
else:
|
||||
data = test_preprocess(images, transform)
|
||||
data = data.to(device)
|
||||
net = net.to(device)
|
||||
with torch.no_grad():
|
||||
features = net(data)
|
||||
res = {img: feature for (img, feature) in zip(images, features)}
|
||||
return res
|
||||
|
||||
# def inference_image(images: list, transform, net, device, bs=16, embedding_size=256) -> dict:
|
||||
# batch_patches = []
|
||||
# patches = []
|
||||
# for d, img in enumerate(images):
|
||||
# img = Image.open(img)
|
||||
# patch = transform(img)
|
||||
|
||||
# if str(device) != "cpu":
|
||||
# patch = patch.to(device).half()
|
||||
# else:
|
||||
# patch = patch.to(device)
|
||||
|
||||
# patches.append(patch)
|
||||
# if (d + 1) % bs == 0:
|
||||
# patches = torch.stack(patches, dim=0)
|
||||
# batch_patches.append(patches)
|
||||
# patches = []
|
||||
|
||||
# if len(patches):
|
||||
# patches = torch.stack(patches, dim=0)
|
||||
# batch_patches.append(patches)
|
||||
|
||||
# features = np.zeros((0, embedding_size), dtype=np.float32)
|
||||
# for patches in batch_patches:
|
||||
# pred = net(patches)
|
||||
# pred[torch.isinf(pred)] = 1.0
|
||||
# feat = pred.cpu().data.numpy()
|
||||
# features = np.vstack((features, feat))
|
||||
|
||||
|
||||
|
||||
# return features
|
||||
|
||||
|
||||
|
||||
def featurize_1(images: list, transform, net, device, train=False) -> dict:
|
||||
"""featurize each image and save into a dictionary
|
||||
Args:
|
||||
images: image paths
|
||||
transform: test transform
|
||||
net: pretrained model
|
||||
device: cpu or cuda
|
||||
Returns:
|
||||
Dict (key: imagePath, value: feature)
|
||||
"""
|
||||
|
||||
data = test_preprocess(images, transform)
|
||||
data = data.to(device)
|
||||
net = net.to(device)
|
||||
with torch.no_grad():
|
||||
features = net(data).data.numpy()
|
||||
|
||||
return features
|
||||
|
||||
|
||||
|
||||
|
||||
def cosin_metric(x1, x2):
|
||||
return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
||||
|
||||
|
||||
def threshold_search(y_score, y_true):
|
||||
y_score = np.asarray(y_score)
|
||||
y_true = np.asarray(y_true)
|
||||
best_acc = 0
|
||||
best_th = 0
|
||||
for i in range(len(y_score)):
|
||||
th = y_score[i]
|
||||
y_test = (y_score >= th)
|
||||
acc = np.mean((y_test == y_true).astype(int))
|
||||
if acc > best_acc:
|
||||
best_acc = acc
|
||||
best_th = th
|
||||
return best_acc, best_th
|
||||
|
||||
|
||||
def showgrid(recall, recall_TN, PrecisePos, PreciseNeg):
|
||||
x = np.linspace(start=-1.0, stop=1.0, num=50, endpoint=True).tolist()
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(x, recall, color='red', label='recall')
|
||||
plt.plot(x, recall_TN, color='black', label='recall_TN')
|
||||
plt.plot(x, PrecisePos, color='blue', label='PrecisePos')
|
||||
plt.plot(x, PreciseNeg, color='green', label='PreciseNeg')
|
||||
plt.legend()
|
||||
plt.xlabel('threshold')
|
||||
# plt.ylabel('Similarity')
|
||||
plt.grid(True, linestyle='--', alpha=0.5)
|
||||
plt.savefig('accuracy_recall_grid.png')
|
||||
plt.show()
|
||||
plt.close()
|
||||
|
||||
|
||||
def compute_accuracy_recall(score, labels):
|
||||
th = 0.1
|
||||
squence = np.linspace(-1, 1, num=50)
|
||||
# squence = [0.4]
|
||||
recall, PrecisePos, PreciseNeg, recall_TN = [], [], [], []
|
||||
for th in squence:
|
||||
t_score = (score > th)
|
||||
t_labels = (labels == 1)
|
||||
# print(t_score)
|
||||
# print(t_labels)
|
||||
TP = np.sum(np.logical_and(t_score, t_labels))
|
||||
FN = np.sum(np.logical_and(np.logical_not(t_score), t_labels))
|
||||
f_score = (score < th)
|
||||
f_labels = (labels == 0)
|
||||
TN = np.sum(np.logical_and(f_score, f_labels))
|
||||
FP = np.sum(np.logical_and(np.logical_not(f_score), f_labels))
|
||||
print("Threshold:{} TP:{},FP:{},TN:{},FN:{}".format(th, TP, FP, TN, FN))
|
||||
|
||||
PrecisePos.append(0 if TP / (TP + FP) == 'nan' else TP / (TP + FP))
|
||||
PreciseNeg.append(0 if TN == 0 else TN / (TN + FN))
|
||||
recall.append(0 if TP == 0 else TP / (TP + FN))
|
||||
recall_TN.append(0 if TN == 0 else TN / (TN + FP))
|
||||
showgrid(recall, recall_TN, PrecisePos, PreciseNeg)
|
||||
|
||||
|
||||
def compute_accuracy(feature_dict, pair_list, test_root):
|
||||
with open(pair_list, 'r') as f:
|
||||
pairs = f.readlines()
|
||||
|
||||
similarities = []
|
||||
labels = []
|
||||
for pair in pairs:
|
||||
img1, img2, label = pair.split()
|
||||
img1 = osp.join(test_root, img1)
|
||||
img2 = osp.join(test_root, img2)
|
||||
feature1 = feature_dict[img1].cpu().numpy()
|
||||
feature2 = feature_dict[img2].cpu().numpy()
|
||||
label = int(label)
|
||||
|
||||
similarity = cosin_metric(feature1, feature2)
|
||||
similarities.append(similarity)
|
||||
labels.append(label)
|
||||
|
||||
accuracy, threshold = threshold_search(similarities, labels)
|
||||
# print('similarities >> {}'.format(similarities))
|
||||
# print('labels >> {}'.format(labels))
|
||||
compute_accuracy_recall(np.array(similarities), np.array(labels))
|
||||
return accuracy, threshold
|
||||
|
||||
|
||||
def deal_group_pair(pairList1, pairList2):
|
||||
allsimilarity = []
|
||||
one_similarity = []
|
||||
for pair1 in pairList1:
|
||||
for pair2 in pairList2:
|
||||
similarity = cosin_metric(pair1.cpu().numpy(), pair2.cpu().numpy())
|
||||
one_similarity.append(similarity)
|
||||
allsimilarity.append(max(one_similarity)) # 最大值
|
||||
# allsimilarity.append(sum(one_similarity)/len(one_similarity)) # 均值
|
||||
# allsimilarity.append(statistics.median(one_similarity)) # 中位数
|
||||
# print(allsimilarity)
|
||||
# print(labels)
|
||||
return allsimilarity
|
||||
|
||||
def compute_group_accuracy(content_list_read):
|
||||
allSimilarity, allLabel= [], []
|
||||
for data_loaded in content_list_read:
|
||||
one_group_list = []
|
||||
for i in range(2):
|
||||
images = [osp.join(conf.test_val, img) for img in data_loaded[i]]
|
||||
group = group_image(images, conf.test_batch_size)
|
||||
d = featurize(group[0], conf.test_transform, model, conf.device)
|
||||
one_group_list.append(d.values())
|
||||
similarity = deal_group_pair(one_group_list[0], one_group_list[1])
|
||||
allLabel.append(data_loaded[-1])
|
||||
allSimilarity.extend(similarity)
|
||||
# print(allSimilarity)
|
||||
# print(allLabel)
|
||||
return allSimilarity, allLabel
|
||||
|
||||
def compute_contrast_accuracy(content_list_read):
|
||||
|
||||
npairs = 50
|
||||
|
||||
same_folder_pairs = content_list_read['same_folder_pairs']
|
||||
cross_folder_pairs = content_list_read['cross_folder_pairs']
|
||||
|
||||
npairs = min((len(same_folder_pairs), len(cross_folder_pairs)))
|
||||
|
||||
Encoder = FeatsInterface(conf)
|
||||
|
||||
same_pairs = same_folder_pairs[:npairs]
|
||||
cross_pairs = cross_folder_pairs[:npairs]
|
||||
|
||||
same_pairs_similarity = []
|
||||
for i in range(len(same_pairs)):
|
||||
images_a = [osp.join(conf.test_val, img) for img in same_pairs[i][0]]
|
||||
images_b = [osp.join(conf.test_val, img) for img in same_pairs[i][1]]
|
||||
|
||||
feats_a = Encoder.inference(images_a)
|
||||
feats_b = Encoder.inference(images_b)
|
||||
# matrix = 1- np.maximum(0.0, cdist(feats_a, feats_b, 'cosine'))
|
||||
matrix = 1 - cdist(feats_a, feats_b, 'cosine')
|
||||
|
||||
feats_am = np.mean(feats_a, axis=0, keepdims=True)
|
||||
feats_bm = np.mean(feats_b, axis=0, keepdims=True)
|
||||
matrixm = 1- np.maximum(0.0, cdist(feats_am, feats_bm, 'cosine'))
|
||||
|
||||
same_pairs_similarity.append(np.mean(matrix))
|
||||
|
||||
'''保存相同 Barcode 图像对'''
|
||||
# foldi = os.path.join('./result/same', f'{i}')
|
||||
# if os.path.exists(foldi):
|
||||
# shutil.rmtree(foldi)
|
||||
# os.makedirs(foldi)
|
||||
# else:
|
||||
# os.makedirs(foldi)
|
||||
# for ipt in range(len(images_a)):
|
||||
# source_path = images_a[ipt]
|
||||
# destination_path = os.path.join(foldi, f'a_{ipt}.png')
|
||||
# shutil.copy2(source_path, destination_path)
|
||||
# for ipt in range(len(images_b)):
|
||||
# source_path = images_b[ipt]
|
||||
# destination_path = os.path.join(foldi, f'b_{ipt}.png')
|
||||
# shutil.copy2(source_path, destination_path)
|
||||
|
||||
cross_pairs_similarity = []
|
||||
for i in range(len(cross_pairs)):
|
||||
images_a = [osp.join(conf.test_val, img) for img in cross_pairs[i][0]]
|
||||
images_b = [osp.join(conf.test_val, img) for img in cross_pairs[i][1]]
|
||||
|
||||
feats_a = Encoder.inference(images_a)
|
||||
feats_b = Encoder.inference(images_b)
|
||||
# matrix = 1- np.maximum(0.0, cdist(feats_a, feats_b, 'cosine'))
|
||||
matrix = 1 - cdist(feats_a, feats_b, 'cosine')
|
||||
|
||||
feats_am = np.mean(feats_a, axis=0, keepdims=True)
|
||||
feats_bm = np.mean(feats_b, axis=0, keepdims=True)
|
||||
matrixm = 1- np.maximum(0.0, cdist(feats_am, feats_bm, 'cosine'))
|
||||
|
||||
cross_pairs_similarity.append(np.mean(matrix))
|
||||
|
||||
'''保存不同 Barcode 图像对'''
|
||||
# foldi = os.path.join('./result/cross', f'{i}')
|
||||
# if os.path.exists(foldi):
|
||||
# shutil.rmtree(foldi)
|
||||
# os.makedirs(foldi)
|
||||
# else:
|
||||
# os.makedirs(foldi)
|
||||
# for ipt in range(len(images_a)):
|
||||
# source_path = images_a[ipt]
|
||||
# destination_path = os.path.join(foldi, f'a_{ipt}.png')
|
||||
# shutil.copy2(source_path, destination_path)
|
||||
# for ipt in range(len(images_b)):
|
||||
# source_path = images_b[ipt]
|
||||
# destination_path = os.path.join(foldi, f'b_{ipt}.png')
|
||||
# shutil.copy2(source_path, destination_path)
|
||||
|
||||
|
||||
Thresh = np.linspace(-0.2, 1, 100)
|
||||
|
||||
Same = np.array(same_pairs_similarity)
|
||||
Cross = np.array(cross_pairs_similarity)
|
||||
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].hist(Same, bins=60, edgecolor='black')
|
||||
axs[0].set_xlim([-0.2, 1])
|
||||
axs[0].set_title('Same Barcode')
|
||||
|
||||
axs[1].hist(Cross, bins=60, edgecolor='black')
|
||||
axs[1].set_xlim([-0.2, 1])
|
||||
axs[1].set_title('Cross Barcode')
|
||||
|
||||
TPFN = len(Same)
|
||||
TNFP = len(Cross)
|
||||
Recall_Pos, Recall_Neg = [], []
|
||||
Precision_Pos, Precision_Neg = [], []
|
||||
Correct = []
|
||||
for th in Thresh:
|
||||
TP = np.sum(Same > th)
|
||||
FN = TPFN - TP
|
||||
TN = np.sum(Cross < th)
|
||||
FP = TNFP - TN
|
||||
|
||||
Recall_Pos.append(TP/TPFN)
|
||||
Recall_Neg.append(TN/TNFP)
|
||||
Precision_Pos.append(TP/(TP+FP))
|
||||
Precision_Neg.append(TN/(TN+FN))
|
||||
Correct.append((TN+TP)/(TPFN+TNFP))
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(Thresh, Correct, 'r', label='Correct: (TN+TP)/(TPFN+TNFP)')
|
||||
ax.plot(Thresh, Recall_Pos, 'b', label='Recall_Pos: TP/TPFN')
|
||||
ax.plot(Thresh, Recall_Neg, 'g', label='Recall_Neg: TN/TNFP')
|
||||
ax.plot(Thresh, Precision_Pos, 'c', label='Precision_Pos: TP/(TP+FP)')
|
||||
ax.plot(Thresh, Precision_Neg, 'm', label='Precision_Neg: TN/(TN+FN)')
|
||||
|
||||
ax.set_xlim([0, 1])
|
||||
ax.set_ylim([0, 1])
|
||||
ax.grid(True)
|
||||
ax.set_title('PrecisePos & PreciseNeg')
|
||||
ax.legend()
|
||||
plt.show()
|
||||
|
||||
print("Haved done!!!")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# Network Setup
|
||||
if conf.testbackbone == 'resnet18':
|
||||
# model = ResIRSE(conf.img_size, conf.embedding_size, conf.drop_ratio).to(conf.device)
|
||||
model = resnet18().to(conf.device)
|
||||
# elif conf.testbackbone == 'resnet34':
|
||||
# model = resnet34().to(conf.device)
|
||||
# elif conf.testbackbone == 'resnet50':
|
||||
# model = resnet50().to(conf.device)
|
||||
# elif conf.testbackbone == 'mobilevit_s':
|
||||
# model = mobilevit_s().to(conf.device)
|
||||
# elif conf.testbackbone == 'mobilenetv3':
|
||||
# model = MobileNetV3_Small().to(conf.device)
|
||||
# elif conf.testbackbone == 'mobilenet_v1':
|
||||
# model = mobilenet_v1().to(conf.device)
|
||||
# elif conf.testbackbone == 'PPLCNET_x1_0':
|
||||
# model = PPLCNET_x1_0().to(conf.device)
|
||||
# elif conf.testbackbone == 'PPLCNET_x0_5':
|
||||
# model = PPLCNET_x0_5().to(conf.device)
|
||||
# elif conf.backbone == 'PPLCNET_x2_5':
|
||||
# model = PPLCNET_x2_5().to(conf.device)
|
||||
# elif conf.testbackbone == 'mobilenet_v2':
|
||||
# model = mobilenet_v2().to(conf.device)
|
||||
# elif conf.testbackbone == 'resnet14':
|
||||
# model = resnet14().to(conf.device)
|
||||
else:
|
||||
raise ValueError('Have not model {}'.format(conf.backbone))
|
||||
|
||||
print('load model {} '.format(conf.testbackbone))
|
||||
# model = nn.DataParallel(model).to(conf.device)
|
||||
model.load_state_dict(torch.load(conf.test_model, map_location=conf.device))
|
||||
model.eval()
|
||||
if not conf.group_test:
|
||||
images = unique_image(conf.test_list)
|
||||
images = [osp.join(conf.test_val, img) for img in images]
|
||||
|
||||
groups = group_image(images, conf.test_batch_size) ##根据batch_size取图片
|
||||
|
||||
feature_dict = dict()
|
||||
for group in groups:
|
||||
d = featurize(group, conf.test_transform, model, conf.device)
|
||||
feature_dict.update(d)
|
||||
# print('feature_dict', feature_dict)
|
||||
accuracy, threshold = compute_accuracy(feature_dict, conf.test_list, conf.test_val)
|
||||
|
||||
print(
|
||||
f"Test Model: {conf.test_model}\n"
|
||||
f"Accuracy: {accuracy:.3f}\n"
|
||||
f"Threshold: {threshold:.3f}\n"
|
||||
)
|
||||
elif conf.group_test:
|
||||
"""
|
||||
conf.test_val: 测试数据集地址
|
||||
conf.test_group_json:测试数据分组配置文件
|
||||
"""
|
||||
filename = conf.test_group_json
|
||||
|
||||
filename = "../cl/images_1.json"
|
||||
with open(filename, 'r', encoding='utf-8') as file:
|
||||
content_list_read = json.load(file)
|
||||
|
||||
|
||||
compute_contrast_accuracy(content_list_read)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Similarity, Label = compute_group_accuracy(content_list_read)
|
||||
# print('allSimilarity >> {}'.format(Similarity))
|
||||
# print('allLabel >> {}'.format(Label))
|
||||
# compute_accuracy_recall(np.array(Similarity), np.array(Label))
|
||||
# # compute_group_accuracy(data_loaded)
|
||||
#
|
||||
# =============================================================================
|
88
contrast/feat_extract/model/BAM.py
Normal file
88
contrast/feat_extract/model/BAM.py
Normal file
@ -0,0 +1,88 @@
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from torch.nn import init
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, x):
|
||||
return x.view(x.shape[0], -1)
|
||||
|
||||
|
||||
class ChannelAttention(nn.Module):
|
||||
def __int__(self, channel, reduction, num_layers):
|
||||
super(ChannelAttention, self).__init__()
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
gate_channels = [channel]
|
||||
gate_channels += [len(channel) // reduction] * num_layers
|
||||
gate_channels += [channel]
|
||||
|
||||
self.ca = nn.Sequential()
|
||||
self.ca.add_module('flatten', Flatten())
|
||||
for i in range(len(gate_channels) - 2):
|
||||
self.ca.add_module('', nn.Linear(gate_channels[i], gate_channels[i + 1]))
|
||||
self.ca.add_module('', nn.BatchNorm1d(gate_channels[i + 1]))
|
||||
self.ca.add_module('', nn.ReLU())
|
||||
self.ca.add_module('', nn.Linear(gate_channels[-2], gate_channels[-1]))
|
||||
|
||||
def forward(self, x):
|
||||
res = self.avgpool(x)
|
||||
res = self.ca(res)
|
||||
res = res.unsqueeze(-1).unsqueeze(-1).expand_as(x)
|
||||
return res
|
||||
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
def __int__(self, channel, reduction=16, num_lay=3, dilation=2):
|
||||
super(SpatialAttention).__init__()
|
||||
self.sa = nn.Sequential()
|
||||
self.sa.add_module('', nn.Conv2d(kernel_size=1, in_channels=channel, out_channels=(channel // reduction) * 3))
|
||||
self.sa.add_module('', nn.BatchNorm2d(num_features=(channel // reduction)))
|
||||
self.sa.add_module('', nn.ReLU())
|
||||
for i in range(num_lay):
|
||||
self.sa.add_module('', nn.Conv2d(kernel_size=3,
|
||||
in_channels=(channel // reduction),
|
||||
out_channels=(channel // reduction),
|
||||
padding=1,
|
||||
dilation=2))
|
||||
self.sa.add_module('', nn.BatchNorm2d(channel // reduction))
|
||||
self.sa.add_module('', nn.ReLU())
|
||||
self.sa.add_module('', nn.Conv2d(channel // reduction, 1, kernel_size=1))
|
||||
|
||||
def forward(self, x):
|
||||
res = self.sa(x)
|
||||
res = res.expand_as(x)
|
||||
return res
|
||||
|
||||
|
||||
class BAMblock(nn.Module):
|
||||
def __init__(self, channel=512, reduction=16, dia_val=2):
|
||||
super(BAMblock, self).__init__()
|
||||
self.ca = ChannelAttention(channel, reduction)
|
||||
self.sa = SpatialAttention(channel, reduction, dia_val)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal(m.weight, mode='fan_out')
|
||||
if m.bais is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
sa_out = self.sa(x)
|
||||
ca_out = self.ca(x)
|
||||
weight = self.sigmoid(sa_out + ca_out)
|
||||
out = (1 + weight) * x
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(512 // 14)
|
70
contrast/feat_extract/model/CBAM.py
Normal file
70
contrast/feat_extract/model/CBAM.py
Normal file
@ -0,0 +1,70 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
|
||||
class channelAttention(nn.Module):
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(channelAttention, self).__init__()
|
||||
self.Maxpooling = nn.AdaptiveMaxPool2d(1)
|
||||
self.Avepooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.ca = nn.Sequential()
|
||||
self.ca.add_module('conv1',nn.Conv2d(channel, channel//reduction, 1, bias=False))
|
||||
self.ca.add_module('Relu', nn.ReLU())
|
||||
self.ca.add_module('conv2',nn.Conv2d(channel//reduction, channel, 1, bias=False))
|
||||
self.sigmod = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
M_out = self.Maxpooling(x)
|
||||
A_out = self.Avepooling(x)
|
||||
M_out = self.ca(M_out)
|
||||
A_out = self.ca(A_out)
|
||||
out = self.sigmod(M_out+A_out)
|
||||
return out
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
def __init__(self, kernel_size=7):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, padding=kernel_size // 2)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
max_result, _ = torch.max(x, dim=1, keepdim=True)
|
||||
avg_result = torch.mean(x, dim=1, keepdim=True)
|
||||
result = torch.cat([max_result, avg_result], dim=1)
|
||||
output = self.conv(result)
|
||||
output = self.sigmoid(output)
|
||||
return output
|
||||
|
||||
class CBAM(nn.Module):
|
||||
def __init__(self, channel, reduction=16, kernel_size=7):
|
||||
super().__init__()
|
||||
self.ca = channelAttention(channel, reduction)
|
||||
self.sa = SpatialAttention(kernel_size)
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():#权重初始化
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
# b,c_,_ = x.size()
|
||||
# residual = x
|
||||
out = x*self.ca(x)
|
||||
out = out*self.sa(out)
|
||||
return out
|
||||
|
||||
if __name__ == '__main__':
|
||||
input=torch.randn(50,512,7,7)
|
||||
kernel_size=input.shape[2]
|
||||
cbam = CBAM(channel=512,reduction=16,kernel_size=kernel_size)
|
||||
output=cbam(input)
|
||||
print(output.shape)
|
33
contrast/feat_extract/model/Tool.py
Normal file
33
contrast/feat_extract/model/Tool.py
Normal file
@ -0,0 +1,33 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
class GeM(nn.Module):
|
||||
def __init__(self, p=3, eps=1e-6):
|
||||
super(GeM, self).__init__()
|
||||
self.p = nn.Parameter(torch.ones(1) * p)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return self.gem(x, p=self.p, eps=self.eps, stride = 2)
|
||||
|
||||
def gem(self, x, p=3, eps=1e-6, stride = 2):
|
||||
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1)), stride=2).pow(1. / p)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + \
|
||||
'(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
|
||||
', ' + 'eps=' + str(self.eps) + ')'
|
||||
|
||||
class TripletLoss(nn.Module):
|
||||
def __init__(self, margin):
|
||||
super(TripletLoss, self).__init__()
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, anchor, positive, negative, size_average = True):
|
||||
distance_positive = (anchor-positive).pow(2).sum(1)
|
||||
distance_negative = (anchor-negative).pow(2).sum(1)
|
||||
losses = F.relu(distance_negative-distance_positive+self.margin)
|
||||
return losses.mean() if size_average else losses.sum()
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('')
|
11
contrast/feat_extract/model/__init__.py
Normal file
11
contrast/feat_extract/model/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
from .fmobilenet import FaceMobileNet
|
||||
from .resnet_face import ResIRSE
|
||||
from .mobilevit import mobilevit_s
|
||||
from .metric import ArcFace, CosFace
|
||||
from .loss import FocalLoss
|
||||
from .resbam import resnet
|
||||
from .resnet_pre import resnet18, resnet34, resnet50, resnet14
|
||||
from .mobilenet_v2 import mobilenet_v2
|
||||
from .mobilenet_v3 import MobileNetV3_Small, MobileNetV3_Large
|
||||
# from .mobilenet_v1 import mobilenet_v1
|
||||
from .lcnet import PPLCNET_x0_25, PPLCNET_x0_35, PPLCNET_x0_5, PPLCNET_x0_75, PPLCNET_x1_0, PPLCNET_x1_5, PPLCNET_x2_0, PPLCNET_x2_5
|
BIN
contrast/feat_extract/model/__pycache__/BAM.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/BAM.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/CBAM.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/CBAM.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/CBAM.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/CBAM.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/CBAM.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/CBAM.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/Tool.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/Tool.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/Tool.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/Tool.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/Tool.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/Tool.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/lcnet.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/lcnet.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/lcnet.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/lcnet.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/lcnet.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/lcnet.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/loss.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/loss.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/loss.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/loss.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/loss.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/loss.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/metric.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/metric.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/metric.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/metric.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/metric.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/metric.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/mobilevit.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/mobilevit.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/mobilevit.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/mobilevit.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/resbam.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/resbam.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/resbam.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/resbam.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/resbam.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/resbam.cpython-39.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/resnet.cpython-310.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/resnet.cpython-310.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/resnet.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/resnet.cpython-38.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/utils.cpython-312.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/utils.cpython-38.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/utils.cpython-38.pyc
Normal file
Binary file not shown.
BIN
contrast/feat_extract/model/__pycache__/utils.cpython-39.pyc
Normal file
BIN
contrast/feat_extract/model/__pycache__/utils.cpython-39.pyc
Normal file
Binary file not shown.
124
contrast/feat_extract/model/fmobilenet.py
Normal file
124
contrast/feat_extract/model/fmobilenet.py
Normal file
@ -0,0 +1,124 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, x):
|
||||
return x.view(x.shape[0], -1)
|
||||
|
||||
class ConvBn(nn.Module):
|
||||
|
||||
def __init__(self, in_c, out_c, kernel=(1, 1), stride=1, padding=0, groups=1):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
|
||||
nn.BatchNorm2d(out_c)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class ConvBnPrelu(nn.Module):
|
||||
|
||||
def __init__(self, in_c, out_c, kernel=(1, 1), stride=1, padding=0, groups=1):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
ConvBn(in_c, out_c, kernel, stride, padding, groups),
|
||||
nn.PReLU(out_c)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class DepthWise(nn.Module):
|
||||
|
||||
def __init__(self, in_c, out_c, kernel=(3, 3), stride=2, padding=1, groups=1):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
ConvBnPrelu(in_c, groups, kernel=(1, 1), stride=1, padding=0),
|
||||
ConvBnPrelu(groups, groups, kernel=kernel, stride=stride, padding=padding, groups=groups),
|
||||
ConvBn(groups, out_c, kernel=(1, 1), stride=1, padding=0),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class DepthWiseRes(nn.Module):
|
||||
"""DepthWise with Residual"""
|
||||
|
||||
def __init__(self, in_c, out_c, kernel=(3, 3), stride=2, padding=1, groups=1):
|
||||
super().__init__()
|
||||
self.net = DepthWise(in_c, out_c, kernel, stride, padding, groups)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
|
||||
class MultiDepthWiseRes(nn.Module):
|
||||
|
||||
def __init__(self, num_block, channels, kernel=(3, 3), stride=1, padding=1, groups=1):
|
||||
super().__init__()
|
||||
|
||||
self.net = nn.Sequential(*[
|
||||
DepthWiseRes(channels, channels, kernel, stride, padding, groups)
|
||||
for _ in range(num_block)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class FaceMobileNet(nn.Module):
|
||||
|
||||
def __init__(self, embedding_size):
|
||||
super().__init__()
|
||||
self.conv1 = ConvBnPrelu(1, 64, kernel=(3, 3), stride=2, padding=1)
|
||||
self.conv2 = ConvBn(64, 64, kernel=(3, 3), stride=1, padding=1, groups=64)
|
||||
self.conv3 = DepthWise(64, 64, kernel=(3, 3), stride=2, padding=1, groups=128)
|
||||
self.conv4 = MultiDepthWiseRes(num_block=4, channels=64, kernel=3, stride=1, padding=1, groups=128)
|
||||
self.conv5 = DepthWise(64, 128, kernel=(3, 3), stride=2, padding=1, groups=256)
|
||||
self.conv6 = MultiDepthWiseRes(num_block=6, channels=128, kernel=(3, 3), stride=1, padding=1, groups=256)
|
||||
self.conv7 = DepthWise(128, 128, kernel=(3, 3), stride=2, padding=1, groups=512)
|
||||
self.conv8 = MultiDepthWiseRes(num_block=2, channels=128, kernel=(3, 3), stride=1, padding=1, groups=256)
|
||||
self.conv9 = ConvBnPrelu(128, 512, kernel=(1, 1))
|
||||
self.conv10 = ConvBn(512, 512, groups=512, kernel=(7, 7))
|
||||
self.flatten = Flatten()
|
||||
self.linear = nn.Linear(2048, embedding_size, bias=False)
|
||||
self.bn = nn.BatchNorm1d(embedding_size)
|
||||
|
||||
def forward(self, x):
|
||||
#print('x',x.shape)
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
out = self.conv3(out)
|
||||
out = self.conv4(out)
|
||||
out = self.conv5(out)
|
||||
out = self.conv6(out)
|
||||
out = self.conv7(out)
|
||||
out = self.conv8(out)
|
||||
out = self.conv9(out)
|
||||
out = self.conv10(out)
|
||||
out = self.flatten(out)
|
||||
out = self.linear(out)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
|
||||
if __name__ == "__main__":
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
x = Image.open("../samples/009.jpg").convert('L')
|
||||
x = x.resize((128, 128))
|
||||
x = np.asarray(x, dtype=np.float32)
|
||||
x = x[None, None, ...]
|
||||
x = torch.from_numpy(x)
|
||||
net = FaceMobileNet(512)
|
||||
net.eval()
|
||||
with torch.no_grad():
|
||||
out = net(x)
|
||||
print(out.shape)
|
233
contrast/feat_extract/model/lcnet.py
Normal file
233
contrast/feat_extract/model/lcnet.py
Normal file
@ -0,0 +1,233 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import thop
|
||||
|
||||
# try:
|
||||
# import softpool_cuda
|
||||
# from SoftPool import soft_pool2d, SoftPool2d
|
||||
# except ImportError:
|
||||
# print('Please install SoftPool first: https://github.com/alexandrosstergiou/SoftPool')
|
||||
# exit(0)
|
||||
|
||||
NET_CONFIG = {
|
||||
# k, in_c, out_c, s, use_se
|
||||
"blocks2": [[3, 16, 32, 1, False]],
|
||||
"blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
|
||||
"blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
|
||||
"blocks5": [[3, 128, 256, 2, False], [5, 256, 256, 1, False],
|
||||
[5, 256, 256, 1, False], [5, 256, 256, 1, False],
|
||||
[5, 256, 256, 1, False], [5, 256, 256, 1, False]],
|
||||
"blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]]
|
||||
}
|
||||
|
||||
|
||||
def autopad(k, p=None):
|
||||
if p is None:
|
||||
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
|
||||
return p
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None):
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class HardSwish(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(HardSwish, self).__init__()
|
||||
self.relu6 = nn.ReLU6(inplace=inplace)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.relu6(x+3) / 6
|
||||
|
||||
|
||||
class HardSigmoid(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(HardSigmoid, self).__init__()
|
||||
self.relu6 = nn.ReLU6(inplace=inplace)
|
||||
|
||||
def forward(self, x):
|
||||
return (self.relu6(x+3)) / 6
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(SELayer, self).__init__()
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel, bias=False),
|
||||
HardSigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.size()
|
||||
y = self.avgpool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class DepthwiseSeparable(nn.Module):
|
||||
def __init__(self, inp, oup, dw_size, stride, use_se=False):
|
||||
super(DepthwiseSeparable, self).__init__()
|
||||
self.use_se = use_se
|
||||
self.stride = stride
|
||||
self.inp = inp
|
||||
self.oup = oup
|
||||
self.dw_size = dw_size
|
||||
self.dw_sp = nn.Sequential(
|
||||
nn.Conv2d(self.inp, self.inp, kernel_size=self.dw_size, stride=self.stride,
|
||||
padding=autopad(self.dw_size, None), groups=self.inp, bias=False),
|
||||
nn.BatchNorm2d(self.inp),
|
||||
HardSwish(),
|
||||
|
||||
nn.Conv2d(self.inp, self.oup, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(self.oup),
|
||||
HardSwish(),
|
||||
)
|
||||
self.se = SELayer(self.oup)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dw_sp(x)
|
||||
if self.use_se:
|
||||
x = self.se(x)
|
||||
return x
|
||||
|
||||
|
||||
class PP_LCNet(nn.Module):
|
||||
def __init__(self, scale=1.0, class_num=10, class_expand=1280, dropout_prob=0.2):
|
||||
super(PP_LCNet, self).__init__()
|
||||
self.scale = scale
|
||||
self.conv1 = nn.Conv2d(3, out_channels=make_divisible(16 * self.scale),
|
||||
kernel_size=3, stride=2, padding=1, bias=False)
|
||||
# k, in_c, out_c, s, use_se inp, oup, dw_size, stride, use_se=False
|
||||
self.blocks2 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks2"])
|
||||
])
|
||||
|
||||
self.blocks3 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks3"])
|
||||
])
|
||||
|
||||
self.blocks4 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks4"])
|
||||
])
|
||||
# k, in_c, out_c, s, use_se inp, oup, dw_size, stride, use_se=False
|
||||
self.blocks5 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks5"])
|
||||
])
|
||||
|
||||
self.blocks6 = nn.Sequential(*[
|
||||
DepthwiseSeparable(inp=make_divisible(in_c * self.scale),
|
||||
oup=make_divisible(out_c * self.scale),
|
||||
dw_size=k, stride=s, use_se=use_se)
|
||||
for i, (k, in_c, out_c, s, use_se) in enumerate(NET_CONFIG["blocks6"])
|
||||
])
|
||||
|
||||
self.GAP = nn.AdaptiveAvgPool2d(1)
|
||||
|
||||
self.last_conv = nn.Conv2d(in_channels=make_divisible(NET_CONFIG["blocks6"][-1][2] * scale),
|
||||
out_channels=class_expand,
|
||||
kernel_size=1, stride=1, padding=0, bias=False)
|
||||
|
||||
self.hardswish = HardSwish()
|
||||
self.dropout = nn.Dropout(p=dropout_prob)
|
||||
|
||||
self.fc = nn.Linear(class_expand, class_num)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
print(x.shape)
|
||||
x = self.blocks2(x)
|
||||
print(x.shape)
|
||||
x = self.blocks3(x)
|
||||
print(x.shape)
|
||||
x = self.blocks4(x)
|
||||
print(x.shape)
|
||||
x = self.blocks5(x)
|
||||
print(x.shape)
|
||||
x = self.blocks6(x)
|
||||
print(x.shape)
|
||||
|
||||
x = self.GAP(x)
|
||||
x = self.last_conv(x)
|
||||
x = self.hardswish(x)
|
||||
x = self.dropout(x)
|
||||
x = torch.flatten(x, start_dim=1, end_dim=-1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def PPLCNET_x0_25(**kwargs):
|
||||
model = PP_LCNet(scale=0.25, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x0_35(**kwargs):
|
||||
model = PP_LCNet(scale=0.35, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x0_5(**kwargs):
|
||||
model = PP_LCNet(scale=0.5, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x0_75(**kwargs):
|
||||
model = PP_LCNet(scale=0.75, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x1_0(**kwargs):
|
||||
model = PP_LCNet(scale=1.0, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x1_5(**kwargs):
|
||||
model = PP_LCNet(scale=1.5, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def PPLCNET_x2_0(**kwargs):
|
||||
model = PP_LCNet(scale=2.0, **kwargs)
|
||||
return model
|
||||
|
||||
def PPLCNET_x2_5(**kwargs):
|
||||
model = PP_LCNet(scale=2.5, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# input = torch.randn(1, 3, 640, 640)
|
||||
# model = PPLCNET_x2_5()
|
||||
# flops, params = thop.profile(model, inputs=(input,))
|
||||
# print('flops:', flops / 1000000000)
|
||||
# print('params:', params / 1000000)
|
||||
|
||||
model = PPLCNET_x1_0()
|
||||
# model_1 = PW_Conv(3, 16)
|
||||
input = torch.randn(2, 3, 256, 256)
|
||||
print(input.shape)
|
||||
output = model(input)
|
||||
print(output.shape) # [1, num_class]
|
||||
|
18
contrast/feat_extract/model/loss.py
Normal file
18
contrast/feat_extract/model/loss.py
Normal file
@ -0,0 +1,18 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
|
||||
def __init__(self, gamma=2):
|
||||
super().__init__()
|
||||
self.gamma = gamma
|
||||
self.ce = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, input, target):
|
||||
|
||||
#print(f'theta {input.shape, input[0]}, target {target.shape, target}')
|
||||
logp = self.ce(input, target)
|
||||
p = torch.exp(-logp)
|
||||
loss = (1 - p) ** self.gamma * logp
|
||||
return loss.mean()
|
83
contrast/feat_extract/model/metric.py
Normal file
83
contrast/feat_extract/model/metric.py
Normal file
@ -0,0 +1,83 @@
|
||||
# Definition of ArcFace loss and CosFace loss
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ArcFace(nn.Module):
|
||||
|
||||
def __init__(self, embedding_size, class_num, s=30.0, m=0.50):
|
||||
"""ArcFace formula:
|
||||
cos(m + theta) = cos(m)cos(theta) - sin(m)sin(theta)
|
||||
Note that:
|
||||
0 <= m + theta <= Pi
|
||||
So if (m + theta) >= Pi, then theta >= Pi - m. In [0, Pi]
|
||||
we have:
|
||||
cos(theta) < cos(Pi - m)
|
||||
So we can use cos(Pi - m) as threshold to check whether
|
||||
(m + theta) go out of [0, Pi]
|
||||
|
||||
Args:
|
||||
embedding_size: usually 128, 256, 512 ...
|
||||
class_num: num of people when training
|
||||
s: scale, see normface https://arxiv.org/abs/1704.06369
|
||||
m: margin, see SphereFace, CosFace, and ArcFace paper
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_features = embedding_size
|
||||
self.out_features = class_num
|
||||
self.s = s
|
||||
self.m = m
|
||||
self.weight = nn.Parameter(torch.FloatTensor(class_num, embedding_size))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
|
||||
self.cos_m = math.cos(m)
|
||||
self.sin_m = math.sin(m)
|
||||
self.th = math.cos(math.pi - m)
|
||||
self.mm = math.sin(math.pi - m) * m
|
||||
|
||||
def forward(self, input, label):
|
||||
#print(f"embding {self.in_features}, class_num {self.out_features}, input {len(input)}, label {len(label)}")
|
||||
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
||||
# print('F.normalize(input)',input.shape)
|
||||
# print('F.normalize(self.weight)',F.normalize(self.weight).shape)
|
||||
sine = ((1.0 - cosine.pow(2)).clamp(0, 1)).sqrt()
|
||||
phi = cosine * self.cos_m - sine * self.sin_m
|
||||
phi = torch.where(cosine > self.th, phi, cosine - self.mm) # drop to CosFace
|
||||
#print(f'consine {cosine.shape, cosine}, sine {sine.shape, sine}, phi {phi.shape, phi}')
|
||||
# update y_i by phi in cosine
|
||||
output = cosine * 1.0 # make backward works
|
||||
batch_size = len(output)
|
||||
output[range(batch_size), label] = phi[range(batch_size), label]
|
||||
# print(f'output {(output * self.s).shape}')
|
||||
# print(f'phi[range(batch_size), label] {phi[range(batch_size), label]}')
|
||||
return output * self.s
|
||||
|
||||
|
||||
class CosFace(nn.Module):
|
||||
|
||||
def __init__(self, in_features, out_features, s=30.0, m=0.40):
|
||||
"""
|
||||
Args:
|
||||
embedding_size: usually 128, 256, 512 ...
|
||||
class_num: num of people when training
|
||||
s: scale, see normface https://arxiv.org/abs/1704.06369
|
||||
m: margin, see SphereFace, CosFace, and ArcFace paper
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.s = s
|
||||
self.m = m
|
||||
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
|
||||
def forward(self, input, label):
|
||||
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
||||
phi = cosine - self.m
|
||||
output = cosine * 1.0 # make backward works
|
||||
batch_size = len(output)
|
||||
output[range(batch_size), label] = phi[range(batch_size), label]
|
||||
return output * self.s
|
148
contrast/feat_extract/model/mobilenet_v1.py
Normal file
148
contrast/feat_extract/model/mobilenet_v1.py
Normal file
@ -0,0 +1,148 @@
|
||||
# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from typing import Callable, Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from torchvision.ops.misc import Conv2dNormActivation
|
||||
from config import config as conf
|
||||
|
||||
__all__ = [
|
||||
"MobileNetV1",
|
||||
"DepthWiseSeparableConv2d",
|
||||
"mobilenet_v1",
|
||||
]
|
||||
|
||||
|
||||
class MobileNetV1(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int = conf.embedding_size,
|
||||
) -> None:
|
||||
super(MobileNetV1, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
Conv2dNormActivation(3,
|
||||
32,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
activation_layer=nn.ReLU,
|
||||
inplace=True,
|
||||
bias=False,
|
||||
),
|
||||
|
||||
DepthWiseSeparableConv2d(32, 64, 1),
|
||||
DepthWiseSeparableConv2d(64, 128, 2),
|
||||
DepthWiseSeparableConv2d(128, 128, 1),
|
||||
DepthWiseSeparableConv2d(128, 256, 2),
|
||||
DepthWiseSeparableConv2d(256, 256, 1),
|
||||
DepthWiseSeparableConv2d(256, 512, 2),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 512, 1),
|
||||
DepthWiseSeparableConv2d(512, 1024, 2),
|
||||
DepthWiseSeparableConv2d(1024, 1024, 1),
|
||||
)
|
||||
|
||||
self.avgpool = nn.AvgPool2d((7, 7))
|
||||
|
||||
self.classifier = nn.Linear(1024, num_classes)
|
||||
|
||||
# Initialize neural network weights
|
||||
self._initialize_weights()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
out = self._forward_impl(x)
|
||||
|
||||
return out
|
||||
|
||||
# Support torch.script function
|
||||
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||
out = self.features(x)
|
||||
out = self.avgpool(out)
|
||||
out = torch.flatten(out, 1)
|
||||
out = self.classifier(out)
|
||||
|
||||
return out
|
||||
|
||||
def _initialize_weights(self) -> None:
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, 0, 0.01)
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
|
||||
class DepthWiseSeparableConv2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(DepthWiseSeparableConv2d, self).__init__()
|
||||
self.stride = stride
|
||||
if stride not in [1, 2]:
|
||||
raise ValueError(f"stride should be 1 or 2 instead of {stride}")
|
||||
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
Conv2dNormActivation(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
norm_layer=norm_layer,
|
||||
activation_layer=nn.ReLU,
|
||||
inplace=True,
|
||||
bias=False,
|
||||
),
|
||||
Conv2dNormActivation(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
norm_layer=norm_layer,
|
||||
activation_layer=nn.ReLU,
|
||||
inplace=True,
|
||||
bias=False,
|
||||
),
|
||||
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
out = self.conv(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def mobilenet_v1(**kwargs: Any) -> MobileNetV1:
|
||||
model = MobileNetV1(**kwargs)
|
||||
|
||||
return model
|
200
contrast/feat_extract/model/mobilenet_v2.py
Normal file
200
contrast/feat_extract/model/mobilenet_v2.py
Normal file
@ -0,0 +1,200 @@
|
||||
from torch import nn
|
||||
from .utils import load_state_dict_from_url
|
||||
from ..config import config as conf
|
||||
|
||||
__all__ = ['MobileNetV2', 'mobilenet_v2']
|
||||
|
||||
|
||||
model_urls = {
|
||||
'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
|
||||
}
|
||||
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_value:
|
||||
:return:
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Sequential):
|
||||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None):
|
||||
padding = (kernel_size - 1) // 2
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
super(ConvBNReLU, self).__init__(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
||||
norm_layer(out_planes),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
# pw
|
||||
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
|
||||
layers.extend([
|
||||
# dw
|
||||
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
norm_layer(oup),
|
||||
])
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(self,
|
||||
num_classes=conf.embedding_size,
|
||||
width_mult=1.0,
|
||||
inverted_residual_setting=None,
|
||||
round_nearest=8,
|
||||
block=None,
|
||||
norm_layer=None):
|
||||
"""
|
||||
MobileNet V2 main class
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes
|
||||
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||
inverted_residual_setting: Network structure
|
||||
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||
Set to 1 to turn off rounding
|
||||
block: Module specifying inverted residual building block for mobilenet
|
||||
norm_layer: Module specifying the normalization layer to use
|
||||
|
||||
"""
|
||||
super(MobileNetV2, self).__init__()
|
||||
|
||||
if block is None:
|
||||
block = InvertedResidual
|
||||
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
|
||||
if inverted_residual_setting is None:
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
# only check the first element, assuming user knows t,c,n,s are required
|
||||
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
||||
raise ValueError("inverted_residual_setting should be non-empty "
|
||||
"or a 4-element list, got {}".format(inverted_residual_setting))
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
||||
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||
features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
|
||||
input_channel = output_channel
|
||||
# building last several layers
|
||||
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
# building classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(self.last_channel, num_classes),
|
||||
)
|
||||
|
||||
# weight initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
||||
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
||||
x = self.features(x)
|
||||
# Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
|
||||
x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
def mobilenet_v2(pretrained=True, progress=True, **kwargs):
|
||||
"""
|
||||
Constructs a MobileNetV2 architecture from
|
||||
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
model = MobileNetV2(**kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
|
||||
progress=progress)
|
||||
src_state_dict = state_dict
|
||||
target_state_dict = model.state_dict()
|
||||
skip_keys = []
|
||||
# skip mismatch size tensors in case of pretraining
|
||||
for k in src_state_dict.keys():
|
||||
if k not in target_state_dict:
|
||||
continue
|
||||
if src_state_dict[k].size() != target_state_dict[k].size():
|
||||
skip_keys.append(k)
|
||||
for k in skip_keys:
|
||||
del src_state_dict[k]
|
||||
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
|
||||
#.load_state_dict(state_dict)
|
||||
return model
|
200
contrast/feat_extract/model/mobilenet_v3.py
Normal file
200
contrast/feat_extract/model/mobilenet_v3.py
Normal file
@ -0,0 +1,200 @@
|
||||
'''MobileNetV3 in PyTorch.
|
||||
|
||||
See the paper "Inverted Residuals and Linear Bottlenecks:
|
||||
Mobile Networks for Classification, Detection and Segmentation" for more details.
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import init
|
||||
from ..config import config as conf
|
||||
|
||||
|
||||
class hswish(nn.Module):
|
||||
def forward(self, x):
|
||||
out = x * F.relu6(x + 3, inplace=True) / 6
|
||||
return out
|
||||
|
||||
|
||||
class hsigmoid(nn.Module):
|
||||
def forward(self, x):
|
||||
out = F.relu6(x + 3, inplace=True) / 6
|
||||
return out
|
||||
|
||||
|
||||
class SeModule(nn.Module):
|
||||
def __init__(self, in_size, reduction=4):
|
||||
super(SeModule, self).__init__()
|
||||
self.se = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(in_size // reduction),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(in_size),
|
||||
hsigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.se(x)
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
'''expand + depthwise + pointwise'''
|
||||
def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
|
||||
super(Block, self).__init__()
|
||||
self.stride = stride
|
||||
self.se = semodule
|
||||
|
||||
self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(expand_size)
|
||||
self.nolinear1 = nolinear
|
||||
self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(expand_size)
|
||||
self.nolinear2 = nolinear
|
||||
self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(out_size)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride == 1 and in_size != out_size:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(out_size),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.nolinear1(self.bn1(self.conv1(x)))
|
||||
out = self.nolinear2(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
if self.se != None:
|
||||
out = self.se(out)
|
||||
out = out + self.shortcut(x) if self.stride==1 else out
|
||||
return out
|
||||
|
||||
|
||||
class MobileNetV3_Large(nn.Module):
|
||||
def __init__(self, num_classes=conf.embedding_size):
|
||||
super(MobileNetV3_Large, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(16)
|
||||
self.hs1 = hswish()
|
||||
|
||||
self.bneck = nn.Sequential(
|
||||
Block(3, 16, 16, 16, nn.ReLU(inplace=True), None, 1),
|
||||
Block(3, 16, 64, 24, nn.ReLU(inplace=True), None, 2),
|
||||
Block(3, 24, 72, 24, nn.ReLU(inplace=True), None, 1),
|
||||
Block(5, 24, 72, 40, nn.ReLU(inplace=True), SeModule(40), 2),
|
||||
Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
|
||||
Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
|
||||
Block(3, 40, 240, 80, hswish(), None, 2),
|
||||
Block(3, 80, 200, 80, hswish(), None, 1),
|
||||
Block(3, 80, 184, 80, hswish(), None, 1),
|
||||
Block(3, 80, 184, 80, hswish(), None, 1),
|
||||
Block(3, 80, 480, 112, hswish(), SeModule(112), 1),
|
||||
Block(3, 112, 672, 112, hswish(), SeModule(112), 1),
|
||||
Block(5, 112, 672, 160, hswish(), SeModule(160), 1),
|
||||
Block(5, 160, 672, 160, hswish(), SeModule(160), 2),
|
||||
Block(5, 160, 960, 160, hswish(), SeModule(160), 1),
|
||||
)
|
||||
|
||||
|
||||
self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(960)
|
||||
self.hs2 = hswish()
|
||||
self.linear3 = nn.Linear(960, 1280)
|
||||
self.bn3 = nn.BatchNorm1d(1280)
|
||||
self.hs3 = hswish()
|
||||
self.linear4 = nn.Linear(1280, num_classes)
|
||||
self.init_params()
|
||||
|
||||
def init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.hs1(self.bn1(self.conv1(x)))
|
||||
out = self.bneck(out)
|
||||
out = self.hs2(self.bn2(self.conv2(out)))
|
||||
out = F.avg_pool2d(out, conf.img_size // 32)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.hs3(self.bn3(self.linear3(out)))
|
||||
out = self.linear4(out)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
class MobileNetV3_Small(nn.Module):
|
||||
def __init__(self, num_classes=conf.embedding_size):
|
||||
super(MobileNetV3_Small, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(16)
|
||||
self.hs1 = hswish()
|
||||
|
||||
self.bneck = nn.Sequential(
|
||||
Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),
|
||||
Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
|
||||
Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
|
||||
Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
|
||||
Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
|
||||
Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
|
||||
Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
|
||||
Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
|
||||
Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
|
||||
Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
|
||||
Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
|
||||
)
|
||||
|
||||
|
||||
self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(576)
|
||||
self.hs2 = hswish()
|
||||
self.linear3 = nn.Linear(576, 1280)
|
||||
self.bn3 = nn.BatchNorm1d(1280)
|
||||
self.hs3 = hswish()
|
||||
self.linear4 = nn.Linear(1280, num_classes)
|
||||
self.init_params()
|
||||
|
||||
def init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.hs1(self.bn1(self.conv1(x)))
|
||||
out = self.bneck(out)
|
||||
out = self.hs2(self.bn2(self.conv2(out)))
|
||||
out = F.avg_pool2d(out, conf.img_size // 32)
|
||||
out = out.view(out.size(0), -1)
|
||||
|
||||
out = self.hs3(self.bn3(self.linear3(out)))
|
||||
out = self.linear4(out)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
def test():
|
||||
net = MobileNetV3_Small()
|
||||
x = torch.randn(2,3,224,224)
|
||||
y = net(x)
|
||||
print(y.size())
|
||||
|
||||
# test()
|
268
contrast/feat_extract/model/mobilevit.py
Normal file
268
contrast/feat_extract/model/mobilevit.py
Normal file
@ -0,0 +1,268 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
# import sys
|
||||
# sys.path.append(r"D:\DetectTracking")
|
||||
from ..config import config as conf
|
||||
|
||||
|
||||
def conv_1x1_bn(inp, oup):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
|
||||
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
attn = self.attend(dots)
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b p h n d -> b p n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
|
||||
class MV2Block(nn.Module):
|
||||
def __init__(self, inp, oup, stride=1, expansion=4):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(inp * expansion)
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
if expansion == 1:
|
||||
self.conv = nn.Sequential(
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.SiLU(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileViTBlock(nn.Module):
|
||||
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.ph, self.pw = patch_size
|
||||
|
||||
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
|
||||
self.conv2 = conv_1x1_bn(channel, dim)
|
||||
|
||||
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
|
||||
|
||||
self.conv3 = conv_1x1_bn(dim, channel)
|
||||
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
y = x.clone()
|
||||
|
||||
# Local representations
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
# Global representations
|
||||
_, _, h, w = x.shape
|
||||
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
|
||||
x = self.transformer(x)
|
||||
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
|
||||
pw=self.pw)
|
||||
|
||||
# Fusion
|
||||
x = self.conv3(x)
|
||||
x = torch.cat((x, y), 1)
|
||||
x = self.conv4(x)
|
||||
return x
|
||||
|
||||
|
||||
class MobileViT(nn.Module):
|
||||
def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):
|
||||
super().__init__()
|
||||
ih, iw = image_size
|
||||
ph, pw = patch_size
|
||||
assert ih % ph == 0 and iw % pw == 0
|
||||
|
||||
L = [2, 4, 3]
|
||||
|
||||
self.conv1 = conv_nxn_bn(3, channels[0], stride=2)
|
||||
|
||||
self.mv2 = nn.ModuleList([])
|
||||
self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
|
||||
self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
|
||||
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
|
||||
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) # Repeat
|
||||
self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
|
||||
self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
|
||||
self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))
|
||||
|
||||
self.mvit = nn.ModuleList([])
|
||||
self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)))
|
||||
self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)))
|
||||
self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)))
|
||||
|
||||
self.conv2 = conv_1x1_bn(channels[-2], channels[-1])
|
||||
|
||||
self.pool = nn.AvgPool2d(ih // 32, 1)
|
||||
self.fc = nn.Linear(channels[-1], num_classes, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
#print('x',x.shape)
|
||||
x = self.conv1(x)
|
||||
x = self.mv2[0](x)
|
||||
|
||||
x = self.mv2[1](x)
|
||||
x = self.mv2[2](x)
|
||||
x = self.mv2[3](x) # Repeat
|
||||
|
||||
x = self.mv2[4](x)
|
||||
x = self.mvit[0](x)
|
||||
|
||||
x = self.mv2[5](x)
|
||||
x = self.mvit[1](x)
|
||||
|
||||
x = self.mv2[6](x)
|
||||
x = self.mvit[2](x)
|
||||
x = self.conv2(x)
|
||||
|
||||
|
||||
#print('pool_before',x.shape)
|
||||
x = self.pool(x).view(-1, x.shape[1])
|
||||
#print('self_pool',self.pool)
|
||||
#print('pool_after',x.shape)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def mobilevit_xxs():
|
||||
dims = [64, 80, 96]
|
||||
channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320]
|
||||
return MobileViT((256, 256), dims, channels, num_classes=1000, expansion=2)
|
||||
|
||||
|
||||
def mobilevit_xs():
|
||||
dims = [96, 120, 144]
|
||||
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384]
|
||||
return MobileViT((256, 256), dims, channels, num_classes=1000)
|
||||
|
||||
|
||||
def mobilevit_s():
|
||||
dims = [144, 192, 240]
|
||||
channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640]
|
||||
return MobileViT((conf.img_size, conf.img_size), dims, channels, num_classes=conf.embedding_size)
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
img = torch.randn(5, 3, 256, 256)
|
||||
|
||||
vit = mobilevit_xxs()
|
||||
out = vit(img)
|
||||
print(out.shape)
|
||||
print(count_parameters(vit))
|
||||
|
||||
vit = mobilevit_xs()
|
||||
out = vit(img)
|
||||
print(out.shape)
|
||||
print(count_parameters(vit))
|
||||
|
||||
vit = mobilevit_s()
|
||||
out = vit(img)
|
||||
print(out.shape)
|
||||
print(count_parameters(vit))
|
145
contrast/feat_extract/model/resbam.py
Normal file
145
contrast/feat_extract/model/resbam.py
Normal file
@ -0,0 +1,145 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .CBAM import CBAM
|
||||
from .Tool import GeM as gem
|
||||
# from model.CBAM import CBAM
|
||||
# from model.Tool import GeM as gem
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inchannel, outchannel, stride=1, dowsample=None):
|
||||
# super(Bottleneck, self).__init__()
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels=inchannel, out_channels=outchannel, kernel_size=1, stride=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(outchannel)
|
||||
self.conv2 = nn.Conv2d(in_channels=outchannel, out_channels=outchannel, kernel_size=3, bias=False,
|
||||
stride=stride, padding=1)
|
||||
self.bn2 = nn.BatchNorm2d(outchannel)
|
||||
self.conv3 = nn.Conv2d(in_channels=outchannel, out_channels=outchannel * self.expansion, stride=1, bias=False,
|
||||
kernel_size=1)
|
||||
self.bn3 = nn.BatchNorm2d(outchannel * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = dowsample
|
||||
|
||||
def forward(self, x):
|
||||
self.identity = x
|
||||
# print('>>>>>>>>',type(x))
|
||||
if self.downsample is not None:
|
||||
# print('>>>>downsample>>>>', type(self.downsample))
|
||||
self.identity = self.downsample(x)
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
# print('>>>>out>>>identity',out.size(),self.identity.size())
|
||||
out = out + self.identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class resnet(nn.Module):
|
||||
def __init__(self, block=Bottleneck, block_num=[3, 4, 6, 3], num_class=1000):
|
||||
super().__init__()
|
||||
self.in_channel = 64
|
||||
self.conv1 = nn.Conv2d(in_channels=3,
|
||||
out_channels=self.in_channel,
|
||||
stride=2,
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(self.in_channel)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.cbam = CBAM(self.in_channel)
|
||||
self.cbam1 = CBAM(2048)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, block_num[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 128, block_num[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, block_num[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, block_num[3], stride=2)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.gem = gem()
|
||||
self.fc = nn.Linear(512 * block.expansion, num_class)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal(m.weight, mode='fan_out',
|
||||
nonlinearity='relu')
|
||||
if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 1.0)
|
||||
|
||||
def _make_layer(self, block, channel, block_num, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.in_channel != channel * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(channel * block.expansion))
|
||||
layer = []
|
||||
layer.append(block(self.in_channel, channel, stride, downsample))
|
||||
self.in_channel = channel * block.expansion
|
||||
for _ in range(1, block_num):
|
||||
layer.append(block(self.in_channel, channel))
|
||||
return nn.Sequential(*layer)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.cbam(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.cbam1(x)
|
||||
# x = self.avgpool(x)
|
||||
x = self.gem(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
class TripletNet(nn.Module):
|
||||
def __init__(self, num_class, flag=True):
|
||||
super(TripletNet, self).__init__()
|
||||
self.initnet = rescbam(num_class)
|
||||
self.flag = flag
|
||||
|
||||
def forward(self, x1, x2=None, x3=None):
|
||||
if self.flag:
|
||||
output1 = self.initnet(x1)
|
||||
output2 = self.initnet(x2)
|
||||
output3 = self.initnet(x3)
|
||||
return output1, output2, output3
|
||||
else:
|
||||
output = self.initnet(x1)
|
||||
return output
|
||||
|
||||
|
||||
def rescbam(num_class):
|
||||
return resnet(block=Bottleneck, block_num=[3, 4, 6, 3], num_class=num_class)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
input1 = torch.randn(4, 3, 640, 640)
|
||||
input2 = torch.randn(4, 3, 640, 640)
|
||||
input3 = torch.randn(4, 3, 640, 640)
|
||||
|
||||
# rescbam测试
|
||||
# Resnet50 = rescbam(512)
|
||||
# output = Resnet50.forward(input1)
|
||||
# print(Resnet50)
|
||||
|
||||
# trnet测试
|
||||
trnet = TripletNet(512)
|
||||
output = trnet(input1, input2, input3)
|
||||
print(output)
|
189
contrast/feat_extract/model/resnet.py
Normal file
189
contrast/feat_extract/model/resnet.py
Normal file
@ -0,0 +1,189 @@
|
||||
"""resnet in pytorch
|
||||
|
||||
|
||||
|
||||
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
|
||||
|
||||
Deep Residual Learning for Image Recognition
|
||||
https://arxiv.org/abs/1512.03385v1
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from config import config as conf
|
||||
from CBAM import CBAM
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
"""Basic Block for resnet 18 and resnet 34
|
||||
|
||||
"""
|
||||
|
||||
#BasicBlock and BottleNeck block
|
||||
#have different output size
|
||||
#we use class attribute expansion
|
||||
#to distinct
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super().__init__()
|
||||
|
||||
#residual function
|
||||
self.residual_function = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
||||
)
|
||||
|
||||
#shortcut
|
||||
self.shortcut = nn.Sequential()
|
||||
|
||||
#the shortcut output dimension is not the same with residual function
|
||||
#use 1*1 convolution to match the dimension
|
||||
if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||
|
||||
class BottleNeck(nn.Module):
|
||||
"""Residual block for resnet over 50 layers
|
||||
|
||||
"""
|
||||
expansion = 4
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super().__init__()
|
||||
self.residual_function = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BottleNeck.expansion),
|
||||
)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
|
||||
if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BottleNeck.expansion)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, num_block, cbam = False, num_classes=conf.embedding_size):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = 64
|
||||
|
||||
# self.conv1 = nn.Sequential(
|
||||
# nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64),
|
||||
# nn.ReLU(inplace=True))
|
||||
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(3, 64,stride=2,kernel_size=7,padding=3,bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
||||
|
||||
self.cbam = CBAM(self.in_channels)
|
||||
|
||||
#we use a different inputsize than the original paper
|
||||
#so conv2_x's stride is 1
|
||||
self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
|
||||
self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
|
||||
self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
|
||||
self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
|
||||
self.cbam1 = CBAM(self.in_channels)
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal(m.weight,mode = 'fan_out',
|
||||
nonlinearity='relu')
|
||||
if isinstance(m, (nn.BatchNorm2d)):
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 1.0)
|
||||
|
||||
def _make_layer(self, block, out_channels, num_blocks, stride):
|
||||
"""make resnet layers(by layer i didnt mean this 'layer' was the
|
||||
same as a neuron netowork layer, ex. conv layer), one layer may
|
||||
contain more than one residual block
|
||||
|
||||
Args:
|
||||
block: block type, basic block or bottle neck block
|
||||
out_channels: output depth channel number of this layer
|
||||
num_blocks: how many blocks per layer
|
||||
stride: the stride of the first block of this layer
|
||||
|
||||
Return:
|
||||
return a resnet layer
|
||||
"""
|
||||
|
||||
# we have num_block blocks per layer, the first block
|
||||
# could be 1 or 2, other blocks would always be 1
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_channels, out_channels, stride))
|
||||
self.in_channels = out_channels * block.expansion
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.conv1(x)
|
||||
if cbam:
|
||||
output = self.cbam(x)
|
||||
output = self.conv2_x(output)
|
||||
output = self.conv3_x(output)
|
||||
output = self.conv4_x(output)
|
||||
output = self.conv5_x(output)
|
||||
if cbam:
|
||||
output = self.cbam1(x)
|
||||
print('pollBefore',output.shape)
|
||||
output = self.avg_pool(output)
|
||||
print('poolAfter',output.shape)
|
||||
output = output.view(output.size(0), -1)
|
||||
print('fcBefore',output.shape)
|
||||
output = self.fc(output)
|
||||
|
||||
return output
|
||||
|
||||
def resnet18(cbam = False):
|
||||
""" return a ResNet 18 object
|
||||
"""
|
||||
return ResNet(BasicBlock, [2, 2, 2, 2], cbam)
|
||||
|
||||
def resnet34():
|
||||
""" return a ResNet 34 object
|
||||
"""
|
||||
return ResNet(BasicBlock, [3, 4, 6, 3])
|
||||
|
||||
def resnet50():
|
||||
""" return a ResNet 50 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 4, 6, 3])
|
||||
|
||||
def resnet101():
|
||||
""" return a ResNet 101 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 4, 23, 3])
|
||||
|
||||
def resnet152():
|
||||
""" return a ResNet 152 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 8, 36, 3])
|
||||
|
||||
|
121
contrast/feat_extract/model/resnet_face.py
Normal file
121
contrast/feat_extract/model/resnet_face.py
Normal file
@ -0,0 +1,121 @@
|
||||
""" Resnet_IR_SE in ArcFace """
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, x):
|
||||
return x.reshape(x.shape[0], -1)
|
||||
|
||||
|
||||
class SEConv(nn.Module):
|
||||
"""Use Convolution instead of FullyConnection in SE"""
|
||||
|
||||
def __init__(self, channels, reduction):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(channels, channels // reduction, kernel_size=1, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(channels // reduction, channels, kernel_size=1, bias=False),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) * x
|
||||
|
||||
|
||||
class SE(nn.Module):
|
||||
|
||||
def __init__(self, channels, reduction):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Linear(channels, channels // reduction),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channels // reduction, channels),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) * x
|
||||
|
||||
|
||||
class IRSE(nn.Module):
|
||||
|
||||
def __init__(self, channels, depth, stride):
|
||||
super().__init__()
|
||||
if channels == depth:
|
||||
self.shortcut = nn.MaxPool2d(kernel_size=1, stride=stride)
|
||||
else:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(channels, depth, (1, 1), stride, bias=False),
|
||||
nn.BatchNorm2d(depth),
|
||||
)
|
||||
self.residual = nn.Sequential(
|
||||
nn.BatchNorm2d(channels),
|
||||
nn.Conv2d(channels, depth, (3, 3), 1, 1, bias=False),
|
||||
nn.PReLU(depth),
|
||||
nn.Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
||||
nn.BatchNorm2d(depth),
|
||||
SEConv(depth, 16),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.shortcut(x) + self.residual(x)
|
||||
|
||||
|
||||
class ResIRSE(nn.Module):
|
||||
"""Resnet50-IRSE backbone"""
|
||||
|
||||
def __init__(self, ih, embedding_size, drop_ratio):
|
||||
super().__init__()
|
||||
ih_last = ih // 16
|
||||
self.input_layer = nn.Sequential(
|
||||
nn.Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.PReLU(64),
|
||||
)
|
||||
self.output_layer = nn.Sequential(
|
||||
nn.BatchNorm2d(512),
|
||||
nn.Dropout(drop_ratio),
|
||||
Flatten(),
|
||||
nn.Linear(512 * ih_last * ih_last, embedding_size),
|
||||
nn.BatchNorm1d(embedding_size),
|
||||
)
|
||||
|
||||
# ["channels", "depth", "stride"],
|
||||
self.res50_arch = [
|
||||
[64, 64, 2], [64, 64, 1], [64, 64, 1],
|
||||
[64, 128, 2], [128, 128, 1], [128, 128, 1], [128, 128, 1],
|
||||
[128, 256, 2], [256, 256, 1], [256, 256, 1], [256, 256, 1], [256, 256, 1],
|
||||
[256, 256, 1], [256, 256, 1], [256, 256, 1], [256, 256, 1], [256, 256, 1],
|
||||
[256, 256, 1], [256, 256, 1], [256, 256, 1], [256, 256, 1],
|
||||
[256, 512, 2], [512, 512, 1], [512, 512, 1],
|
||||
]
|
||||
|
||||
self.body = nn.Sequential(*[IRSE(a, b, c) for (a, b, c) in self.res50_arch])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.input_layer(x)
|
||||
x = self.body(x)
|
||||
x = self.output_layer(x)
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
x = Image.open("../samples/009.jpg").convert('L')
|
||||
x = x.resize((128, 128))
|
||||
x = np.asarray(x, dtype=np.float32)
|
||||
x = x[None, None, ...]
|
||||
x = torch.from_numpy(x)
|
||||
net = ResIRSE(512, 0.6)
|
||||
net.eval()
|
||||
with torch.no_grad():
|
||||
out = net(x)
|
||||
print(out.shape)
|
462
contrast/feat_extract/model/resnet_pre.py
Normal file
462
contrast/feat_extract/model/resnet_pre.py
Normal file
@ -0,0 +1,462 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..config import config as conf
|
||||
|
||||
try:
|
||||
from torch.hub import load_state_dict_from_url
|
||||
except ImportError:
|
||||
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
||||
# from .utils import load_state_dict_from_url
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||
'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
def __init__(self, kernel_size=7):
|
||||
super(SpatialAttention, self).__init__()
|
||||
|
||||
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
||||
padding = 3 if kernel_size == 7 else 1
|
||||
|
||||
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
avg_out = torch.mean(x, dim=1, keepdim=True)
|
||||
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
||||
x = torch.cat([avg_out, max_out], dim=1)
|
||||
x = self.conv1(x)
|
||||
return self.sigmoid(x)
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None, cam=False, bam=False):
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
self.cam = cam
|
||||
self.bam = bam
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
if self.cam:
|
||||
if planes == 64:
|
||||
self.globalAvgPool = nn.AvgPool2d(56, stride=1)
|
||||
elif planes == 128:
|
||||
self.globalAvgPool = nn.AvgPool2d(28, stride=1)
|
||||
elif planes == 256:
|
||||
self.globalAvgPool = nn.AvgPool2d(14, stride=1)
|
||||
elif planes == 512:
|
||||
self.globalAvgPool = nn.AvgPool2d(7, stride=1)
|
||||
|
||||
self.fc1 = nn.Linear(in_features=planes, out_features=round(planes / 16))
|
||||
self.fc2 = nn.Linear(in_features=round(planes / 16), out_features=planes)
|
||||
self.sigmod = nn.Sigmoid()
|
||||
if self.bam:
|
||||
self.bam = SpatialAttention()
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
if self.cam:
|
||||
ori_out = self.globalAvgPool(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.fc1(out)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
out = self.sigmod(out)
|
||||
out = out.view(out.size(0), out.size(-1), 1, 1)
|
||||
out = out * ori_out
|
||||
|
||||
if self.bam:
|
||||
out = out*self.bam(out)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None, cam=False, bam=False):
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
self.cam = cam
|
||||
self.bam = bam
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
if self.cam:
|
||||
if planes == 64:
|
||||
self.globalAvgPool = nn.AvgPool2d(56, stride=1)
|
||||
elif planes == 128:
|
||||
self.globalAvgPool = nn.AvgPool2d(28, stride=1)
|
||||
elif planes == 256:
|
||||
self.globalAvgPool = nn.AvgPool2d(14, stride=1)
|
||||
elif planes == 512:
|
||||
self.globalAvgPool = nn.AvgPool2d(7, stride=1)
|
||||
|
||||
self.fc1 = nn.Linear(planes * self.expansion, round(planes / 4))
|
||||
self.fc2 = nn.Linear(round(planes / 4), planes * self.expansion)
|
||||
self.sigmod = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
if self.cam:
|
||||
ori_out = self.globalAvgPool(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.fc1(out)
|
||||
out = self.relu(out)
|
||||
out = self.fc2(out)
|
||||
out = self.sigmod(out)
|
||||
out = out.view(out.size(0), out.size(-1), 1, 1)
|
||||
out = out * ori_out
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, num_classes=conf.embedding_size, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None, scale=0.75):
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, int(64*scale), layers[0])
|
||||
self.layer2 = self._make_layer(block, int(128*scale), layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, int(256*scale), layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, int(512*scale), layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(int(512 * block.expansion*scale), num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
# print('poolBefore', x.shape)
|
||||
x = self.avgpool(x)
|
||||
# print('poolAfter', x.shape)
|
||||
x = torch.flatten(x, 1)
|
||||
# print('fcBefore',x.shape)
|
||||
x = self.fc(x)
|
||||
|
||||
# print('fcAfter',x.shape)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
# def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||
# model = ResNet(block, layers, **kwargs)
|
||||
# if pretrained:
|
||||
# state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
# progress=progress)
|
||||
# model.load_state_dict(state_dict, strict=False)
|
||||
# return model
|
||||
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||
model = ResNet(block, layers, **kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
progress=progress)
|
||||
|
||||
src_state_dict = state_dict
|
||||
target_state_dict = model.state_dict()
|
||||
skip_keys = []
|
||||
# skip mismatch size tensors in case of pretraining
|
||||
for k in src_state_dict.keys():
|
||||
if k not in target_state_dict:
|
||||
continue
|
||||
if src_state_dict[k].size() != target_state_dict[k].size():
|
||||
skip_keys.append(k)
|
||||
for k in skip_keys:
|
||||
del src_state_dict[k]
|
||||
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def resnet14(pretrained=True, progress=True, **kwargs):
|
||||
r"""ResNet-14 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet18', BasicBlock, [2, 1, 1, 2], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet18(pretrained=True, progress=True, **kwargs):
|
||||
r"""ResNet-18 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet34(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-34 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet50(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-50 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet101(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-101 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet152(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNet-152 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNeXt-50 32x4d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 4
|
||||
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
||||
r"""ResNeXt-101 32x8d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 8
|
||||
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-50-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-101-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
4
contrast/feat_extract/model/utils.py
Normal file
4
contrast/feat_extract/model/utils.py
Normal file
@ -0,0 +1,4 @@
|
||||
try:
|
||||
from torch.hub import load_state_dict_from_url
|
||||
except ImportError:
|
||||
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
137
contrast/feat_extract/model/vit.py
Normal file
137
contrast/feat_extract/model/vit.py
Normal file
@ -0,0 +1,137 @@
|
||||
import torch
|
||||
from vit_pytorch.mobile_vit import MobileViT
|
||||
from vit_pytorch import vit
|
||||
from vit_pytorch import SimpleViT
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout),
|
||||
FeedForward(dim, mlp_dim, dropout=dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3,
|
||||
dim_head=64, dropout=0., emb_dropout=0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user