机器学习框架ML.NET学习笔记【6】TensorFlow图片分类

时间:2022-03-19 02:47:00

一、概述

通过之前两篇文章的学习,我们应该已经了解了多元分类的工作原理,图片的分类其流程和之前完全一致,其中最核心的问题就是特征的提取,只要完成特征提取,分类算法就很好处理了,具体流程如下:

机器学习框架ML.NET学习笔记【6】TensorFlow图片分类

之前介绍过,图片的特征是不能采用像素的灰度值的,这部分原理的台阶有点高,还好可以直接使用通过TensorFlow训练过的特征提取模型(美其名曰迁移学习)。

模型文件为:tensorflow_inception_graph.pb

二、样本介绍

我随便在网上找了一些图片,分成6类:男孩、女孩、猫、狗、男人、女人。tags文件标记了每个文件所代表的类型标签(Label)。

机器学习框架ML.NET学习笔记【6】TensorFlow图片分类

通过对这六类图片的学习,期望输入新的图片时,可以判断出是何种类型。

三、代码

全部代码:

namespace TensorFlow_ImageClassification
{ class Program
{
//Assets files download from:https://gitee.com/seabluescn/ML_Assets
static readonly string AssetsFolder = @"D:\StepByStep\Blogs\ML_Assets";
static readonly string TrainDataFolder = Path.Combine(AssetsFolder, "ImageClassification", "train");
static readonly string TrainTagsPath = Path.Combine(AssetsFolder, "ImageClassification", "train_tags.tsv");
static readonly string TestDataFolder = Path.Combine(AssetsFolder, "ImageClassification","test");
static readonly string inceptionPb = Path.Combine(AssetsFolder, "TensorFlow", "tensorflow_inception_graph.pb");
static readonly string imageClassifierZip = Path.Combine(Environment.CurrentDirectory, "MLModel", "imageClassifier.zip"); //配置用常量
private struct ImageNetSettings
{
public const int imageHeight = ;
public const int imageWidth = ;
public const float mean = ;
public const float scale = ;
public const bool channelsLast = true;
} static void Main(string[] args)
{
TrainAndSaveModel();
LoadAndPrediction(); Console.WriteLine("Hit any key to finish the app");
Console.ReadKey();
} public static void TrainAndSaveModel()
{
MLContext mlContext = new MLContext(seed: ); // STEP 1: 准备数据
var fulldata = mlContext.Data.LoadFromTextFile<ImageNetData>(path: TrainTagsPath, separatorChar: '\t', hasHeader: false); var trainTestData = mlContext.Data.TrainTestSplit(fulldata, testFraction: 0.1);
var trainData = trainTestData.TrainSet;
var testData = trainTestData.TestSet; // STEP 2:创建学习管道
var pipeline = mlContext.Transforms.Conversion.MapValueToKey(outputColumnName: "LabelTokey", inputColumnName: "Label")
.Append(mlContext.Transforms.LoadImages(outputColumnName: "input", imageFolder: TrainDataFolder, inputColumnName: nameof(ImageNetData.ImagePath)))
.Append(mlContext.Transforms.ResizeImages(outputColumnName: "input", imageWidth: ImageNetSettings.imageWidth, imageHeight: ImageNetSettings.imageHeight, inputColumnName: "input"))
.Append(mlContext.Transforms.ExtractPixels(outputColumnName: "input", interleavePixelColors: ImageNetSettings.channelsLast, offsetImage: ImageNetSettings.mean))
.Append(mlContext.Model.LoadTensorFlowModel(inceptionPb).
ScoreTensorFlowModel(outputColumnNames: new[] { "softmax2_pre_activation" }, inputColumnNames: new[] { "input" }, addBatchDimensionInput: true))
.Append(mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(labelColumnName: "LabelTokey", featureColumnName: "softmax2_pre_activation"))
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabelValue", "PredictedLabel"))
.AppendCacheCheckpoint(mlContext); // STEP 3:通过训练数据调整模型
ITransformer model = pipeline.Fit(trainData); // STEP 4:评估模型
Console.WriteLine("===== Evaluate model =======");
var evaData = model.Transform(testData);
var metrics = mlContext.MulticlassClassification.Evaluate(evaData, labelColumnName: "LabelTokey", predictedLabelColumnName: "PredictedLabel");
PrintMultiClassClassificationMetrics(metrics); //STEP 5:保存模型
Console.WriteLine("====== Save model to local file =========");
mlContext.Model.Save(model, trainData.Schema, imageClassifierZip);
} static void LoadAndPrediction()
{
MLContext mlContext = new MLContext(seed: ); // Load the model
ITransformer loadedModel = mlContext.Model.Load(imageClassifierZip, out var modelInputSchema); // Make prediction function (input = ImageNetData, output = ImageNetPrediction)
var predictor = mlContext.Model.CreatePredictionEngine<ImageNetData, ImageNetPrediction>(loadedModel); DirectoryInfo testdir = new DirectoryInfo(TestDataFolder);
foreach (var jpgfile in testdir.GetFiles("*.jpg"))
{
ImageNetData image = new ImageNetData();
image.ImagePath = jpgfile.FullName;
var pred = predictor.Predict(image); Console.WriteLine($"Filename:{jpgfile.Name}:\tPredict Result:{pred.PredictedLabelValue}");
}
}
} public class ImageNetData
{
[LoadColumn()]
public string ImagePath; [LoadColumn()]
public string Label;
} public class ImageNetPrediction
{
//public float[] Score;
public string PredictedLabelValue;
}
}

