热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

机器学习算法:SVM(支持向量机)

SVM算法(SupportVectorMachine,支持向量机)的核心思想有2点:1、如果数据线性可分,那么基于最大间隔的方式来确定超平面,以确保全局最优,

SVM算法(Support Vector Machine,支持向量机)的核心思想有2点:1、如果数据线性可分,那么基于最大间隔的方式来确定超平面,以确保全局最优,使得分类器尽可能健壮;2、如果数据线性不可分,通过核函数将低维样本转化为高维样本使其线性可分。注意和AdaBoost类似,SVM只能解决二分类问题。


SVM的算法在数学上实在是太复杂了,没研究明白。建议还是直接使用现成的第三方组件吧,比如libsvm的C#版本,推荐这个:http://www.matthewajohnson.org/software/svm.html。


虽然没研究明白,不过这几天照着Python版本的代码试着用C#改写了一下,算是研究SVM过程中唯一的收获吧。此版本基于SMO(序列最小优化)算法求解,核函数使用的是比较常用的径向基函数(RBF)。别问我为什么没有注释,我只是从Python移植过来的,我也没看懂,等我看懂了再来补注释吧。

using System;
using System.Collections.Generic;
using System.Linq;

namespace MachineLearning
{
    /// 
    /// 支持向量机(SMO算法,RBF核)
    /// 
    public class SVM
    {
        private Random m_Rand;
        private double[][] m_Kernel;
        private double[] m_Alpha;
        private double m_C = 1.0;
        private double m_B = 0.0;
        private double m_Toler = 0.0;
        private double[][] m_Cache;
        private double[][] m_Data;
        private double m_Reach;
        private int[] m_Label;
        private int m_Count;
        private int m_Dimension;
        
        public SVM()
        {
            m_Rand = new Random();
        }
        
        /// 
        /// 训练
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        public void Train(List> trainingSet, double C, double toler, double reach, int iterateCount = 10)
        {
            //初始化
            m_Count = trainingSet.Count;
            m_Dimension = trainingSet[0].Dimension;
            m_Toler = toler;
            m_C = C;
            m_Reach = reach;
            this.Init(trainingSet);
            this.InitKernel();
            
            int iter = 0;
            int alphaChanged = 0;
            bool entireSet = true;
            while(iter < iterateCount && (alphaChanged > 0 || entireSet))
            {
                alphaChanged = 0;
                if(entireSet)
                {
                    for(int i = 0;i < m_Count;++i)
                        alphaChanged += InnerL(i);
                    iter++;
                }
                else
                {
                    for(int i = 0;i < m_Count;++i)
                    {
                        if(m_Alpha[i] > 0 && m_Alpha[i] < m_C)
                            alphaChanged += InnerL(i);
                    }
                    iter += 1;
                }
                
                if(entireSet)
                    entireSet = false;
                else if(alphaChanged == 0)
                    entireSet = true;
            }
        }
        
        /// 
        /// 分类
        /// 
        /// 
        /// 
        public int Classify(DataVector vector)
        {
            double predict = 0.0;
            
            int svCnt = m_Alpha.Count(a => a > 0);
            var supportVectors = new double[svCnt][];
            var supportLabels = new int[svCnt];
            var supportAlphas = new double[svCnt];
            int index = 0;
            for(int i = 0;i < m_Count;++i)
            {
                if(m_Alpha[i] > 0)
                {
                    supportVectors[index] = m_Data[i];
                    supportLabels[index] = m_Label[i];
                    supportAlphas[index] = m_Alpha[i];
                    index++;
                }
            }
            
            var kernelEval = KernelTrans(supportVectors, vector.Data);
            for(int i = 0;i < svCnt;++i)
                predict += kernelEval[i] * supportAlphas[i] * supportLabels[i];
            predict += m_B;
            
            return Math.Sign(predict);
        }
        
