热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

使用迁移学习(TransferLearning)完成图像的多标签分类(MultiLabel)任务sunwq06

使用迁移学习(TransferLearning)完成图像的多标签分类(Multi-Label)任务本文通过迁移学习将训练好的模型应用到图像的多标签分类问题中本文通过迁移学习将训练好

使用迁移学习(Transfer Learning)完成图像的多标签分类(Multi-Label)任务


本文通过迁移学习将训练好的模型应用到图像的多标签分类问题中

本文通过迁移学习将训练好的VGG16模型应用到图像的多标签分类问题中。该项目数据来自于Kaggle,每张图片可同时属于多个标签。模型的准确度使用F score进行量化,如下表所示:




















标签预测为Positive(1)预测为Negative(0)
真值为Positive(1)TPFN
真值为Negative(0)FPTN

例如真实标签是(1,0,1,1,0,0), 预测标签是(1,1,0,1,1,0), 则TP=2, FN=1, FP=2, TN=1。$$Precision=\frac{TP}{TP+FP},\text{  }Recall=\frac{TP}{TP+FN},\text{  }F{\_}score=\frac{(1+\beta^2)*Presicion*Recall}{Recall+\beta^2*Precision}$$其中$\beta$越小,F score中Precision的权重越大,$\beta$等于0时F score就变为Precision;$\beta$越大,F score中Recall的权重越大,$\beta$趋于无穷大时F score就变为Recall。可以在Keras中自定义该函数(y_pred表示预测概率):

from tensorflow.keras import backend

# calculate fbeta score for multi-label classification
def fbeta(y_true, y_pred, beta=2):
# clip predictions
y_pred = backend.clip(y_pred, 0, 1)
# calculate elements for each sample
tp = backend.sum(backend.round(backend.clip(y_true * y_pred, 0, 1)), axis=1)
fp
= backend.sum(backend.round(backend.clip(y_pred - y_true, 0, 1)), axis=1)
fn
= backend.sum(backend.round(backend.clip(y_true - y_pred, 0, 1)), axis=1)
# calculate precision
p = tp / (tp + fp + backend.epsilon())
# calculate recall
r = tp / (tp + fn + backend.epsilon())
# calculate fbeta, averaged across samples
bb = beta ** 2
fbeta_score
= backend.mean((1 + bb) * (p * r) / (bb * p + r + backend.epsilon()))
return fbeta_score

此外在损失函数的使用上多标签分类和多类别(multi-class)分类也有区别,多标签分类使用binary_crossentropy,假设一个样本的真实标签是(1,0,1,1,0,0),预测概率是(0.2, 0.3, 0.4, 0.7, 0.9, 0.2): $$binary{\_}crossentropy\text{  }loss=-(\ln 0.2 + \ln 0.7 + \ln 0.4 + \ln 0.7 + \ln 0.1 + \ln 0.8)/6=0.96$$另外多标签分类输出层的激活函数选择sigmoid而非softmax。模型架构如下所示:

