本文共 5275 字,大约阅读时间需要 17 分钟。
org.apache.spark spark-mllib_2.11 2.2.0 ml.combust.mleap mleap-spark_2.11 0.11.0 org.apache.spark spark-sql_2.11 2.2.0 org.apache.spark spark-hive_2.11 2.2.0}
import ml.combust.bundle.BundleFileimport ml.combust.mleap.spark.SparkSupport._import org.apache.spark.ml.bundle.SparkBundleContextimport org.apache.spark.ml.classification.{ GBTClassifier, LogisticRegression}import org.apache.spark.ml.evaluation.BinaryClassificationEvaluatorimport org.apache.spark.ml.feature.{ OneHotEncoder, StringIndexer, VectorAssembler}import org.apache.spark.sql._import org.apache.spark.{ SparkConf, SparkContext}import org.apache.spark.ml.{ Pipeline, PipelineStage}import resource.managedobject trainModelLeap { def main(args: Array[String]): Unit = { val sparkConf = new SparkConf().setMaster("local[2]").setAppName("zhunshibao_test") val sc = new SparkContext(sparkConf) val spark = SparkSession.builder().config(sc.getConf).config("hive.metastore.uris", "thrift://10.202.77.200:9083").enableHiveSupport().getOrCreate() var data = spark.sql(constantConfig.testTrainDataSetSql).na.drop val splited = data.randomSplit(Array(0.8, 0.2), 2L) var trainSet = splited(0) var testSet = splited(1) trainSet.show(5) var dataProcessList = List[PipelineStage]() /** StringtoIndex **/ val stringColumns = Array("new_cargo","oprid") var StringColumnsInc = List[String]() for (filed <- stringColumns) { val indexer = new StringIndexer().setInputCol(filed).setOutputCol(filed + "Inc").setHandleInvalid("skip") dataProcessList = dataProcessList :+ indexer StringColumnsInc = StringColumnsInc :+ (filed + "Inc") } /** 合并features **/ val assembler = new VectorAssembler().setInputCols(Array("top1","top2")).setOutputCol("features") dataProcessList = dataProcessList :+ assembler// /** model **/ val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01).setFeaturesCol("features").setLabelCol("output1").setPredictionCol("predict") val gbt = new GBTClassifier() .setLabelCol("output1") .setFeaturesCol("features") .setPredictionCol("gbt_prediction") .setProbabilityCol("gbt_prediction_prob") .setRawPredictionCol("gbt_prediction_raw") .setMaxBins(80) .setMaxIter(50) dataProcessList = dataProcessList :+ lr :+ gbt val pipeline = new Pipeline().setStages(dataProcessList.toArray) val model = pipeline.fit(trainSet) val ds = model.transform(trainSet) implicit val context = SparkBundleContext().withDataset(ds)/** 模型存储**/ for(bf <- managed(BundleFile("jar:file:/tmp/lc.model.zip"))){ model.writeBundle.save(bf)(context).get } }
注意输入model中transform的数据格式。
package com.sfimport ml.combust.bundle.BundleFileimport ml.combust.mleap.core.types.{ ScalarType, StructField, StructType}import ml.combust.mleap.runtime.MleapSupport._import ml.combust.mleap.runtime.frame.{ DefaultLeapFrame, Row}import ml.combust.mleap.spark.SparkLeapFrameimport ml.combust.mleap.spark.SparkSupport._import org.apache.spark.sql.SparkSessionimport org.apache.spark.{ SparkConf, SparkContext}import resource._object PredictLeap { def getModelProbBatch(frame: SparkLeapFrame) = { val bundle = (for (bundleFile <- managed(BundleFile("jar:file:/tmp/lc.model.zip"))) yield { bundleFile.loadMleapBundle().get }).opt.get val model = bundle.root val df = model.transform(frame).get df } def getModelProbOne(data:Seq[Row]) = { val bundle = (for (bundleFile <- managed(BundleFile("jar:file:/tmp/lc.model.zip"))) yield { bundleFile.loadMleapBundle().get }).opt.get val schema = StructType(StructField("top1", ScalarType.Double), StructField("top2", ScalarType.Double), StructField("new_cargo", ScalarType.String), StructField("oprid", ScalarType.String)).get val frame = DefaultLeapFrame(schema, data) val model = bundle.root val df = model.transform(frame).get df } def main(args: Array[String]): Unit = { /** 单条预测 **/ val dataOne: Seq[Row] = Seq(Row(1.0, 1.0, "other", "1247226")) val resOne = getModelProbOne(dataOne) resOne.show(5) /** dataframe批量预测 **/ val sparkConf = new SparkConf().setMaster("local[2]").setAppName("zhunshibao_test") val sc = new SparkContext(sparkConf) val spark = SparkSession.builder().config(sc.getConf).config("hive.metastore.uris", "thrift://10.202.77.200:9083").enableHiveSupport().getOrCreate() var dataBatch = spark.sql(constantConfig.testTrainDataSetSql).na.drop val splited = dataBatch.randomSplit(Array(0.8, 0.2), 2L) val testSet = splited(1) val resBatch = PredictLeap.getModelProbBatch(testSet.toSparkLeapFrame) resBatch.toSpark.toDF().show(5) }}
参考文档:
https://www.bookstack.cn/read/mleap-zh/mleap-runtime-create-leap-frame.md转载地址:http://awwzb.baihongyu.com/