四、分析

 1、数据处理通道

可以看出,其代码流程与结构与上两篇文章介绍的完全一致,这里就介绍一下核心的数据处理模型部分的代码:

var pipeline = mlContext.Transforms.Conversion.MapValueToKey(outputColumnName: "LabelTokey", inputColumnName: "Label")
.Append(mlContext.Transforms.LoadImages(outputColumnName: "input", imageFolder: TrainDataFolder, inputColumnName: nameof(ImageNetData.ImagePath)))
.Append(mlContext.Transforms.ResizeImages(outputColumnName: "input", imageWidth: ImageNetSettings.imageWidth, imageHeight: ImageNetSettings.imageHeight, inputColumnName: "input"))
.Append(mlContext.Transforms.ExtractPixels(outputColumnName: "input", interleavePixelColors: ImageNetSettings.channelsLast, offsetImage: ImageNetSettings.mean))
.Append(mlContext.Model.LoadTensorFlowModel(inceptionPb).
ScoreTensorFlowModel(outputColumnNames: new[] { "softmax2_pre_activation" }, inputColumnNames: new[] { "input" }, addBatchDimensionInput: true))
.Append(mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(labelColumnName: "LabelTokey", featureColumnName: "softmax2_pre_activation"))
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabelValue", "PredictedLabel"))

MapValueToKey与MapKeyToValue之前已经介绍过了;
LoadImages是读取文件,输入为文件名、输出为Image;
ResizeImages是改变图片尺寸,这一步是必须的,即使所有训练图片都是标准划一的图片也需要这个操作,后面需要根据这个尺寸确定容纳图片像素信息的数组大小;
ExtractPixels是将图片转换为包含像素数据的矩阵;
LoadTensorFlowModel是加载第三方模型,ScoreTensorFlowModel是调用模型处理数据,其输入为:“input”,输出为:“softmax2_pre_activation”,由于模型中输入、输出的名称是规定的,所以,这里的名称不可以随便修改。
分类算法采用的是L-BFGS最大熵分类算法,其特征数据为TensorFlow网络输出的值,标签值为"LabelTokey"。

2、验证过程
            MLContext mlContext = new MLContext(seed: );
ITransformer loadedModel = mlContext.Model.Load(imageClassifierZip, out var modelInputSchema);
var predictor = mlContext.Model.CreatePredictionEngine<ImageNetData, ImageNetPrediction>(loadedModel); ImageNetData image = new ImageNetData();
image.ImagePath = jpgfile.FullName;
var pred = predictor.Predict(image);
Console.WriteLine($"Filename:{jpgfile.Name}:\tPredict Result:{pred.PredictedLabelValue}");

两个实体类代码:

    public class ImageNetData
{
[LoadColumn()]
public string ImagePath; [LoadColumn()]
public string Label;
} public class ImageNetPrediction
{
public string PredictedLabelValue;
}
3、验证结果
我在网络上又随便找了20张图片进行验证,发现验证成功率是非常高的,基本都是准确的,只有两个出错了。

机器学习框架ML.NET学习笔记【6】TensorFlow图片分类

上图片被识别为girl(我认为是Woman),这个情有可原,本来girl和worman在外貌上也没有一个明确的分界点。

机器学习框架ML.NET学习笔记【6】TensorFlow图片分类

上图被识别为woman,这个也情有可原,不解释。

需要了解的是:不管你输入什么图片,最终的结果只能是以上六个类型之一,算法会寻找到和六个分类中特征最接近的一个分类作为结果。

4、调试
注意看实体类的话,我们只提供了三个基本属性,如果想看一下在学习过程中数据是如何处理的,可以给ImageNetPrediction类增加一些字段进行调试。
首先我们需要看一下IDateView有哪些列(Column)
            var predictions = trainedModel.Transform(testData);          

            var OutputColumnNames = predictions.Schema.Where(col => !col.IsHidden).Select(col => col.Name);
foreach (string column in OutputColumnNames)
{
Console.WriteLine($"OutputColumnName:{ column }");
}

将我们要调试的列加入到实体对象中去,特别要注意数据类型。

    public class ImageNetPrediction
{
public float[] Score;
public string PredictedLabelValue;
public string Label; public void PrintToConsole()
{
//打印字段信息
}
}

查看数据集详细信息:

           var predictions = trainedModel.Transform(testData);
var DataShowList = new List<ImageNetPrediction>(mlContext.Data.CreateEnumerable<ImageNetPrediction>(predictions, false, true));
foreach (var dataline in DataShowList)
{
dataline.PrintToConsole();
}

五、资源获取 

源码下载地址:https://github.com/seabluescn/Study_ML.NET

工程名称:TensorFlow_ImageClassification

资源获取:https://gitee.com/seabluescn/ML_Assets

点击查看机器学习框架ML.NET学习笔记系列文章目录

机器学习框架ML.NET学习笔记【6】TensorFlow图片分类的更多相关文章

  1. 机器学习框架ML&period;NET学习笔记【1】基本概念与系列文章目录

    一.序言 微软的机器学习框架于2018年5月出了0.1版本,2019年5月发布1.0版本.期间各版本之间差异(包括命名空间.方法等)还是比较大的,随着1.0版发布,应该是趋于稳定了.之前在园子里也看到 ...

  2. 机器学习框架ML&period;NET学习笔记【5】多元分类之手写数字识别(续)

    一.概述 上一篇文章我们利用ML.NET的多元分类算法实现了一个手写数字识别的例子,这个例子存在一个问题,就是输入的数据是预处理过的,很不直观,这次我们要直接通过图片来进行学习和判断.思路很简单,就是 ...

  3. 机器学习框架ML&period;NET学习笔记【7】人物图片颜值判断

    一.概述 这次要解决的问题是输入一张照片,输出人物的颜值数据. 学习样本来源于华南理工大学发布的SCUT-FBP5500数据集,数据集包括 5500 人,每人按颜值魅力打分,分值在 1 到 5 分之间 ...

  4. 机器学习框架ML&period;NET学习笔记【4】多元分类之手写数字识别

    一.问题与解决方案 通过多元分类算法进行手写数字识别,手写数字的图片分辨率为8*8的灰度图片.已经预先进行过处理,读取了各像素点的灰度值,并进行了标记. 其中第0列是序号(不参与运算).1-64列是像 ...

  5. 机器学习框架ML&period;NET学习笔记【3】文本特征分析

    一.要解决的问题 问题:常常一些单位或组织召开会议时需要录入会议记录,我们需要通过机器学习对用户输入的文本内容进行自动评判,合格或不合格.(同样的问题还类似垃圾短信检测.工作日志质量分析等.) 处理思 ...

  6. 机器学习框架ML&period;NET学习笔记【2】入门之二元分类

    一.准备样本 接上一篇文章提到的问题:根据一个人的身高.体重来判断一个人的身材是否很好.但我手上没有样本数据,只能伪造一批数据了,伪造的数据比较标准,用来学习还是蛮合适的. 下面是我用来伪造数据的代码 ...

  7. 机器学习框架ML&period;NET学习笔记【8】目标检测(采用YOLO2模型)

    一.概述 本篇文章介绍通过YOLO模型进行目标识别的应用,原始代码来源于:https://github.com/dotnet/machinelearning-samples 实现的功能是输入一张图片, ...

  8. 机器学习框架ML&period;NET学习笔记【9】自动学习

    一.概述 本篇我们首先通过回归算法实现一个葡萄酒品质预测的程序,然后通过AutoML的方法再重新实现,通过对比两种实现方式来学习AutoML的应用. 首先数据集来自于竞赛网站kaggle.com的UC ...

  9. 【学习笔记】tensorflow图片读取

    目录 图像基本概念 图像基本操作 图像基本操作API 图像读取API 狗图片读取 CIFAR-10二进制数据读取 TFRecords TFRecords存储 TFRecords读取方法 图像基本概念 ...

