查看原文
其他

数据算法之反转排序 | 寻找相邻单词的数量

大数据技术与架构点击右侧关注,大数据开发领域最强公众号!

暴走大数据点击右侧关注,暴走大数据!


上一期的题目看这里:《数据算法第三章中的问题你面试和工作中遇到过吗?》。

这期题目和Leetcode中的一些搜索题目有点类似。
想处理的问题是:统计一个单词相邻前后两位的数量,如有w1,w2,w3,w4,w5,w6,则:
最终要输出为(word,neighbor,frequency)。
我们用五种方法实现:
  • MapReduce

  • Spark

  • Spark SQL的方法

  • Scala方法

  • Scala版Spark SQL

MapReduce
//map函数 @Override protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
String[] tokens = StringUtils.split(value.toString(), " "); //String[] tokens = StringUtils.split(value.toString(), "\\s+"); if ((tokens == null) || (tokens.length < 2)) { return; } //计算相邻两个单词的计算规则 for (int i = 0; i < tokens.length; i++) { tokens[i] = tokens[i].replaceAll("\\W+", "");
if (tokens[i].equals("")) { continue; }
pair.setWord(tokens[i]);
int start = (i - neighborWindow < 0) ? 0 : i - neighborWindow; int end = (i + neighborWindow >= tokens.length) ? tokens.length - 1 : i + neighborWindow; for (int j = start; j <= end; j++) { if (j == i) { continue; } pair.setNeighbor(tokens[j].replaceAll("\\W", "")); context.write(pair, ONE); } // pair.setNeighbor("*"); totalCount.set(end - start); context.write(pair, totalCount); }    }
//reduce函数 @Override protected void reduce(PairOfWords key, Iterable<IntWritable> values, Context context) throws IOException, InterruptedException { //等于*表示为单词本身,它的count为totalCount if (key.getNeighbor().equals("*")) { if (key.getWord().equals(currentWord)) { totalCount += totalCount + getTotalCount(values); } else { currentWord = key.getWord(); totalCount = getTotalCount(values); } } else { //其它的则为单次的word,需要通过getTotalCount获得相加 int count = getTotalCount(values); relativeCount.set((double) count / totalCount); context.write(key, relativeCount); }
}

Spark
public static void main(String[] args) { if (args.length < 3) { System.out.println("Usage: RelativeFrequencyJava <neighbor-window> <input-dir> <output-dir>"); System.exit(1); }
SparkConf sparkConf = new SparkConf().setAppName("RelativeFrequency"); JavaSparkContext sc = new JavaSparkContext(sparkConf);
int neighborWindow = Integer.parseInt(args[0]); String input = args[1]; String output = args[2];
final Broadcast<Integer> brodcastWindow = sc.broadcast(neighborWindow);
JavaRDD<String> rawData = sc.textFile(input);
/* * Transform the input to the format: (word, (neighbour, 1)) */ JavaPairRDD<String, Tuple2<String, Integer>> pairs = rawData.flatMapToPair( new PairFlatMapFunction<String, String, Tuple2<String, Integer>>() { private static final long serialVersionUID = -6098905144106374491L;
@Override public java.util.Iterator<scala.Tuple2<String, scala.Tuple2<String, Integer>>> call(String line) throws Exception { List<Tuple2<String, Tuple2<String, Integer>>> list = new ArrayList<Tuple2<String, Tuple2<String, Integer>>>(); String[] tokens = line.split("\\s"); for (int i = 0; i < tokens.length; i++) { int start = (i - brodcastWindow.value() < 0) ? 0 : i - brodcastWindow.value(); int end = (i + brodcastWindow.value() >= tokens.length) ? tokens.length - 1 : i + brodcastWindow.value(); for (int j = start; j <= end; j++) { if (j != i) { list.add(new Tuple2<String, Tuple2<String, Integer>>(tokens[i], new Tuple2<String, Integer>(tokens[j], 1))); } else { // do nothing continue; } } } return list.iterator(); } } );
// (word, sum(word)) //PairFunction<T, K, V> T => Tuple2<K, V> JavaPairRDD<String, Integer> totalByKey = pairs.mapToPair(
new PairFunction<Tuple2<String, Tuple2<String, Integer>>, String, Integer>() { private static final long serialVersionUID = -213550053743494205L;
@Override public Tuple2<String, Integer> call(Tuple2<String, Tuple2<String, Integer>> tuple) throws Exception { return new Tuple2<String, Integer>(tuple._1, tuple._2._2); } }).reduceByKey( new Function2<Integer, Integer, Integer>() { private static final long serialVersionUID = -2380022035302195793L;
@Override public Integer call(Integer v1, Integer v2) throws Exception { return (v1 + v2); } });
JavaPairRDD<String, Iterable<Tuple2<String, Integer>>> grouped = pairs.groupByKey();
// (word, (neighbour, 1)) -> (word, (neighbour, sum(neighbour))) //flatMapValues至少对value进行操作,但是不改变key的顺序 JavaPairRDD<String, Tuple2<String, Integer>> uniquePairs = grouped.flatMapValues( //Function<T1, R> -> R call(T1 v1) new Function<Iterable<Tuple2<String, Integer>>, Iterable<Tuple2<String, Integer>>>() { private static final long serialVersionUID = 5790208031487657081L;
@Override public Iterable<Tuple2<String, Integer>> call(Iterable<Tuple2<String, Integer>> values) throws Exception { Map<String, Integer> map = new HashMap<>(); List<Tuple2<String, Integer>> list = new ArrayList<>(); Iterator<Tuple2<String, Integer>> iterator = values.iterator(); while (iterator.hasNext()) { Tuple2<String, Integer> value = iterator.next(); int total = value._2; if (map.containsKey(value._1)) { total += map.get(value._1); } map.put(value._1, total); } for (Map.Entry<String, Integer> kv : map.entrySet()) { list.add(new Tuple2<String, Integer>(kv.getKey(), kv.getValue())); } return list; } });
// (word, ((neighbour, sum(neighbour)), sum(word))) JavaPairRDD<String, Tuple2<Tuple2<String, Integer>, Integer>> joined = uniquePairs.join(totalByKey);
// ((key, neighbour), sum(neighbour)/sum(word)) JavaPairRDD<Tuple2<String, String>, Double> relativeFrequency = joined.mapToPair( new PairFunction<Tuple2<String, Tuple2<Tuple2<String, Integer>, Integer>>, Tuple2<String, String>, Double>() { private static final long serialVersionUID = 3870784537024717320L;
@Override public Tuple2<Tuple2<String, String>, Double> call(Tuple2<String, Tuple2<Tuple2<String, Integer>, Integer>> tuple) throws Exception { return new Tuple2<Tuple2<String, String>, Double>(new Tuple2<String, String>(tuple._1, tuple._2._1._1), ((double) tuple._2._1._2 / tuple._2._2)); } });
// For saving the output in tab separated format // ((key, neighbour), relative_frequency) //将结果转换成一个String JavaRDD<String> formatResult_tab_separated = relativeFrequency.map( new Function<Tuple2<Tuple2<String, String>, Double>, String>() { private static final long serialVersionUID = 7312542139027147922L;
@Override public String call(Tuple2<Tuple2<String, String>, Double> tuple) throws Exception { return tuple._1._1 + "\t" + tuple._1._2 + "\t" + tuple._2; } });
// save output formatResult_tab_separated.saveAsTextFile(output);
// done sc.close();
}

