# 根据标签文件求先验框 import os import numpy as np import xml.etree.cElementTree as et from kmeans import kmeans, avg_iou FILE_ROOT = "paper_data/" # 根路径 ANNOTATION_ROOT = "Annotations" # 数据集标签文件夹路径 ANNOTATION_PATH = FILE_ROOT + ANNOTATION_ROOT ANCHORS_TXT_PATH = "data/anchors.txt" CLUSTERS = 9 CLASS_NAMES = ["6925303773908", "6924743915848", "6920152471616", "6920005772716", "6902227018162", "6920459905012", "6972194461407", "6935284412918", "6921489033706", "6904012526494", "6923644272159", "6924882486100", "6956511907458"] def load_data(anno_dir, class_names): xml_names = os.listdir(anno_dir) boxes = [] for xml_name in xml_names: xml_pth = os.path.join(anno_dir, xml_name) tree = et.parse(xml_pth) width = float(tree.findtext("./size/width")) height = float(tree.findtext("./size/height")) try: for obj in tree.iter("object"): xmin = float(obj.findtext("bndbox/xmin")) / width ymin = float(obj.findtext("bndbox/ymin")) / height xmax = float(obj.findtext("bndbox/xmax")) / width ymax = float(obj.findtext("bndbox/ymax")) / height box = [xmax - xmin, ymax - ymin] boxes.append(box) except: print("***",obj) return np.array(boxes) if __name__ == '__main__': anchors_txt = open(ANCHORS_TXT_PATH, "w") print(ANNOTATION_PATH) train_boxes = load_data(ANNOTATION_PATH, CLASS_NAMES) print(train_boxes) count = 1 best_accuracy = 0 best_anchors = [] best_ratios = [] for i in range(10): ##### 可以修改,不要太大,否则时间很长 anchors_tmp = [] clusters = kmeans(train_boxes, k=CLUSTERS) idx = clusters[:, 0].argsort() clusters = clusters[idx] # print(clusters) for j in range(CLUSTERS): anchor = [round(clusters[j][0] * 640), round(clusters[j][1] * 640)] anchors_tmp.append(anchor) print(f"Anchors:{anchor}") temp_accuracy = avg_iou(train_boxes, clusters) * 100 print("Train_Accuracy:{:.2f}%".format(temp_accuracy)) ratios = np.around(clusters[:, 0] / clusters[:, 1], decimals=2).tolist() ratios.sort() print("Ratios:{}".format(ratios)) print(20 * "*" + " {} ".format(count) + 20 * "*") count += 1 if temp_accuracy > best_accuracy: best_accuracy = temp_accuracy best_anchors = anchors_tmp best_ratios = ratios best_anchors = np.array(best_anchors) best_anchors_idx = (best_anchors[:, 0]*best_anchors[:,1]).argsort() best_anchors_sort = best_anchors[best_anchors_idx] print("best_anchors_sort:",best_anchors_sort) anchors_txt.write("Best Accuracy = " + str(round(best_accuracy, 2)) + '%' + "\r\n") anchors_txt.write("Best Anchors = " + str(best_anchors_sort) + "\r\n") anchors_txt.write("Best Ratios = " + str(best_ratios)) anchors_txt.close()