From e00fb468477c2afefb09056d7b8d0365e7f95742 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=BA=86=E5=88=9A?= Date: Mon, 2 Sep 2024 18:39:12 +0800 Subject: [PATCH] update 20240902 --- contrast/.idea/.gitignore | 8 + contrast/.idea/contrastInference.iml | 12 + contrast/.idea/deployment.xml | 14 + .../inspectionProfiles/Project_Default.xml | 12 + .../inspectionProfiles/profiles_settings.xml | 6 + contrast/.idea/misc.xml | 7 + contrast/.idea/modules.xml | 8 + contrast/__init__.py | 1 + contrast/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 139 bytes contrast/__pycache__/config.cpython-38.pyc | Bin 0 -> 1638 bytes contrast/__pycache__/config.cpython-39.pyc | Bin 0 -> 1637 bytes contrast/__pycache__/inference.cpython-39.pyc | Bin 0 -> 2395 bytes contrast/config.py | 84 ++++ contrast/contrast_one2one.py | 380 ++++++++++++++ contrast/inference.py | 103 ++++ contrast/model/__init__.py | 1 + .../model/__pycache__/__init__.cpython-38.pyc | Bin 0 -> 246 bytes .../model/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 245 bytes .../__pycache__/resnet_pre.cpython-38.pyc | Bin 0 -> 13929 bytes .../__pycache__/resnet_pre.cpython-39.pyc | Bin 0 -> 13904 bytes contrast/model/resnet_pre.py | 462 ++++++++++++++++++ tracking/contrast_one2one.py | 10 +- 22 files changed, 1105 insertions(+), 3 deletions(-) create mode 100644 contrast/.idea/.gitignore create mode 100644 contrast/.idea/contrastInference.iml create mode 100644 contrast/.idea/deployment.xml create mode 100644 contrast/.idea/inspectionProfiles/Project_Default.xml create mode 100644 contrast/.idea/inspectionProfiles/profiles_settings.xml create mode 100644 contrast/.idea/misc.xml create mode 100644 contrast/.idea/modules.xml create mode 100644 contrast/__init__.py create mode 100644 contrast/__pycache__/__init__.cpython-39.pyc create mode 100644 contrast/__pycache__/config.cpython-38.pyc create mode 100644 contrast/__pycache__/config.cpython-39.pyc create mode 100644 contrast/__pycache__/inference.cpython-39.pyc create mode 100644 contrast/config.py create mode 100644 contrast/contrast_one2one.py create mode 100644 contrast/inference.py create mode 100644 contrast/model/__init__.py create mode 100644 contrast/model/__pycache__/__init__.cpython-38.pyc create mode 100644 contrast/model/__pycache__/__init__.cpython-39.pyc create mode 100644 contrast/model/__pycache__/resnet_pre.cpython-38.pyc create mode 100644 contrast/model/__pycache__/resnet_pre.cpython-39.pyc create mode 100644 contrast/model/resnet_pre.py diff --git a/contrast/.idea/.gitignore b/contrast/.idea/.gitignore new file mode 100644 index 0000000..35410ca --- /dev/null +++ b/contrast/.idea/.gitignore @@ -0,0 +1,8 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/contrast/.idea/contrastInference.iml b/contrast/.idea/contrastInference.iml new file mode 100644 index 0000000..6d6038d --- /dev/null +++ b/contrast/.idea/contrastInference.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/contrast/.idea/deployment.xml b/contrast/.idea/deployment.xml new file mode 100644 index 0000000..b7f9a78 --- /dev/null +++ b/contrast/.idea/deployment.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/contrast/.idea/inspectionProfiles/Project_Default.xml b/contrast/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..920d523 --- /dev/null +++ b/contrast/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,12 @@ + + + + \ No newline at end of file diff --git a/contrast/.idea/inspectionProfiles/profiles_settings.xml b/contrast/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/contrast/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/contrast/.idea/misc.xml b/contrast/.idea/misc.xml new file mode 100644 index 0000000..3afa107 --- /dev/null +++ b/contrast/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/contrast/.idea/modules.xml b/contrast/.idea/modules.xml new file mode 100644 index 0000000..316bf04 --- /dev/null +++ b/contrast/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/contrast/__init__.py b/contrast/__init__.py new file mode 100644 index 0000000..e3e0f3f --- /dev/null +++ b/contrast/__init__.py @@ -0,0 +1 @@ +# from .config import config \ No newline at end of file diff --git a/contrast/__pycache__/__init__.cpython-39.pyc b/contrast/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f894aad46697b9f3e50d4cf9ec4c2809527e33c0 GIT binary patch literal 139 zcmYe~<>g`kg6@W^X)-|iF^Gc<7=auIATDMB5-AM944RC7D;bJF!U*D*fs0j)OKM4K za!E*0Vsdt7UV2P&eqKpYVsS}KW?ouqQEFatYD|25W?p7Ve7s&k~^b302Q-GdyDbPcKqR64MlwU>MC1!SJcHTP+z8oD5EL?xFUpuXu zW&LiLi>rjp8iHIz1QytdMTWbbI;3P<&n$4D^pgdp*A8)+``RL9u)nwJ?f@;UjmY4O ztBTATf}FQ4i&?}5i#TAD5}-Yl(Ap&)lt~3V;-h>7DkNalqZ-2Kv4wby`ONP)$2QXA z;G2DS0?IJ*#z-dF)b5o+b`{E9yMW-0)Af#Rat+Tf_w#>NE?P}PmCQhm%)%(S4r62v z#>oxLbRH(iO_(GLFhy>`6>=M{k~^6FUARU*f@$(G%#crDmV64=$!9P}K8G8m&Zds; zVTSjQ7LTo;ZQ;Ru=DT?oOx-3yTg2TRl`$#nj&@T*#w?59pI6U5fJFz(xIy zdQ1CZm|=ZksDm&}IrI`k*TV3J9!*Y8{5FmD+B{?0OPLaJMBJ!NQ$5DgHUo&WZYbkE zBLN7Whk~k@Yd=m;q!q;|_C}D3R6Sss()Jdqtn)PIk`aH4Z?Q~rp}kF3cp2NI8Sr$I zE8{35)pee5@lC80`k2Hy+Gj#NPidDuRfjp#WyM9bqbr?+Q}tk3k8E&}(!?Y*UQC4P zZzek^)nf{;;+!YJP(g66$VhQ^flOkl(1o$+3f{{@GXuJM;huE`?+L>*GI~Oo7P3Uw zjMZt^h+p`?A=E)Z6*^>M7`shXv@`7AFxI(1ms!rSv=N@GIK_g5b`tDkB0{NHu6-nc zMf6bD4w!i*C!A}pTYJFv_TJOj*jsp{{1VG(6y9j|oYF zVWlR`x0bk%LsFoBn(I%tUrN;2{`$rCQzrKm&$lCXc7Wf+-(FDghJNq)uH%+ar;K1Iqi?{-E_`!_cZxW5Pt0Nz zGF?o0Z$_?`PIbMM-{d;(oI(W?l4@EIdN6yJ$0yaqq?Ur zH>NL%ffmB%GZWmiAZlht?c6DgF670Y>dPl;l(lI-$izZ7YqRgWr(2l3jSxPNc$ zc5ib1{dae7^?L)Zbhp&ICah~)yEj=p4%WZ3e!I^8zYbn2#{IkNYi}G(f1;+OH>qi+ z$&lrbN3UM|_1m|D?>DA)`A5U*{JC)~ey{=2P$VpiGpb_}Bx7Xv*GwZR$5Ws5H#>b3 z#&JRQ#j%OvI2Sn1EZvUdAIB`KZy+47WPczEZi1X^nI_;T1D3-)&j%baEk?0QCmbS_ zqKqZesemxekEJcm8TCd&OX)H4LK}A*S`S5D3dJGZ7Tdf~LYiQUS6=3~Sb-wn652W{ zVD_QNg#0GentUwM9vyJ0zs}i+KhlRKH%%>MvTItyOfY?G-OOwXnX}9$bgJ42+uuSB zNowXbjp8j2QLP|_&@fZooglMRX>wsLMpBH+*zSOtJuzp~qB&7`4s*wZZ6VK0+giQu z8ibV(T0#?5RHXw8b?gDt$!^`hI@Vi(Y4TD~X)~hKX-)-6<7L#xOvXy{(u71HPuQVp zKj!w89CN-Mw+WCRqzN~^9lL3#DrN^P&DbF0CMwe+&Ur56p_whldHh(&eJ&M5b)`1V z7mn=El2quQ?xXuVkGSTE-j=isH1eH96jXuIJ83cGk{1cLB6{TA@=$tY*(6XvDD-^y z&1dNu-woV$BXlGB7XA*RinsKuAoM-IK{`zWOIdvjR(9fBvAkEs=>&3_qQG*s2nVU6 ut?5!~^~@DD{S7V8kYZ(If;!o^HP+M(QZV5c^-K7YUYer|3EK;-g?S0*PQZuL#^P11RrPKo0<3Vd+$xUyW3-U zKKS&%r~m6S_J5i@|Ad&lg|B|Z8Iw%%IZN6xi| zAAL=6KN)iyX;qiXyzwT+!P`YTQy6sW>eLkJM;L}#Ri5UvjMiU2K!*`lzsFbq4#Y5c z%#AqnuK1EQZ0(!iEX2G+Jjq*+z>mly)0JY~gJh?*s$sCmJNb&+evvMw^D0YCUX|Yt z=2a@=qLOMJf4q#xeJ!BN3R791tg@k0PxDOKAuXKNCN*j*^UT=5R62WNed61mQMH-6 zW2~>9rgIxCbY2=8o>sa@jqTHd$7y!{xGKr}z^(m?=^Ie@?Y-m4qhn=MX5QCWlb5qc zP;PWun@4$hsFdb9ylwmWv#+}ijYfsKHx_@;v=DNfk%PW%3SC&BcHCHE(Q@k z7ubJeFk?$$5ZVR+el}TpjR){wTKYybJ|#~h)}CQ!{yJz_2GAOsNQx$C_^Q>|G2MnYom;F4 zR$>`7;o56N(`llAh}*CY8b5792jk^VoKqUrabCunQf1t;qNpk#nm@XEo zl-8@&!un@bURu9Q3uV2zDs7_bvpo?8zZ9EvTuNs4!7y1H0Q-2r#NsEjpa~E@ zeliYSe6)w-tY0BUnyl2jG!KyD**TVwJ+_M^ENe<)eV67QVqisHSYM_l{h{AOR|X5M z7P`t*UB7}Y^#}yMJI{NdeXqy$>)>VN*Y!hwZqxETOyCWLfCtj&taX6qs~PwK5~8>~r8k%+hR?FzKPNsItNl9y`JdA8EF*G%GUaS>NK z&Msu?Wbe5$e$I$vSQy_wSI;(fybCiP#OOO0Y7)0TsYm+X(&&>uOxFjP-iFeeVkWOSgMEE#TjD>bmLW-XAHr zWo5RFRuj~kFi#?6)Rdz2HP!KDWby(&e7qb#(6{NVPE|joC|loAuc{ZDL@h50v_w%Z z^dX(*8n^fE+wZu%P7Yjuepr?dPcO0: + # bboxes, ffeats, trackerboxes, tracker_feat_dict, trackingboxes, tracking_feat_dict = extract_data(datapath) + + ''' 3.2 读取 0/1_tracking_output.data 中数据''' + if dataname.find("_tracking_output.data")>0: + tracking_output_boxes, tracking_output_feats = read_tracking_output(datapath) + if len(tracking_output_boxes) != len(tracking_output_feats): continue + if CamerType == '0': + event['back_boxes'] = tracking_output_boxes + event['back_feats'] = tracking_output_feats + elif CamerType == '1': + event['front_boxes'] = tracking_output_boxes + event['front_feats'] = tracking_output_feats + + # '''1.1 事件的特征表征方式选择''' + # bk_feats = event['back_feats'] + # ft_feats = event['front_feats'] + + # feats_compose = np.empty((0, 256), dtype=np.float64) + # if len(ft_feats): + # feats_compose = np.concatenate((feats_compose, ft_feats), axis=0) + # if len(bk_feats): + # feats_compose = np.concatenate((feats_compose, bk_feats), axis=0) + # event['feats_compose'] = feats_compose + + # '''3. 构造前摄特征''' + # if len(ft_feats): + # event['feats_select'] = ft_feats + + + + '''================ 2. 读取图像文件地址,并按照帧ID排序 =============''' + frontImgs, frontFid = [], [] + backImgs, backFid = [], [] + for imgname in os.listdir(filepath): + name, ext = os.path.splitext(imgname) + if ext not in IMG_FORMAT or name.find('frameId')<0: continue + + CamerType = name.split('_')[0] + frameId = int(name.split('_')[3]) + imgpath = os.path.join(filepath, imgname) + if CamerType == '0': + backImgs.append(imgpath) + backFid.append(frameId) + if CamerType == '1': + frontImgs.append(imgpath) + frontFid.append(frameId) + + frontIdx = np.argsort(np.array(frontFid)) + backIdx = np.argsort(np.array(backFid)) + + '''2.1 生成依据帧 ID 排序的前后摄图像地址列表''' + frontImgs = [frontImgs[i] for i in frontIdx] + backImgs = [backImgs[i] for i in backIdx] + + '''2.2 将前、后摄图像路径添加至事件字典''' + bfid = event['back_boxes'][:, 7].astype(np.int64) + ffid = event['front_boxes'][:, 7].astype(np.int64) + if len(bfid) and max(bfid) <= len(backImgs): + event['back_imgpaths'] = [backImgs[i-1] for i in bfid] + if len(ffid) and max(ffid) <= len(frontImgs): + event['front_imgpaths'] = [frontImgs[i-1] for i in ffid] + + + '''================ 3. 判断当前事件有效性,并添加至事件列表 ==========''' + condt1 = len(event['back_imgpaths'])==0 or len(event['front_imgpaths'])==0 + condt2 = len(event['front_feats'])==0 and len(event['back_feats'])==0 + + if condt1 or condt2: + print(f" Error, condt1: {condt1}, condt2: {condt2}") + continue + + eventList.append(event) + + # k += 1 + # if k==1: + # continue + + '''一、构造放入商品事件列表,暂不处理''' + # delepath = os.path.join(basepath, 'deletedBarcode.txt') + # bcdList = read_deletedBarcode_file(delepath) + # for slist in bcdList: + # getoutFold = slist['SeqDir'].strip() + # getoutPath = os.path.join(basepath, getoutFold) + + # '''取出事件文件夹不存在,跳出循环''' + # if not os.path.exists(getoutPath) and not os.path.isdir(getoutPath): + # continue + + # ''' 生成取出事件字典 ''' + # event = {} + # event['barcode'] = slist['Deleted'].strip() + # event['type'] = 'getout' + # event['basepath'] = getoutPath + + + return eventList + +def get_std_barcodeDict(bcdpath): + stdBlist = [] + for filename in os.listdir(bcdpath): + filepath = os.path.join(bcdpath, filename) + if not os.path.isdir(filepath) or not filename.isdigit(): continue + + stdBlist.append(filename) + + + bcdpaths = [(barcode, os.path.join(bcdpath, barcode)) for barcode in stdBlist] + + k = 0 + for barcode, bpath in bcdpaths: + stdBarcodeDict = {} + stdBarcodeDict[barcode] = [] + for root, dirs, files in os.walk(bpath): + + imgpaths = [] + if "base" in dirs: + broot = os.path.join(root, "base") + for imgname in os.listdir(broot): + imgpath = os.path.join(broot, imgname) + _, ext = os.path.splitext(imgpath) + if ext not in IMG_FORMAT: continue + imgpaths.append(imgpath) + + stdBarcodeDict[barcode].extend(imgpaths) + break + + else: + for imgname in files: + imgpath = os.path.join(root, imgname) + _, ext = os.path.splitext(imgpath) + if ext not in IMG_FORMAT: continue + imgpaths.append(imgpath) + stdBarcodeDict[barcode].extend(imgpaths) + + jsonpath = os.path.join(r'\\192.168.1.28\share\测试_202406\contrast\barcodes', f"{barcode}.pickle") + with open(jsonpath, 'wb') as f: + pickle.dump(stdBarcodeDict, f) + + print(f"Barcode: {barcode}") + + k += 1 + if k == 10: + break + + + + + return stdBarcodeDict + + +def one2one_test(filepath): + + savepath = r'\\192.168.1.28\share\测试_202406\contrast' + + '''获得 Barcode 列表''' + bcdpath = r'\\192.168.1.28\share\已标注数据备份\对比数据\barcode\barcode_1771' + stdBarcodeDict = get_std_barcodeDict(bcdpath) + + + eventList = creat_shopping_event(filepath) + print("=========== eventList have generated! ===========") + barcodeDict = {} + for event in eventList: + '''9 items: barcode, type, filepath, back_imgpaths, front_imgpaths, + back_boxes, front_boxes, back_feats, front_feats + ''' + + barcode = event['barcode'] + if barcode not in stdBarcodeDict.keys(): + continue + + + if len(event['feats_select']): + event_feats = event['feats_select'] + elif len(event['back_feats']): + event_feats = event['back_feats'] + else: + continue + + std_bcdpath = os.path.join(bcdpath, barcode) + + + + for root, dirs, files in os.walk(std_bcdpath): + if "base" in files: + std_bcdpath = os.path.join(root, "base") + break + + + + + + + + + + + + + + '''保存一次购物事件的轨迹子图''' + basename = os.path.basename(event['filepath']) + spath = os.path.join(savepath, basename) + if not os.path.exists(spath): + os.makedirs(spath) + cameras = ('front', 'back') + for camera in cameras: + if camera == 'front': + boxes = event['front_boxes'] + imgpaths = event['front_imgpaths'] + else: + boxes = event['back_boxes'] + imgpaths = event['back_imgpaths'] + + for i, box in enumerate(boxes): + x1, y1, x2, y2, tid, score, cls, fid, bid = box + + imgpath = imgpaths[i] + image = cv2.imread(imgpath) + subimg = image[int(y1/2):int(y2/2), int(x1/2):int(x2/2), :] + + camerType, timeTamp, _, frameID = os.path.basename(imgpath).split('.')[0].split('_') + subimgName = f"{camerType}_{tid}_fid({fid}, {frameID}).png" + subimgPath = os.path.join(spath, subimgName) + + cv2.imwrite(subimgPath, subimg) + print(f"Image saved: {basename}") + + + +def batch_inference(imgpaths, batch): + size = len(imgpaths) + groups = [] + for i in range(0, size, batch): + end = min(batch + i, size) + groups.append(imgpaths[i: end]) + + features = [] + for group in groups: + feature = featurize(group, conf.test_transform, model, conf.device) + features.append(feature) + + return features + +def main_infer(): + + + + bpath = r"\\192.168.1.28\share\测试_202406\contrast\barcodes" + for filename in os.listdir(bpath): + filepath = os.path.join(bpath, filename) + + with open(filepath, 'rb') as f: + bpDict = pickle.load(f) + + for barcode, imgpaths in bpDict.items(): + feature = batch_inference(imgpaths, 8) + + print("Done!!!") + + + +def main(): + fplist = [#r'\\192.168.1.28\share\测试_202406\0723\0723_1', + #r'\\192.168.1.28\share\测试_202406\0723\0723_2', + r'\\192.168.1.28\share\测试_202406\0723\0723_3', + #r'\\192.168.1.28\share\测试_202406\0722\0722_01', + #r'\\192.168.1.28\share\测试_202406\0722\0722_02' + ] + + + + for filepath in fplist: + one2one_test(filepath) + + # for filepath in fplist: + # try: + # one2one_test(filepath) + + # except Exception as e: + # print(f'{filepath}, Error: {e}') + +if __name__ == '__main__': + # main() + main_infer() \ No newline at end of file diff --git a/contrast/inference.py b/contrast/inference.py new file mode 100644 index 0000000..e7a6dbc --- /dev/null +++ b/contrast/inference.py @@ -0,0 +1,103 @@ +import os +import os.path as osp + +import torch + +import numpy as np +from model import resnet18 +from PIL import Image + +from torch.nn.functional import softmax +from config import config as conf +import time + +embedding_size = conf.embedding_size +img_size = conf.img_size +device = conf.device + +def load_contrast_model(): + model = resnet18().to(conf.device) + model.load_state_dict(torch.load(conf.test_model, map_location=conf.device)) + model.eval() + print('load model {} '.format(conf.testbackbone)) + + return model + + +def group_image(imageDirs, batch) -> list: + images = [] + """Group image paths by batch size""" + with os.scandir(imageDirs) as entries: + for imgpth in entries: + print(imgpth) + images.append(os.sep.join([imageDirs, imgpth.name])) + print(f"{len(images)} images in {imageDirs}") + size = len(images) + res = [] + for i in range(0, size, batch): + end = min(batch + i, size) + res.append(images[i: end]) + return res + +def test_preprocess(images: list, transform) -> torch.Tensor: + res = [] + for img in images: + # print(img) + im = Image.open(img) + im = transform(im) + res.append(im) + # data = torch.cat(res, dim=0) # shape: (batch, 128, 128) + # data = data[:, None, :, :] # shape: (batch, 1, 128, 128) + data = torch.stack(res) + return data + +def featurize(images: list, transform, net, device) -> dict: + """featurize each image and save into a dictionary + Args: + images: image paths + transform: test transform + net: pretrained model + device: cpu or cuda + Returns: + Dict (key: imagePath, value: feature) + """ + data = test_preprocess(images, transform) + data = data.to(device) + net = net.to(device) + with torch.no_grad(): + features = net(data) + # res = {img: feature for (img, feature) in zip(images, features)} + return features + + + +if __name__ == '__main__': + # Network Setup + if conf.testbackbone == 'resnet18': + model = resnet18().to(device) + else: + raise ValueError('Have not model {}'.format(conf.backbone)) + + print('load model {} '.format(conf.testbackbone)) + # model = nn.DataParallel(model).to(conf.device) + model.load_state_dict(torch.load(conf.test_model, map_location=conf.device)) + model.eval() + + # images = unique_image(conf.test_list) + # images = [osp.join(conf.test_val, img) for img in images] + # print('images', images) + # images = ['./data/2250_train/val/6920616313186/6920616313186_6920616313186_20240220-124502_53d2e103-ae3a-4689-b745-9d8723b770fe_front_returnGood_70f75407b7ae_31_01.jpg'] + + + # groups = group_image(conf.test_val, conf.test_batch_size) ##根据batch_size取图片 + groups = group_image('img_test', 1) ##根据batch_size取图片, 默认batch_size = 8 + + feature_dict = dict() + for group in groups: + s = time.time() + features = featurize(group, conf.test_transform, model, conf.device) + e = time.time() + print('time: {}'.format(e - s)) + # out = softmax(features, dim=1).argmax(dim=1) + # print('d >>> {}'. format(out)) + # feature_dict.update(d) diff --git a/contrast/model/__init__.py b/contrast/model/__init__.py new file mode 100644 index 0000000..9eebc77 --- /dev/null +++ b/contrast/model/__init__.py @@ -0,0 +1 @@ +from .resnet_pre import resnet18, resnet34, resnet50, resnet14 diff --git a/contrast/model/__pycache__/__init__.cpython-38.pyc b/contrast/model/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..160bc374c70bf29f078e992def813facbfef5562 GIT binary patch literal 246 zcmWIL<>g`kf}8SJ)9iutV-N=!FabFZKwK;XBvKes7;_kM8KW3;nWC6-nWLCs*%h1G6lQoJ9qBOpsD76S=QxP+W zU;z=VKw>3B5gUjNB7T`WTg8MHrxq2*cm%|_q!wqF=B1?;rRF84 l#^mOwq~^rL$7kkcmc+;F6;$5humOsd=A_zz-1ixX831}pK63y7 literal 0 HcmV?d00001 diff --git a/contrast/model/__pycache__/__init__.cpython-39.pyc b/contrast/model/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..539ec67dda6c7e7d2f9ee70186d96fef1084d6b7 GIT binary patch literal 245 zcmYe~<>g`k0v5%qY4$++F^Gc(44TX@8G%YPS#EI@r55Lk_tYTbJOHz|dLW&ZTvorJ3W0LdpN{SMTOJXwf(o&03^O93z ka`RJCb7JD-GxIV_;^XxSDsOSv0L4mkQtd#_`wYYk0E{|4cmMzZ literal 0 HcmV?d00001 diff --git a/contrast/model/__pycache__/resnet_pre.cpython-38.pyc b/contrast/model/__pycache__/resnet_pre.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..199085defa483fcb5f375a6fa70a20cf5f54e5df GIT binary patch literal 13929 zcmeHOTWlNId7c}G35KOTp(C|MbUI@wTG;I~Dpr)Os0Te?e6gp$N61 zl%ij?sg*R9@>)Z0>Lo)})`w8i8;Pb_GF9b6#j_2?O|Dq-x~kA0T8&1knWolnD4zC_ zA`-$BNs$t1VTp{$iUExwuoUi?%MFS^=t?txw!M2#&@pyoteGlZIx;w99)6xR%+=4FvUO(L!t5nmCn ze5#ZVh_m99$YbVFaZbD{OymxVIdNJfkvk;LiwTiJ?y$HZCPf;#G4VBVMp(!l5f{aj z$RKxAye0}Fi`+4BNlc3Y2Rp8g1SwwpI-`sRsX`7Gop3ZqjgQL z-Va)`vNYL}wKL6@a2x)a`1x}avz5Zs*@|!{SA(Tz8Csl~C>D#=b8cnkS%=O}O%%>o z=G@}!IUai985TTMn3xsC#ThYIede*V(-Vt@>4FQI^VrM#er}H>6XzeOlxbVY=>>R~wG!`k~cY3*x*S!9(OI zC5MnaQsg*(et!PV3%B0)UFqL?>-w#$uD=qrR&QaUfpq-fnpbtD>s8!aEJC*;2`aBj zh|c{mN$-RAqIrDUE=9|#n!{}^$>e+3y= zDXBs!X~O)jQdMbfpQcN?FbJVb#&;D%DbZS!;n0m$(7 zU;jNMff9gE9%!4|7L*Y2%ZlkA55O-E)NiRM>noakRT!Qo5*u1zER#}7w82%RnjR+F zO4$sI&E%GXTuLZdNO~c!$sx=w={?tvMr>pg8`;e@Kd={FyJ)wheU3XsU1?jlX{GZR ze1Hbiw&v4e!e3i;rL<8BQ{{5qs|V$Bn7iVGP3RqeQ_7ZIFj^+^yP6vE_2 zi!kX&nYLby07I$raaKMJBJx(|8huHo=M}G@7pCg?iJu7wQ!!kiiXDY`w@4>YKZb! z4G88rO&0{S%4JCHT7&Z0a{2a}(}-GlyX7pcX&|OX#ceb?6yc@3idL2&W&?T`pClxu z=F(jAnEVig(#NNfUHMTVHEEV1FfnAvAQpxUC0K?`?5`DCZ#aIv@df-*uIP@ z<|~B4oe$-S!Yo7|0E>4C5yW66CC3AO+1N<66A&rgKixKkwxY|KAR%-Xtt9FVc{MPF z!R0grDA~41jDj=?SX+Ofqhz&HZ3rWB8O)d^tMDO}2GC1Gi-vrz+h-79uat+{G;+AD zEGIu#VUU^mNZ)ag8!L zQkW$z*D2+X$(PVaUZ8}4g^?=Eyx|0urG=JkPK(g`z-g?xjAP+}g;sE_x!Q1>KyX(? zm5|_Os4qHAR%0y7D0+p^c+Y+3{m_tZV=cnHX{w(NEiycPhvwn82-8>Y)ULN$4Loh2 z)@Ut4M?@8&`A*$)9T_H|%{>w7)k=YqX-NnOtxw=7Sjh632su)`O{CAwNUuj)kdz54 z^w8?jfR%e_d3e!s!?_O}|1IkF4kawqU8Ld7p=52r2!Us)-s)&8bYq>N6v9ucXs`T^2k__;9q6?OMDpE^0`PUyo*}- z+mu|Vq8J)?>h4`=PbhHtA+>)*$u}q=N~Wh z6zVZ$6qwvjfJ-uM(l#b*n`DrJw#l#_NkflhSl3uVrcHl(JJC))ATux9Z5zOd4CYjp zbC{7d$sp;UxJ-6Hb zqi8$E+G(84+QEeyb{wDUmv2e}sfg1qQI4!v`8p++DY-()8- z&}==GAm2hU)a!6{dJsvzL7$Y86VoT9RT{zOZjoTWC>N-mMVSSdcvjw^grr!$Psscr?$$UfM0HBUeq5*!I>KCRo`%3C+KQjy{QB2hB$r#|gnn zwnu>Oxm{QuwG3(F>Nw$eR~q(#Aqe z)7*zp9+~3FV)N=<;K{aDK3lJ$8ZkGTP+|B}~OG zoXbzB&Y`5o2n^TT=UDAe=-ceSDIujtAhvB8RmK4;tPx{O4-250DcgN3m8!AWe$50-BQ) zbH#6p{ev|8gF#7l|9nr;bTHvRU$0p7xcHO7FZm#h&nICnKGj1=0g2!xAcc>BG}$JO zOUcWm?}c8|HwdGXu%``p5Tzh365zMwmdbtuVRNQ!BFzH%1_bQ?Ws}ekZ>Yo4rxx-L zP|L_Zh+ab!Q)Vw>7^M;POk!6u4`E4^4}h;n+iGwSD@wD*0J0uYwy5n8W{KJcx^3DK zMbL`TN8~slaJbuRu-hvc_oCiWud!~gp>8iD?nS+$UWCv`#qbvG&N1Z2v96)Ey8ObZ z5|7dut&d~XFK#JUm2ZB$VFh;EqP=)PUds28JF$GSZ7shP#mpb-s^ZPWeGIg|T!j~f zp3Bs8^DEpFwwkJR%ddzml|NT+FSm`)HQ9UG)%a;&f7We( zwbPC{f5!Ud;6l%IndhL$`diXvj|4Gq3ckUw36>D~jv{Tgztmc5z~gc4g_h@1K*h0( z6K>=?*ufgC|I>C2KF{xMMB3*hI5KbaMXnB@g5xZaLfaA5;x|*^3&Z& zxshH&11mpD!erI)*xDl~f?s4%pf1dM@YmSOfLWOTjY#{E9PQSAnR1__z?z5@PN>5o zlGi9_0!5mLX7n@?g?#!Tu&V*X^t;Zz*yEw_#wwhjFjHglL^XJdB^`sEK+bHUZNjnytQT& z#B%-82rTqg6eEKxS{I12L_QgvT7Y*DNh`CvTpE8GYq5?;B?OidCPrtG&#LQ# z@9{ZKPxYJhwD>~+OJB0D{C`FQFi|#4p=>0PQz`5sUuh#WiYT~o9u9jM0k;hUF>ngT zTtXj&P5>kNr~*h8h+rlH0A;VXZ6Nq!cKVRHxRISZO^UF#)1OfsRgc0Skk}r-cP%h* z!)?5FWRkJ05GzK}1?z@;dyTv>P@hc5&~#Q;T~GApfxGW`H8*A{`E43XR$pi`bNP|= z*Q9DT_sD`M4I|Qir`}rgyAUM5Lp^3F=_zhD8}6X)=lJ^5NE9%H!QUKy$M79dvtJkq z1B5mxLJwvk78ydxP_-{CeSM@u{GP<_2(m(@g19mL17raIk^2tXzX<5xO)<-*A z*FWsB%_GZ%g@??R&gS8jR2;d=@1ah9PRTB^lG}cYRzUl&sO!1-Ssm*xQU>#RYq#5N z0N<)%&TTHbe8|LyN)i4Ot3*$v+4uIvNn`v2zCK9<1qAqLhte-{MeA_t@Qwqq96Aw$ zklE{hA4vfAXlqrio%l@MfNd{73)BGC))9q30TFlDFn|qnQ{RG^L(~y>-89&un1a6y zk9j32-@1K)yx=476W|%9J2^neg|@Y6wlgfm$V1>Qj6&bnl-sLqb2-^gu|KRK-k;qV zXj8P=pyTb-s%+Jy>-+O4@F`m*$fwxOltiS^A|=|OV{Ryy-r;v>tho36$d3O*>NmLS zbUdF64F`wLk#v)!h883kM`_T-LjDm|m@5eG*ZeTS&yYVvT`A|wO5Y)SX*fU^H|VcQ z`z!U;GCe<3SHf(w?nCF*qKd&a4|&uX8K0dOiH*o2!3hfxAw?=fr(XkD zfD&+mei7l*CwYA)I;ts0a$;zlR!i=9JfN8YR7 znWRJgRZ%@tm@Q0CP8ExDI9C?cOP`c2p8X0o_yXtXSv?BarxsyGLru84CUpGnkx7>@U8st$W0GA`{_)ee?iyK^tW4qCP! z2!v&M&gcx9Xu>ufK#4Y*cJYI>41(U|71kWwcxS4Trl z6`qr74^r{HC^~lKSo-#1%Kc<78d2x_|4%fUoqkTEJ!nKO1&h-*3CZ@$RNRwB)cM={ zr4gML(v3EuUBqQt{2CX@@=L(-GZE}e3Uim`mvYW^RO3G~OEQJ#(NHpX{$eV!EKXoaV3cd|B#+39vAg@&zz0B5peY(fiPlqq+ z{5Qnb6i8AA_6J_iz|<{9 literal 0 HcmV?d00001 diff --git a/contrast/model/__pycache__/resnet_pre.cpython-39.pyc b/contrast/model/__pycache__/resnet_pre.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0301f55938ba959456d20d3832de4ec4aa589f19 GIT binary patch literal 13904 zcmeHOTWlm(TCQ7HU){cp=kCl_?&)OgxP6)N%uXgV*}kXxlgs-c6YUN zsyr8O8!b#A64ERI0YZSZc7)IlEFi=S;)N%~%RV8=I_+0q4f31mEHSOP7>Ho0z0=~gHUDJdn^oF(=|LRR+ z(a>2o8fMd6v~+EA2sN{jYTApouDz@ICoIiN-&oK!{-MumWSUv_eNXd^No`pZsZaIA zoUlb&WJFfvL|zocfGCQ=KgllU#jF?>BOmFD1#wCo5Ths!d{+|(#i5ThamX`Pi#N^1 zL9`qeN6>O4X&FMxQE?0{$C8#|v>X>>Xc|qOp6(m9ufLmTIu30&vl`sM|QR4*XqkC9ci>&Q3*mf^eUoW4J$R-YF0L61AQ6Q z$8s*R-}Hh7FN`wM3w$pu&vi=EGo8}xWRy*MC(GU9>{K+UTCaz*la=YI>oX!6N*cwlIk|iOa$UI4b zWPqdyvI~6Cswt4phh>ib4!*%LkWgFI+ge)}dfRAQZBu{>^Hv7RWo?d4U!QiMSXWz( zjj-PGoon@Q#aVZSsQb&iCZ^Aq>?ptHN#AQ!g8HTx*+D2F_)$6;71_(OwXq&V)>7RK zqKv3FTrPFnhz#G4?C)UpDN#yA*}7j@Z@9h}M7h>Rn3UxRE+R)s4uRa(4!gIt+qVp1C~U z?cz4nEy+hvig$+vA4MfFKr};s3$)5&|Hn!y!BZ%B+M+JBMMKyhYc-vF`zX6;3QME_ z#@5H0r7d!=>1gP!bsWCixEO|>PcH284y)=y$%38R=5O#ONT`JnkekM~u>*TULAk00 zCtihxx~c!Vj=H&K$Y+J+8zQx3gyyOR%adwDoM?Y;rP^906`I@jj)qcNXxK=0p=8J* ztSxz;o5y1LaVS4db0Y|yCC{05TGBbK2E;>oth?OlJSHFDWFBiN6QzQUbx%qMwJ1}m z)ctx`smLQ7ag=Q-a=aX+->NS+TXhkogSb#KBP;M4wU|*{@wOgiE0t=)4Fas7SVOMJ zCqSOYH=yQddeIothm1q|=Fz>!InlZN9Y;!?<5x-2t0*W*!a*u2GI40ys=sO2| zbIU@T35_(D4YXUgj4d!>+iL4Ox+>v-cW3}p3rgNDmyF01^=6b_^SpHwO1gXwSAuq? zltqcGrsw*R*>tZ*X4MVlljxPasZvTlgP%yh9;Mx@%N6Kel*TYs(34~6o<)?dwd6Hd zib>2#&SOX))kpN4zIo)X^GF6`H{>*S0|xnS)2mdXT%`j2+-R^|s8l}Oa2s)tI&L}3 zJq;w>sCtb?ha+k$CG;wdQEaCU1k{L(Ud*bNyYz)v$xrYsNLOEUXh~M-OiI|$8H*wd zok=O3v2nh+DEE>Z)T=KwTGcgm)C;>8u*7_oaNGH?Ob#ohc?Y1oM`HQ!5>pk7JBHOG02#NEPK@a>MG% zLQ6KMM3j5SZESeIAZ1HN2Nqi4E6w$W*92mFB5s5RHzRY&Z7Q>+bQx8@99eIAue}{v z(rawQxHrZ2sVGMeGjO?%x{D}#@#^v$tyTk98(40%mS7X&hRA-c?t89`QZU=Th|F5G zOfn^@WiamqPob2ol#h`kqmGG&+l~!*Z1re7l;w_cJsv149_4)8XrTq~U!8s7kDC!yA)`vj(XX`dva(YoOfTE^aM?m9m4hVfjKZ3rDzImX>*8LHh zzJNmLe?y_IZ$Tfn@e90#=4|V2Gbup}lakpj!EcJ&R#VV)UA~GL3Zc2msxAKx^|Ts; zk=q%3vjQs%V+UA|zFcSic5n53x2=$r`ev#X46I-ZbH~COX4@LoxYp``{26t9dw{c4 z8|=tw)z@9!d`r7+El9|(w)-VM3fM};*eJY?R{2XLZ;<>liQ=W=)MZp7>uTM*2I~of zE#GDTdnDf@At!IgCPkJ}jxz8D;cJ99B2$@`maNmX#OAffZQ}j+@eNLbXu4(Oz`Go& z`r4?`#>9VPdd9F}hcoEi9O>h3$KGUOs8&>=MO>esznzw+u)aJ^LbZ;v-t~3Y58w>- z439j+ZnZf@dgg`#$+^n(Ww1?vz?a3xl{!Cw(a9;Jg9u$29YpDO7@a(fPAd!>p6^xH zc4^p>+ad1VX?f`HQ*xE%S^65w6Q*DvY@r|5MuElc6r>~Hrj@akl}U#gSed-CA6eLs zys|O5FyH2%*-o|7H|fq5x_twfk;j_aY7r~af(+8`Bz4-PAt3#*@@rTdQMTxSvPGlR z3?TgbKz|E%N|$MT;0{ze1vBt94Wuc7<(>#U$0W zx5p@@jO@Ko{5LT$pxSj%d{p5#k()X_0#q;V!SJ|e$QaYdh~HgpxDWjPFI?yUir-y? zBz|{r-tLrE5xxN3R}`AlGEmYp@-x)4Y7EgGSiYUV1Km6Lt@W+0&{<)BF~H29hCI*=X-{wg9ZQjVpT6lG%;1;3$k0Mb0#*29C? zPgdCpc=t!O9rhibKJU(W23J#a{1fU*CDs2~xBuBrKi2$AS(QTzJx8Tv&d<-Qw8!Sl(o-=A657nG zA(9DKkk*cqXil)w+Grrg@tlR0?=fZKI@9A`94$EE2K@b#&N8B$A8f^T=P}6Se97ud zP8~ptN~OfM?2g2h>Vs1eTQlTcXbU1+UEYiB8ciBHl-(mig245agGUrXC}c4U7ZrSj zXezLPCs=weHhfg2y@oGe>5C1N5u&Y$OgKaG6_#uOND~Q+o-tyUrw{$Q#xKgg?p{xV z9A-4u5${C#O4D8QlGtNaEvVv0-Vi4v1jb@y;LJ&3rr@R%X!4Q4c1rjT97{(_O5FRAU zDk?&b#=c6nlzB&Tt?D+sNN?WDuKyFZ70~t9%s4u;;gus2Gc$_~X8z?o#voe)EHOtl;Hn6%?NkV~?6tNn1Ayz+s-1I10t!BR-0ENzS3v5Z)SFMq)C%?fm%k0EDAC=2C(U=t=t`RJM*|h zApb3!b3fb;t<_XJ`y*yN%sAr#jqS;M*ZG1pNZfYpjVWEBD8>ne+*{s<8w|X_eR_|P z?XItTzUZxk2;cRWy+ouWoiyd}MYa;IAa?wGHY-n$&O1Ilf6qXX_S(!h#rB5CRLrQ24@1SzvSqfFQ?NDR ze24<@IF8l96~yZF%~%2#l9Jgi?a864adx*Cd`)H`7Gf#P!q!4=z+tB7D7B%_(m6cbl8#5I^7qjue~)AjL+JuO z!yWMaHGQ*~T&v^K#b#i>l-nC)TR^s2RP>rlo_cho9u>vNPa%n~$kq1`ylE@`5#OK) z!ss8O^hI$xcQ|ucxny@d<>DC12VLFZ4?sdlN86|wb*-KHQ~gK!7JPd7yP+P!+d8Td z5s+br7X!?&x6K{sIP@HO)@@7q6qAUP5h-8Qm*nZSvRR6FJ=gNbnaKJwyHkqb}be#XK1r1I1#>X^8hI zLLCmX@hKFwt=Cduwn1hSM%_nowOCxo%QKF;TZbgrK=F)i$2~Zz5 zqV{}wwmda4IXyjx2gagy{)0+R zt$q;)d{&istsV!QXO`eYl+HS@)SSz51D+a~vmyP!agz;V?0CX1q`2dq;FX2GwPMM+ zdnHj1;5OdiQXTO^CR{qLniH;gPUnRKCu})ED3FF#YpRFKur=gL&!tmUU=*345?!}V z`D2zTL2{eqPN3+#K2e@|g*;X^J2MLY%D=$iZ;Xv|580T!H}X|%RJ`nw+ho_D-_0|{ z`nUJ6i=3PPrua4eQ0?|pyEM-EwiQ3;W@d1g2On4WE?NUc}p8)w9esRuk?BN&N_UG40^zvOdVVpy56~ERw zx{fL3I@jm~!EaRFq;fPD zybb9#9Jq4VT5>Je-=REJE>BF(&PqDo-WVe zh1|X%?=qsQII64%D#XIuHlE`K$+aEtdevJG6$CHFD2U2%Q9NE@hF z?Al7by5exLYxSz)3HUiR(RjctG$mAv7KrM=S z+~G;bn^>N}>U`^xIb}}m9+$I%8Y+H}iOEB8zAvD@ z06C1snKt!n;F&}I;hy5m9y}u)Z`jdQB#N z`0=`WGAA82pCwr#Q4b)0!V(ix5A?4UeJL-M{Ng(jQRW=>gIAi*
  • D5(t5yU>JC` pQP4-sOv*rB^Nzp&u?KRgTxw)^WPar35kGS%SLpjI=8CzI{{@^RVzU4M literal 0 HcmV?d00001 diff --git a/contrast/model/resnet_pre.py b/contrast/model/resnet_pre.py new file mode 100644 index 0000000..5e52ad9 --- /dev/null +++ b/contrast/model/resnet_pre.py @@ -0,0 +1,462 @@ +import torch +import torch.nn as nn +from config import config as conf + +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url +# from .utils import load_state_dict_from_url + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class SpatialAttention(nn.Module): + def __init__(self, kernel_size=7): + super(SpatialAttention, self).__init__() + + assert kernel_size in (3, 7), 'kernel size must be 3 or 7' + padding = 3 if kernel_size == 7 else 1 + + self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out, _ = torch.max(x, dim=1, keepdim=True) + x = torch.cat([avg_out, max_out], dim=1) + x = self.conv1(x) + return self.sigmoid(x) + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None, cam=False, bam=False): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.cam = cam + self.bam = bam + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + if self.cam: + if planes == 64: + self.globalAvgPool = nn.AvgPool2d(56, stride=1) + elif planes == 128: + self.globalAvgPool = nn.AvgPool2d(28, stride=1) + elif planes == 256: + self.globalAvgPool = nn.AvgPool2d(14, stride=1) + elif planes == 512: + self.globalAvgPool = nn.AvgPool2d(7, stride=1) + + self.fc1 = nn.Linear(in_features=planes, out_features=round(planes / 16)) + self.fc2 = nn.Linear(in_features=round(planes / 16), out_features=planes) + self.sigmod = nn.Sigmoid() + if self.bam: + self.bam = SpatialAttention() + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + if self.cam: + ori_out = self.globalAvgPool(out) + out = out.view(out.size(0), -1) + out = self.fc1(out) + out = self.relu(out) + out = self.fc2(out) + out = self.sigmod(out) + out = out.view(out.size(0), out.size(-1), 1, 1) + out = out * ori_out + + if self.bam: + out = out*self.bam(out) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None, cam=False, bam=False): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + self.cam = cam + self.bam = bam + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + if self.cam: + if planes == 64: + self.globalAvgPool = nn.AvgPool2d(56, stride=1) + elif planes == 128: + self.globalAvgPool = nn.AvgPool2d(28, stride=1) + elif planes == 256: + self.globalAvgPool = nn.AvgPool2d(14, stride=1) + elif planes == 512: + self.globalAvgPool = nn.AvgPool2d(7, stride=1) + + self.fc1 = nn.Linear(planes * self.expansion, round(planes / 4)) + self.fc2 = nn.Linear(round(planes / 4), planes * self.expansion) + self.sigmod = nn.Sigmoid() + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + if self.cam: + ori_out = self.globalAvgPool(out) + out = out.view(out.size(0), -1) + out = self.fc1(out) + out = self.relu(out) + out = self.fc2(out) + out = self.sigmod(out) + out = out.view(out.size(0), out.size(-1), 1, 1) + out = out * ori_out + out += identity + out = self.relu(out) + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=conf.embedding_size, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None, scale=0.75): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, int(64*scale), layers[0]) + self.layer2 = self._make_layer(block, int(128*scale), layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, int(256*scale), layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, int(512*scale), layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(int(512 * block.expansion*scale), num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + # print('poolBefore', x.shape) + x = self.avgpool(x) + # print('poolAfter', x.shape) + x = torch.flatten(x, 1) + # print('fcBefore',x.shape) + x = self.fc(x) + + # print('fcAfter',x.shape) + + return x + + def forward(self, x): + return self._forward_impl(x) + + +# def _resnet(arch, block, layers, pretrained, progress, **kwargs): +# model = ResNet(block, layers, **kwargs) +# if pretrained: +# state_dict = load_state_dict_from_url(model_urls[arch], +# progress=progress) +# model.load_state_dict(state_dict, strict=False) +# return model +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + + src_state_dict = state_dict + target_state_dict = model.state_dict() + skip_keys = [] + # skip mismatch size tensors in case of pretraining + for k in src_state_dict.keys(): + if k not in target_state_dict: + continue + if src_state_dict[k].size() != target_state_dict[k].size(): + skip_keys.append(k) + for k in skip_keys: + del src_state_dict[k] + missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=False) + + return model + + +def resnet14(pretrained=True, progress=True, **kwargs): + r"""ResNet-14 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 1, 1, 2], pretrained, progress, + **kwargs) + + +def resnet18(pretrained=True, progress=True, **kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained=False, progress=True, **kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained=False, progress=True, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet101(pretrained=False, progress=True, **kwargs): + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained=False, progress=True, **kwargs): + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_ + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_ + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) diff --git a/tracking/contrast_one2one.py b/tracking/contrast_one2one.py index bb7cc3f..1625275 100644 --- a/tracking/contrast_one2one.py +++ b/tracking/contrast_one2one.py @@ -34,6 +34,7 @@ import cv2 import os import sys import json +import pickle sys.path.append(r"D:\DetectTracking") from tracking.utils.read_data import extract_data, read_tracking_output, read_deletedBarcode_file @@ -213,9 +214,12 @@ def get_std_barcodeDict(bcdpath): if ext not in IMG_FORMAT: continue imgpaths.append(imgpath) stdBarcodeDict[barcode].extend(imgpaths) - - with open('stdBarcodeDict.json', 'wb') as f: - json.dump(stdBarcodeDict, f) + + jsonpath = os.path.join(r'\\192.168.1.28\share\测试_202406\contrast\barcodes', f"{barcode}.pickle") + with open(jsonpath, 'wb') as f: + pickle.dump(stdBarcodeDict, f) + + print(f"Barcode: {barcode}")