spark 之TF-IDF提取文章关键词

2019-06-28 11:58:28 浏览数 (1)

提取一篇文章中的关键词时,一个很常见的思路就是找到出现次数最多的词。但是很多时候一些副词、形容词,英文中的a the an on等,中文里的 的、是、在等在文档中出现的词数会比较多,但是并不是关键词,没有实际意义,所以这些被列入停用词范畴。下面咱们就来探索一下使用spark的ml-lib来提取文章的关键 词以及在寻找关键词过程中出现的一些概念说明。 对于下面这样一篇金融类的文章(由于篇幅有限,只截取部分):

代码语言:javascript复制
"<p><img src=\"https://xxx.com/collect/article/236725779415695360.png\" "  
            "title=\"继续突破or M型顶?比特币短线尚未走出迷雾但已进入上升抛物线趋势\"></p>\n<br>\n<p>"  
            "本文来自小葱区块链,阅读更多请登陆<a>www.xcong.com</a>或小葱APP</p>\n<br>\n<p>"  
            "转载请注明出处</p>\n<br>\n<p>大家壕,本栏目为小葱APP原创栏目——小葱龙虎榜,持续追踪每日资金流入/流出最多的各十大币种。"  
            "本栏目由小葱APP和AICoin联合推出。</p>\n<br>\n<p>==本文数据来源:AICoin==</p>\n<br>\n<p>"  
            "本文数据皆以人民币进行统计</p>\n<br>\n<p>--------------------------------</p>\n<br>\n<p>"  
            "今天数字货币市场大部分反弹,比特币重新回到5100美元以上,而且一度突破5200美元,现在稍稍回落,其它主流币种大多数跟随上涨。"  
            "而昨天受益BSV被交易所下架消息的BCH今天小幅下跌,BSV继续下跌。</p>\n<br>\n<p>昨天火币PRIME第二期NEW开盘,盘前HT曾一度上涨,"  
            "但开盘后HT回落,今天小幅下跌。另外的平台币BNB和OKB均上涨,OKB领涨。</p>\n<br>\n<p>市值方面,数字货币市场总市值为1781.31亿美元..."  

1. Term Frequency与Inverse Document Frequency

  • Term Frequency:缩写为TF,也就是词频统计。在统计的时候,会发现"比特币","流入","数字货币","区块链","上涨"等词都有出现。相对而言,“流入”,“上涨”这些词的重要程度不及“比特币”、“区块链”。 在统计的时候,对较常见的词“流入”、“上涨”给予较小的权重,对“比特币”、“区块链”给予较大的权重。这个权重也就是Inverse Document Frequency,缩写为IDF,与一个词的常见程度成反比。

TF=某个词在文章中的出现次数/文章的总词数

  • Inverse Document Frequency:缩写为IDF

IDF(逆文档频率) = log(语料库的文档总数/(包含该词的文档数 1))

  • TF-IDF:"词频"(TF)和"逆文档频率"(IDF)以后,将这两个值相乘,就得到了一个词的TF-IDF值。某个词对文章的重要性越高,它的TF-IDF值就越大。

TF-IDF = TF * IDF

可以看到,TF-IDF与一个词在文档中的出现次数成正比,与该词在整个语言中的出现次数成反比。所以,自动提取关键词的算法就很清楚了,就是计算出文档的每个词的TF-IDF值,然后按降序排列,取排在最前面的几个词。

2. 用spark计算TF-IDF

使用spark-mllib包进行计算,mllib包中提供了计算TF-IDF算法的封装。

1. 计算tf的值

  • 使用方法为:org.apache.spark.ml.feature.HashingTF#HashingTF()
  • HashingTF的解释是:通过取hash值的方式映射一组词条和它们词频之间的关系。在spark ml包中目前使用的hash算法是Austin Appleby的MurmurHash 3算法,也就是著名的MurmurHash3_x86_32算法来计算每个词条对象的 hashcode值。因为简单的模被用来把hash函数转变成一列索引,所以建议使用2的次幂作为numFeatures的参数,否则features值会被映射的毫无规律。 模运算部分可以参考:
代码语言:javascript复制
 org.apache.spark.mllib.feature.HashingTF#indexOf:
 /**
   * Returns the index of the input term.
   */
  @Since("1.1.0")
  def indexOf(term: Any): Int = {
	Utils.nonNegativeMod(getHashFunction(term), numFeatures)
  }
  
 org.apache.spark.util.Utils#nonNegativeMod:
  /* Calculates 'x' modulo 'mod', takes to consideration sign of x,
   * i.e. if 'x' is negative, than 'x' % 'mod' is negative too
   * so function return (x % mod)   mod in that case.
   */
   def nonNegativeMod(x: Int, mod: Int): Int = {
	 val rawMod = x % mod
	 rawMod   (if (rawMod < 0) mod else 0)
   }	

由于spark.ml包的HashingTF中没有通过hash值取到词条的方法,所以需要对HashingTF进行改造:

