热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

tensorflowC++接口调用目标检测pb模型代码

CMakeLists.txt内容如下目

#include
#include
"tensorflow/cc/ops/const_op.h"
#include
"tensorflow/cc/ops/image_ops.h"
#include
"tensorflow/cc/ops/standard_ops.h"
#include
"tensorflow/core/framework/graph.pb.h"
#include
"tensorflow/core/framework/tensor.h"
#include
"tensorflow/core/graph/default_device.h"
#include
"tensorflow/core/graph/graph_def_builder.h"
#include
"tensorflow/core/lib/core/errors.h"
#include
"tensorflow/core/lib/core/stringpiece.h"
#include
"tensorflow/core/lib/core/threadpool.h"
#include
"tensorflow/core/lib/io/path.h"
#include
"tensorflow/core/lib/strings/stringprintf.h"
#include
"tensorflow/core/platform/env.h"
#include
"tensorflow/core/platform/init_main.h"
#include
"tensorflow/core/platform/logging.h"
#include
"tensorflow/core/platform/types.h"
#include
"tensorflow/core/public/session.h"
#include
"tensorflow/core/util/command_line_flags.h"
#include

#include

#include

#include

#include

using namespace std;
using namespace cv;
using namespace tensorflow;

// 定义一个函数讲OpenCV的Mat数据转化为tensor,python里面只要对cv2.read读进来的矩阵进行np.reshape之后,
// 数据类型就成了一个tensor,即tensor与矩阵一样,然后就可以输入到网络的入口了,但是C++版本,我们网络开放的入口
// 也需要将输入图片转化成一个tensor,所以如果用OpenCV读取图片的话,就是一个Mat,然后就要考虑怎么将Mat转化为
// Tensor了
void CVMat_to_Tensor(Mat img,Tensor* output_tensor,int input_rows,int input_cols)
{
//imshow("input image",img);
//图像进行resize处理
resize(img,img,cv::Size(input_cols,input_rows));
//imshow("resized image",img);
//归一化
img.convertTo(img,CV_8UC3); // CV_32FC3
//img=1-img/255;
//创建一个指向tensor的内容的指针
uint8 *p = output_tensor->flat().data();
//创建一个Mat,与tensor的指针绑定,改变这个Mat的值,就相当于改变tensor的值
cv::Mat tempMat(input_rows, input_cols, CV_8UC3, p);
img.convertTo(tempMat,CV_8UC3);
// waitKey(0);

}
int main()
{
/*--------------------------------配置关键信息------------------------------*/
string model_path="../model/coco.pb";
string image_path="../test.jpg";
int input_height = 1000;
int input_width = 1000;
string input_tensor_name="image_tensor";
vector
<string> out_put_nodes; //注意,在object detection中输出的三个节点名称为以下三个
out_put_nodes.push_back("detection_scores"); //detection_scores detection_classes detection_boxes
out_put_nodes.push_back("detection_classes");
out_put_nodes.push_back(
"detection_boxes");
/*--------------------------------创建session------------------------------*/
Session
* session;
Status status
= NewSession(SessionOptions(), &session);//创建新会话Session
/*--------------------------------从pb文件中读取模型--------------------------------*/
GraphDef graphdef;
//Graph Definition for current model

Status status_load
= ReadBinaryProto(Env::Default(), model_path, &graphdef); //从pb文件中读取图模型;
if (!status_load.ok()) {
cout
<<"ERROR: Loading model failed..." < std::endl;
cout <"\n";
return -1;
}
Status status_create
= session->Create(graphdef); //将模型导入会话Session中;
if (!status_create.ok()) {
cout
<<"ERROR: Creating graph in session failed..." < std::endl;
return -1;
}
cout
<<"<----Successfully created session and load graph.------->"<< endl;
/*---------------------------------载入测试图片-------------------------------------*/
cout
<"<------------loading test_image-------------->"<<endl;
Mat img;
img
= imread(image_path);
cvtColor(img, img, CV_BGR2RGB);
if(img.empty())
{
cout
<<"can‘t open the image!!!!!!!"<<endl;
return -1;
}
//创建一个tensor作为输入网络的接口
Tensor resized_tensor(DT_UINT8, TensorShape({1,input_height,input_width,3})); //DT_FLOAT
//将Opencv的Mat格式的图片存入tensor
CVMat_to_Tensor(img,&resized_tensor,input_height,input_width);
cout
<endl;
/*-----------------------------------用网络进行测试-----------------------------------------*/
cout
<"<-------------Running the model with test_image--------------->"<<endl;
//前向运行,输出结果一定是一个tensor的vector
vector outputs;
Status status_run
= session->Run({{input_tensor_name, resized_tensor}}, {out_put_nodes}, {}, &outputs);
if (!status_run.ok()) {
cout
<<"ERROR: RUN failed..." << std::endl;
cout
<"\n";
return -1;
}
//把输出值给提取出
cout <<"Output tensor size:" <//3
for (int i = 0; i )
{
cout <// [1, 50], [1, 50], [1, 50, 4]
}
cvtColor(img, img, CV_RGB2BGR);
// opencv读入的是BGR格式输入网络前转为RGB
resize(img,img,cv::Size(1000,1000)); // 模型输入图像大小
int pre_num = outputs[0].dim_size(1); // 50 模型预测的目标数量
auto tmap_pro = outputs[0].tensor<float, 2>(); //第一个是score输出shape为[1,50]
auto tmap_clas = outputs[1].tensor<float, 2>(); //第二个是class输出shape为[1,50]
auto tmap_coor = outputs[2].tensor<float, 3>(); //第三个是coordinate输出shape为[1,50,4]
float probability = 0.5; //自己设定的score阈值
for (int pre_i = 0; pre_i )
{
if (tmap_pro(0, pre_i) < probability)
{
break;
}
cout
<<"Class ID: " <0, pre_i) << endl;
cout
<<"Probability: " <0, pre_i) << endl;
string id = to_string(int(tmap_clas(0, pre_i)));
int xmin = int(tmap_coor(0, pre_i, 1) * input_width);
int ymin = int(tmap_coor(0, pre_i, 0) * input_height);
int xmax = int(tmap_coor(0, pre_i, 3) * input_width);
int ymax = int(tmap_coor(0, pre_i, 2) * input_height);
cout
<<"Xmin is: " < endl;
cout <<"Ymin is: " < endl;
cout <<"Xmax is: " < endl;
cout <<"Ymax is: " < endl;
rectangle(img, cvPoint(xmin, ymin), cvPoint(xmax, ymax), Scalar(255, 0, 0), 1, 1, 0);
putText(img, id, cvPoint(xmin, ymin), FONT_HERSHEY_COMPLEX,
1.0, Scalar(255,0,0), 1);
}
imshow(
"1", img);
cvWaitKey(
0);
return 0;
}