        /// 
        /// 将原始数据转化成方便使用的形式
        /// 
        /// 
        private void Init(List> trainingSet)
        {
            m_Data = new double[m_Count][];
            m_Label = new int[m_Count];
            m_Alpha = new double[m_Count];
            m_Cache = new double[m_Count][];
            
            for(int i = 0;i < m_Count;++i)
            {
                m_Label[i] = trainingSet[i].Label;
                m_Alpha[i] = 0.0;
                m_Cache[i] = new double[2];
                m_Cache[i][0] = 0.0;
                m_Cache[i][1] = 0.0;
                m_Data[i] = new double[m_Dimension];
                for(int j = 0;j < m_Dimension;++j)
                    m_Data[i][j] = trainingSet[i].Data[j];
            }
        }
        
        /// 
        /// 初始化RBF核
        /// 
        private void InitKernel()
        {
            m_Kernel = new double[m_Count][];
            
            for(int i = 0;i < m_Count;++i)
            {
                m_Kernel[i] = new double[m_Count];
                var kernels = KernelTrans(m_Data, m_Data[i]);
                for(int k = 0;k < kernels.Length;++k)
                    m_Kernel[i][k] = kernels[k];
            }
        }
        
        private double[] KernelTrans(double[][] X, double[] A)
        {
            var kernel = new double[X.Length];
            
            for(int i = 0;i < X.Length;++i)
            {
                double delta = 0.0;
                for(int k = 0;k < X[0].Length;++k)
                    delta += Math.Pow(X[i][k] - A[k], 2);
                kernel[i] = Math.Exp(delta * -1.0 / Math.Pow(m_Reach, 2));
            }
            
            return kernel;
        }
        
        private double E(int k)
        {
            double x = 0.0;
            for(int i = 0;i < m_Count;++i)
                x += m_Alpha[i] * m_Label[i] * m_Kernel[i][k];
            x += m_B;
            
            return x - m_Label[k];
        }
        
        private void UpdateE(int k)
        {
            double Ek = E(k);
            m_Cache[k][0] = 1.0;
            m_Cache[k][1] = Ek;
        }
        
        private int InnerL(int i)
        {
            double Ei = E(i);
            
            if((m_Label[i] * Ei < -m_Toler && m_Alpha[i] < m_C) || (m_Label[i] * Ei > m_Toler && m_Alpha[i] > 0))
            {
                double Ej = 0.0;
                int j = SelectJ(i, Ei, out Ej);
                double oldAi = m_Alpha[i];
                double oldAj = m_Alpha[j];
                
                double H, L;
                if(m_Label[i] != m_Label[j])
                {
                    L = Math.Max(0, m_Alpha[j] - m_Alpha[i]);
                    H = Math.Min(m_C, m_C + m_Alpha[j] - m_Alpha[i]);
                }
                else
                {
                    L = Math.Max(0, m_Alpha[j] + m_Alpha[i] - m_C);
                    H = Math.Min(m_C, m_Alpha[j] + m_Alpha[i]);
                }
                
                if(L == H)
                    return 0;
                    
                double eta = 2.0 * m_Kernel[i][j] - m_Kernel[i][i] - m_Kernel[j][j];
                if(eta >= 0)
                    return 0;
                    
                m_Alpha[j] -= m_Label[j] * (Ei - Ej) / eta;
                m_Alpha[j] = ClipAlpha(m_Alpha[j], H, L);
                UpdateE(j);
                
                if(Math.Abs(m_Alpha[j] - oldAj) < 0.00001)
                    return 0;
                    
                m_Alpha[i] += m_Label[j] * m_Label[i] * (oldAj - m_Alpha[j]);
                UpdateE(i);
                
                double b1 = m_B - Ei - m_Label[i] * (m_Alpha[i] - oldAi) * m_Kernel[i][i] - m_Label[j] * (m_Alpha[j] - oldAj) * m_Kernel[i][j];
                double b2 = m_B - Ej - m_Label[i] * (m_Alpha[i] - oldAi) * m_Kernel[i][j] - m_Label[j] * (m_Alpha[j] - oldAj) * m_Kernel[j][j];
                
                if(m_Alpha[i] > 0 && m_Alpha[i] < m_C)
                    m_B = b1;
                else if(m_Alpha[j] > 0 && m_Alpha[j] < m_C)
                    m_B = b2;
                else
                    m_B = (b1 + b2) / 2.0;
                    
                return 1;
            }
            
            return 0;
        }
        
        private int SelectJ(int i, double Ei, out double Ej)
        {
            Ej = 0.0;
            
            int j = 0;
            int maxK = -1;
            double maxDeltaE = 0.0;
            
            m_Cache[i][0] = 1;
            m_Cache[i][1] = Ei;
            
            for(int k = 0;k < m_Count;++k)
            {
                if(k == i || m_Cache[k][0] == 0)
                    continue;
                    
                double Ek = E(k);
                double deltaE = Math.Abs(Ei - Ek);
                if(deltaE > maxDeltaE)
                {
                    maxK = k;
                    maxDeltaE = deltaE;
                    Ej = Ek;
                }
            }
            
            if(maxK >= 0)
            {
                j = maxK;
            }
            else
            {
                j = RandomSelect(i);
            }
            
            return j;
        }
        
        private int RandomSelect(int i)
        {
            int j = 0;
            do 
            {
                j = m_Rand.Next(0, m_Count);
            }
            while(j == i);
            
            return j;
        }
        
        private double ClipAlpha(double alpha, double H, double L)
        {
            return alpha > H ? H : (alpha < L ? L : alpha);
        }
    }
}


最后上测试,还是使用上次的breast-cancer-wisconsin.txt做测试,之前用kNN和AdaBoost测试的错误率分别是2.02%和1.01%,这回用SVM对比一下。上测试代码:


