这是我下一篇要学习解读的训练代码:
from __future__ import print_function
import os,time,cv2, sys, math
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
import time, datetime
import argparse
import random
import os, sys
import subprocess# use 'Agg' on matplotlib so that plots could be generated even without Xserver
# running
import matplotlib
matplotlib.use('Agg')from utils import utils, helpers
from builders import model_builderimport matplotlib.pyplot as pltdef str2bool(v):if v.lower() in ('yes', 'true', 't', 'y', '1'):return Trueelif v.lower() in ('no', 'false', 'f', 'n', '0'):return Falseelse:raise argparse.ArgumentTypeError('Boolean value expected.')parser = argparse.ArgumentParser()
parser.add_argument('--num_epochs', type=int, default=30, help='Number of epochs to train for')
parser.add_argument('--epoch_start_i', type=int, default=0, help='Start counting epochs from this number')
parser.add_argument('--checkpoint_step', type=int, default=5, help='How often to save checkpoints (epochs)')
parser.add_argument('--validation_step', type=int, default=1, help='How often to perform validation (epochs)')
parser.add_argument('--image', type=str, default=None, help='The image you want to predict on. Only valid in "predict" mode.')
parser.add_argument('--continue_training', type=str2bool, default=False, help='Whether to continue training from a checkpoint')
parser.add_argument('--dataset', type=str, default="CamVid", help='Dataset you are using.')
parser.add_argument('--crop_height', type=int, default=512, help='Height of cropped input image to network')
parser.add_argument('--crop_width', type=int, default=512, help='Width of cropped input image to network')
parser.add_argument('--batch_size', type=int, default=1, help='Number of images in each batch')
parser.add_argument('--num_val_images', type=int, default=20, help='The number of images to used for validations')
parser.add_argument('--h_flip', type=str2bool, default=False, help='Whether to randomly flip the image horizontally for data augmentation')
parser.add_argument('--v_flip', type=str2bool, default=False, help='Whether to randomly flip the image vertically for data augmentation')
parser.add_argument('--brightness', type=float, default=None, help='Whether to randomly change the image brightness for data augmentation. Specifies the max bightness change as a factor between 0.0 and 1.0. For example, 0.1 represents a max brightness change of 10%% (+-).')
parser.add_argument('--rotation', type=float, default=None, help='Whether to randomly rotate the image for data augmentation. Specifies the max rotation angle in degrees.')
parser.add_argument('--model', type=str, default="DeepLabV3", help='The model you are using. See model_builder.py for supported models')
parser.add_argument('--frontend', type=str, default="ResNet101", help='The frontend you are using. See frontend_builder.py for supported models')
args = parser.parse_args()def data_augmentation(input_image, output_image):# Data augmentationinput_image, output_image = utils.random_crop(input_image, output_image, args.crop_height, args.crop_width)if args.h_flip and random.randint(0,1):input_image = cv2.flip(input_image, 1)output_image = cv2.flip(output_image, 1)if args.v_flip and random.randint(0,1):input_image = cv2.flip(input_image, 0)output_image = cv2.flip(output_image, 0)if args.brightness:factor = 1.0 + random.uniform(-1.0*args.brightness, args.brightness)table = np.array([((i / 255.0) * factor) * 255 for i in np.arange(0, 256)]).astype(np.uint8)input_image = cv2.LUT(input_image, table)if args.rotation:angle = random.uniform(-1*args.rotation, args.rotation)if args.rotation:M = cv2.getRotationMatrix2D((input_image.shape[1]//2, input_image.shape[0]//2), angle, 1.0)input_image = cv2.warpAffine(input_image, M, (input_image.shape[1], input_image.shape[0]), flags=cv2.INTER_NEAREST)output_image = cv2.warpAffine(output_image, M, (output_image.shape[1], output_image.shape[0]), flags=cv2.INTER_NEAREST)return input_image, output_image# Get the names of the classes so we can record the evaluation results
class_names_list, label_values = helpers.get_label_info(os.path.join(args.dataset, "class_dict.csv"))
class_names_string = ""
for class_name in class_names_list:if not class_name == class_names_list[-1]:class_names_string = class_names_string + class_name + ", "else:class_names_string = class_names_string + class_namenum_classes = len(label_values)config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess=tf.Session(config=config)# Compute your softmax cross entropy loss
net_input = tf.placeholder(tf.float32,shape=[None,None,None,3])
net_output = tf.placeholder(tf.float32,shape=[None,None,None,num_classes])network, init_fn = model_builder.build_model(model_name=args.model, frontend=args.frontend, net_input=net_input, num_classes=num_classes, crop_width=args.crop_width, crop_height=args.crop_height, is_training=True)loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=network, labels=net_output))opt = tf.train.RMSPropOptimizer(learning_rate=0.0001, decay=0.995).minimize(loss, var_list=[var for var in tf.trainable_variables()])saver=tf.train.Saver(max_to_keep=1000)
sess.run(tf.global_variables_initializer())utils.count_params()# If a pre-trained ResNet is required, load the weights.
# This must be done AFTER the variables are initialized with sess.run(tf.global_variables_initializer())
if init_fn is not None:init_fn(sess)# Load a previous checkpoint if desired
model_checkpoint_name = "checkpoints/latest_model_" + args.model + "_" + args.dataset + ".ckpt"
if args.continue_training:print('Loaded latest model checkpoint')saver.restore(sess, model_checkpoint_name)# Load the data
print("Loading the data ...")
train_input_names,train_output_names, val_input_names, val_output_names, test_input_names, test_output_names = utils.prepare_data(dataset_dir=args.dataset)print("\n***** Begin training *****")
print("Dataset -->", args.dataset)
print("Model -->", args.model)
print("Crop Height -->", args.crop_height)
print("Crop Width -->", args.crop_width)
print("Num Epochs -->", args.num_epochs)
print("Batch Size -->", args.batch_size)
print("Num Classes -->", num_classes)print("Data Augmentation:")
print("\tVertical Flip -->", args.v_flip)
print("\tHorizontal Flip -->", args.h_flip)
print("\tBrightness Alteration -->", args.brightness)
print("\tRotation -->", args.rotation)
print("")avg_loss_per_epoch = []
avg_scores_per_epoch = []
avg_iou_per_epoch = []# Which validation images do we want
val_indices = []
num_vals = min(args.num_val_images, len(val_input_names))# Set random seed to make sure models are validated on the same validation images.
# So you can compare the results of different models more intuitively.
random.seed(16)
val_indices=random.sample(range(0,len(val_input_names)),num_vals)# Do the training here
Animal = []
Archway = []
Bicyclist = []
Bridge = []
Building = []
Car = []
CartLuggagePram = []
Child = []
Column_Pole = []
Fence = []
LaneMkgsDriv = []
LaneMkgsNonDriv = []
Misc_Text = []
MotorcycleScooter = []
OtherMoving = []
ParkingBlock = []
Pedestrian = []
Road = []
RoadShoulder = []
Sidewalk = []
SignSymbol = []
Sky = []
SUVPickupTruck = []
TrafficCone = []
TrafficLight = []
Train = []
Tree = []
Truck_Bus = []
Tunnel = []
VegetationMisc = []
Void = []
Wall = []for epoch in range(args.epoch_start_i, args.num_epochs):current_losses = []cnt=0# Equivalent to shufflingid_list = np.random.permutation(len(train_input_names))num_iters = int(np.floor(len(id_list) / args.batch_size))st = time.time()epoch_st=time.time()for i in range(num_iters):# st=time.time()input_image_batch = []output_image_batch = []# Collect a batch of imagesfor j in range(args.batch_size):index = i*args.batch_size + jid = id_list[index]input_image = utils.load_image(train_input_names[id])output_image = utils.load_image(train_output_names[id])with tf.device('/cpu:0'):input_image, output_image = data_augmentation(input_image, output_image)# Prep the data. Make sure the labels are in one-hot formatinput_image = np.float32(input_image) / 255.0output_image = np.float32(helpers.one_hot_it(label=output_image, label_values=label_values))input_image_batch.append(np.expand_dims(input_image, axis=0))output_image_batch.append(np.expand_dims(output_image, axis=0))if args.batch_size == 1:input_image_batch = input_image_batch[0]output_image_batch = output_image_batch[0]else:input_image_batch = np.squeeze(np.stack(input_image_batch, axis=1))output_image_batch = np.squeeze(np.stack(output_image_batch, axis=1))# Do the training_,current=sess.run([opt,loss],feed_dict={net_input:input_image_batch,net_output:output_image_batch})current_losses.append(current)cnt = cnt + args.batch_sizeif cnt % 20 == 0:string_print = "Epoch = %d Count = %d Current_Loss = %.4f Time = %.2f"%(epoch,cnt,current,time.time()-st)utils.LOG(string_print)st = time.time()mean_loss = np.mean(current_losses)avg_loss_per_epoch.append(mean_loss)# Create directories if neededif not os.path.isdir("%s/%04d"%("checkpoints",epoch)):os.makedirs("%s/%04d"%("checkpoints",epoch))# Save latest checkpoint to same file nameprint("Saving latest checkpoint")saver.save(sess,model_checkpoint_name)if val_indices != 0 and epoch % args.checkpoint_step == 0:print("Saving checkpoint for this epoch")saver.save(sess,"%s/%04d/model.ckpt"%("checkpoints",epoch))if epoch % args.validation_step == 0:print("Performing validation")target=open("%s/%04d/val_scores.csv"%("checkpoints",epoch),'w')target.write("val_name, avg_accuracy, precision, recall, f1 score, mean iou, %s\n" % (class_names_string))scores_list = []class_scores_list = []precision_list = []recall_list = []f1_list = []iou_list = []# Do the validation on a small set of validation imagesfor ind in val_indices:input_image = np.expand_dims(np.float32(utils.load_image(val_input_names[ind])[:args.crop_height, :args.crop_width]),axis=0)/255.0gt = utils.load_image(val_output_names[ind])[:args.crop_height, :args.crop_width]gt = helpers.reverse_one_hot(helpers.one_hot_it(gt, label_values))# st = time.time()output_image = sess.run(network,feed_dict={net_input:input_image})output_image = np.array(output_image[0,:,:,:])output_image = helpers.reverse_one_hot(output_image)out_vis_image = helpers.colour_code_segmentation(output_image, label_values)accuracy, class_accuracies, prec, rec, f1, iou = utils.evaluate_segmentation(pred=output_image, label=gt, num_classes=num_classes)file_name = utils.filepath_to_name(val_input_names[ind])target.write("%s, %f, %f, %f, %f, %f"%(file_name, accuracy, prec, rec, f1, iou))for item in class_accuracies:target.write(", %f"%(item))target.write("\n")scores_list.append(accuracy)class_scores_list.append(class_accuracies)precision_list.append(prec)recall_list.append(rec)f1_list.append(f1)iou_list.append(iou)gt = helpers.colour_code_segmentation(gt, label_values)file_name = os.path.basename(val_input_names[ind])file_name = os.path.splitext(file_name)[0]cv2.imwrite("%s/%04d/%s_pred.png"%("checkpoints",epoch, file_name),cv2.cvtColor(np.uint8(out_vis_image), cv2.COLOR_RGB2BGR))cv2.imwrite("%s/%04d/%s_gt.png"%("checkpoints",epoch, file_name),cv2.cvtColor(np.uint8(gt), cv2.COLOR_RGB2BGR))target.close()avg_score = np.mean(scores_list)class_avg_scores = np.mean(class_scores_list, axis=0)avg_scores_per_epoch.append(avg_score)avg_precision = np.mean(precision_list)avg_recall = np.mean(recall_list)avg_f1 = np.mean(f1_list)avg_iou = np.mean(iou_list)avg_iou_per_epoch.append(avg_iou)print("\nAverage validation accuracy for epoch # %04d = %f"% (epoch, avg_score))print("Average per class validation accuracies for epoch # %04d:"% (epoch))for index, item in enumerate(class_avg_scores):print("%s = %f" % (class_names_list[index], item))print("Validation precision = ", avg_precision)print("Validation recall = ", avg_recall)print("Validation F1 score = ", avg_f1)print("Validation IoU score = ", avg_iou)epoch_time=time.time()-epoch_stremain_time=epoch_time*(args.num_epochs-1-epoch)m, s = divmod(remain_time, 60)h, m = divmod(m, 60)if s!=0:train_time="Remaining training time = %d hours %d minutes %d seconds\n"%(h,m,s)else:train_time="Remaining training time : Training completed.\n"utils.LOG(train_time)scores_list = []fig1, ax1 = plt.subplots(figsize=(11, 8))ax1.plot(range(epoch+1), avg_scores_per_epoch)ax1.set_title("Average validation accuracy vs epochs")ax1.set_xlabel("Epoch")ax1.set_ylabel("Avg. val. accuracy")plt.savefig('accuracy_vs_epochs.png')plt.clf()fig2, ax2 = plt.subplots(figsize=(11, 8))ax2.plot(range(epoch+1), avg_loss_per_epoch)ax2.set_title("Average loss vs epochs")ax2.set_xlabel("Epoch")ax2.set_ylabel("Current loss")plt.savefig('loss_vs_epochs.png')plt.clf()fig3, ax3 = plt.subplots(figsize=(11, 8))ax3.plot(range(epoch+1), avg_iou_per_epoch)ax3.set_title("Average IoU vs epochs")ax3.set_xlabel("Epoch")ax3.set_ylabel("Current IoU")plt.savefig('iou_vs_epochs.png')for index, item in enumerate(class_avg_scores):print(item)print(class_names_list[index])if "Animal" == class_names_list[index]:Animal.append(item)print(Animal)print(class_names_list[index])fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Animal)elif "Archway" == class_names_list[index]:Archway.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Archway)elif "Bicyclist" == class_names_list[index]:Bicyclist.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Bicyclist)elif "Bridge" == class_names_list[index]:Bridge.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Bridge)elif "Building" == class_names_list[index]:Building.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Building)elif "Car" == class_names_list[index]:Car .append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Car )elif "CartLuggagePram" == class_names_list[index]:CartLuggagePram.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), CartLuggagePram)elif "Child" == class_names_list[index]:Child.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Child)elif "Column_Pole" == class_names_list[index]:Column_Pole.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Column_Pole)elif "Fence" == class_names_list[index]:Fence.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Fence)elif "LaneMkgsDriv" == class_names_list[index]:LaneMkgsDriv.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), LaneMkgsDriv)elif "LaneMkgsNonDriv" == class_names_list[index]:LaneMkgsNonDriv.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), LaneMkgsNonDriv)elif "Misc_Text" == class_names_list[index]:Misc_Text.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Misc_Text)elif "MotorcycleScooter" == class_names_list[index]:MotorcycleScooter.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), MotorcycleScooter)elif "OtherMoving" == class_names_list[index]:OtherMoving.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), OtherMoving)elif "ParkingBlock" == class_names_list[index]:ParkingBlock.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), ParkingBlock)elif "Pedestrian" == class_names_list[index]:Pedestrian.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Pedestrian)elif "Road" == class_names_list[index]:Road.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Road)elif "RoadShoulder" == class_names_list[index]:RoadShoulder.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), RoadShoulder)elif "Sidewalk" == class_names_list[index]:Sidewalk.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Sidewalk)elif "SignSymbol" == class_names_list[index]:SignSymbol.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), SignSymbol)elif "Sky" == class_names_list[index]:Sky.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Sky)elif "SUVPickupTruck" == class_names_list[index]:SUVPickupTruck.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), SUVPickupTruck)elif "TrafficCone" == class_names_list[index]:TrafficCone .append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), TrafficCone )elif "TrafficLight" == class_names_list[index]:TrafficLight .append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), TrafficLight )elif "Train" == class_names_list[index]:Train.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Train)elif "Tree" == class_names_list[index]:Tree.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Tree)elif "Truck_Bus" == class_names_list[index]:Truck_Bus.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Truck_Bus)elif "Tunnel" == class_names_list[index]:Tunnel.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Tunnel)elif "VegetationMisc" == class_names_list[index]:VegetationMisc.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), VegetationMisc)elif "Void" == class_names_list[index]:Void.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Void)elif "Wall" == class_names_list[index]:Wall.append(item)fig, ax = plt.subplots(figsize=(11, 8))ax.plot(range(epoch + 1), Wall)else:print("fuck")ax.set_title("%s vs epochs"% class_names_list[index])ax.set_xlabel("Epoch")ax.set_ylabel("accuracy")plt.savefig("%s_accuracy_vs_epochs.png"% class_names_list[index])
参考:https://blog.csdn.net/qq_34606546/article/details/86514306