CMakeLists.txt内容如下

cmake_minimum_required(VERSION 3.0.0)
project(tensorflow_cpp)
set(CMAKE_CXX_STANDARD 11)
find_package(OpenCV
3.0 QUIET)
if(NOT OpenCV_FOUND)
find_package(OpenCV
2.4.3 QUIET)
if(NOT OpenCV_FOUND)
message(FATAL_ERROR
"OpenCV > 2.4.3 not found.")
endif()
endif()
set(TENSORFLOW_INCLUDES
/usr/local/include/tf/
/usr/local/include/tf/bazel-genfiles
/usr/local/include/tf/tensorflow/
/usr/local/include/tf/tensorflow/third_party)
set(TENSORFLOW_LIBS
/usr/local/lib/libtensorflow_cc.so
/usr/local/lib//libtensorflow_framework.so)

include_directories(
${TENSORFLOW_INCLUDES}
${PROJECT_SOURCE_DIR}
/third_party/eigen3
)
add_executable(predict predict.cpp)
target_link_libraries(predict
${TENSORFLOW_LIBS}
${OpenCV_LIBS}
)

目录结构如图所示

技术分享图片


推荐阅读
  • 本文介绍了 Confluence 6 中使用的其他 Cookie,这些 Cookie 主要用于存储产品的基本持久性和用户偏好设置,以提升用户体验。 ... [详细]
  • 如何解决TS1219:实验性装饰器功能可能在未来版本中更改的问题
    本文介绍了两种方法来解决TS1219错误:通过VSCode设置启用实验性装饰器,或在项目根目录下创建配置文件(jsconfig.json或tsconfig.json)。 ... [详细]
  • 蒜头君的倒水问题(矩阵快速幂优化)
    蒜头君将两杯热水分别倒入两个杯子中,每杯水的初始量分别为a毫升和b毫升。为了使水冷却,蒜头君采用了一种特殊的方式,即每次将第一杯中的x%的水倒入第二杯,同时将第二杯中的y%的水倒入第一杯。这种操作会重复进行k次,最终求出两杯水中各自的水量。 ... [详细]
  • malloc 是 C 语言中的一个标准库函数,全称为 memory allocation,即动态内存分配。它用于在程序运行时申请一块指定大小的连续内存区域,并返回该区域的起始地址。当无法预先确定内存的具体位置时,可以通过 malloc 动态分配内存。 ... [详细]
  • 2017年5月9日学习总结
    本文记录了2017年5月9日的学习内容,包括技术分享和相关知识点的深入探讨。 ... [详细]
  • 本文章提供了适用于 Cacti 的多核 CPU 监控模板,支持 2、4、8、12、16、24 和 32 核配置。请注意,0.87g 版本的 Cacti 需要手动修改哈希值为 0021 才能使用,而 0.88 及以上版本则可直接导入。 ... [详细]
  • Gty的二逼妹子序列 - 分块与莫队算法的应用
    Autumn 和 Bakser 正在研究 Gty 的妹子序列,但遇到了一个难题。他们希望计算某个区间内美丽度属于 [a, b] 的妹子的美丽度种类数。本文将详细介绍如何利用分块和莫队算法解决这一问题。 ... [详细]
  • JavaSE For循环入门示例
    本文将介绍Java中For循环的基本概念和使用方法,通过几个简单的示例帮助初学者更好地理解和掌握For循环。 ... [详细]
  • 本文介绍了一种支付平台异步风控系统的架构模型,旨在为开发类似系统的工程师提供参考。 ... [详细]
  • 使用 Git Rebase -i 合并多个提交
    在开发过程中,频繁的小改动往往会生成多个提交记录。为了保持代码仓库的整洁,我们可以使用 git rebase -i 命令将多个提交合并成一个。 ... [详细]
  • Manacher算法详解:寻找最长回文子串
    本文将详细介绍Manacher算法,该算法用于高效地找到字符串中的最长回文子串。通过在字符间插入特殊符号,Manacher算法能够同时处理奇数和偶数长度的回文子串问题。 ... [详细]
  • 本文介绍了多种开源数据库及其核心数据结构和算法,包括MySQL的B+树、MVCC和WAL,MongoDB的tokuDB和cola,boltDB的追加仅树和mmap,levelDB的LSM树,以及内存缓存中的一致性哈希。 ... [详细]
  • Python多线程详解与示例
    本文介绍了Python中的多线程编程,包括僵尸进程和孤儿进程的概念,并提供了具体的代码示例。同时,详细解释了0号进程和1号进程在系统中的作用。 ... [详细]
  • 本文详细介绍了Linux系统中用于管理IPC(Inter-Process Communication)资源的两个重要命令:ipcs和ipcrm。通过这些命令,用户可以查看和删除系统中的消息队列、共享内存和信号量。 ... [详细]
  • A*算法在AI路径规划中的应用
    路径规划算法用于在地图上找到从起点到终点的最佳路径,特别是在存在障碍物的情况下。A*算法是一种高效且广泛使用的路径规划算法,适用于静态和动态环境。 ... [详细]
author-avatar
ppqq21
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有