随机推荐

  1. zju&lpar;8&rpar;串口通信实验

    1.实验目的 1.学习和掌握linux下串口的操作方法以及应用程序的编写: 二.实验内容 1.编写EduKit-IV实验箱Linux操作系统下串口的应用程序,运行时只需要将串口线的一端连接到开发板的c ...

  2. javaWeb 使用cookie显示上次访问网站时间

    package de.bvb.cookie; import java.io.IOException; import java.io.PrintWriter; import java.util.Date ...

  3. MVC验证08-jQuery异步验证

    原文:MVC验证08-jQuery异步验证 本文主要体验通过jQuery异步验证. 在很多的教材和案例中,MVC验证都是通过提交表单进行的.通过提交表单,可以很容易获得验证出错信息.因为,无论是客户端 ...

  4. 线程内唯一对象HttpContext

    在asp.net中,HttpContext是主线程内唯一对象.在web应用中开启多线程,在另外一个线程中是无法访问HttpContext. 如果需要在另外一个线程中使用HttpContext.Curr ...

  5. Exp6 信息收集与漏洞扫描

    一.实践过程 1.信息收集 1.1 通过DNS和IP查询目标网站的信息 (1)whois命令用来进行域名注册信息查询,可查询到3R注册信息,包括注册人的姓名.组织和城市等信息. whois baidu ...

  6. 用Vue&period;js搭建一个小说阅读网站

    目录 1.简介 2.如何使用vue.js 3.部署api服务器 4.vue.js路由配置 5.实现页面加载数据 6.测试vue项目 7.在正式环境部署 8.Vue前端代码下载 1.简介 这是一个使用v ...

  7. mysql 开源~canal维护相关问题

    一 简介:咱们来讨论下canal的一些技巧 二 场景 场景1 canal过滤指定库后,后端java调用读取相关数据时候出现大量的空事务,为何会出现空事务呢,空事务是由于配置了指定的过滤规则,导致了其他 ...

  8. YII缓存整理

    缓存 缓存是用于提升网站性能的一种即简单又有效的途径.通过存储相对静态的数据至缓存以备所需,我们可以省去生成这些数据的时间.在 Yii 中使用缓存主要包括配置和访问缓存组件 . 如下的应用配置指定了一 ...

  9. ajax个人总结

    ajax是什么? AJAX = Asynchronous JavaScript and XML(异步的 JavaScript 和 XML). AJAX 不是新的编程语言,而是一种使用现有标准的新方法. ...

  10. ConstraintLayout导读

    ConstraintLayout是Android Studio 2.2中主要的新增功能之一,也是Google在去年的I/O大会上重点宣传的一个功能,可以把ConstraintLayout看成是一个更高 ...