博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Spark学习——利用Mleap部署spark pipeline模型
阅读量:2162 次
发布时间:2019-05-01

本文共 5275 字,大约阅读时间需要 17 分钟。

Spark学习——利用Mleap部署spark pipeline模型

1.需要的依赖

org.apache.spark
spark-mllib_2.11
2.2.0
ml.combust.mleap
mleap-spark_2.11
0.11.0
org.apache.spark
spark-sql_2.11
2.2.0
org.apache.spark
spark-hive_2.11
2.2.0}

2.代码

2.1 数据预处理、模型训练、存储

import ml.combust.bundle.BundleFileimport ml.combust.mleap.spark.SparkSupport._import org.apache.spark.ml.bundle.SparkBundleContextimport org.apache.spark.ml.classification.{
GBTClassifier, LogisticRegression}import org.apache.spark.ml.evaluation.BinaryClassificationEvaluatorimport org.apache.spark.ml.feature.{
OneHotEncoder, StringIndexer, VectorAssembler}import org.apache.spark.sql._import org.apache.spark.{
SparkConf, SparkContext}import org.apache.spark.ml.{
Pipeline, PipelineStage}import resource.managedobject trainModelLeap {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setMaster("local[2]").setAppName("zhunshibao_test") val sc = new SparkContext(sparkConf) val spark = SparkSession.builder().config(sc.getConf).config("hive.metastore.uris", "thrift://10.202.77.200:9083").enableHiveSupport().getOrCreate() var data = spark.sql(constantConfig.testTrainDataSetSql).na.drop val splited = data.randomSplit(Array(0.8, 0.2), 2L) var trainSet = splited(0) var testSet = splited(1) trainSet.show(5) var dataProcessList = List[PipelineStage]() /** StringtoIndex **/ val stringColumns = Array("new_cargo","oprid") var StringColumnsInc = List[String]() for (filed <- stringColumns) {
val indexer = new StringIndexer().setInputCol(filed).setOutputCol(filed + "Inc").setHandleInvalid("skip") dataProcessList = dataProcessList :+ indexer StringColumnsInc = StringColumnsInc :+ (filed + "Inc") } /** 合并features **/ val assembler = new VectorAssembler().setInputCols(Array("top1","top2")).setOutputCol("features") dataProcessList = dataProcessList :+ assembler// /** model **/ val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01).setFeaturesCol("features").setLabelCol("output1").setPredictionCol("predict") val gbt = new GBTClassifier() .setLabelCol("output1") .setFeaturesCol("features") .setPredictionCol("gbt_prediction") .setProbabilityCol("gbt_prediction_prob") .setRawPredictionCol("gbt_prediction_raw") .setMaxBins(80) .setMaxIter(50) dataProcessList = dataProcessList :+ lr :+ gbt val pipeline = new Pipeline().setStages(dataProcessList.toArray) val model = pipeline.fit(trainSet) val ds = model.transform(trainSet) implicit val context = SparkBundleContext().withDataset(ds)/** 模型存储**/ for(bf <- managed(BundleFile("jar:file:/tmp/lc.model.zip"))){
model.writeBundle.save(bf)(context).get } }

2.2 预测(单条测试、批量测试)

注意输入model中transform的数据格式。

package com.sfimport ml.combust.bundle.BundleFileimport ml.combust.mleap.core.types.{
ScalarType, StructField, StructType}import ml.combust.mleap.runtime.MleapSupport._import ml.combust.mleap.runtime.frame.{
DefaultLeapFrame, Row}import ml.combust.mleap.spark.SparkLeapFrameimport ml.combust.mleap.spark.SparkSupport._import org.apache.spark.sql.SparkSessionimport org.apache.spark.{
SparkConf, SparkContext}import resource._object PredictLeap {
def getModelProbBatch(frame: SparkLeapFrame) = {
val bundle = (for (bundleFile <- managed(BundleFile("jar:file:/tmp/lc.model.zip"))) yield {
bundleFile.loadMleapBundle().get }).opt.get val model = bundle.root val df = model.transform(frame).get df } def getModelProbOne(data:Seq[Row]) = {
val bundle = (for (bundleFile <- managed(BundleFile("jar:file:/tmp/lc.model.zip"))) yield {
bundleFile.loadMleapBundle().get }).opt.get val schema = StructType(StructField("top1", ScalarType.Double), StructField("top2", ScalarType.Double), StructField("new_cargo", ScalarType.String), StructField("oprid", ScalarType.String)).get val frame = DefaultLeapFrame(schema, data) val model = bundle.root val df = model.transform(frame).get df } def main(args: Array[String]): Unit = {
/** 单条预测 **/ val dataOne: Seq[Row] = Seq(Row(1.0, 1.0, "other", "1247226")) val resOne = getModelProbOne(dataOne) resOne.show(5) /** dataframe批量预测 **/ val sparkConf = new SparkConf().setMaster("local[2]").setAppName("zhunshibao_test") val sc = new SparkContext(sparkConf) val spark = SparkSession.builder().config(sc.getConf).config("hive.metastore.uris", "thrift://10.202.77.200:9083").enableHiveSupport().getOrCreate() var dataBatch = spark.sql(constantConfig.testTrainDataSetSql).na.drop val splited = dataBatch.randomSplit(Array(0.8, 0.2), 2L) val testSet = splited(1) val resBatch = PredictLeap.getModelProbBatch(testSet.toSparkLeapFrame) resBatch.toSpark.toDF().show(5) }}

参考文档:

https://www.bookstack.cn/read/mleap-zh/mleap-runtime-create-leap-frame.md

转载地址:http://awwzb.baihongyu.com/

你可能感兴趣的文章
Java网络编程与NIO详解8:浅析mmap和Direct Buffer
查看>>
Java网络编程与NIO详解10:深度解读Tomcat中的NIO模型
查看>>
Java网络编程与NIO详解11:Tomcat中的Connector源码分析(NIO)
查看>>
深入理解JVM虚拟机1:JVM内存的结构与消失的永久代
查看>>
深入理解JVM虚拟机3:垃圾回收器详解
查看>>
深入理解JVM虚拟机4:Java class介绍与解析实践
查看>>
深入理解JVM虚拟机5:虚拟机字节码执行引擎
查看>>
深入理解JVM虚拟机6:深入理解JVM类加载机制
查看>>
深入了解JVM虚拟机8:Java的编译期优化与运行期优化
查看>>
深入理解JVM虚拟机9:JVM监控工具与诊断实践
查看>>
深入理解JVM虚拟机10:JVM常用参数以及调优实践
查看>>
深入理解JVM虚拟机11:Java内存异常原理与实践
查看>>
深入理解JVM虚拟机12:JVM性能管理神器VisualVM介绍与实战
查看>>
深入理解JVM虚拟机13:再谈四种引用及GC实践
查看>>
Spring源码剖析1:Spring概述
查看>>
Spring源码剖析2:初探Spring IOC核心流程
查看>>
Spring源码剖析3:Spring IOC容器的加载过程
查看>>
Spring源码剖析4:懒加载的单例Bean获取过程分析
查看>>
Spring源码剖析5:JDK和cglib动态代理原理详解
查看>>
Spring源码剖析6:Spring AOP概述
查看>>