热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

利用TensorFlowObjectDetectionAPI的maskrcnn模型训练自己的样本数据集

之前写过TensorFlowObjectDetectionAPI的部署方法,如何用样本标定工具标定自己的样本数据,以及用tensorflowkereas版本mask-rcnn进行训

之前写过TensorFlow Object Detection API的部署方法,如何用样本标定工具标定自己的样本数据,以及用tensorflow/kereas版本mask-rcnn进行训练。 本文记录如何用 TensorFlow Object Detection API 和 tensorflow的预训练模型训练自己的样本。

目录

准备工作:

将标定样本生成为.record格式文件

转换代码create_tf_record

编辑*.pbtxt类别定义

生成.record数据

训练样本数据

1、下载预训练模型

2、编辑pipeline_config文件

1、num_classes 修改为自己样本类别数

 2、修改与训练模型路径:

3、修改train_input_reader: 

4、修改eval_input_reader: 

进行训练:

tensorboad:

导出模型:



 


准备工作:

准备工作可参考我之前的博客:



  1. Tensorflow Object Detection API 环境搭建

  2. 标定自己的训练数据集,参考自制图像标注软件

 


将标定样本生成为.record格式文件

本步骤参考了mahxn0的代码:

 


转换代码create_tf_record

创建create_tf_record_Label_Image.py文件,内容如下:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Monday February 25 09:34:09 2019
@author: QingShui-Cheng
"""
"""Convert labeled dataset with Label_Image tool to TFRecord for object_detection.
Please note that this tool only applies to Label_Image's annotations(json file).
Example usage:
python3 create_tf_record.py \
--images_dir=your absolute path to read images and annotaion json files.
--label_map_path=your path to label_map.pbtxt
--output_path=your path to write .record.
"""
import cv2
import glob
import hashlib
import io
import json
import numpy as np
import os
import PIL.Image
import tensorflow as tf
import logging
from object_detection.utils import label_map_util
flags = tf.app.flags
flags.DEFINE_string('images_dir', None, 'Absolute path to images and annotaion json files.')
flags.DEFINE_string('label_map_path', None, 'Path to label map proto.')
flags.DEFINE_string('output_path', None, 'Path to the output tfrecord.')
FLAGS = flags.FLAGS
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def int64_list_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def bytes_list_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def float_list_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def create_tf_example(annotation_dict, label_map_dict=None):
"""Converts image and annotations to a tf.Example proto.
Args:
annotation_dict: A dictionary containing the following keys:
['height', 'width', 'filename', 'sha256_key', 'encoded_jpg',
'format', 'xmins', 'xmaxs', 'ymins', 'ymaxs', 'masks',
'class_names'].
label_map_dict: A dictionary maping class_names to indices.
Returns:
example: The converted tf.Example.
Raises:
ValueError: If label_map_dict is None or is not containing a class_name.
"""
if annotation_dict is None:
return None
if label_map_dict is None:
raise ValueError('`label_map_dict` is None')
height = annotation_dict.get('height', None)
width = annotation_dict.get('width', None)
filename = annotation_dict.get('filename', None)
sha256_key = annotation_dict.get('sha256_key', None)
encoded_jpg = annotation_dict.get('encoded_jpg', None)
image_format = annotation_dict.get('format', None)
xmins = annotation_dict.get('xmins', None)
xmaxs = annotation_dict.get('xmaxs', None)
ymins = annotation_dict.get('ymins', None)
ymaxs = annotation_dict.get('ymaxs', None)
masks = annotation_dict.get('masks', None)
class_names = annotation_dict.get('class_names', None)
print("class_names:",class_names)
labels = []
for class_name in class_names:
label = label_map_dict.get(class_name, 'None')
print("label:",label)
if label is None:
raise ValueError('`label_map_dict` is not containing {}.'.format(
class_name))
labels.append(label)
encoded_masks = []
for mask in masks:
pil_image = PIL.Image.fromarray(mask.astype(np.uint8))
output_io = io.BytesIO()
pil_image.save(output_io, format='PNG')
encoded_masks.append(output_io.getvalue())
feature_dict = {
'image/height': int64_feature(height),
'image/width': int64_feature(width),
'image/filename': bytes_feature(filename.encode('utf8')),
'image/source_id': bytes_feature(filename.encode('utf8')),
'image/key/sha256': bytes_feature(sha256_key.encode('utf8')),
'image/encoded': bytes_feature(encoded_jpg),
'image/format': bytes_feature(image_format.encode('utf8')),
'image/object/bbox/xmin': float_list_feature(xmins),
'image/object/bbox/xmax': float_list_feature(xmaxs),
'image/object/bbox/ymin': float_list_feature(ymins),
'image/object/bbox/ymax': float_list_feature(ymaxs),
'image/object/mask': bytes_list_feature(encoded_masks),
'image/object/class/label': int64_list_feature(labels)}
example = tf.train.Example(features=tf.train.Features(
feature=feature_dict))
return example
def _get_annotation_dict(images_path, json_path):
"""Get boundingboxes and masks.
Args:
images_path: Path to image.
json_path: Path to annotated json file corresponding to
the image. The json file annotated by labelme with keys:
"filename": "2018_6_27_18_3_45_523.jpg",
"size": "393954",
"file_attributes": "@Qingshui [email protected] @ All Right Reserved!",
"regions": [
{
"region_attributes": {
"type": "stone"
},
"shape_attributes": {
"name": "polygon",
"all_points_x": [ ... ],
"all_points_y": [ ... ]
}
}
]
Returns:
annotation_dict: A dictionary containing the following keys:
['height', 'width', 'filename', 'sha256_key', 'encoded_jpg',
'format', 'xmins', 'xmaxs', 'ymins', 'ymaxs', 'masks',
'class_names'].
#
# Raises:
# ValueError: If images_path or json_path is not exist.
"""
if (not os.path.exists(images_path) or
not os.path.exists(json_path)):
return None
fo = open(json_path,encoding='utf-8')
text = fo.read()
fo.close()
if text.startswith(u'\ufeff'):
text = text.encode('utf8')[3:].decode('utf8')
annotatiOns= json.loads(text)
regiOns= annotations.get('regions', None)
if regions is None:
return None

image_relative_path = images_path
print("imagePath",image_relative_path)
image_name = image_relative_path.split('/')[-1]
image_format = image_name.split('.')[-1].replace('jpg', 'jpeg')
with tf.gfile.GFile(images_path, 'rb') as fid:
encoded_jpg = fid.read()
image = cv2.imread(images_path)
height = image.shape[0]
width = image.shape[1]
key = hashlib.sha256(encoded_jpg).hexdigest()
xmins = []
xmaxs = []
ymins = []
ymaxs = []
masks = []
class_names = []
for mark in regions:
class_name = mark['region_attributes']['type']
class_names.append(class_name)
xarray = np.array(mark['shape_attributes']['all_points_x'])
yarray = np.array(mark['shape_attributes']['all_points_y'])
polygon = [xarray,yarray];
polygon = np.array(polygon).T

mask = np.zeros(image.shape[:2])
cv2.fillPoly(mask, [polygon], 1)
masks.append(mask)
# Boundingbox
x = polygon[:, 0]
y = polygon[:, 1]
xmin = np.min(x)
xmax = np.max(x)
ymin = np.min(y)
ymax = np.max(y)
xmins.append(float(xmin) / width)
xmaxs.append(float(xmax) / width)
ymins.append(float(ymin) / height)
ymaxs.append(float(ymax) / height)
annotation_dict = {'height': height,
'width': width,
'filename': image_name,
'sha256_key': key,
'encoded_jpg': encoded_jpg,
'format': image_format,
'xmins': xmins,
'xmaxs': xmaxs,
'ymins': ymins,
'ymaxs': ymaxs,
'masks': masks,
'class_names': class_names}
return annotation_dict
def create_tf_record(output_filename,
label_map_dict,
sample_dir,
samples_list):
"""Creates a TFRecord file from examples.