代码语言:javascript复制
1. 添加实例私有变量:private[this] var hashingTF: feature.HashingTF = _ 注意,这个HashingTF是属于org.apache.spark.mllib.feature包的
2. 添加对上面变量的初始化方法:
@Since("2.0.0")
  override def transform(dataset: Dataset[_]): DataFrame = {
	val outputSchema = transformSchema(dataset.schema)
	hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary))
	// TODO: Make the hashingTF.transform natively in ml framework to avoid extra conversion.
	val t = udf { terms: Seq[_] => hashingTF.transform(terms).asML }
	val metadata = outputSchema($(outputCol)).metadata
	dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
  }
  
3. 添加indexOf方法:
/**
	* Returns the index of the input term.
	*/
  @Since("2.3.0")
  def indexOf(term: Any): Int = {
	if (hashingTF != null) {
	//这里对应的就是上面那个mllib包中的hashingTF
	  hashingTF.indexOf(term)
	} else {
	  throw UninitializedFieldError("Use transform method to initialize the model at first.")
	}
  }  

2. 计算idf的值:

使用方法:org.apache.spark.ml.feature.IDF#IDF() 看如下代码,idf的fit方法需要以tf的结果为入参来生成IDFModel,然后通过IDFModel去生成tf-idf的值:

代码语言:javascript复制
val featurizedData = hashingTF.transform(wordsData)
// CountVectorizer也可获取词频向量
val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")
//生成idfModel
val idfModel = idf.fit(featurizedData)
//生成tf-idf
val rescaledData = idfModel.transform(featurizedData)

org.apache.spark.mllib.feature.IDF#fit方法:

代码语言:javascript复制
@Since("2.0.0")
  override def fit(dataset: Dataset[_]): IDFModel = {
    transformSchema(dataset.schema, logging = true)
    val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map {
      case Row(v: Vector) => OldVectors.fromML(v)
    }
    val idf = new feature.IDF($(minDocFreq)).fit(input)
    copyValues(new IDFModel(uid, idf).setParent(this))
  }

org.apache.spark.ml.feature.IDFModel#transform方法如下:

代码语言:javascript复制
  @Since("2.0.0")
  override def transform(dataset: Dataset[_]): DataFrame = {
    transformSchema(dataset.schema, logging = true)
    // TODO: Make the idfModel.transform natively in ml framework to avoid extra conversion.
    //调用的是org.apache.spark.mllib.feature.IDFModel#transform(org.apache.spark.mllib.linalg.Vector)
    val idf = udf { vec: Vector => idfModel.transform(OldVectors.fromML(vec)).asML }
    dataset.withColumn($(outputCol), idf(col($(inputCol))))
  }

上面的idfModel.transform(OldVectors.fromML(vec)).asML调用的是mllib中的IDFModel的transform方法。

代码语言:javascript复制
 /**
   * Transforms a term frequency (TF) vector to a TF-IDF vector
   *
   * @param v a term frequency vector
   * @return a TF-IDF vector
   */
  @Since("1.3.0")
  def transform(v: Vector): Vector = IDFModel.transform(idf, v)
  
   /**
     * Transforms a term frequency (TF) vector to a TF-IDF vector with a IDF vector
     *
     * @param idf an IDF vector
     * @param v a term frequency vector
     * @return a TF-IDF vector
     */
    def transform(idf: Vector, v: Vector): Vector = {
      val n = v.size
      v match {
        case SparseVector(size, indices, values) =>
          val nnz = indices.length
          val newValues = new Array[Double](nnz)
          var k = 0
          while (k < nnz) {
            newValues(k) = values(k) * idf(indices(k))
            k  = 1
          }
          Vectors.sparse(n, indices, newValues)
        case DenseVector(values) =>
          val newValues = new Array[Double](n)
          var j = 0
          while (j < n) {
            newValues(j) = values(j) * idf(j)
            j  = 1
          }
          Vectors.dense(newValues)
        case other =>
          throw new UnsupportedOperationException(
            s"Only sparse and dense vectors are supported but got ${other.getClass}.")
      }
    }
  }

可以看到,上面返回的rescaledData对应的dataset的column是一个矩阵。

3. 分词器ansj

这个分词器可以添加停用词,过滤词性,添加自定义词典等。

4. 实例:

有了上面这些概念就可以开始看下面的实例代码片段了:

代码语言:javascript复制
def main(args: Array[String]): Unit = {
  //formatNature
  //refactorFile
  splitStringContent()
}

输出结果为:WrappedArray((正邦科技,15.380572041353537), (牧原股份,3.295836866004329), (巨亏,2.772588722239781), (天邦股份,2.1972245773362196), (天风证券,2.1972245773362196), (,1.6218604324326575), (温氏股份,1.0986122886681098), (testtesttest,0.8109302162163288), (营收,0.6931471805599453))。关键词提取成功! 上面也可以直接从mongodb中拉取数据:

