Spark分布式机器学习源码分析:如何用分布式集群构建线性模型?
suiw9 2024-11-13 14:43 40 浏览 0 评论
Spark是一个极为优秀的大数据框架,在大数据批处理上基本无人能敌,流处理上也有一席之地,机器学习则是当前正火热AI人工智能的驱动引擎,在大数据场景下如何发挥AI技术成为优秀的大数据挖掘工程师必备技能。
本文采用的组件版本为:Ubuntu 19.10、Jdk 1.8.0_241、Scala 2.11.12、Hadoop 3.2.1、Spark 2.4.5,老规矩先开启一系列Hadoop、Spark服务与Spark-shell窗口:
spark.mllib提供了多种方法用于用于二分类、多分类以及回归分析。下表介绍了每种问题类型支持的算法。
问题类型支持的方法二分类线性SVMs、逻辑回归、决策树、随机森林、梯度增强树、朴素贝叶斯多分类逻辑回归、决策树、随机森林、朴素贝叶斯回归线性最小二乘、决策树、随机森林、梯度增强树、保序回归
1.数学公式
许多标准的机器学习方法可以表述为凸优化问题,即寻找凸函数f的最小化器的任务,该函数取决于具有d个条目的变量向量w(在代码中称为权重)。形式上,我们可以将其写为优化问题minw∈Rdf(w),其中目标函数的形式为:
下表总结了spark.mllib支持的方法的损失函数及其梯度或子梯度:
正则化器的目的是鼓励简化模型并避免过度拟合。我们在spark.mllib中支持以下正则化器:
在幕后,线性方法使用凸优化方法来优化目标函数。spark.mllib使用两种方法,SGD和L-BFGS,在优化部分中进行了介绍。当前,大多数算法API支持随机梯度下降(SGD),而少数支持L-BFGS。有关在两种优化方法之间进行选择的指导,请参阅此优化部分。
线性回归假设特征和结果都满足线性。即不大于一次方。收集的数据中,每一个分量,就可以看做一个特征数据。每个特征至少对应一个未知的参数。这样就形成了一个线性模型函数,向量表示形式:
线性最小二乘法是回归问题的最常见表示。这是一种线性方法,公式中的损失函数由平方损失给出:
2.线性回归
线性最小二乘法是回归问题的最常见表示。这是一种线性方法,如上面的公式所述,公式中的损失函数由平方损失给出。通过使用不同类型的正则化来推导各种相关的回归方法:普通最小二乘法或线性最小二乘不使用正则化;岭回归使用L2正则化;Lasso使用L1正则化。对于所有这些模型,平均损失或训练误差
被称为均方误差。
下面的示例演示如何加载训练数据,将其解析为LabeledPoint的RDD。然后,该示例使用LinearRegressionWithSGD构建一个简单的线性模型来预测标签值。我们在最后计算均方误差以评估拟合优度。
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.regression.LinearRegressionModel
import org.apache.spark.mllib.regression.LinearRegressionWithSGD
// 加载和解析数据(本文已将位于/usr/local/spark/data/mllib的数据上传至dfs,如不上传路径需自己指定)
val data = sc.textFile("data/mllib/ridge-data/lpsa.data")
val parsedData = data.map { line =>
val parts = line.split(',')
LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
}.cache()
// 建立模型
val numIterations = 100
val stepSize = 0.00000001
val model = LinearRegressionWithSGD.train(parsedData, numIterations, stepSize)
// 在训练集上评估模型并且计算误差
val valuesAndPreds = parsedData.map { point =>
val prediction = sameModel.predict(point.features)
(point.label, prediction)
}
val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2) }.mean()
println(s"training Mean Squared Error $MSE")
// Save and load model 保存和加载模型
model.save(sc, "target/tmp/scalaLinearRegressionWithSGDModel")
val sameModel = LinearRegressionModel.load(sc, "target/tmp/scalaLinearRegressionWithSGDModel")
下面截取spark mllib线性回归模型train训练参数的源码进行分析:
// * @paraminpu RDD标签点键值对。 每对描述一行数据矩阵A以及相应的右侧标签y
// * @param numIterations要运行的梯度下降的迭代次数。
// * @param stepSize用于梯度下降的每次迭代的步长。
// * @param miniBatchFraction每次迭代使用的数据分数。
// * @param initialWeights要使用的初始权重集。 数组的大小应等于数据中的要素数量。
@Since("1.0.0")
def train(
input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
initialWeights: Vector): LinearRegressionModel = {
new LinearRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction)
.run(input, initialWeights)
}
MLlib的线性回归模型采用随机梯度下降算法来优化目标函数,MLlib实现了分布式的随机梯度下降算法,其分布方法是:在每次迭代中,随机抽取一定比例的样本作为当前迭代的计算样本;对计算样本中的每一个样本分别计算梯度(分布式计算每个样本的梯度);然后再通过聚合函数对样本的梯度进行累加,得到该样本的平均梯度及损失;最后根据最新的梯度及上次迭代的权重进行权重的更新。源码分解说明:
3.逻辑回归
Logistic回归广泛用于预测二进制响应。这是一种线性方法,如上面的的损失函数公式所述,公式中的损失函数由逻辑损失给出:
对于二进制分类问题,该算法输出二进制logistic回归模型。给定一个新的数据点,用x表示,该模型通过应用逻辑函数进行预测:
默认情况下,如果f(z)> 0.5,则结果为正,否则为负,尽管与线性SVM不同,对数回归模型的原始输出具有概率解释(即,x的概率是肯定的)。
以下代码说明了如何加载示例多类数据集,将其拆分为训练和测试,以及如何使用LogisticRegressionWithLBFGS来拟合Logistic回归模型。然后根据测试数据集评估模型并将其保存到磁盘。
import org.apache.spark.mllib.classification.{LogisticRegressionModel, LogisticRegressionWithLBFGS}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
// 加载libsvm格式的训练数据
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// 划分数据集:训练集60% 测试集40%
val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
val training = splits(0).cache()
val test = splits(1)
// 训练逻辑回归算法去建立模型
val model = new LogisticRegressionWithLBFGS().setNumClasses(10).run(training)
// 计算在测试集上的原始分数
val predictionAndLabels = test.map { case LabeledPoint(label, features) =>
val prediction = model.predict(features)
(prediction, label)
}
// 获取多分类指标
val metrics = new MulticlassMetrics(predictionAndLabels)
val accuracy = metrics.accuracy
println(s"Accuracy = $accuracy")
// 保存和加载模型
model.save(sc, "target/tmp/scalaLogisticRegressionWithLBFGSModel")
val sameModel = LogisticRegressionModel.load(sc, "target/tmp/scalaLogisticRegressionWithLBFGSModel")
如上所述,在MLlib中,分别使用了梯度下降法和L-BFGS实现逻辑回归参数的计算。这两个算法的实现我们会在最优化章节介绍,这里我们介绍公共的部分。
LogisticRegressionWithLBFGS和LogisticRegressionWithSGD的入口函数均是GeneralizedLinearAlgorithm.run,下面详细分析该方法。
def run(input: RDD[LabeledPoint]): M = {
if (numFeatures < 0) {
// 计算特征数
numFeatures = input.map(_.features.size).first()
}
val initialWeights = {
if (numOfLinearPredictor == 1) {
Vectors.zeros(numFeatures)
} else if (addIntercept) {
Vectors.zeros((numFeatures + 1) * numOfLinearPredictor)
} else {
Vectors.zeros(numFeatures * numOfLinearPredictor)
}
}
run(input, initialWeights)
}
4.线性支持向量机
线性SVM是用于大规模分类任务的标准方法。这是一种线性方法,如上面的损失函数公式所述,公式中的损耗函数由铰链损耗给出:
默认情况下,线性SVM使用L2正则化训练。我们还支持替代的L1正则化。在这种情况下,问题就变成了线性程序。
线性SVM算法输出一个SVM模型。给定一个新的数据点,用x表示,该模型基于wTx的值进行预测。默认情况下,如果wTx≥0,则结果为正,否则为负。
以下代码段说明了如何加载样本数据集,如何使用算法对象中的静态方法对此训练数据执行训练算法以及如何使用所得模型进行预测以计算训练误差。
import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.util.MLUtils
// 加载libsvm格式训练数据
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// 划分数据集:训练集60% 测试集40%
val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
val training = splits(0).cache()
val test = splits(1)
// 训练支持向量机算法去建立模型
val numIterations = 100
val model = SVMWithSGD.train(training, numIterations)
// 清除默认阈值
model.clearThreshold()
// 计算在测试集上的原始分数
val scoreAndLabels = test.map { point =>
val score = model.predict(point.features)
(score, point.label)
}
// 获取多分类指标
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
val auROC = metrics.areaUnderROC()
println(s"Area under ROC = $auROC")
// 保存和加载模型
model.save(sc, "target/tmp/scalaSVMWithSGDModel")
val sameModel = SVMModel.load(sc, "target/tmp/scalaSVMWithSGDModel")
和逻辑回归一样,训练过程均使用GeneralizedLinearModel中的run训练,只是训练使用的Gradient和Updater不同。在线性支持向量机中,使用HingeGradient计算梯度,使用SquaredL2Updater进行更新。它的实现过程分为4步。参加逻辑回归了解这五步的详细情况。我们只需要了解HingeGradient和SquaredL2Updater的实现。
class HingeGradient extends Gradient {
override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
val dotProduct = dot(data, weights)
// 我们的损失函数是 max(0, 1 - (2y - 1) (f_w(x)))
// 所以梯度是 -(2y - 1)*x
val labelScaled = 2 * label - 1.0
if (1.0 > labelScaled * dotProduct) {
val gradient = data.copy
scal(-labelScaled, gradient)
(gradient, 1.0 - labelScaled * dotProduct)
} else {
(Vectors.sparse(weights.size, Array.empty, Array.empty), 0.0)
}
}
override def compute(
data: Vector,
label: Double,
weights: Vector,
cumGradient: Vector): Double = {
val dotProduct = dot(data, weights)
// 我们的损失函数是 max(0, 1 - (2y - 1) (f_w(x)))
// 所以梯度是 -(2y - 1)*x
val labelScaled = 2 * label - 1.0
if (1.0 > labelScaled * dotProduct) {
//cumGradient -= labelScaled * data
axpy(-labelScaled, data, cumGradient)
//损失值
1.0 - labelScaled * dotProduct
} else {
0.0
}
}
}
Spark线性模型的内容至此结束,有关Spark的基础文章可参考前文:
阿里是怎么做大数据的?淘宝怎么能承载双11?大数据之眸告诉你
高频面经总结:最全大数据+AI方向面试100题(附答案详解)
参考链接:
https://github.com/endymecy
https://github.com/endymecy/spark-ml-source-analysis
相关推荐
- 俄罗斯的 HTTPS 也要被废了?(俄罗斯网站关闭)
-
发布该推文的ScottHelme是一名黑客,SecurityHeaders和ReportUri的创始人、Pluralsight作者、BBC常驻黑客。他表示,CAs现在似乎正在停止为俄罗斯域名颁发...
- 如何强制所有流量使用 HTTPS一网上用户
-
如何强制所有流量使用HTTPS一网上用户使用.htaccess强制流量到https的最常见方法可能是使用.htaccess重定向请求。.htaccess是一个简单的文本文件,简称为“.h...
- https和http的区别(https和http有何区别)
-
“HTTPS和HTTP都是数据传输的应用层协议,区别在于HTTPS比HTTP安全”。区别在哪里,我们接着往下看:...
- 快码住!带你十分钟搞懂HTTP与HTTPS协议及请求的区别
-
什么是协议?网络协议是计算机之间为了实现网络通信从而达成的一种“约定”或“规则”,正是因为这个“规则”的存在,不同厂商的生产设备、及不同操作系统组成的计算机之间,才可以实现通信。简单来说,计算机与网络...
- 简述HTTPS工作原理(简述https原理,以及与http的区别)
-
https是在http协议的基础上加了一层SSL(由网景公司开发),加密由ssl实现,它的目的是为用户提供对网站服务器的身份认证(需要CA),以至于保护交换数据的隐私和完整性,原理如图示。1、客户端发...
- 21、HTTPS 有几次握手和挥手?HTTPS 的原理什么是(高薪 常问)
-
HTTPS是3次握手和4次挥手,和HTTP是一样的。HTTPS的原理...
- 一次安全可靠的通信——HTTPS原理
-
为什么HTTPS协议就比HTTP安全呢?一次安全可靠的通信应该包含什么东西呢,这篇文章我会尝试讲清楚这些细节。Alice与Bob的通信...
- 为什么有的网站没有使用https(为什么有的网站点不开)
-
有的网站没有使用HTTPS的原因可能涉及多个方面,以下是.com、.top域名的一些见解:服务器性能限制:HTTPS使用公钥加密和私钥解密技术,这要求服务器具备足够的计算能力来处理加解密操作。如果服务...
- HTTPS是什么?加密原理和证书。SSL/TLS握手过程
-
秘钥的产生过程非对称加密...
- 图解HTTPS「转」(图解http 完整版 彩色版 pdf)
-
我们都知道HTTPS能够加密信息,以免敏感信息被第三方获取。所以很多银行网站或电子邮箱等等安全级别较高的服务都会采用HTTPS协议。...
- HTTP 和 HTTPS 有何不同?一文带你全面了解
-
随着互联网时代的高速发展,Web服务器和客户端之间的安全通信需求也越来越高。HTTP和HTTPS是两种广泛使用的Web通信协议。本文将介绍HTTP和HTTPS的区别,并探讨为什么HTTPS已成为We...
- HTTP与HTTPS的区别,详细介绍(http与https有什么区别)
-
HTTP与HTTPS介绍超文本传输协议HTTP协议被用于在Web浏览器和网站服务器之间传递信息,HTTP协议以明文方式发送内容,不提供任何方式的数据加密,如果攻击者截取了Web浏览器和网站服务器之间的...
- 一文让你轻松掌握 HTTPS(https详解)
-
一文让你轻松掌握HTTPS原文作者:UC国际研发泽原写在最前:欢迎你来到“UC国际技术”公众号,我们将为大家提供与客户端、服务端、算法、测试、数据、前端等相关的高质量技术文章,不限于原创与翻译。...
- 如何在Spring Boot应用程序上启用HTTPS?
-
HTTPS是HTTP的安全版本,旨在提供传输层安全性(TLS)[安全套接字层(SSL)的后继产品],这是地址栏中的挂锁图标,用于在Web服务器和浏览器之间建立加密连接。HTTPS加密每个数据包以安全方...
- 一文彻底搞明白Http以及Https(http0)
-
早期以信息发布为主的Web1.0时代,HTTP已可以满足绝大部分需要。证书费用、服务器的计算资源都比较昂贵,作为HTTP安全扩展的HTTPS,通常只应用在登录、交易等少数环境中。但随着越来越多的重要...
你 发表评论:
欢迎- 一周热门
-
-
Linux:Ubuntu22.04上安装python3.11,简单易上手
-
宝马阿布达比分公司推出独特M4升级套件,整套升级约在20万
-
MATLAB中图片保存的五种方法(一)(matlab中保存图片命令)
-
别再傻傻搞不清楚Workstation Player和Workstation Pro的区别了
-
Linux上使用tinyproxy快速搭建HTTP/HTTPS代理器
-
如何提取、修改、强刷A卡bios a卡刷bios工具
-
Element Plus 的 Dialog 组件实现点击遮罩层不关闭对话框
-
日本组合“岚”将于2020年12月31日停止团体活动
-
SpringCloud OpenFeign 使用 okhttp 发送 HTTP 请求与 HTTP/2 探索
-
tinymce 号称富文本编辑器世界第一,大家同意么?
-
- 最近发表
- 标签列表
-
- dialog.js (57)
- importnew (44)
- windows93网页版 (44)
- yii2框架的优缺点 (45)
- tinyeditor (45)
- qt5.5 (60)
- windowsserver2016镜像下载 (52)
- okhttputils (51)
- android-gif-drawable (53)
- 时间轴插件 (56)
- docker systemd (65)
- slider.js (47)
- android webview缓存 (46)
- pagination.js (59)
- loadjs (62)
- openssl1.0.2 (48)
- velocity模板引擎 (48)
- pcre library (47)
- zabbix微信报警脚本 (63)
- jnetpcap (49)
- pdfrenderer (43)
- fastutil (48)
- uinavigationcontroller (53)
- bitbucket.org (44)
- python websocket-client (47)