90 lines
3.1 KiB
Python
90 lines
3.1 KiB
Python
# 根据标签文件求先验框
|
|
|
|
import os
|
|
import numpy as np
|
|
import xml.etree.cElementTree as et
|
|
from kmeans import kmeans, avg_iou
|
|
|
|
FILE_ROOT = "/home/nxy/nxy_project/python_project/Data/zhanting_add/" # 根路径
|
|
ANNOTATION_ROOT = "xmls" # 数据集标签文件夹路径
|
|
ANNOTATION_PATH = FILE_ROOT + ANNOTATION_ROOT
|
|
|
|
ANCHORS_TXT_PATH = "data/anchors.txt"
|
|
|
|
CLUSTERS = 9
|
|
CLASS_NAMES = ["6925303773908", "6924743915848", "6920152471616", "6920005772716", "6902227018162",
|
|
"6920459905012", "6972194461407", "6935284412918", "6921489033706", "6904012526494",
|
|
"6923644272159", "6924882486100", "6956511907458"]
|
|
|
|
def load_data(anno_dir, class_names):
|
|
xml_names = os.listdir(anno_dir)
|
|
boxes = []
|
|
for xml_name in xml_names:
|
|
xml_pth = os.path.join(anno_dir, xml_name)
|
|
tree = et.parse(xml_pth)
|
|
|
|
width = float(tree.findtext("./size/width"))
|
|
height = float(tree.findtext("./size/height"))
|
|
|
|
try:
|
|
for obj in tree.iter("object"):
|
|
xmin = float(obj.findtext("bndbox/xmin")) / width
|
|
ymin = float(obj.findtext("bndbox/ymin")) / height
|
|
xmax = float(obj.findtext("bndbox/xmax")) / width
|
|
ymax = float(obj.findtext("bndbox/ymax")) / height
|
|
|
|
box = [xmax - xmin, ymax - ymin]
|
|
boxes.append(box)
|
|
|
|
except:
|
|
print("***",obj)
|
|
|
|
return np.array(boxes)
|
|
|
|
if __name__ == '__main__':
|
|
|
|
anchors_txt = open(ANCHORS_TXT_PATH, "w")
|
|
print(ANNOTATION_PATH)
|
|
|
|
train_boxes = load_data(ANNOTATION_PATH, CLASS_NAMES)
|
|
print(train_boxes)
|
|
count = 1
|
|
best_accuracy = 0
|
|
best_anchors = []
|
|
best_ratios = []
|
|
|
|
for i in range(10): ##### 可以修改,不要太大,否则时间很长
|
|
anchors_tmp = []
|
|
clusters = kmeans(train_boxes, k=CLUSTERS)
|
|
idx = clusters[:, 0].argsort()
|
|
clusters = clusters[idx]
|
|
# print(clusters)
|
|
|
|
for j in range(CLUSTERS):
|
|
anchor = [round(clusters[j][0] * 640), round(clusters[j][1] * 640)]
|
|
anchors_tmp.append(anchor)
|
|
print(f"Anchors:{anchor}")
|
|
|
|
temp_accuracy = avg_iou(train_boxes, clusters) * 100
|
|
print("Train_Accuracy:{:.2f}%".format(temp_accuracy))
|
|
|
|
ratios = np.around(clusters[:, 0] / clusters[:, 1], decimals=2).tolist()
|
|
ratios.sort()
|
|
print("Ratios:{}".format(ratios))
|
|
print(20 * "*" + " {} ".format(count) + 20 * "*")
|
|
|
|
count += 1
|
|
|
|
if temp_accuracy > best_accuracy:
|
|
best_accuracy = temp_accuracy
|
|
best_anchors = anchors_tmp
|
|
best_ratios = ratios
|
|
best_anchors = np.array(best_anchors)
|
|
best_anchors_idx = (best_anchors[:, 0]*best_anchors[:,1]).argsort()
|
|
best_anchors_sort = best_anchors[best_anchors_idx]
|
|
print("best_anchors_sort:",best_anchors_sort)
|
|
anchors_txt.write("Best Accuracy = " + str(round(best_accuracy, 2)) + '%' + "\r\n")
|
|
anchors_txt.write("Best Anchors = " + str(best_anchors_sort) + "\r\n")
|
|
anchors_txt.write("Best Ratios = " + str(best_ratios))
|
|
anchors_txt.close()
|