热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

python怎么使用预训练的模型_Keras使用ImageNet上预训练的模型方式

我就废话不多说了,大家还是直接看代码吧!importkerasimportnumpyasnpfromkeras.applicationsimpor

我就废话不多说了,大家还是直接看代码吧!

import keras

import numpy as np

from keras.applications import vgg16, inception_v3, resnet50, mobilenet

#Load the VGG model

vgg_model = vgg16.VGG16(weights='imagenet')

#Load the Inception_V3 model

inception_model = inception_v3.InceptionV3(weights='imagenet')

#Load the ResNet50 model

resnet_model = resnet50.ResNet50(weights='imagenet')

#Load the MobileNet model

mobilenet_model = mobilenet.MobileNet(weights='imagenet')

在以上代码中,我们首先import各种模型对应的module,然后load模型,并用ImageNet的参数初始化模型的参数。

如果不想使用ImageNet上预训练到的权重初始话模型,可以将各语句的中'imagenet'替换为'None'。

补充知识:keras上使用alexnet模型来高准确度对mnist数据进行分类

纲要

本文有两个特点:一是直接对本地mnist数据进行读取(假设事先已经下载或从别处拷来)二是基于keras框架(网上多是基于tf)使用alexnet对mnist数据进行分类,并获得较高准确度(约为98%)

本地数据读取和分析

很多代码都是一开始简单调用一行代码来从网站上下载mnist数据,虽然只有10来MB,但是现在下载速度非常慢,而且经常中途出错,要费很大的劲才能拿到数据。

(X_train, y_train), (X_test, y_test) = mnist.load_data()

其实可以单独来获得这些数据(一共4个gz包,如下所示),然后调用别的接口来分析它们。

20200523150647.jpg

mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #导入已经下载好的数据集,"./MNIST_data"为存放mnist数据的目录

x_train = mnist.train.images

y_train = mnist.train.labels

x_test = mnist.test.images

y_test = mnist.test.labels

这里面要注意的是,两种接口拿到的数据形式是不一样的。 从网上直接下载下来的数据 其image data值的范围是0~255,且label值为0,1,2,3...9。 而第二种接口获取的数据 image值已经除以255(归一化)变成0~1范围,且label值已经是one-hot形式(one_hot=True时),比如label值2的one-hot code为(0 0 1 0 0 0 0 0 0 0)

所以,以第一种方式获取的数据需要做一些预处理(归一和one-hot)才能输入网络模型进行训练 而第二种接口拿到的数据则可以直接进行训练。

Alexnet模型的微调

按照公开的模型框架,Alexnet只有第1、2个卷积层才跟着BatchNormalization,后面三个CNN都没有(如有说错,请指正)。如果按照这个来搭建网络模型,很容易导致梯度消失,现象就是 accuracy值一直处在很低的值。 如下所示。

20200523150705.jpg

在每个卷积层后面都加上BN后,准确度才迭代提高。如下所示

20200523150719.jpg

完整代码

import keras

from keras.datasets import mnist

from keras.models import Sequential

from keras.layers.core import Dense, Dropout, Activation, Flatten

from keras.layers.convolutional import Conv2D, MaxPooling2D, ZeroPadding2D

from keras.layers.normalization import BatchNormalization

from keras.callbacks import ModelCheckpoint

import numpy as np

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data #tensorflow已经包含了mnist案例的数据

batch_size = 64

num_classes = 10

epochs = 10

img_shape = (28,28,1)

# input dimensions

img_rows, img_cols = 28,28

# dataset input

#(x_train, y_train), (x_test, y_test) = mnist.load_data()

mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #导入已经下载好的数据集,"./MNIST_data"为存放mnist数据的目录

print(mnist.train.images.shape, mnist.train.labels.shape)

print(mnist.test.images.shape, mnist.test.labels.shape)

print(mnist.validation.images.shape, mnist.validation.labels.shape)

x_train = mnist.train.images

y_train = mnist.train.labels

x_test = mnist.test.images

y_test = mnist.test.labels

# data initialization

x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)

x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)

input_shape = (img_rows, img_cols, 1)

# Define the input layer

inputs = keras.Input(shape = [img_rows, img_cols, 1])

#Define the converlutional layer 1

conv1 = keras.layers.Conv2D(filters= 64, kernel_size= [11, 11], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(inputs)

# Define the pooling layer 1

pooling1 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv1)

# Define the standardization layer 1

stand1 = keras.layers.BatchNormalization(axis= 1)(pooling1)

# Define the converlutional layer 2

conv2 = keras.layers.Conv2D(filters= 192, kernel_size= [5, 5], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand1)

# Defien the pooling layer 2

pooling2 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv2)

# Define the standardization layer 2

stand2 = keras.layers.BatchNormalization(axis= 1)(pooling2)

# Define the converlutional layer 3

conv3 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand2)

stand3 = keras.layers.BatchNormalization(axis=1)(conv3)

# Define the converlutional layer 4

conv4 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand3)

stand4 = keras.layers.BatchNormalization(axis=1)(conv4)

# Define the converlutional layer 5

conv5 = keras.layers.Conv2D(filters= 256, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand4)

pooling5 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv5)

stand5 = keras.layers.BatchNormalization(axis=1)(pooling5)

# Define the fully connected layer

