spark 自定義partitioner分區 java版

在遍歷spark dataset的時候,通常會使用 forpartition 在每個分區內進行遍歷,而在默認分區(由生成dataset時的分區決定)可能因數據分佈原因導致datasetc處理時的數據傾斜,造成整個dataset處理緩慢,發揮不了spark多executor(jvm 進程)多partition(線程)的並行處理能力,因此,普遍的做法是在dataset遍歷之前使用repartition進行重新分區,讓數據按照指定的key進行分區,充分發揮spark的並行處理能力,例如:

dataset.repartition(9,new Column("name")).foreachPartition(it -> {
while (it.hasNext()) {
Row row = it.next();
....
}
});

先看一下準備的原始數據集:

spark 自定義partitioner分區 java版

按照上面的代碼,預想的結果應該是,相同名字在記錄在同個partition(分區),不同名字在不同的partition,並且一個partition裡面不會有不同名字的記錄,而實際分區卻是這樣的

spark 自定義partitioner分區 java版

(查看分區分佈情況的代碼在之前一篇文章 spark sql 在mysql的應用實踐 有說明,如果調用reparation時未指定分區數量9,則默認為200,使用 spark.default.parallelism 配置的數量為分區數,在partitioner.scala 的 partition object 定義可以看到)

這個很囧...乍看一下,壓根看不出什麼情況,翻看源碼發現,rdd 的partition 分區器有兩種 HashPartitioner & RangePartitioner,默認情況下使用 HashPartitioner,從 repartition 源碼開始入手

/** 
* Dataset.scala
* Returns a new Dataset partitioned by the given partitioning expressions into
* `numPartitions`. The resulting Dataset is hash partitioned.
*
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
*
* @group typedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan {
RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, Some(numPartitions))
}

The resulting Dataset is hash partitioned,說的很清楚,使用hash 分區,那看看hash 分區的源碼,

/**
* Partitioner.scala
* A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using
* Java's `Object.hashCode`.
*
* Java arrays have hashCodes that are based on the arrays' identities rather than their contents,
* so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will
* produce an unexpected or incorrect result.
*/
class HashPartitioner(partitions: Int) extends Partitioner {
require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")
def numPartitions: Int = partitions
def getPartition(key: Any): Int = key match {

case null => 0
case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
}
override def equals(other: Any): Boolean = other match {
case h: HashPartitioner =>
h.numPartitions == numPartitions
case _ =>
false
}
override def hashCode: Int = numPartitions
}

Utils.nonNegativeMod(key.hashCode, numPartitions) 說明在獲取當前row所在分區時,用了分區key的hashCode作為實際分區的key值,在看看 nonNegativeMod

 /* Calculates 'x' modulo 'mod', takes to consideration sign of x,
* i.e. if 'x' is negative, than 'x' % 'mod' is negative too
* so function return (x % mod) + mod in that case.
*/
def nonNegativeMod(x: Int, mod: Int): Int = {
val rawMod = x % mod
rawMod + (if (rawMod < 0) mod else 0)
}

看到這裡,前面的相同分區存在不同的 name 的記錄就不難理解了,不同的name值hashCode%分區數後落到相同的分區... 簡單的調整方式,在遍歷分區裡面用hashMap兼容不同name值的記錄處理,那如果我們想自定義分區呢,自定義分組分區代碼寫起來就比較直觀容易理解,幸好spark提供了partitioner接口,可以自定義partitioner,支持這種自定義分組分區的方式,這裡我也有個簡單實現類,可以支持同個分區只有相同name的記錄

import org.apache.commons.collections.CollectionUtils;
import org.apache.spark.Partitioner;
import org.junit.Assert;
import java.util.List;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* Created by lesly.lai on 2018/7/25.
*/
public class CuxGroupPartitioner extends Partitioner {
private int partitions;
/**
* map
* 主要為了區分不同分區
*/
private Map hashCodePartitionIndexMap = new ConcurrentHashMap<>();
public CuxGroupPartitioner(List groupList) {
int size = groupList.size();
this.partitions = size;
initMap(partitions, groupList);
}
private void initMap(int size, List groupList) {
Assert.assertTrue(CollectionUtils.isNotEmpty(groupList));
for (int i=0; i hashCodePartitionIndexMap.put(groupList.get(i), i);
}
}
@Override
public int numPartitions() {
return partitions;
}
@Override
public int getPartition(Object key) {
return hashCodePartitionIndexMap.get(key);
}
public boolean equals(Object obj) {
if (obj instanceof CuxGroupPartitioner) {
return ((CuxGroupPartitioner) obj).partitions == partitions;
}
return false;
}
}

查看分區分佈情況工具類

import org.apache.spark.sql.{Dataset, Row}
/**
* Created by lesly.lai on 2017/12FeeTask/25.
*/
class SparkRddTaskInfo {
def getTask(dataSet: Dataset[Row]) {
val size = dataSet.rdd.partitions.length
println(s"==> partition size: $size " )
import scala.collection.Iterator
val showElements = (it: Iterator[Row]) => {
val ns = it.toSeq
import org.apache.spark.TaskContext
val pid = TaskContext.get.partitionId
println(s"[partition: $pid][size: ${ns.size}] ${ns.mkString(" ")}")
}
dataSet.foreachPartition(showElements)
}
}

調用方式

import com.vip.spark.db.ConnectionInfos;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import scala.Tuple2;
import java.util.List;
import java.util.stream.Collectors;
/**
* Created by lesly.lai on 2018/7/23.
*/
public class SparkSimpleTestPartition {
public static void main(String[] args) throws InterruptedException {

SparkSession sparkSession = SparkSession.builder().appName("Java Spark SQL basic example").getOrCreate();
// 原始數據集
Dataset originSet = sparkSession.read().jdbc(ConnectionInfos.TEST_MYSQL_CONNECTION_URL, "people", ConnectionInfos.getTestUserAndPasswordProperties());
originSet.createOrReplaceTempView("people");
// 獲取分區分佈情況工具類

SparkRddTaskInfo taskInfo = new SparkRddTaskInfo();
Dataset groupSet = sparkSession.sql(" select name from people group by name");
List groupList = groupSet.javaRDD().collect().stream().map(row -> row.getAs("name")).collect(Collectors.toList());
// 創建pairRDD 目前只有pairRdd支持自定義partitioner,所以需要先轉成pairRdd
JavaPairRDD pairRDD = originSet.javaRDD().mapToPair(row -> {
return new Tuple2(row.getAs("name"), row);
});
// 指定自定義partitioner
JavaRDD javaRdd = pairRDD.partitionBy(new CuxGroupPartitioner(groupList)).map(new Function, Row>(){
@Override
public Row call(Tuple2 v1) throws Exception {
return v1._2;
}
});
Dataset result = sparkSession.createDataFrame(javaRdd, originSet.schema());
// 打印分區分佈情況
taskInfo.getTask(result);
}
}

調用結果:

spark 自定義partitioner分區 java版

可以看到,目前的分區分佈已經按照name值進行分區,並沒有不同的name值落到同個分區了。


分享到:


相關文章: