Files
ieemoo-ai-zhanting/anchors.py
2022-04-08 16:44:51 +08:00

90 lines
3.1 KiB
Python

# 根据标签文件求先验框
import os
import numpy as np
import xml.etree.cElementTree as et
from kmeans import kmeans, avg_iou
FILE_ROOT = "/home/nxy/nxy_project/python_project/Data/zhanting_add/" # 根路径
ANNOTATION_ROOT = "xmls" # 数据集标签文件夹路径
ANNOTATION_PATH = FILE_ROOT + ANNOTATION_ROOT
ANCHORS_TXT_PATH = "data/anchors.txt"
CLUSTERS = 9
CLASS_NAMES = ["6925303773908", "6924743915848", "6920152471616", "6920005772716", "6902227018162",
"6920459905012", "6972194461407", "6935284412918", "6921489033706", "6904012526494",
"6923644272159", "6924882486100", "6956511907458"]
def load_data(anno_dir, class_names):
xml_names = os.listdir(anno_dir)
boxes = []
for xml_name in xml_names:
xml_pth = os.path.join(anno_dir, xml_name)
tree = et.parse(xml_pth)
width = float(tree.findtext("./size/width"))
height = float(tree.findtext("./size/height"))
try:
for obj in tree.iter("object"):
xmin = float(obj.findtext("bndbox/xmin")) / width
ymin = float(obj.findtext("bndbox/ymin")) / height
xmax = float(obj.findtext("bndbox/xmax")) / width
ymax = float(obj.findtext("bndbox/ymax")) / height
box = [xmax - xmin, ymax - ymin]
boxes.append(box)
except:
print("***",obj)
return np.array(boxes)
if __name__ == '__main__':
anchors_txt = open(ANCHORS_TXT_PATH, "w")
print(ANNOTATION_PATH)
train_boxes = load_data(ANNOTATION_PATH, CLASS_NAMES)
print(train_boxes)
count = 1
best_accuracy = 0
best_anchors = []
best_ratios = []
for i in range(10): ##### 可以修改,不要太大,否则时间很长
anchors_tmp = []
clusters = kmeans(train_boxes, k=CLUSTERS)
idx = clusters[:, 0].argsort()
clusters = clusters[idx]
# print(clusters)
for j in range(CLUSTERS):
anchor = [round(clusters[j][0] * 640), round(clusters[j][1] * 640)]
anchors_tmp.append(anchor)
print(f"Anchors:{anchor}")
temp_accuracy = avg_iou(train_boxes, clusters) * 100
print("Train_Accuracy:{:.2f}%".format(temp_accuracy))
ratios = np.around(clusters[:, 0] / clusters[:, 1], decimals=2).tolist()
ratios.sort()
print("Ratios:{}".format(ratios))
print(20 * "*" + " {} ".format(count) + 20 * "*")
count += 1
if temp_accuracy > best_accuracy:
best_accuracy = temp_accuracy
best_anchors = anchors_tmp
best_ratios = ratios
best_anchors = np.array(best_anchors)
best_anchors_idx = (best_anchors[:, 0]*best_anchors[:,1]).argsort()
best_anchors_sort = best_anchors[best_anchors_idx]
print("best_anchors_sort:",best_anchors_sort)
anchors_txt.write("Best Accuracy = " + str(round(best_accuracy, 2)) + '%' + "\r\n")
anchors_txt.write("Best Anchors = " + str(best_anchors_sort) + "\r\n")
anchors_txt.write("Best Ratios = " + str(best_ratios))
anchors_txt.close()