flatten = keras.layers.Flatten()(stand5)

fc1 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(flatten)

drop1 = keras.layers.Dropout(0.5)(fc1)

fc2 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(drop1)

drop2 = keras.layers.Dropout(0.5)(fc2)

fc3 = keras.layers.Dense(10, activation= keras.activations.softmax, use_bias= True)(drop2)

# 基于Model方法构建模型

model = keras.Model(inputs= inputs, outputs = fc3)

# 编译模型

model.compile(optimizer= tf.train.AdamOptimizer(0.001),

loss= keras.losses.categorical_crossentropy,

metrics= ['accuracy'])

# 训练配置,仅供参考

model.fit(x_train, y_train, batch_size= batch_size, epochs= epochs, validation_data=(x_test,y_test))

以上这篇Keras使用ImageNet上预训练的模型方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。



推荐阅读
  • 一、MATLAB常用的基本数学函数abs(x):纯量的绝对值或向量的长度angle(z):复数z的相角(Phaseangle)sqrt(x)࿱ ... [详细]
  • 目录预备知识导包构建数据集神经网络结构训练测试精度可视化计算模型精度损失可视化输出网络结构信息训练神经网络定义参数载入数据载入神经网络结构、损失及优化训练及测试损失、精度可视化qu ... [详细]
  • Leetcode学习成长记:天池leetcode基础训练营Task01数组
    前言这是本人第一次参加由Datawhale举办的组队学习活动,这个活动每月一次,之前也一直关注,但未亲身参与过,这次看到活动 ... [详细]
  • 2020年9月15日,Oracle正式发布了最新的JDK 15版本。本次更新带来了许多新特性,包括隐藏类、EdDSA签名算法、模式匹配、记录类、封闭类和文本块等。 ... [详细]
  • 本文介绍如何使用OpenCV和线性支持向量机(SVM)模型来开发一个简单的人脸识别系统,特别关注在只有一个用户数据集时的处理方法。 ... [详细]
  • [转]doc,ppt,xls文件格式转PDF格式http:blog.csdn.netlee353086articledetails7920355确实好用。需要注意的是#import ... [详细]
  • php更新数据库字段的函数是,php更新数据库字段的函数是 ... [详细]
  • poj 3352 Road Construction ... [详细]
  • 利用REM实现移动端布局的高效适配技巧
    在移动设备上实现高效布局适配时,使用rem单位已成为一种流行且有效的技术。本文将分享过去一年中使用rem进行布局适配的经验和心得。rem作为一种相对单位,能够根据根元素的字体大小动态调整,从而确保不同屏幕尺寸下的布局一致性。通过合理设置根元素的字体大小,开发者可以轻松实现响应式设计,提高用户体验。此外,文章还将探讨一些常见的问题和解决方案,帮助开发者更好地掌握这一技术。 ... [详细]
  • 大类|电阻器_使用Requests、Etree、BeautifulSoup、Pandas和Path库进行数据抓取与处理 | 将指定区域内容保存为HTML和Excel格式
    大类|电阻器_使用Requests、Etree、BeautifulSoup、Pandas和Path库进行数据抓取与处理 | 将指定区域内容保存为HTML和Excel格式 ... [详细]
  • 在《Cocos2d-x学习笔记:基础概念解析与内存管理机制深入探讨》中,详细介绍了Cocos2d-x的基础概念,并深入分析了其内存管理机制。特别是针对Boost库引入的智能指针管理方法进行了详细的讲解,例如在处理鱼的运动过程中,可以通过编写自定义函数来动态计算角度变化,利用CallFunc回调机制实现高效的游戏逻辑控制。此外,文章还探讨了如何通过智能指针优化资源管理和避免内存泄漏,为开发者提供了实用的编程技巧和最佳实践。 ... [详细]
  • 在Django中提交表单时遇到值错误问题如何解决?
    在Django项目中,当用户提交包含多个选择目标的表单时,可能会遇到值错误问题。本文将探讨如何通过优化表单处理逻辑和验证机制来有效解决这一问题,确保表单数据的准确性和完整性。 ... [详细]
  • 本文详细解析了 Python 2.x 版本中 `urllib` 模块的核心功能与应用实例,重点介绍了 `urlopen()` 和 `urlretrieve()` 方法的使用技巧。其中,`urlopen()` 方法用于发送网络请求并获取响应内容,而 `urlretrieve()` 方法则用于下载文件并保存到本地。文章通过具体示例展示了这两个方法在实际开发中的应用场景,帮助读者更好地理解和掌握 `urllib` 模块的使用。 ... [详细]
  • 在Python编程中,探讨了并发与并行的概念及其区别。并发指的是系统同时处理多个任务的能力,而并行则指在同一时间点上并行执行多个任务。文章详细解析了阻塞与非阻塞操作、同步与异步编程模型,以及IO多路复用技术的应用。通过模拟socket发送HTTP请求的过程,展示了如何创建连接、发送数据和接收响应,并强调了默认情况下socket的阻塞特性。此外,还介绍了如何利用这些技术优化网络通信性能和提高程序效率。 ... [详细]
  • 深入解析经典卷积神经网络及其实现代码
    深入解析经典卷积神经网络及其实现代码 ... [详细]
author-avatar
手机用户2502896257
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有