代码语言:javascript复制
val spark = SparkSession
      .builder()
      .master("local") // 本地测试,否则报错 A master URL must be set in your configuration at org.apache.spark.SparkContext.
      .appName("test")
      .config("spark.mongodb.input.uri","mongodb://zx:zx123456@192.168.1.41:28071/zx.business_collect_article_info")
      .config("spark.mongodb.output.uri","mongodb://zx:zx123456@192.168.1.41:28071/zx.business_collect_article_info")
      // .enableHiveSupport()
      .getOrCreate() // 有就获取无则创建
	//生成dataFrame
    var df = MongoSpark.load(spark)

5. 附录

停用词片段:

代码语言:javascript复制
今天
平整
位置
抬
dgb
app
aicoin
a
b
c
d
e
f
g
h
i
j
k
l
m
n

自定义词典(dict4)片段:

代码语言:javascript复制
股票	nl	100
商品	nl	100
外汇	nl	100
虚拟货币	nl	100
虚拟币	nl	100
区块链	nl	100
数字货币	nl	100
数字币	nl	100
期货	nl	100
指数	nl	100
火币	nl	100
ETH	nl	100
eth	nl	100
以太坊	nl	100
BTC	nl	100
btc	nl	100
BCH	nl	100
bch	nl	100
XRP	nl	100
xrp	nl	100
bsv	nl	100
okb	nl	100
瑞波	nl	100
ETC	nl	100
etc	nl	100
EOS	nl	100
eos	nl	100
柚子	nl	100
大饼	nl	100
资金	nl	100
流入	nl	100
推文	nl	100
流出	nl	100

清除特殊字符的方法:

代码语言:javascript复制
  /**
     * 去掉注释和平特殊字符
     */
   val regex:String = "\<!--(. )--\>|\\n|[`~!@#$%^&*() =|{}':'\[\].<>/?~!@#¥%……&*()—— |{}【】‘;:”“’。,、?-]" ;
 
 
   def clearComment(html:String):String = {
     val p = Pattern.compile(regex)
     var newHtml = Jsoup.clean(html, Whitelist.none()) //jsoup得到的html代码
     val m = p.matcher(newHtml)
     while ( {
       m.find
     }) newHtml = newHtml.replace(m.group(),"")
     newHtml
   }

tf-idf结果解析部分:

代码语言:javascript复制
 rescaledData.select("features").rdd.map{
       x => {
         //x的值:
         //[(100,[22,24,25,26,28,31,34,38,39,40,44,48,52,55,59,60,63,68,72,80,84,87,95,96],[1.0986122886681098,3.295836866004329,
         // 4.394449154672439,0.4054651081081644,1.0986122886681098,1.0986122886681098,3.295836866004329,2.1972245773362196,
         // 1.3862943611198906,1.0986122886681098,1.0986122886681098,2.1972245773362196,8.788898309344878,1.3862943611198906,
         // 1.0986122886681098,3.295836866004329,1.0986122886681098,1.0986122886681098,1.6218604324326575,3.295836866004329,
         // 1.0986122886681098,1.0986122886681098,1.0986122886681098,1.0986122886681098])]
 
         //v的值为:
         //(100,[22,24,25,26,28,31,34,38,39,40,44,48,52,55,59,60,63,68,72,80,84,87,95,96],
         // [1.0986122886681098,3.295836866004329,4.394449154672439,0.4054651081081644,1.0986122886681098,
         // 1.0986122886681098,3.295836866004329,2.1972245773362196,1.3862943611198906,1.0986122886681098,
         // 1.0986122886681098,2.1972245773362196,8.788898309344878,1.3862943611198906,1.0986122886681098,
         // 3.295836866004329,1.0986122886681098,1.0986122886681098,1.6218604324326575,3.295836866004329,
         // 1.0986122886681098,1.0986122886681098,1.0986122886681098,1.0986122886681098])
         val v = x.getAs[SparseVector](0)
         //v.indices.foreach(println)
         //v.indices.zip(v.values)的值为:
         //(22,1.0986122886681098) (24,3.295836866004329) (25,4.394449154672439)(26,0.4054651081081644)(28,1.0986122886681098)
         //(31,1.0986122886681098)(34,3.295836866004329)(38,2.1972245773362196)(39,1.3862943611198906)(40,1.0986122886681098)
         //(44,1.0986122886681098)(48,2.1972245773362196)(52,8.788898309344878)(55,1.3862943611198906)(59,1.0986122886681098)
         //(60,3.295836866004329)(63,1.0986122886681098)(68,1.0986122886681098)(72,1.6218604324326575)(80,3.295836866004329)
         //(84,1.0986122886681098)(87,1.0986122886681098)(95,1.0986122886681098)(96,1.0986122886681098)
         //对v.indices.zip(v.values)的结果按第二个值从大到小排序
         v.indices.zip(v.values).sortWith((a,b) => {
           a._2 > b._2
         }).take(10).map(x => (wordMap.get(x._1).get,x._2))
       }
     }

0 人点赞