public void TestSvm()
{
    var trainingSet = new List>();
    var testSet = new List>();
    
    //读取数据
    var file = new StreamReader("breast-cancer-wisconsin.txt", Encoding.Default);
    for(int i = 0;i < 699;++i)
    {
        string line = file.ReadLine();
        var parts = line.Split(&#39;,&#39;);
        var p = new DataVector(9);
        for(int j = 0;j < p.Dimension;++j)
        {
            if(parts[j + 1] == "?")
                parts[j + 1] = "0";
            p.Data[j] = Convert.ToDouble(parts[j + 1]);
        }
        p.Label = Convert.ToInt32(parts[10]) == 2 ? 1 : -1;
        
        //和上次一样,600个做训练,99个做测试
        if(i < 600)
            trainingSet.Add(p);
        else
            testSet.Add(p);
    }
    file.Close();
    
    //检验
    var svm = new SVM();
    svm.Train(trainingSet, 1, 0.01, 3.0, 10);
    int error = 0;
    foreach(var p in testSet)
    {
        var label = boost.Classify(p);
        if(label != p.Label)
            error++;
    }
    
    Console.WriteLine("Error = {0}/{1}, {2}%", error, testSet.Count, (error * 100.0 / testSet.Count));
}


最终结果是99个测试样本猜错1个,错误率1.01%,和AdaBoost相当。


Train时使用不同的参数,错误率会变化,很可惜的是参数的选择往往没有固定的方法,需要在一定范围内尝试以得到最小错误率。



另外对于为什么核函数可以处理线性不可分数据,网上有2张图很能说明问题,转载一下:

以下数据明显是线性不可分的,在二维空间下找不到一条直线能将数据分开:

机器学习算法:SVM(支持向量机)


但在在二维空间下的线性不可分,到了三维空间,是可以找到一个平面来分隔数据的:

机器学习算法:SVM(支持向量机)




推荐阅读
  • 主板IO用W83627THG,用VC如何取得CPU温度,系统温度,CPU风扇转速,VBat的电压. ... [详细]
  • 使用 Azure Service Principal 和 Microsoft Graph API 获取 AAD 用户列表
    本文介绍了一段通用代码示例,该代码不仅能够操作 Azure Active Directory (AAD),还可以通过 Azure Service Principal 的授权访问和管理 Azure 订阅资源。Azure 的架构可以分为两个层级:AAD 和 Subscription。 ... [详细]
  • 本文详细介绍了如何解决Uploadify插件在Internet Explorer(IE)9和10版本中遇到的点击失效及JQuery运行时错误问题。通过修改相关JavaScript代码,确保上传功能在不同浏览器环境中的一致性和稳定性。 ... [详细]
  • 探讨如何高效使用FastJSON进行JSON数据解析,特别是从复杂嵌套结构中提取特定字段值的方法。 ... [详细]
  • 导航栏样式练习:项目实例解析
    本文详细介绍了如何创建一个具有动态效果的导航栏,包括HTML、CSS和JavaScript代码的实现,并附有详细的说明和效果图。 ... [详细]
  • 本文介绍了Java并发库中的阻塞队列(BlockingQueue)及其典型应用场景。通过具体实例,展示了如何利用LinkedBlockingQueue实现线程间高效、安全的数据传递,并结合线程池和原子类优化性能。 ... [详细]
  • 1.如何在运行状态查看源代码?查看函数的源代码,我们通常会使用IDE来完成。比如在PyCharm中,你可以Ctrl+鼠标点击进入函数的源代码。那如果没有IDE呢?当我们想使用一个函 ... [详细]
  • 本文介绍了如何使用JQuery实现省市二级联动和表单验证。首先,通过change事件监听用户选择的省份,并动态加载对应的城市列表。其次,详细讲解了使用Validation插件进行表单验证的方法,包括内置规则、自定义规则及实时验证功能。 ... [详细]
  • 从 .NET 转 Java 的自学之路:IO 流基础篇
    本文详细介绍了 Java 中的 IO 流,包括字节流和字符流的基本概念及其操作方式。探讨了如何处理不同类型的文件数据,并结合编码机制确保字符数据的正确读写。同时,文中还涵盖了装饰设计模式的应用,以及多种常见的 IO 操作实例。 ... [详细]
  • 本文详细介绍了在企业级项目中如何优化 Webpack 配置,特别是在 React 移动端项目中的最佳实践。涵盖资源压缩、代码分割、构建范围缩小、缓存机制以及性能优化等多个方面。 ... [详细]
  • 解决SVN图标显示异常问题的综合指南
    本文详细探讨了SVN图标无法正常显示的问题,并提供了多种有效的解决方案,涵盖不同环境下的具体操作步骤。通过本文,您将了解如何排查和修复这些常见的SVN图标显示故障。 ... [详细]
  • Kubernetes 持久化存储与数据卷详解
    本文深入探讨 Kubernetes 中持久化存储的使用场景、PV/PVC/StorageClass 的基本操作及其实现原理,旨在帮助读者理解如何高效管理容器化应用的数据持久化需求。 ... [详细]
  • 深入理解Cookie与Session会话管理
    本文详细介绍了如何通过HTTP响应和请求处理浏览器的Cookie信息,以及如何创建、设置和管理Cookie。同时探讨了会话跟踪技术中的Session机制,解释其原理及应用场景。 ... [详细]
  • 本文介绍了一款用于自动化部署 Linux 服务的 Bash 脚本。该脚本不仅涵盖了基本的文件复制和目录创建,还处理了系统服务的配置和启动,确保在多种 Linux 发行版上都能顺利运行。 ... [详细]
  • 将Web服务部署到Tomcat
    本文介绍了如何在JDeveloper 12c中创建一个Java项目,并将其打包为Web服务,然后部署到Tomcat服务器。内容涵盖从项目创建、编写Web服务代码、配置相关XML文件到最终的本地部署和验证。 ... [详细]
author-avatar
vuvhvuvh
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有