from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.models import Model
def define_model(in_shape=(128, 128, 3), out_shape=17):
# load model
base_model = VGG16(weights=\'imagenet\', include_top=False, input_shape=in_shape)
# mark loaded layers as not trainable
for layer in base_model.layers: layer.trainable = False
# make the last block trainable
tune_layers = [layer.name for layer in base_model.layers if layer.name.startswith(\'block5_\')]
for layer_name in tune_layers: base_model.get_layer(layer_name).trainable = True
# add new classifier layers
flat1 = Flatten()(base_model.layers[-1].output)
class1
= Dense(128, activation=\'relu\', kernel_initializer=\'he_uniform\')(flat1)
output
= Dense(out_shape, activation=\'sigmoid\')(class1)
# define new model
model = Model(inputs=base_model.input, outputs=output)
# compile model
opt = Adam(learning_rate=1e-3)
model.compile(optimizer
=opt, loss=\'binary_crossentropy\', metrics=[fbeta])
model.summary()
return model

从Kaggle网站上下载数据并解压,将其处理成可被模型读取的数据格式

from os import listdir
from numpy import zeros, asarray, savez_compressed
from pandas import read_csv
from tensorflow.keras.preprocessing.image import load_img, img_to_array
# create a mapping of tags to integers given the loaded mapping file
def create_tag_mapping(mapping_csv):
labels
= set() # create a set of all known tags
for i in range(len(mapping_csv)):
tags
= mapping_csv[\'tags\'][i].split(\' \') # convert spaced separated tags into an array of tags
labels.update(tags) # add tags to the set of known labels
labels = sorted(list(labels)) # convert set of labels to a sorted list
# dict that maps labels to integers, and the reverse
labels_map = {labels[i]:i for i in range(len(labels))}
inv_labels_map
= {i:labels[i] for i in range(len(labels))}
return labels_map, inv_labels_map
# create a mapping of filename to a list of tags
def create_file_mapping(mapping_csv):
mapping
= dict()
for i in range(len(mapping_csv)):
name, tags
= mapping_csv[\'image_name\'][i], mapping_csv[\'tags\'][i]
mapping[name]
= tags.split(\' \')
return mapping
# create a one hot encoding for one list of tags
def one_hot_encode(tags, mapping):
encoding
= zeros(len(mapping), dtype=\'uint8\') # create empty vector
# mark 1 for each tag in the vector
for tag in tags: encoding[mapping[tag]] = 1
return encoding
# load all images into memory
def load_dataset(path, file_mapping, tag_mapping):
photos, targets
= list(), list()
# enumerate files in the directory
for filename in listdir(path):
photo
= load_img(path + filename, target_size=(128,128)) # load image
photo = img_to_array(photo, dtype=\'uint8\') # convert to numpy array
tags = file_mapping[filename[:-4]] # get tags
target = one_hot_encode(tags, tag_mapping) # one hot encode tags
photos.append(photo)
targets.append(target)
X
= asarray(photos, dtype=\'uint8\')
y
= asarray(targets, dtype=\'uint8\')
return X, y
filename
= \'train_v2.csv\' # load the target file
mapping_csv = read_csv(filename)
tag_mapping, _
= create_tag_mapping(mapping_csv) # create a mapping of tags to integers
file_mapping = create_file_mapping(mapping_csv) # create a mapping of filenames to tag lists
folder = \'train-jpg/\' # load the jpeg images
X, y = load_dataset(folder, file_mapping, tag_mapping)
print(X.shape, y.shape)
savez_compressed(
\'planet_data.npz\', X, y) # save both arrays to one file in compressed format

View Code

接下来再建立两个辅助函数,第一个函数用来分割训练集和验证集,第二个函数用来画出模型在训练过程中的学习曲线

import numpy as np
from matplotlib import pyplot
from sklearn.model_selection import train_test_split
# load train and test dataset
def load_dataset():
# load dataset
data = np.load(\'planet_data.npz\')
X, y
= data[\'arr_0\'], data[\'arr_1\']
# separate into train and test datasets
trainX, testX, trainY, testY = train_test_split(X, y, test_size=0.3, random_state=1)
print(trainX.shape, trainY.shape, testX.shape, testY.shape)
return trainX, trainY, testX, testY
# plot diagnostic learning curves
def summarize_diagnostics(history):
# plot loss
pyplot.subplot(121)
pyplot.title(
\'Cross Entropy Loss\')
pyplot.plot(history.history[
\'loss\'], color=\'blue\', label=\'train\')
pyplot.plot(history.history[
\'val_loss\'], color=\'orange\', label=\'test\')
# plot accuracy
pyplot.subplot(122)
pyplot.title(
\'Fbeta\')
pyplot.plot(history.history[
\'fbeta\'], color=\'blue\', label=\'train\')
pyplot.plot(history.history[
\'val_fbeta\'], color=\'orange\', label=\'test\')
pyplot.show()

View Code

使用数据扩充技术(Data Augmentation)对模型进行训练

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.callbacks import ModelCheckpoint
trainX, trainY, testX, testY
= load_dataset() # load dataset
#
create data generator using augmentation
#
vertical flip is reasonable since the pictures are satellite images
train_datagen = ImageDataGenerator(horizontal_flip=True, vertical_flip=True, rotation_range=90, preprocessing_function=preprocess_input)
test_datagen
= ImageDataGenerator(preprocessing_function=preprocess_input)
# prepare generators
train_it = train_datagen.flow(trainX, trainY, batch_size=128)
test_it
= test_datagen.flow(testX, testY, batch_size=128)
# define model
model = define_model()
# fit model
#
When one epoch ends, the validation generator will yield validation_steps batches, then average the evaluation results of all batches
checkpointer = ModelCheckpoint(filepath=\'./weights.best.vgg16.hdf5\', verbose=1, save_best_Only=True)
history
= model.fit_generator(train_it, steps_per_epoch=len(train_it), validation_data=test_it, validation_steps=len(test_it), \
epochs
=15, callbacks=[checkpointer], verbose=0)
# evaluate optimal model
#
For simplicity, the validation set is used to test the model here. In fact an entirely new test set should have been used.
model.load_weights(\'./weights.best.vgg16.hdf5\') #load stored optimal coefficients
loss, fbeta = model.evaluate_generator(test_it, steps=len(test_it), verbose=0)
print(\'> loss=%.3f, fbeta=%.3f\' % (loss, fbeta)) # loss=0.108, fbeta=0.884
model.save(\'final_model.h5\')
# learning curves
summarize_diagnostics(history)

 蓝线代表训练集,黄线代表验证集



推荐阅读
  • 2019年后蚂蚁集团与拼多多面试经验详述与深度剖析
    2019年后蚂蚁集团与拼多多面试经验详述与深度剖析 ... [详细]
  • Go语言实现Redis客户端与服务器的交互机制深入解析
    在前文对Godis v1.0版本的基础功能进行了详细介绍后,本文将重点探讨如何实现客户端与服务器之间的交互机制。通过具体代码实现,使客户端与服务器能够顺利通信,赋予项目实际运行的能力。本文将详细解析Go语言在实现这一过程中的关键技术和实现细节,帮助读者深入了解Redis客户端与服务器的交互原理。 ... [详细]
  • 深入解析十大经典排序算法:动画演示、原理分析与代码实现
    本文深入探讨了十种经典的排序算法,不仅通过动画直观展示了每种算法的运行过程,还详细解析了其背后的原理与机制,并提供了相应的代码实现,帮助读者全面理解和掌握这些算法的核心要点。 ... [详细]
  • 深入解析 UIImageView 与 UIImage 的关键细节与应用技巧
    本文深入探讨了 UIImageView 和 UIImage 的核心特性及应用技巧。首先,详细介绍了如何在 UIImageView 中实现动画效果,包括创建和配置 UIImageView 实例的具体步骤。此外,还探讨了 UIImage 的加载方式及其对性能的影响,提供了优化图像显示和内存管理的有效方法。通过实例代码和实际应用场景,帮助开发者更好地理解和掌握这两个重要类的使用技巧。 ... [详细]
  • 如何使用 net.sf.extjwnl.data.Word 类及其代码示例详解 ... [详细]
  • 计算 n 叉树中各节点子树的叶节点数量分析 ... [详细]
  • 深入解析 Django 中用户模型的自定义方法与技巧 ... [详细]
  • 在CentOS上部署和配置FreeSWITCH
    在CentOS系统上部署和配置FreeSWITCH的过程涉及多个步骤。本文详细介绍了从源代码安装FreeSWITCH的方法,包括必要的依赖项安装、编译和配置过程。此外,还提供了常见的配置选项和故障排除技巧,帮助用户顺利完成部署并确保系统的稳定运行。 ... [详细]
  • Spring Boot 实战(一):基础的CRUD操作详解
    在《Spring Boot 实战(一)》中,详细介绍了基础的CRUD操作,涵盖创建、读取、更新和删除等核心功能,适合初学者快速掌握Spring Boot框架的应用开发技巧。 ... [详细]
  • 在Spring与Ibatis集成的环境中,通过Spring AOP配置事务管理至服务层。当在一个服务方法中引入自定义多线程时,发现事务管理功能失效。若不使用多线程,事务管理则能正常工作。本文深入分析了这一现象背后的潜在风险,并探讨了可能的解决方案,以确保事务一致性和线程安全。 ... [详细]
  • 本文详细介绍了如何在Linux系统中搭建51单片机的开发与编程环境,重点讲解了使用Makefile进行项目管理的方法。首先,文章指导读者安装SDCC(Small Device C Compiler),这是一个专为小型设备设计的C语言编译器,适合用于51单片机的开发。随后,通过具体的实例演示了如何配置Makefile文件,以实现代码的自动化编译与链接过程,从而提高开发效率。此外,还提供了常见问题的解决方案及优化建议,帮助开发者快速上手并解决实际开发中可能遇到的技术难题。 ... [详细]
  • MySQL性能优化与调参指南【数据库管理】
    本文详细探讨了MySQL数据库的性能优化与参数调整技巧,旨在帮助数据库管理员和开发人员提升系统的运行效率。内容涵盖索引优化、查询优化、配置参数调整等方面,结合实际案例进行深入分析,提供实用的操作建议。此外,还介绍了常见的性能监控工具和方法,助力读者全面掌握MySQL性能优化的核心技能。 ... [详细]
  • voc生成xml 代码
    目录 lxmlwindows安装 读取示例 可视化 生成示例 上面是代码,下面有调用示例 api调用代码,其实只有几行:这个生成代码也很简 ... [详细]
  • 本文介绍了实现链表数据结构的方法与技巧,通过定义一个 `MyLinkedList` 类来管理链表节点。该类包含三个主要属性:`first` 用于指向链表的第一个节点,`last` 用于指向链表的最后一个节点,以及 `size` 用于记录链表中节点的数量。此外,还详细探讨了如何通过这些属性高效地进行链表的操作,如插入、删除和查找等。 ... [详细]
  • 利用ViewComponents在Asp.Net Core中构建高效分页组件
    通过运用 ViewComponents 技术,在 Asp.Net Core 中实现了高效的分页组件开发。本文详细介绍了如何通过创建 `PaginationViewComponent` 类并利用 `HelloWorld.DataContext` 上下文,实现对分页参数的定义与管理,从而提升 Web 应用程序的性能和用户体验。 ... [详细]
author-avatar
与幸福约定2502895163
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有