Spark SQL
public static void main(String[] args) { if (args.length < 3) { System.out.println("Usage: SparkSQLRelativeFrequency <neighbor-window> <input-dir> <output-dir>"); System.exit(1); }
SparkConf sparkConf = new SparkConf().setAppName("SparkSQLRelativeFrequency"); //创建SparkSQL需要的SparkSession SparkSession spark = SparkSession .builder() .appName("SparkSQLRelativeFrequency") .config(sparkConf) .getOrCreate();
JavaSparkContext sc = new JavaSparkContext(spark.sparkContext()); int neighborWindow = Integer.parseInt(args[0]); String input = args[1]; String output = args[2];
final Broadcast<Integer> brodcastWindow = sc.broadcast(neighborWindow);
/* *注册一个Schema表,这个frequency等会要用 * Schema (word, neighbour, frequency) */ StructType rfSchema = new StructType(new StructField[]{ new StructField("word", DataTypes.StringType, false, Metadata.empty()), new StructField("neighbour", DataTypes.StringType, false, Metadata.empty()), new StructField("frequency", DataTypes.IntegerType, false, Metadata.empty())});
JavaRDD<String> rawData = sc.textFile(input);
/* * Transform the input to the format: (word, (neighbour, 1)) */ JavaRDD<Row> rowRDD = rawData .flatMap(new FlatMapFunction<String, Row>() { private static final long serialVersionUID = 5481855142090322683L;
@Override public Iterator<Row> call(String line) throws Exception { List<Row> list = new ArrayList<>(); String[] tokens = line.split("\\s"); for (int i = 0; i < tokens.length; i++) { int start = (i - brodcastWindow.value() < 0) ? 0 : i - brodcastWindow.value(); int end = (i + brodcastWindow.value() >= tokens.length) ? tokens.length - 1 : i + brodcastWindow.value(); for (int j = start; j <= end; j++) { if (j != i) { list.add(RowFactory.create(tokens[i], tokens[j], 1)); } else { // do nothing continue; } } } return list.iterator(); } }); //创建DataFrame Dataset<Row> rfDataset = spark.createDataFrame(rowRDD, rfSchema); //将rfDataset转成一个table,可以进行查询 rfDataset.createOrReplaceTempView("rfTable");
String query = "SELECT a.word, a.neighbour, (a.feq_total/b.total) rf " + "FROM (SELECT word, neighbour, SUM(frequency) feq_total FROM rfTable GROUP BY word, neighbour) a " + "INNER JOIN (SELECT word, SUM(frequency) as total FROM rfTable GROUP BY word) b ON a.word = b.word"; Dataset<Row> sqlResult = spark.sql(query);
sqlResult.show(); // print first 20 records on the console sqlResult.write().parquet(output + "/parquetFormat"); // saves output in compressed Parquet format, recommended for large projects. sqlResult.rdd().saveAsTextFile(output + "/textFormat"); // to see output via cat command
// done sc.close(); spark.stop();
}

Scala
def main(args: Array[String]): Unit = {
if (args.size < 3) { println("Usage: RelativeFrequency <neighbor-window> <input-dir> <output-dir>") sys.exit(1) }
val sparkConf = new SparkConf().setAppName("RelativeFrequency") val sc = new SparkContext(sparkConf)
val neighborWindow = args(0).toInt val input = args(1) val output = args(2)
val brodcastWindow = sc.broadcast(neighborWindow)
val rawData = sc.textFile(input)
/* * Transform the input to the format: * (word, (neighbour, 1)) */ val pairs = rawData.flatMap(line => { val tokens = line.split("\\s") for { i <- 0 until tokens.length start = if (i - brodcastWindow.value < 0) 0 else i - brodcastWindow.value end = if (i + brodcastWindow.value >= tokens.length) tokens.length - 1 else i + brodcastWindow.value j <- start to end if (j != i) //用yield来收集转换之后的函数(word, (neighbour, 1)) } yield (tokens(i), (tokens(j), 1)) })
// (word, sum(word)) val totalByKey = pairs.map(t => (t._1, t._2._2)).reduceByKey(_ + _)
val grouped = pairs.groupByKey()
// (word, (neighbour, sum(neighbour))) val uniquePairs = grouped.flatMapValues(_.groupBy(_._1).mapValues(_.unzip._2.sum)) //用join函数把两个RDD连接起来 // (word, ((neighbour, sum(neighbour)), sum(word))) val joined = uniquePairs join totalByKey
// ((key, neighbour), sum(neighbour)/sum(word)) val relativeFrequency = joined.map(t => { ((t._1, t._2._1._1), (t._2._1._2.toDouble / t._2._2.toDouble)) })
// For saving the output in tab separated format // ((key, neighbour), relative_frequency) val formatResult_tab_separated = relativeFrequency.map(t => t._1._1 + "\t" + t._1._2 + "\t" + t._2) formatResult_tab_separated.saveAsTextFile(output)
// done sc.stop() }

Scala版Spark SQL
def main(args: Array[String]): Unit = {
if (args.size < 3) { println("Usage: SparkSQLRelativeFrequency <neighbor-window> <input-dir> <output-dir>") sys.exit(1) }
val sparkConf = new SparkConf().setAppName("SparkSQLRelativeFrequency")
val spark = SparkSession .builder() .config(sparkConf) .getOrCreate() val sc = spark.sparkContext
val neighborWindow = args(0).toInt val input = args(1) val output = args(2)
val brodcastWindow = sc.broadcast(neighborWindow)
val rawData = sc.textFile(input)
/* * Schema * (word, neighbour, frequency) */ val rfSchema = StructType(Seq( StructField("word", StringType, false), StructField("neighbour", StringType, false), StructField("frequency", IntegerType, false)))
/* * Transform the input to the format: * Row(word, neighbour, 1) */ //转换成StructType中要求的格式 val rowRDD = rawData.flatMap(line => { val tokens = line.split("\\s") for { i <- 0 until tokens.length //正常的计算规则,与MapReduce有区别 start = if (i - brodcastWindow.value < 0) 0 else i - brodcastWindow.value end = if (i + brodcastWindow.value >= tokens.length) tokens.length - 1 else i + brodcastWindow.value j <- start to end if (j != i) } yield Row(tokens(i), tokens(j), 1) })
val rfDataFrame = spark.createDataFrame(rowRDD, rfSchema) //创建rfTable表 rfDataFrame.createOrReplaceTempView("rfTable")
import spark.sql
val query = "SELECT a.word, a.neighbour, (a.feq_total/b.total) rf " + "FROM (SELECT word, neighbour, SUM(frequency) feq_total FROM rfTable GROUP BY word, neighbour) a " + "INNER JOIN (SELECT word, SUM(frequency) as total FROM rfTable GROUP BY word) b ON a.word = b.word"
val sqlResult = sql(query) sqlResult.show() // print first 20 records on the console sqlResult.write.save(output + "/parquetFormat") // saves output in compressed Parquet format, recommended for large projects. sqlResult.rdd.saveAsTextFile(output + "/textFormat") // to see output via cat command
// done spark.stop()
}
以上就是用五种方法解决这个问题。

更多阅读:
《大数据技术与架构,19年文章精选
数据算法第三章中的问题你面试和工作中遇到过吗?

——END——

欢迎点赞+收藏+转发朋友圈素质三连



文章不错?点个【在看】吧! 👇

: . Video Mini Program Like ,轻点两下取消赞 Wow ,轻点两下取消在看

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存