spark的ml中已经封装了许多关于特征的处理方式:
极大方便了我们在做数据预处理时的使用。 但是这明显不够,在机器学习的领域中,还有许许多多的处理方式,这些都没有存在于feature包中。 那要如何去实现?
比较简单的方式:spark ml本质上就是对dataframe的操作,可以在代码中处理df以实现该功能。
但是实际应用中发现,这样的方式并不好用,我们所做的处理,纯粹是对df的转换提取等操作,这个过程无法进行落地,也无法加入pipeline做重复训练。
所以,我采用了另一种方式:基于saprk源代码开发 首先介绍一下本次想要实现的功能:WOE
woe的计算逻辑:
计算的逻辑还是比较清楚的,公式如下:
其中 i为数据离散后的组,good i 和 bad i 对应该组好坏的个数, good all 和bad all 对应好坏的总数。
编写代码:
对于woe转换的功能,有如下参数:
- 输入字段:哪些字段需要做woe转换
- 输出字段:字段做woe转换之后的新列名是什么
- 标签列:label列的列名
- 正类: positiveLabel 确定 1 为 good ,还是 0 为 good
1、自定义一个代码接口
方便transform和transformModel共同使用
代码语言:javascript复制trait woeTransformParams extends Params with HasInputCols with HasOutputCols with HasLabelCol{
val positiveLabel: Param[String] = new Param(this,"positiveLabel","positiveLabel you want to choose",ParamValidators.inArray(Array(woeTransform.one,woeTransform.zero)))
def getPositiveLabel = ${positiveLabel}
}
2、编写woeTransform
继承Estimator抽象类,实现copy,transformSchema,fit方法。
- fit方法会生成一个代理df,并通过该代理df生成model。在使用该model进行转换的时候,实际上就是使用代理df里的规则对数据集进行处理
- transformSchema :生成新的schema信息
- copy:返回一个相同UID的实例,包含extraMap的信息。
代码实现过程如下:
代码语言:javascript复制class woeTransform(override val uid: String) extends Estimator[woeTransformModel] with woeTransformParams with DefaultParamsWritable{
def this() = this(Identifiable.randomUID("woeTransform"))
def setLabelCol(value:String) = set(labelCol,value)
def setInputCols(value:Array[String]) = set(inputCols,value)
def setOutputCols(value:Array[String]) = set(outputCols,value)
def setPositiveLabel(value:String) = set(positiveLabel,value)
override def copy(extra: ParamMap): Estimator[woeTransformModel] = defaultCopy(extra)
override def transformSchema(schema: StructType): StructType = {
val tmpArr = $(inputCols).filter(schema.fieldNames.contains(_))
require(tmpArr.length == ${inputCols}.length,"输入字段中有schema中不存在的字段")
val addedFields = $(outputCols).map{outputCol =>
StructField(outputCol,DoubleType,false)
}
StructType(schema addedFields)
}
/**
* dataset中包含训练数据,将该数据计算出surrogateDF并生成model
*/
override def fit(dataset: Dataset[_]): woeTransformModel = {
val schema = dataset.schema
transformSchema(schema,logging = true)
val lb = new ListBuffer[(String,String,Double)]()
val cols: Array[Column] = schema.fieldNames.map(col(_))
val newLabel = "new_" $(labelCol)
val labelColt = when(col($(labelCol)).equalTo(1.0),woeTransform.one)
.when(col($(labelCol)).equalTo(1.0),woeTransform.zero)
.otherwise(col($(labelCol)))
.as(newLabel)
val dataFrame = dataset.select((cols. :(labelColt)):_*)
//对每一个inputcol的每一个组做woe转换并且加入到listBuffer中
$(inputCols).foreach{inputCol =>
//crosstab 交叉表计算,具体公式可以问度娘
val singleInfo = dataFrame.stat.crosstab(inputCol, newLabel)
val analyseDF = $(positiveLabel) match {
case woeTransform.zero => singleInfo.withColumnRenamed(woeTransform.zero.toString,"good").withColumnRenamed(woeTransform.one.toString,"bad")
case woeTransform.one => singleInfo.withColumnRenamed(woeTransform.one.toString,"good").withColumnRenamed(woeTransform.zero.toString,"bad")
}
val row = analyseDF.select(sum("bad"),sum("good")).head()
val (bad,good) = (row.getLong(0).toInt,row.getLong(1).toInt)
analyseDF.collect().foreach{row =>
val bi = row.getAs[Long]("bad").toDouble 0.0000001
val gi = row.getAs[Long]("good").toDouble 0.0000001
val woe = Math.log((gi / good) / ((bi / bad) 0.0000001))
lb. =((inputCol,row.getString(0),woe))
}
}
//将之前记录信息的listbuffer 转成代理 df ,并生成 woeTransformModel
import dataset.sparkSession.implicits._
val surrogateDF = lb.toList.toDF()
copyValues(new woeTransformModel(uid,surrogateDF).setParent(this))
}
}
object woeTransform extends DefaultParamsReadable[woeTransform]{
val zero = "0"
val one = "1"
override def load(path: String): woeTransform = super.load(path)
}
3、编写woeTransformModel
class woeTransform:
- 继承Model抽象类,实现copy 、 transformSchema 、 transform方法 。 前两个方法与之前一致。transform方法中主要实现的是,以surrogatedf 为转换逻辑,来处理新的数据集。
- 实现MLWritable实现模型的写操作。
object woeTransformModel:
- 实现MLReadable 对模型的 读操作。 读写过程要对应,否则在模型的落地与加载过程中会出错
代码如下:
代码语言:javascript复制class woeTransformModel(override val uid: String,val surrogateDF: DataFrame)
extends Model[woeTransformModel] with woeTransformParams with MLWritable{
import woeTransformModel._
def setLabelCol(value:String) = set(labelCol,value)
def setInputCols(value:Array[String]) = set(inputCols,value)
def setOutputCols(value:Array[String]) = set(outputCols,value)
def setPositiveLabel(value:String) = set(positiveLabel,value)
override def copy(extra: ParamMap): woeTransformModel = {
val copied = new woeTransformModel(uid,surrogateDF)
copyValues(copied, extra).setParent(parent)
}
/**
* Transforms the input dataset.
*/
override def transform(dataset: Dataset[_]): DataFrame = {
val newSchema = transformSchema (dataset.schema, logging = true)
val inArray = $(inputCols)
val outArray = $(outputCols)
var ruleMap: Map[String,Double] = Map()
surrogateDF.rdd.collect().foreach{row=>
val colName: String = row.getString(0)
val bucket: String = row.getString(1)
val woe: Double = row.getDouble(2)
ruleMap = (colName "-" bucket -> woe)
}
val newRdd = dataset.toDF.rdd.map{ row=>
val ab = new ArrayBuffer[Double]()
for(i <- 0 to inArray.length-1){
val colName= inArray(i)
val bucket = row.getAs[Object](colName).toString
val woe = ruleMap.apply(colName "-" bucket)
ab = woe
}
Row.merge(row,Row.fromSeq(ab))
}
dataset.sparkSession.createDataFrame(newRdd,newSchema)
}
override def transformSchema(schema: StructType): StructType = {
val tmpArr = $(inputCols).filter(schema.fieldNames.contains(_))
require(tmpArr.length == $(inputCols).length,"输入字段中有schema中不存在的字段")
val addedFields = $(outputCols).map{outputCol =>
StructField(outputCol,DoubleType,false)
}
StructType(schema addedFields)
}
/**
* Returns an `MLWriter` instance for this ML instance.
*/
override def write: MLWriter = new woeTransformModelWriter(this)
}
object woeTransformModel extends MLReadable[woeTransformModel]{
class woeTransformModelWriter(instance: woeTransformModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val dataPath = new Path(path, "data").toString
instance.surrogateDF.repartition(1).write.parquet(dataPath)
}
}
class woeTransformReader extends MLReader[woeTransformModel]{
private val className = classOf[woeTransformModel].getName
/**
* Loads the ML component from the input path.
*/
override def load(path: String): woeTransformModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val surrogateDF = sqlContext.read.parquet(dataPath)
val model = new woeTransformModel(metadata.uid, surrogateDF)
metadata.getAndSetParams(model)
model
}
}
/**
* Returns an `MLReader` instance for this class.
*/
override def read: MLReader[woeTransformModel] = new woeTransformReader
override def load(path: String): woeTransformModel = super.load(path)
}
检验功能正确性
我使用了一个简单的数据来做检验,下面是使用我们的计算公式来计算得到的结果.
然后来测试下,我们编写的代码的结果。 将我们刚编写的代码放入org.apache.spark.ml.feature包下,重新编译打包,引入工程.
使用同样的数据集,得到的结果如下:
与之前结果一致。
这里只是为了实现逻辑,并没有对特殊情况做完善。 各位若有想法,可以指出共同探讨