Files
ieemoo-ai-zhanting/spilt_train_val.py
huangtao 4bb117c407 1.3
2022-07-01 14:19:10 +08:00

56 lines
1.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import random
import argparse
parser = argparse.ArgumentParser()
#xml文件的地址根据自己的数据进行修改 xml一般存放在Annotations下
parser.add_argument('--xml_path', default='paper_data/Annotations', type=str, help='input xml label path')
#数据集的划分地址选择自己数据下的ImageSets/Main
parser.add_argument('--txt_path', default='paper_data/ImageSets/Main', type=str, help='output txt label path')
opt = parser.parse_args()
trainval_percent = 1.0
train_percent = 0.8
xmlfilepath = opt.xml_path
txtsavepath = opt.txt_path
total_xml = os.listdir(xmlfilepath)
if not os.path.exists(txtsavepath):
os.makedirs(txtsavepath)
num = len(total_xml)
list_index = range(num)
tv = int(num * trainval_percent)
tr = int(tv * train_percent)
trainval = random.sample(list_index, tv)
train = random.sample(trainval, tr)
file_trainval = open(txtsavepath + '/trainval.txt', 'w')
file_test = open(txtsavepath + '/test.txt', 'w')
file_train = open(txtsavepath + '/train.txt', 'w')
file_val = open(txtsavepath + '/val.txt', 'w')
addtrain_path = r"D:\PycharmProjects\Zhanting\yolov5_1\img_data\getimg_6.30"
for i in list_index:
name = total_xml[i][:-4] + '\n'
addimg_name = name.strip() + ".jpg"
# print(addimg_name,type(addimg_name),len(addimg_name))
if i in trainval:
file_trainval.write(name)
if i in train:
file_train.write(name)
else:
if addimg_name in os.listdir(addtrain_path):#把某些被划分到验证集的数据加入训练集中
print("addimg_name:",addimg_name)
file_train.write(name)
else:
file_val.write(name)
# file_val.write(name)
else:
file_test.write(name)
file_trainval.close()
file_train.close()
file_val.close()
file_test.close()