这期题目和Leetcode中的一些搜索题目有点类似。
想处理的问题是:统计一个单词相邻前后两位的数量,如有w1,w2,w3,w4,w5,w6,则:
最终要输出为(word,neighbor,frequency)。
我们用五种方法实现:
- MapReduce
- Spark
- Spark SQL的方法
- Scala方法
- Scala版Spark SQL
MapReduce
代码语言:javascript复制//map函数
@Override
protected void map(LongWritable key, Text value, Context context)
throws IOException, InterruptedException {
String[] tokens = StringUtils.split(value.toString(), " ");
//String[] tokens = StringUtils.split(value.toString(), "\s ");
if ((tokens == null) || (tokens.length < 2)) {
return;
}
//计算相邻两个单词的计算规则
for (int i = 0; i < tokens.length; i ) {
tokens[i] = tokens[i].replaceAll("\W ", "");
if (tokens[i].equals("")) {
continue;
}
pair.setWord(tokens[i]);
int start = (i - neighborWindow < 0) ? 0 : i - neighborWindow;
int end = (i neighborWindow >= tokens.length) ? tokens.length - 1 : i neighborWindow;
for (int j = start; j <= end; j ) {
if (j == i) {
continue;
}
pair.setNeighbor(tokens[j].replaceAll("\W", ""));
context.write(pair, ONE);
}
//
pair.setNeighbor("*");
totalCount.set(end - start);
context.write(pair, totalCount);
}
}
代码语言:javascript复制//reduce函数
@Override
protected void reduce(PairOfWords key, Iterable<IntWritable> values, Context context)
throws IOException, InterruptedException {
//等于*表示为单词本身,它的count为totalCount
if (key.getNeighbor().equals("*")) {
if (key.getWord().equals(currentWord)) {
totalCount = totalCount getTotalCount(values);
} else {
currentWord = key.getWord();
totalCount = getTotalCount(values);
}
} else {
//其它的则为单次的word,需要通过getTotalCount获得相加
int count = getTotalCount(values);
relativeCount.set((double) count / totalCount);
context.write(key, relativeCount);
}
}
Spark
代码语言:javascript复制public static void main(String[] args) {
if (args.length < 3) {
System.out.println("Usage: RelativeFrequencyJava <neighbor-window> <input-dir> <output-dir>");
System.exit(1);
}
SparkConf sparkConf = new SparkConf().setAppName("RelativeFrequency");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
int neighborWindow = Integer.parseInt(args[0]);
String input = args[1];
String output = args[2];
final Broadcast<Integer> brodcastWindow = sc.broadcast(neighborWindow);
JavaRDD<String> rawData = sc.textFile(input);
/*
* Transform the input to the format: (word, (neighbour, 1))
*/
JavaPairRDD<String, Tuple2<String, Integer>> pairs = rawData.flatMapToPair(
new PairFlatMapFunction<String, String, Tuple2<String, Integer>>() {
private static final long serialVersionUID = -6098905144106374491L;
@Override
public java.util.Iterator<scala.Tuple2<String, scala.Tuple2<String, Integer>>> call(String line) throws Exception {
List<Tuple2<String, Tuple2<String, Integer>>> list = new ArrayList<Tuple2<String, Tuple2<String, Integer>>>();
String[] tokens = line.split("\s");
for (int i = 0; i < tokens.length; i ) {
int start = (i - brodcastWindow.value() < 0) ? 0 : i - brodcastWindow.value();
int end = (i brodcastWindow.value() >= tokens.length) ? tokens.length - 1 : i brodcastWindow.value();
for (int j = start; j <= end; j ) {
if (j != i) {
list.add(new Tuple2<String, Tuple2<String, Integer>>(tokens[i], new Tuple2<String, Integer>(tokens[j], 1)));
} else {
// do nothing
continue;
}
}
}
return list.iterator();
}
}
);
// (word, sum(word))
//PairFunction<T, K, V> T => Tuple2<K, V>
JavaPairRDD<String, Integer> totalByKey = pairs.mapToPair(
new PairFunction<Tuple2<String, Tuple2<String, Integer>>, String, Integer>() {
private static final long serialVersionUID = -213550053743494205L;
@Override
public Tuple2<String, Integer> call(Tuple2<String, Tuple2<String, Integer>> tuple) throws Exception {
return new Tuple2<String, Integer>(tuple._1, tuple._2._2);
}
}).reduceByKey(
new Function2<Integer, Integer, Integer>() {
private static final long serialVersionUID = -2380022035302195793L;
@Override
public Integer call(Integer v1, Integer v2) throws Exception {
return (v1 v2);
}
});
JavaPairRDD<String, Iterable<Tuple2<String, Integer>>> grouped = pairs.groupByKey();
// (word, (neighbour, 1)) -> (word, (neighbour, sum(neighbour)))
//flatMapValues至少对value进行操作,但是不改变key的顺序
JavaPairRDD<String, Tuple2<String, Integer>> uniquePairs = grouped.flatMapValues(
//Function<T1, R> -> R call(T1 v1)
new Function<Iterable<Tuple2<String, Integer>>, Iterable<Tuple2<String, Integer>>>() {
private static final long serialVersionUID = 5790208031487657081L;
@Override
public Iterable<Tuple2<String, Integer>> call(Iterable<Tuple2<String, Integer>> values) throws Exception {
Map<String, Integer> map = new HashMap<>();
List<Tuple2<String, Integer>> list = new ArrayList<>();
Iterator<Tuple2<String, Integer>> iterator = values.iterator();
while (iterator.hasNext()) {
Tuple2<String, Integer> value = iterator.next();
int total = value._2;
if (map.containsKey(value._1)) {
total = map.get(value._1);
}
map.put(value._1, total);
}
for (Map.Entry<String, Integer> kv : map.entrySet()) {
list.add(new Tuple2<String, Integer>(kv.getKey(), kv.getValue()));
}
return list;
}
});
// (word, ((neighbour, sum(neighbour)), sum(word)))
JavaPairRDD<String, Tuple2<Tuple2<String, Integer>, Integer>> joined = uniquePairs.join(totalByKey);
// ((key, neighbour), sum(neighbour)/sum(word))
JavaPairRDD<Tuple2<String, String>, Double> relativeFrequency = joined.mapToPair(
new PairFunction<Tuple2<String, Tuple2<Tuple2<String, Integer>, Integer>>, Tuple2<String, String>, Double>() {
private static final long serialVersionUID = 3870784537024717320L;
@Override
public Tuple2<Tuple2<String, String>, Double> call(Tuple2<String, Tuple2<Tuple2<String, Integer>, Integer>> tuple) throws Exception {
return new Tuple2<Tuple2<String, String>, Double>(new Tuple2<String, String>(tuple._1, tuple._2._1._1), ((double) tuple._2._1._2 / tuple._2._2));
}
});
// For saving the output in tab separated format
// ((key, neighbour), relative_frequency)
//将结果转换成一个String
JavaRDD<String> formatResult_tab_separated = relativeFrequency.map(
new Function<Tuple2<Tuple2<String, String>, Double>, String>() {
private static final long serialVersionUID = 7312542139027147922L;
@Override
public String call(Tuple2<Tuple2<String, String>, Double> tuple) throws Exception {
return tuple._1._1 "t" tuple._1._2 "t" tuple._2;
}
});
// save output
formatResult_tab_separated.saveAsTextFile(output);
// done
sc.close();
}
Spark SQL
代码语言:javascript复制 public static void main(String[] args) {
if (args.length < 3) {
System.out.println("Usage: SparkSQLRelativeFrequency <neighbor-window> <input-dir> <output-dir>");
System.exit(1);
}
SparkConf sparkConf = new SparkConf().setAppName("SparkSQLRelativeFrequency");
//创建SparkSQL需要的SparkSession
SparkSession spark = SparkSession
.builder()
.appName("SparkSQLRelativeFrequency")
.config(sparkConf)
.getOrCreate();
JavaSparkContext sc = new JavaSparkContext(spark.sparkContext());
int neighborWindow = Integer.parseInt(args[0]);
String input = args[1];
String output = args[2];
final Broadcast<Integer> brodcastWindow = sc.broadcast(neighborWindow);
/*
*注册一个Schema表,这个frequency等会要用
* Schema (word, neighbour, frequency)
*/
StructType rfSchema = new StructType(new StructField[]{
new StructField("word", DataTypes.StringType, false, Metadata.empty()),
new StructField("neighbour", DataTypes.StringType, false, Metadata.empty()),
new StructField("frequency", DataTypes.IntegerType, false, Metadata.empty())});
JavaRDD<String> rawData = sc.textFile(input);
/*
* Transform the input to the format: (word, (neighbour, 1))
*/
JavaRDD<Row> rowRDD = rawData
.flatMap(new FlatMapFunction<String, Row>() {
private static final long serialVersionUID = 5481855142090322683L;
@Override
public Iterator<Row> call(String line) throws Exception {
List<Row> list = new ArrayList<>();
String[] tokens = line.split("\s");
for (int i = 0; i < tokens.length; i ) {
int start = (i - brodcastWindow.value() < 0) ? 0
: i - brodcastWindow.value();
int end = (i brodcastWindow.value() >= tokens.length) ? tokens.length - 1
: i brodcastWindow.value();
for (int j = start; j <= end; j ) {
if (j != i) {
list.add(RowFactory.create(tokens[i], tokens[j], 1));
} else {
// do nothing
continue;
}
}
}
return list.iterator();
}
});
//创建DataFrame
Dataset<Row> rfDataset = spark.createDataFrame(rowRDD, rfSchema);
//将rfDataset转成一个table,可以进行查询
rfDataset.createOrReplaceTempView("rfTable");
String query = "SELECT a.word, a.neighbour, (a.feq_total/b.total) rf "
"FROM (SELECT word, neighbour, SUM(frequency) feq_total FROM rfTable GROUP BY word, neighbour) a "
"INNER JOIN (SELECT word, SUM(frequency) as total FROM rfTable GROUP BY word) b ON a.word = b.word";
Dataset<Row> sqlResult = spark.sql(query);
sqlResult.show(); // print first 20 records on the console
sqlResult.write().parquet(output "/parquetFormat"); // saves output in compressed Parquet format, recommended for large projects.
sqlResult.rdd().saveAsTextFile(output "/textFormat"); // to see output via cat command
// done
sc.close();
spark.stop();
}
Scala
代码语言:javascript复制def main(args: Array[String]): Unit = {
if (args.size < 3) {
println("Usage: RelativeFrequency <neighbor-window> <input-dir> <output-dir>")
sys.exit(1)
}
val sparkConf = new SparkConf().setAppName("RelativeFrequency")
val sc = new SparkContext(sparkConf)
val neighborWindow = args(0).toInt
val input = args(1)
val output = args(2)
val brodcastWindow = sc.broadcast(neighborWindow)
val rawData = sc.textFile(input)
/*
* Transform the input to the format:
* (word, (neighbour, 1))
*/
val pairs = rawData.flatMap(line => {
val tokens = line.split("\s")
for {
i <- 0 until tokens.length
start = if (i - brodcastWindow.value < 0) 0 else i - brodcastWindow.value
end = if (i brodcastWindow.value >= tokens.length) tokens.length - 1 else i brodcastWindow.value
j <- start to end if (j != i)
//用yield来收集转换之后的函数(word, (neighbour, 1))
} yield (tokens(i), (tokens(j), 1))
})
// (word, sum(word))
val totalByKey = pairs.map(t => (t._1, t._2._2)).reduceByKey(_ _)
val grouped = pairs.groupByKey()
// (word, (neighbour, sum(neighbour)))
val uniquePairs = grouped.flatMapValues(_.groupBy(_._1).mapValues(_.unzip._2.sum))
//用join函数把两个RDD连接起来
// (word, ((neighbour, sum(neighbour)), sum(word)))
val joined = uniquePairs join totalByKey
// ((key, neighbour), sum(neighbour)/sum(word))
val relativeFrequency = joined.map(t => {
((t._1, t._2._1._1), (t._2._1._2.toDouble / t._2._2.toDouble))
})
// For saving the output in tab separated format
// ((key, neighbour), relative_frequency)
val formatResult_tab_separated = relativeFrequency.map(t => t._1._1 "t" t._1._2 "t" t._2)
formatResult_tab_separated.saveAsTextFile(output)
// done
sc.stop()
}
Scala版Spark SQL
代码语言:javascript复制def main(args: Array[String]): Unit = {
if (args.size < 3) {
println("Usage: SparkSQLRelativeFrequency <neighbor-window> <input-dir> <output-dir>")
sys.exit(1)
}
val sparkConf = new SparkConf().setAppName("SparkSQLRelativeFrequency")
val spark = SparkSession
.builder()
.config(sparkConf)
.getOrCreate()
val sc = spark.sparkContext
val neighborWindow = args(0).toInt
val input = args(1)
val output = args(2)
val brodcastWindow = sc.broadcast(neighborWindow)
val rawData = sc.textFile(input)
/*
* Schema
* (word, neighbour, frequency)
*/
val rfSchema = StructType(Seq(
StructField("word", StringType, false),
StructField("neighbour", StringType, false),
StructField("frequency", IntegerType, false)))
/*
* Transform the input to the format:
* Row(word, neighbour, 1)
*/
//转换成StructType中要求的格式
val rowRDD = rawData.flatMap(line => {
val tokens = line.split("\s")
for {
i <- 0 until tokens.length
//正常的计算规则,与MapReduce有区别
start = if (i - brodcastWindow.value < 0) 0 else i - brodcastWindow.value
end = if (i brodcastWindow.value >= tokens.length) tokens.length - 1 else i brodcastWindow.value
j <- start to end if (j != i)
} yield Row(tokens(i), tokens(j), 1)
})
val rfDataFrame = spark.createDataFrame(rowRDD, rfSchema)
//创建rfTable表
rfDataFrame.createOrReplaceTempView("rfTable")
import spark.sql
val query = "SELECT a.word, a.neighbour, (a.feq_total/b.total) rf "
"FROM (SELECT word, neighbour, SUM(frequency) feq_total FROM rfTable GROUP BY word, neighbour) a "
"INNER JOIN (SELECT word, SUM(frequency) as total FROM rfTable GROUP BY word) b ON a.word = b.word"
val sqlResult = sql(query)
sqlResult.show() // print first 20 records on the console
sqlResult.write.save(output "/parquetFormat") // saves output in compressed Parquet format, recommended for large projects.
sqlResult.rdd.saveAsTextFile(output "/textFormat") // to see output via cat command
// done
spark.stop()
}
以上就是用五种方法解决这个问题。