Args:
output_filename: File Path to where output file is saved.
label_map_dict: The label map dictionary.
sample_dir: Directory where image files are stored.
examples: Examples to parse and save to tf record.

"""
writer = tf.python_io.TFRecordWriter(output_filename)

for idx, jpgname in enumerate(samples_list):
if idx % 100 == 0:
logging.info('On image %d of %d', idx, len(samples_list))

jsOnname= jpgname[:-3]+"json"
try:
annotation_dict = _get_annotation_dict(jpgname,jsonname)
if annotation_dict is None:
continue
#print(annotation_dict)
tf_example = create_tf_example(annotation_dict, label_map_dict)
writer.write(tf_example.SerializeToString())
except ValueError:
logging.warning('Invalid example: %s, ignoring.', jpgname)

writer.close()
def main(_):
if not os.path.exists(FLAGS.images_dir):
raise ValueError('`images_dir` is not exist.')
if not os.path.exists(FLAGS.label_map_path):
raise ValueError('`label_map_path` is not exist.')
label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)
#搜寻样本图片和标定文件
sublist = os.listdir(FLAGS.images_dir)
imagelist = [];
for i in range(0, len(sublist)):
path = os.path.join(FLAGS.images_dir, sublist[i])
if os.path.isfile(path) and path.lower().endswith(".jpg"):
jsOnname= path[:-3]+"json";
if os.path.isfile(jsonname):
imagelist.append(path);

print("search {} sample images", len(imagelist))

#拆分训练和验证数据集
np.random.seed(25)
np.random.shuffle(imagelist)
num_images = len(imagelist)
num_train = int(0.7 * num_images)
train_images = imagelist[:num_train]
val_images = imagelist[num_train:]
logging.info('%d training and %d validation examples.',
len(train_images), len(val_images ))
train_output_path = os.path.join(FLAGS.output_path, 'sample_train.record')
val_output_path = os.path.join(FLAGS.output_path, 'sample_val.record')

create_tf_record(
train_output_path,
label_map_dict,
FLAGS.images_dir,
train_images)
create_tf_record(
val_output_path,
label_map_dict,
FLAGS.images_dir,
val_images)
print('Successfully created TFRecord to {}.'.format(FLAGS.output_path))
if __name__ == '__main__':
tf.app.run()

编辑*.pbtxt类别定义

我们命名为为sample_class.pbtxt文件,文件内容格式如下


可参考上述内容,修改item数目及内容,定义自己的样本类别。

 


生成.record数据

执行命令


python create_tf_record_mine.py   --images_dir /home/kc/code/tensorflow/MyData_Samples/   --label_map_path /home/kc/code/tensorflow/sample_class.pbtxt   --output_path /home/kc/code/tensorflow/sample 


 




  • --images_dir:样本图片及标注.json文件所在目录


  • --label_map_path:类别定义文件sample_class.pbtxt 全路径文件名


  • --output_path:转换后sample_train.record和sample_val.record存放路径。

 


训练样本数据


1、下载预训练模型

在Tensorflow提供的Tensorflow detection model zoo 下载COCO-trained models mask_rcnn_inception_v2_coco

下载后解压:


tar -zxvf mask_rcnn_inception_v2_coco_2018_01_28.tar.gz



2、编辑pipeline_config文件

/models/research/object_detection/samples/configs/下的 mask_rcnn_inception_v2_coco.config拷贝一份


cp /home/kc/code/tensorflow/models/research/object_detection/samples/configs/mask_rcnn_inception_v2_coco.config /home/kc/code/tensorflow/sample/


修改其中内容:


1、num_classes 修改为自己样本类别数

model {
  faster_rcnn {
    num_classes: 2
    image_resizer {
      keep_aspect_ratio_resizer {
        min_dimension: 800
        max_dimension: 1365
      }
    }

 2、修改与训练模型路径:



  • fine_tune_checkpoint 项修改为:mask_rcnn_inception_v2_coco_2018_01_28.tar.gz解压后目录/model.ckpt

gradient_clipping_by_norm: 10.0
fine_tune_checkpoint: "/home/kc/code/tensorflow/mask_rcnn_inception_v2_coco_2018_01_28/model.ckpt"
from_detection_checkpoint: true
# Note: The below line limits the training process to 200K steps, which we
# empirically found to be sufficient enough to train the pets dataset. This
# effectively bypasses the learning rate schedule (the learning rate will
# never decay). Remove the below line to train indefinitely.

3、修改train_input_reader: 



  • input_path: 前面生成的 sample_train.record 全路径

  • label_map_path:前面编辑的自己样本类别定义文件sample_class.pbtxt 全路径 

train_input_reader: {
tf_record_input_reader {
input_path: "/home/kc/code/tensorflow/sample/sample_train.record"
}
label_map_path: "/home/kc/code/tensorflow/sample/sample_class.pbtxt"
load_instance_masks: true
mask_type: PNG_MASKS
}

4、修改eval_input_reader: 



  • input_path: 前面生成的 sample_val.record 全路径

  • label_map_path:前面编辑的自己样本类别定义文件sample_class.pbtxt 全路径 

eval_input_reader: {
tf_record_input_reader {
input_path: "/home/kc/code/tensorflow/sample/sample_val.record"
}
label_map_path: "/home/kc/code/tensorflow/sample/sample_class.pbtxt"
load_instance_masks: true
mask_type: PNG_MASKS
shuffle: false
num_readers: 1
}

进行训练:

进入搭建好的环境或conda虚拟环境


source ./anaconda/bin/activate py35tf


进入 /models/research/ 目录,执行


export PYTHOnPATH=$PYTHONPATH:`pwd`:`pwd`/slim


进入 /models/research/object_detection/legacy/目录  开始训练 执行:


python train.py --logtostderr --train_dir=/home/kc/code/tensorflow/log --pipeline_config_path /home/kc/code/tensorflow/sample/mask_rcnn_inception_v2_coal.config


 




  • --train_dir: 训练模型保存目录


  • --pipeline_config_path: 前面修改的pipeline_config文件全路径

如果没有意外,会出现训练信息:



tensorboad:

输入命令:


tensorboard --logdir=/home/kc/code/tensorflow/log


在浏览器中输入http://0.0.0.0:6006,就能看到训练曲线了


导出模型:

进入 /models/models/research/object_detection/


python3 export_inference_graph.py \

             --input_type image_tensor \

             --pipeline_config_path /home/kc/code/tensorflow/sample/mask_rcnn_inception_v2_coco.config \

             --trained_checkpoint_prefix /home/kc/code/tensorflow/log/model.ckpt-200000\

             --output_directory /home/kc/code/tensorflow/output





  • --pipeline_config_path: 前面修改的pipeline_config文件全路径


  • --trained_checkpoint_prefix: 训练模型目录训练最后保存的model.ckpt-????


  • --output_directory: 导出模型目录

 

 



 



推荐阅读
  • 从 .NET 转 Java 的自学之路:IO 流基础篇
    本文详细介绍了 Java 中的 IO 流,包括字节流和字符流的基本概念及其操作方式。探讨了如何处理不同类型的文件数据,并结合编码机制确保字符数据的正确读写。同时,文中还涵盖了装饰设计模式的应用,以及多种常见的 IO 操作实例。 ... [详细]
  • 尽管使用TensorFlow和PyTorch等成熟框架可以显著降低实现递归神经网络(RNN)的门槛,但对于初学者来说,理解其底层原理至关重要。本文将引导您使用NumPy从头构建一个用于自然语言处理(NLP)的RNN模型。 ... [详细]
  • Scala 实现 UTF-8 编码属性文件读取与克隆
    本文介绍如何使用 Scala 以 UTF-8 编码方式读取属性文件,并实现属性文件的克隆功能。通过这种方式,可以确保配置文件在多线程环境下的一致性和高效性。 ... [详细]
  • Keras 实战:自编码器入门指南
    本文介绍了使用 Keras 框架实现自编码器的基本方法。自编码器是一种用于无监督学习的神经网络模型,主要功能包括数据降维、特征提取等。通过实际案例,我们将展示如何使用全连接层和卷积层来构建自编码器,并讨论不同维度对重建效果的影响。 ... [详细]
  • 本文详细介绍了Java中org.eclipse.ui.forms.widgets.ExpandableComposite类的addExpansionListener()方法,并提供了多个实际代码示例,帮助开发者更好地理解和使用该方法。这些示例来源于多个知名开源项目,具有很高的参考价值。 ... [详细]
  • 深入解析Spring Cloud Ribbon负载均衡机制
    本文详细介绍了Spring Cloud中的Ribbon组件如何实现服务调用的负载均衡。通过分析其工作原理、源码结构及配置方式,帮助读者理解Ribbon在分布式系统中的重要作用。 ... [详细]
  • 本文深入探讨了 Java 中的 Serializable 接口,解释了其实现机制、用途及注意事项,帮助开发者更好地理解和使用序列化功能。 ... [详细]
  • 本文详细介绍了Java编程语言中的核心概念和常见面试问题,包括集合类、数据结构、线程处理、Java虚拟机(JVM)、HTTP协议以及Git操作等方面的内容。通过深入分析每个主题,帮助读者更好地理解Java的关键特性和最佳实践。 ... [详细]
  • XNA 3.0 游戏编程:从 XML 文件加载数据
    本文介绍如何在 XNA 3.0 游戏项目中从 XML 文件加载数据。我们将探讨如何将 XML 数据序列化为二进制文件,并通过内容管道加载到游戏中。此外,还会涉及自定义类型读取器和写入器的实现。 ... [详细]
  • 本文深入探讨了Linux系统中网卡绑定(bonding)的七种工作模式。网卡绑定技术通过将多个物理网卡组合成一个逻辑网卡,实现网络冗余、带宽聚合和负载均衡,在生产环境中广泛应用。文章详细介绍了每种模式的特点、适用场景及配置方法。 ... [详细]
  • 2023年京东Android面试真题解析与经验分享
    本文由一位拥有6年Android开发经验的工程师撰写,详细解析了京东面试中常见的技术问题。涵盖引用传递、Handler机制、ListView优化、多线程控制及ANR处理等核心知识点。 ... [详细]
  • 本文介绍了在Windows环境下使用pydoc工具的方法,并详细解释了如何通过命令行和浏览器查看Python内置函数的文档。此外,还提供了关于raw_input和open函数的具体用法和功能说明。 ... [详细]
  • 本文详细介绍了中央电视台电影频道的节目预告,并通过专业工具分析了其加载方式,确保用户能够获取最准确的电视节目信息。 ... [详细]
  • 本文详细探讨了JDBC(Java数据库连接)的内部机制,重点分析其作为服务提供者接口(SPI)框架的应用。通过类图和代码示例,展示了JDBC如何注册驱动程序、建立数据库连接以及执行SQL查询的过程。 ... [详细]
  • 毕业设计:基于机器学习与深度学习的垃圾邮件(短信)分类算法实现
    本文详细介绍了如何使用机器学习和深度学习技术对垃圾邮件和短信进行分类。内容涵盖从数据集介绍、预处理、特征提取到模型训练与评估的完整流程,并提供了具体的代码示例和实验结果。 ... [详细]
author-avatar
死了才能爱_403
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有