diff --git a/data/anchors.txt b/data/anchors.txt index c87b7e3..b299248 100644 --- a/data/anchors.txt +++ b/data/anchors.txt @@ -1,11 +1,11 @@ -Best Accuracy = 77.39% -Best Anchors = [[ 87 51] - [ 78 85] - [136 65] - [ 99 152] - [147 108] - [238 72] - [240 127] - [173 180] - [264 203]] -Best Ratios = [0.65, 0.92, 0.96, 1.3, 1.35, 1.69, 1.89, 2.1, 3.31] \ No newline at end of file +Best Accuracy = 77.86% +Best Anchors = [[109 52] + [ 78 81] + [ 96 152] + [139 106] + [230 70] + [160 172] + [241 126] + [217 202] + [307 201]] +Best Ratios = [0.63, 0.93, 0.95, 1.07, 1.31, 1.53, 1.91, 2.08, 3.29] \ No newline at end of file diff --git a/detect.py b/detect.py index 1da7335..cd7299e 100644 --- a/detect.py +++ b/detect.py @@ -150,7 +150,8 @@ def detect(opt, save_img=False): "6923644272159", "6924882486100", "6956511907458"] targets = [] for target in pred[0]: - targets.append({"Class": names[int(target[5].item())], "precision": target[4].item(), "xy1": [target[0].item(), target[1].item()], + targets.append({"Class": names[int(target[5].item())], "precision": target[4].item(), + "xy1": [target[0].item(), target[1].item()], "xy2": [target[2].item(), target[3].item()]}) resu = {"TargetDetect": targets} print(resu) diff --git a/getval_imgs.py b/getval_imgs.py index 056d4ea..2ffbbba 100644 --- a/getval_imgs.py +++ b/getval_imgs.py @@ -1,7 +1,7 @@ import os,shutil allimgs_path = "paper_data/images" -valimgs_path = "paper_data/val_imgs" +valimgs_path = "paper_data/val_imgs_7.1" valtxt_path = "paper_data/ImageSets/Main/val.txt" with open(valtxt_path, "r", encoding='UTF-8') as val_file: diff --git a/ieemoo-ai-zhanting.py b/ieemoo-ai-zhanting.py index 31e68be..24ab19f 100644 --- a/ieemoo-ai-zhanting.py +++ b/ieemoo-ai-zhanting.py @@ -40,7 +40,7 @@ parser = argparse.ArgumentParser() parser.add_argument('--weights', nargs='+', type=str, default='../module/ieemoo-ai-zhanting/model/now/best.pt', help='model.pt path(s)') parser.add_argument('--source', type=str, default='../module/ieemoo-ai-zhanting/imgs/1.jpg', help='source') # file/folder, 0 for webcam parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') -parser.add_argument('--conf-thres', type=float, default=0.55, help='object confidence threshold') +parser.add_argument('--conf-thres', type=float, default=0.60, help='object confidence threshold') parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--view-img', type=bool, default=True, help='display results') diff --git a/label_num.py b/label_num.py new file mode 100644 index 0000000..e4b186b --- /dev/null +++ b/label_num.py @@ -0,0 +1,30 @@ +### 计算xml中每个类别标签的数量 ### + +import os +import shutil +import xml.etree.ElementTree as ET + +def caculate_label_numble(file_path): + label_numble_info = {} + for file in os.listdir(file_path): + if file.endswith(".xml"): + xml_path = os.path.join(file_path, file) + tree = ET.parse(xml_path) + root = tree.getroot() + + for element in root.findall('object'): + for element1 in element.findall('name'): + if element1.text not in label_numble_info: + # if element1.text=="6935284412918":#查看异常分类对应的xml + # print(file) + label_numble_info[element1.text] = 1 + else: + label_numble_info[element1.text] += 1 + return label_numble_info + +if __name__=="__main__": + + file_path = "paper_data/Annotations" + label_numble_info = caculate_label_numble(file_path) + for key, value in label_numble_info.items(): + print('{}:{}'.format(key, value)) diff --git a/parseXml.py b/parseXml.py new file mode 100644 index 0000000..001f3e6 --- /dev/null +++ b/parseXml.py @@ -0,0 +1,191 @@ +# -*- coding: utf-8 -*- +# -*- author: jokker -*- +import os +import time +import numpy as np +import xml.etree.ElementTree as ET +from xml.dom.minidom import parseString +from lxml.etree import Element, SubElement, tostring +from xml.etree.ElementTree import fromstring, ElementTree +from xml.etree import ElementTree + +cwd=os.getcwd() +def prettyXml(element, indent, newline, level=0): # elemnt为传进来的Elment类,参数indent用于缩进,newline用于换行 + if element: # 判断element是否有子元素 + if element.text == None or element.text.isspace(): # 如果element的text没有内容 + element.text = newline + indent * (level + 1) + else: + element.text = newline + indent * (level + 1) + element.text.strip() + newline + indent * (level + 1) + # else: # 此处两行如果把注释去掉,Element的text也会另起一行 + # element.text = newline + indent * (level + 1) + element.text.strip() + newline + indent * level + temp = list(element) # 将elemnt转成list + for subelement in temp: + if temp.index(subelement) < (len(temp) - 1): # 如果不是list的最后一个元素,说明下一个行是同级别元素的起始,缩进应一致 + subelement.tail = newline + indent * (level + 1) + else: # 如果是list的最后一个元素, 说明下一行是母元素的结束,缩进应该少一个 + subelement.tail = newline + indent * level + prettyXml(subelement, indent, newline, level=level + 1) # 对子元素进行递归操作 + return element + + + +class ParseXml(object): + """解析 xml 中的信息,将信息导出为 xml""" + + def __init__(self, input_xml): + self.__ttrs = {"folder", "filename", "path", "segmented", "size", "source", "object"} # 所有的属性 + self.__xml_info_dict = {} # xml 信息字典 + # self.__objects_info = [] + # self.__size_info = {} + # self.__source_info = {} + self.input_xml = input_xml + + def get_info(self): + tree = ET.parse(self.input_xml) + root = tree.getroot() + for name in self.__ttrs: + for element in root.findall(name): + print(element.text) + + def get_xml_info(self): #获取xml中的信息,以字典形式保存 + """获取xml中所有信息,并以字典形式呈现""" + xml_info = {} + tree = ET.parse(self.input_xml) + root = tree.getroot() + + for element in root.findall('folder'): + xml_info["folder"]=element.text + for element in root.findall('filename'): + xml_info['filename'] = element.text + for element in root.findall('path'): + xml_info['path'] = element.text + + for element in root.findall('source'): + for element1 in element.findall('database'): + xml_info['database'] = "Unknown" + + for element in root.findall('size'): + for element1 in element.findall('width'): + xml_info['width'] = element1.text + for element1 in element.findall('height'): + xml_info['height'] = element1.text + for element1 in element.findall('depth'): + xml_info['depth'] = element1.text + + for element in root.findall('segmented'): + xml_info['segmented'] = 0 + boxes=[] + for element in root.findall('object'): + box={} + for element1 in element.findall('name'): + box['name'] = element1.text + for element1 in element.findall('pose'): + box['pose'] = element1.text + for element1 in element.findall('truncated'): + box['truncated'] = 0 + for element1 in element.findall('difficult'): + box['difficult'] = 0 + for element1 in element.findall("bndbox"): + box['bndbox']={} + for element2 in element1.findall("xmin"): + box['bndbox']['xmin'] = element2.text + for element2 in element1.findall("ymin"): + box['bndbox']['ymin'] = element2.text + for element2 in element1.findall("xmax"): + box['bndbox']['xmax'] = element2.text + for element2 in element1.findall("ymax"): + box['bndbox']['ymax'] = element2.text + boxes.append(box) + xml_info['object']=boxes + return xml_info + + def write_to_xml(self,result,save_dir): + """ + 将目标信息写入到xml中 + """ + # if not os.path.exists(save_dir): + # os.mkdir(save_dir) + imgname=os.path.splitext(result['filename'])[0] + xmlPath = os.path.join(save_dir, imgname)+".xml" + + root = ET.Element("annotations") + ET.SubElement(root, "folder").text = result['folder'] + ET.SubElement(root, "filename").text = result['filename'] + ET.SubElement(root, "path").text = result['path'] + size = ET.SubElement(root, "size") + ET.SubElement(size, "width").text = str(result['width']) + ET.SubElement(size, "height").text = str(result['height']) + ET.SubElement(size, "depth").text = "3" + + for info in result['object']: + obj = ET.SubElement(root, "object") + ET.SubElement(obj, "name").text = str(info['name']) + ET.SubElement(obj, "pose").text = str(info['pose']) + ET.SubElement(obj, "truncated").text = str(info['truncated']) + ET.SubElement(obj, "difficult").text = str(info['difficult']) + + bbox = ET.SubElement(obj, "bndbox") + ET.SubElement(bbox, "xmin").text = str(int(float(info['bndbox']['xmin']))) + ET.SubElement(bbox, "ymin").text = str(int(float(info['bndbox']['ymin']))) + ET.SubElement(bbox, "xmax").text = str(int(float(info['bndbox']['xmax']))) + ET.SubElement(bbox, "ymax").text = str(int(float(info['bndbox']['ymax']))) + + tree = ET.ElementTree(root) + tree.write(xmlPath) + + tree = ElementTree.parse(xmlPath) + root = tree.getroot() # 得到根元素,Element类 + root = prettyXml(root, '\t', '\n') # 执行美化方法 + # ElementTree.dump(root) # 打印美化后的结果 + tree = ET.ElementTree(root) # 转换为可保存的结构 + tree.write(xmlPath) # 保存美化后的结果 + + return xmlPath + + def update_node(self,xml,save_dir,note,note_value): + """ + 更改xml中的任意节点,并保存新的xml + """ + # if not os.path.exists(save_dir): + # os.mkdir(save_dir) + imgname = os.path.basename(xml) + xmlPath = os.path.join(save_dir, imgname) + + tree = ElementTree.parse(xml) + root = tree.getroot() # 得到根元素,Element类 + try: + if note in ["folder", "filename", "path"]: + for element in root.findall(note): + element.text = note_value + elif note in ["width", "height", "depth"]: + for element in root.findall("size"): + for element1 in element.findall(note): + element1.text = str(note_value) + elif note in ["name", "pose", "truncated", "difficult"]: + for element in root.findall("object"): + for element1 in element.findall(note): + element1.text = str(note_value) + elif note in ["xmin", "xmax", "ymin", "ymax"]: + for element in root.findall("object"): + for element1 in element.findall("bndbox"): + for element2 in element1.findall(note): + element2.text = str(note_value) + except Exception as error: + print("修改的节点不在xml") + tree.write(xmlPath) +# + + + +if __name__ == "__main__": + + xml_path = 'C:\\Users\\Administrator\\Desktop\\1.xml' + save_dir = 'C:\\Users\\Administrator\\Desktop\\merge1' + xmlInfo = ParseXml(xml_path) + result = xmlInfo.get_xml_info() + print(result) + + xmlPath = xmlInfo.write_to_xml(result,save_dir) + print("--*--xmlPath--*--:", xmlPath) + xmlInfo.update_node(xmlPath, save_dir, "xmin", 200) + diff --git a/spilt_train_val.py b/spilt_train_val.py index 71d563a..1506d44 100644 --- a/spilt_train_val.py +++ b/spilt_train_val.py @@ -30,14 +30,22 @@ file_test = open(txtsavepath + '/test.txt', 'w') file_train = open(txtsavepath + '/train.txt', 'w') file_val = open(txtsavepath + '/val.txt', 'w') +addtrain_path = r"D:\PycharmProjects\Zhanting\yolov5_1\img_data\getimg_6.30" for i in list_index: name = total_xml[i][:-4] + '\n' + addimg_name = name.strip() + ".jpg" + # print(addimg_name,type(addimg_name),len(addimg_name)) if i in trainval: file_trainval.write(name) if i in train: file_train.write(name) else: - file_val.write(name) + if addimg_name in os.listdir(addtrain_path):#把某些被划分到验证集的数据加入训练集中 + print("addimg_name:",addimg_name) + file_train.write(name) + else: + file_val.write(name) + # file_val.write(name) else: file_test.write(name)