发布时间:2025-06-24 19:00:58 作者:北方职教升学中心 阅读量:741
之前也做过时序预测的业务,只不过使用的是pyspark+fbprophet(下次记录一下pyspark+fbprophet的使用笔记),这次使用sparkts里的holtWinters模型批量对多个商户的营业额进行预测。训练和预测,批量商户模型训练的核心就是将每个商户的数据放在一个DenseVector(排好序了的),然后HoltWinters模型就会按照这个序列数据预测后续指定的N个值 */valresult_1 =processStepOne(dtIndex,trainDF,testDF)result_1.persist()println("result_1总数为"+result_1.count())vallessDateLotDF =trainDF.join(result_1,Seq("lot_id"),"left_anti").groupBy("lot_id").agg(collect_list("total_money").as("values")).where(size($"values")>=12)lessDateLotDF.persist()lessDateLotDF.show(5,false)println("lessDateLotDF总数为"+lessDateLotDF.count())valtrainRdd:RDD[(String,DenseVector)]=lessDateLotDF .rdd .map(row =>{valseries =row.getAs[mutable.WrappedArray[Double]]("values").toArray (row.getAs[String]("lot_id"),newDenseVector(series))})valresult_2 =processStepTwo(trainRdd,testDF)valresult =result_1.unionByName(result_2)result.persist()result.show(false).....}//第一次处理,使用TimeSeriesRDD进行训练模型defprocessStepOne(dtIndex:UniformDateTimeIndex,trainDF:DataFrame,testDF:DataFrame):DataFrame ={valtrainTsRdd =TimeSeriesRDD.timeSeriesRDDFromObservations(dtIndex,trainDF,"flow_date","lot_id","total_money").filterStartingBefore(startZonedDateTime)trainTsRdd.persist()//填充缺失值,linear是线性填充,还有好几种(趋势填充,零值填充,上一个值,下一个值),后面记录源码时记录一下valfilledTrainTsRdd:TimeSeriesRDD[String]=trainTsRdd.fill("linear")valforecast =holtWintersModelTrainKey(1,filledTrainTsRdd,period,holtWintersModelType)processPredictResult(forecast,testDF)}//第一次处理,使用普通RDD进行训练模型defprocessStepTwo(trainRdd:RDD[(String,DenseVector)],testDF:DataFrame):DataFrame ={valforecast =holtWintersModelTrainKey2(1,trainRdd,period,holtWintersModelType)processPredictResult(forecast,testDF)}//这个方法就是将测试数据和预测数据合并在一起,因为这里的预测数据只有一天,每个商户只预测1个值,所以我可以 .withColumn("flow_date", lit(endTimeStr).cast(DataTypes.TimestampType))//如果预测多个值,不能这样拼接 测试数据和预测数据defprocessPredictResult(forecast:RDD[(String,Array[Double])],testDF:DataFrame):DataFrame ={valpredictDF =forecast.toDF("lot_id","fs").select($"lot_id",explode($"fs").as("predict")).withColumn("predict",bround($"predict")).where("predict is not null and predict!='NaN' ").withColumn("flow_date",lit(endTimeStr).cast(DataTypes.TimestampType))valresult =testDF.join(predictDF,Seq("lot_id","flow_date"),"right")result }/** * 预测结果 */defmodelPredict(predictedN:Int,holtWintersAndVectorRdd:RDD[(String,HoltWintersModel,DenseVector)],period:Int,holtWintersModelType:String):(RDD[(String,Array[Double])])={/** *预测出后N个的值 *****///构成N个预测值向量,之后导入到holtWinters的forcast方法中valpredictedArrayBuffer =newArrayBuffer[Double]()vari =0while(i <predictedN){predictedArrayBuffer +=i i =i +1}valpredictedVectors =Vectors.dense(predictedArrayBuffer.toArray)//预测valforecast:RDD[(String,Array[Double])]=holtWintersAndVectorRdd.mapPartitions {rows =>rows.map {case(key,holtWintersModel,denseVector)=>{valvector =holtWintersModel.forecast(denseVector,predictedVectors)(key,vector.toArray)}}}forecast }/** * 使用TimeSeriesRDD进行模型训练, */defholtWintersModelTrainKey(predictedN:Int,trainTsRdd:TimeSeriesRDD[String],period:Int,holtWintersModelType:String):RDD[(String,Array[Double])]={/** *参数设置
//** *创建HoltWinters模型 ***///创建和训练HoltWinters模型.其RDD格式为(HoltWinters,Vector)valholtWintersAndVectorRdd =trainTsRdd.mapPartitions {lines =>{lines.map {case(key,denseVector:DenseVector)=>(key,HoltWinters.fitModel(denseVector,period,holtWintersModelType),denseVector)}}}modelPredict(predictedN,holtWintersAndVectorRdd,period,holtWintersModelType)}/** * 使用普通RDD进行模型训练 */defholtWintersModelTrainKey2(predictedN:Int,trainTsRdd:RDD[(String,DenseVector)],period:Int,holtWintersModelType:String):RDD[(String,Array[Double])]={/** *参数设置 //** *创建HoltWinters模型 ***///创建和训练HoltWinters模型.其RDD格式为(HoltWinters,Vector)valholtWintersAndVectorRdd =trainTsRdd.mapPartitions {lines =>{lines.map {case(key,denseVector:Vector)=>// warn("当前key为:" + key)(key,HoltWinters.fitModel(denseVector,period,holtWintersModelType),denseVector)}}}modelPredict(predictedN,holtWintersAndVectorRdd,period,holtWintersModelType)}6.问题解析,修改源码
sparkts在使用过程中,当我批量预测几万个商户时,会出现trust region step has failed to reduce Q异常,这个是源码中 com.cloudera.sparkts.models.HoltWinters类的def fitModelWithBOBYQA(ts: Vector, period: Int, modelType:String): HoltWintersModel 方法导致的,这个方法里optimizer.optimize(objectiveFunction, goal, bounds,initGuess, maxIter, maxEval)会出现这个异常。
deffitModelWithBOBYQA(ts:Vector,period:Int,modelType:String):HoltWintersModel ={valoptimizer =newBOBYQAOptimizer(7)valobjectiveFunction =newObjectiveFunction(newMultivariateFunction(){defvalue(params:Array[Double]):Double={newHoltWintersModel(modelType,period,params(0),params(1),params(2)).sse(ts)}})// The starting guesses in R's stats:HoltWintersvalinitGuess =newInitialGuess(Array(0.3,0.1,0.1))valmaxIter =newMaxIter(30000)valmaxEval =newMaxEval(30000)valgoal =GoalType.MINIMIZE valbounds =newSimpleBounds(Array(0.0,0.0,0.0),Array(1.0,1.0,1.0))valoptimal =optimizer.optimize(objectiveFunction,goal,bounds,initGuess,maxIter,maxEval)valparams =optimal.getPoint newHoltWintersModel(modelType,period,params(0),params(1),params(2))}
需要修改这个方法,捕获的了异常,修改参数,重新调用这个方法,如下:
deffitModelWithBOBYQA(ts:Vector,period:Int,modelType:String,initGuess:InitialGuess =newInitialGuess(Array(0.3,0.1,0.1))):HoltWintersModel ={valoptimizer =newBOBYQAOptimizer(7)valobjectiveFunction =newObjectiveFunction(newMultivariateFunction(){defvalue(params:Array[Double]):Double={newHoltWintersModel(modelType,period,params(0),params(1),params(2)).sse(ts)}})// The starting guesses in R's stats:HoltWinters// val initGuess = new InitialGuess(Array(0.3, 0.1, 0.1))valmaxIter =newMaxIter(30000)valmaxEval =newMaxEval(30000)valgoal =GoalType.MINIMIZE //bounds 范围低位设置非0极小值能降低异常出现的频率valbounds =newSimpleBounds(Array(0.00001,0.00001,0.00001),Array(1.0,1.0,1.0))varoptimal:PointValuePair =nullvarparams =Array(0D,0D,0D)try{optimal =optimizer.optimize(objectiveFunction,goal,bounds,initGuess,maxIter,maxEval)params =optimal.getPoint newHoltWintersModel(modelType,period,params(0),params(1),params(2))}catch{casee:MathIllegalStateException =>//如果出现错误,随机新的initGuess,从新再运行一次,直到不出现异常MathIllegalStateException: trust region step has failed to reduce QfitModelWithBOBYQA(ts,period,modelType,newInitialGuess(Array(math.random,0.1,0.1)))casee =>{throwe }}}
7.预测结果
这是一个效果较好的预测结果…
1.引入sparkts的maven坐标
sparkts应该很久没更新过了,其实sparkml相对其它常用机器学习平台来说还是不够主流。
defmain(args:Array[String]):Unit={.....valdata =getLotDayReportHistoryData()//以最后一天的数据作为测试验证数据,其它的为训练数据valtestDF =data.where(s"flow_date='${endTimeStr}'")valtrainDF =data.where(s"flow_date<'${endTimeStr}'")valdtIndex:UniformDateTimeIndex =DateTimeIndex.uniformFromInterval(startZonedDateTime,endZonedDateTime,newDayFrequency(7))/** * 3、<dependency><groupId>com.cloudera.sparkts</groupId><artifactId>sparkts</artifactId><version>0.4.1</version></dependency>
2.准备参数
//周期长度,是holtWinters模型中的一个重要 季节性(我理解为周期性)参数,这里是以一周7天为周期valperiod:Int=conf.getInt("spark.hw.period",7)//holtWinters选择模型:additive(加法模型)、使用时根据业务需要合理的递增数valdtIndex:UniformDateTimeIndex =DateTimeIndex.uniformFromInterval(startZonedDateTime,endZonedDateTime,newDayFrequency(7))
5.训练模型,预测结果
理想状态下,假如我设置的时间跨度所有商户营收数据都有,那么很好,训练预测过程会很顺利,但是实际情况下,我基于一年的训练数据,上万个商户很多商户的预测结果可能都是Double.NaN,或者自动补充的数据太多影响预测结果,在这个地方,我会先按照正常的处理方式预测一遍,将没有预测结果的商户数据重新拉出来使用普通RDD[(key,Vector)]进行holtWinters模型的训练和预测。Multiplicative(乘法模型,常用)
valholtWintersModelType:String=conf.get("spark.hw.modeltype","Multiplicative")valzone:ZoneId =ZoneId.systemDefault()// 需要指定训练用数据的开始时间,结束时间 ZonedDateTime类型privatevalstartTimeStr:String=conf.get("spark.start.time",newDateTime().plusMonths(-6).toString("yyyy-MM-dd"))valstartTime:DateTime =DateTime.parse(startTimeStr,DateTimeFormat.forPattern("yyyy-MM-dd"))valstartZonedDateTime:ZonedDateTime =ZonedDateTime.ofInstant(Instant.ofEpochMilli(startTime.getMillis),zone)privatevalendTimeStr:String=conf.get("spark.end.time",newDateTime().plusDays(-1).toString("yyyy-MM-dd"))valendTime:DateTime =DateTime.parse(endTimeStr,DateTimeFormat.forPattern("yyyy-MM-dd"))valendZonedDateTime:ZonedDateTime =ZonedDateTime.ofInstant(Instant.ofEpochMilli(endTime.getMillis),zone)//商户idprivatevallotIds:String=conf.get("spark.lot.filter","123,456").trim()
3.获取历史数据,划分为训练数据和测试验证数据
defmain(args:Array[String]):Unit={.....valdata =getLotDayReportHistoryData()//以最后一天的数据作为测试验证数据,其它的为训练数据valtestDF =data.where(s"flow_date='${endTimeStr}'")valtrainDF =data.where(s"flow_date<'${endTimeStr}'").....}defgetLotDayReportHistoryData():DataFrame ={vallotFilter =if(lotIds.contains("all")){"1=1"}else{vallotIdArr =lotIds.split(",")s"lot_id in (${lotIdArr.map("'"+_ +"'").mkString(",")})"}//这里需要将时序的时间字段转为Timestamp类型,作为label的字段转为Double类型//特别重要的一点是要按照时间正序排序spark.sql(s"select lot_id,flow_date ,total_money from data_table where $lotFilterand flow_date >='${startTimeStr}' and flow_date <='${endTimeStr}' order by lot_id,flow_date asc").withColumn("flow_date",$"flow_date".cast(DataTypes.TimestampType)).withColumn("total_money",$"total_money".cast(DataTypes.DoubleType)).withColumn("is_holiday",isHoliday2($"flow_date",lit("yyyy-MM-dd HH:mm:ss")))}}
4.指定训练数据的时间跨度
这个会帮助我们创建TimeSeriesRDD,然后使用TimeSeriesRDD可以帮助我们进行数据的补全,筛选,DateTimeIndex.uniformFromInterval() 这个方法里的第三个参数是时间递增数,我这儿是7天。