pipeline.py 等更新

This commit is contained in:
王庆刚
2025-01-21 18:22:56 +08:00
parent bfe7bc0fd5
commit 64248b1557
15 changed files with 106 additions and 77 deletions

View File

@ -61,8 +61,9 @@ class Config:
test_val = "D:/比对/cl"
# test_val = "./data/test_data_100"
test_model = "checkpoints/best_resnet18_v12.pth"
# test_model = "checkpoints/best_resnet18_v12.pth"
# test_model = "checkpoints/zhanting_res_801.pth"
test_model = "checkpoints/zhanting_res_abroad_8021.pth"

View File

@ -61,8 +61,17 @@ class FeatsInterface:
batch_patches = []
patches = []
for i, img in enumerate(images):
img = img.copy()
patch = self.transform(img)
img = img.copy()
## 对 img 进行补黑边生成新的图像new_img
width, height = img.size
new_size = max(width, height)
new_img = Image.new("RGB", (new_size, new_size), (0, 0, 0))
paste_x = (new_size - width) // 2
paste_y = (new_size - height) // 2
new_img.paste(img, (paste_x, paste_y))
patch = self.transform(new_img)
if str(self.device) != "cpu":
patch = patch.to(device=self.device).half()
else:
@ -107,10 +116,12 @@ class FeatsInterface:
patch = self.transform(img1)
# patch = patch.to(device=self.device).half()
if str(self.device) != "cpu":
patch = patch.to(device=self.device).half()
else:
patch = patch.to(device=self.device)
# if str(self.device) != "cpu":
# patch = patch.to(device=self.device).half()
# patch = patch.to(device=self.device)
# else:
# patch = patch.to(device=self.device)
patch = patch.to(device=self.device)
patches.append(patch)
if (d + 1) % self.batch_size == 0: