热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

139、TensorFlowServing实现模型的部署(二)TextCnn文本分类模型

昨晚终于实现了Tensorflow模型的部署使用TensorFlowServing 1、使用Docker获取TensorflowServing的镜像,Docker在国内的需要将镜像

昨晚终于实现了Tensorflow模型的部署 使用TensorFlow Serving

 

1、使用Docker 获取Tensorflow Serving的镜像,Docker在国内的需要将镜像的Repository地址设置为阿里云的加速地址,这个大家可以自己去CSDN上面找

然后启动docker

2、使用Tensorflow 的 SaveModelBuilder保存Tensorflow的计算图模型,并且设置Signature,

Signature主要用来标识模型的输入值的名称和类型

        builder = saved_model_builder.SavedModelBuilder(export_path)
        
        
        classification_inputs = utils.build_tensor_info(cnn.input_x)
        classification_dropout_keep_prob = utils.build_tensor_info(cnn.dropout_keep_prob)
        classification_outputs_classes = utils.build_tensor_info(prediction_classes)
        classification_outputs_scores = utils.build_tensor_info(cnn.scores)

   
        classification_signature = signature_def_utils.build_signature_def(
        inputs={signature_constants.CLASSIFY_INPUTS: classification_inputs,
                     signature_constants.CLASSIFY_INPUTS:classification_dropout_keep_prob
                     },
        outputs={
              signature_constants.CLASSIFY_OUTPUT_CLASSES:
              classification_outputs_classes,
              signature_constants.CLASSIFY_OUTPUT_SCORES:
              classification_outputs_scores
         },
         method_name=signature_constants.CLASSIFY_METHOD_NAME)

        tensor_info_x = utils.build_tensor_info(cnn.input_x)
        tensor_info_y = utils.build_tensor_info(cnn.predictions)
        tensor_info_dropout_keep_prob = utils.build_tensor_info(cnn.dropout_keep_prob)

        prediction_signature = signature_def_utils.build_signature_def(
               inputs={'inputX': tensor_info_x,
                            'input_dropout_keep_prob':tensor_info_dropout_keep_prob},
               outputs={'predictClass': tensor_info_y},
        method_name=signature_constants.PREDICT_METHOD_NAME)

        legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
  
        #add the sigs to the servable
        builder.add_meta_graph_and_variables(
                sess, [tag_constants.SERVING],
                signature_def_map={
                    'textclassified':
                    prediction_signature,
                    signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                    classification_signature,
         },
         legacy_init_op=legacy_init_op)
         #save it!
        builder.save(True)

保存之后的计算图的结构可以从下面这里看见,下面这里只给出模型的signature部分,因为signature里面定义了你到时候call restful接口的参数名称和类型

  signature_def {
    key: "serving_default"
    value {
      inputs {
        key: "inputs"
        value {
          name: "dropout_keep_prob:0"
          dtype: DT_FLOAT
          tensor_shape {
            unknown_rank: true
          }
        }
      }
      outputs {
        key: "classes"
        value {
          name: "index_to_string_Lookup:0"
          dtype: DT_STRING
          tensor_shape {
            dim {
              size: 1
            }
          }
        }
      }
      outputs {
        key: "scores"
        value {
          name: "output/scores:0"
          dtype: DT_FLOAT
          tensor_shape {
            dim {
              size: -1
            }
            dim {
              size: 2
            }
          }
        }
      }
      method_name: "tensorflow/serving/classify"
    }
  }
  signature_def {
    key: "textclassified"
    value {
      inputs {
        key: "inputX"
        value {
          name: "input_x:0"
          dtype: DT_INT32
          tensor_shape {
            dim {
              size: -1
            }
            dim {
              size: 40
            }
          }
        }
      }
      inputs {
        key: "input_dropout_keep_prob"
        value {
          name: "dropout_keep_prob:0"
          dtype: DT_FLOAT
          tensor_shape {
            unknown_rank: true
          }
        }
      }
      outputs {
        key: "predictClass"
        value {
          name: "output/predictions:0"
          dtype: DT_INT64
          tensor_shape {
            dim {
              size: -1
            }
          }
        }
      }
      method_name: "tensorflow/serving/predict"
    }
  }
}

从上面的Signature定义可以看出 到时候call restfull 接口需要传两个参数,

int32类型的名称为inputX参数

float类型名称为input_drop_out_keep_prob的参数

 

上面的protocol buffer 中

textclassified表示使用TextCnn卷积神经网络来进行预测,然后预测功能的名称叫做textclassified

 

 3、将模型部署到Tensorflow Serving 上面

首先把模型通过工具传输到docker上面

模型的结构如下

139、TensorFlow Serving 实现模型的部署(二)  TextCnn文本分类模型

 

 传到docker上面,然后在外边套一个文件夹名字起位模型的名字,叫做

 

text_classified_model
然后执行下面这条命令运行tensorflow/serving
docker run -p 8500:8500 --mount type=bind,source=/home/docker/model/text_classified_model,target=/mo
dels/text_classified_model -e MODEL_NAME=text_classified_model -t tensorflow/serving

 

source表示模型在docker上面的路径
target表示模型在docker中TensorFlow/serving container上面的路径

 然后输入如下所示

139、TensorFlow Serving 实现模型的部署(二)  TextCnn文本分类模型

 

上面显示运行了两个接口一个是REST API 接口,端口是8501

另一个是gRPC接口端口是8500

gRPC是HTTP/2协议,REST API 是HTTP/1协议

区别是gRPC只有POST/GET两种请求方式

REST API还有其余很多种 列如 PUT/DELETE 等

 

 

4、客户端调用gPRC接口

 

需要传两个参数,

一个是

 

inputX

另一个是

input_dropout_keep_prob
'''
Created on 2018年10月17日

@author: 95890
'''

"""Send text to tensorflow serving and gets result
"""


# This is a placeholder for a Google-internal import.

from grpc.beta import implementations
import tensorflow as tf
import data_helpers
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2
from tensorflow.contrib import learn
import numpy as np


tf.flags.DEFINE_string("positive_data_file", "./data/rt-polaritydata/rt-polarity.pos", "Data source for the positive data.")
tf.flags.DEFINE_string("negative_data_file", "./data/rt-polaritydata/rt-polarity.neg", "Data source for the negative data.")
tf.flags.DEFINE_string('server', '192.168.99.100:8500',
                           'PredictionService host:port')
FLAGS = tf.flags.FLAGS

x_text=[]
y=[]
max_document_length=40


def main(_):


  testStr =["wisegirls is its low-key quality and genuine"]

  
  if x_text.__len__()==0:
      x_text, y = data_helpers.load_data_and_labels(FLAGS.positive_data_file, FLAGS.negative_data_file)
      max_document_length = max([len(x.split(" ")) for x in x_text])

  vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length)
  vocab_processor.fit(x_text)
  x = np.array(list(vocab_processor.fit_transform(testStr)))
  
  host, port = FLAGS.server.split(':')
  channel = implementations.insecure_channel(host, int(port))
  stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
  request = predict_pb2.PredictRequest()
  request.model_spec.name = "text_classified_model"
  request.model_spec.signature_name = 'textclassified'
  dropout_keep_prob = np.float(1.0)
  
  request.inputs['inputX'].CopyFrom(
  tf.contrib.util.make_tensor_proto(x, shape=[1,40],dtype=np.int32))
  
  request.inputs['input_dropout_keep_prob'].CopyFrom(
  tf.contrib.util.make_tensor_proto(dropout_keep_prob, shape=[1],dtype=np.float))
  
  result = stub.Predict(request, 10.0)  # 10 secs timeout
  print(result)


if __name__ == '__main__':
  tf.app.run()

调用的结果如下所示

outputs {
  key: "predictClass"
  value {
    dtype: DT_INT64
    tensor_shape {
      dim {
        size: 1
      }
    }
    int64_val: 1
  }
}
model_spec {
  name: "text_classified_model"
  version {
    value: 1
  }
  signature_name: "textclassified"
}

从上面的结果可以看出,我们传入了一句话

wisegirls is its low-key quality and genuine

分类的结果

predictClass
int64_val: 1

分成第一类

 

这个真的是神经网络的部署呀。

 

 

啦啦啦 ,  Tensorflow真的很牛,上至浏览器,下到手机,一次训练,一次导出。处处运行。

没有不敢想,只有不敢做

 

 

 The Full version can be find here

https://github.com/weizhenzhao/TextCNN_Tensorflow_Serving/tree/master

 

Thanks

WeiZhen


推荐阅读
  • 个人学习使用:谨慎参考1Client类importcom.thoughtworks.gauge.Step;importcom.thoughtworks.gauge.T ... [详细]
  • 海马s5近光灯能否直接更换为H7?
    本文主要介绍了海马s5车型的近光灯是否可以直接更换为H7灯泡,并提供了完整的教程下载地址。此外,还详细讲解了DSP功能函数中的数据拷贝、数据填充和浮点数转换为定点数的相关内容。 ... [详细]
  • PDO MySQL
    PDOMySQL如果文章有成千上万篇,该怎样保存?数据保存有多种方式,比如单机文件、单机数据库(SQLite)、网络数据库(MySQL、MariaDB)等等。根据项目来选择,做We ... [详细]
  • 本文介绍了在iOS开发中使用UITextField实现字符限制的方法,包括利用代理方法和使用BNTextField-Limit库的实现策略。通过这些方法,开发者可以方便地限制UITextField的字符个数和输入规则。 ... [详细]
  • 本文介绍了Swing组件的用法,重点讲解了图标接口的定义和创建方法。图标接口用来将图标与各种组件相关联,可以是简单的绘画或使用磁盘上的GIF格式图像。文章详细介绍了图标接口的属性和绘制方法,并给出了一个菱形图标的实现示例。该示例可以配置图标的尺寸、颜色和填充状态。 ... [详细]
  • Android工程师面试准备及设计模式使用场景
    本文介绍了Android工程师面试准备的经验,包括面试流程和重点准备内容。同时,还介绍了建造者模式的使用场景,以及在Android开发中的具体应用。 ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • Spring源码解密之默认标签的解析方式分析
    本文分析了Spring源码解密中默认标签的解析方式。通过对命名空间的判断,区分默认命名空间和自定义命名空间,并采用不同的解析方式。其中,bean标签的解析最为复杂和重要。 ... [详细]
  • 本文介绍了在开发Android新闻App时,搭建本地服务器的步骤。通过使用XAMPP软件,可以一键式搭建起开发环境,包括Apache、MySQL、PHP、PERL。在本地服务器上新建数据库和表,并设置相应的属性。最后,给出了创建new表的SQL语句。这个教程适合初学者参考。 ... [详细]
  • 本文介绍了在rhel5.5操作系统下搭建网关+LAMP+postfix+dhcp的步骤和配置方法。通过配置dhcp自动分配ip、实现外网访问公司网站、内网收发邮件、内网上网以及SNAT转换等功能。详细介绍了安装dhcp和配置相关文件的步骤,并提供了相关的命令和配置示例。 ... [详细]
  • 阿,里,云,物,联网,net,core,客户端,czgl,aliiotclient, ... [详细]
  • 本文详细介绍了如何使用MySQL来显示SQL语句的执行时间,并通过MySQL Query Profiler获取CPU和内存使用量以及系统锁和表锁的时间。同时介绍了效能分析的三种方法:瓶颈分析、工作负载分析和基于比率的分析。 ... [详细]
  • 本文介绍了绕过WAF的XSS检测机制的方法,包括确定payload结构、测试和混淆。同时提出了一种构建XSS payload的方法,该payload与安全机制使用的正则表达式不匹配。通过清理用户输入、转义输出、使用文档对象模型(DOM)接收器和源、实施适当的跨域资源共享(CORS)策略和其他安全策略,可以有效阻止XSS漏洞。但是,WAF或自定义过滤器仍然被广泛使用来增加安全性。本文的方法可以绕过这种安全机制,构建与正则表达式不匹配的XSS payload。 ... [详细]
  • 本文讨论了在使用Git进行版本控制时,如何提供类似CVS中自动增加版本号的功能。作者介绍了Git中的其他版本表示方式,如git describe命令,并提供了使用这些表示方式来确定文件更新情况的示例。此外,文章还介绍了启用$Id:$功能的方法,并讨论了一些开发者在使用Git时的需求和使用场景。 ... [详细]
  • 1223  drf引入以及restful规范
    [toc]前后台的数据交互前台安装axios插件,进行与后台的数据交互安装axios,并在main.js中设置params传递拼接参数data携带数据包参数headers中发送头部 ... [详细]
author-avatar
回忆回忆194567
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有