/** * Dataset.scala * Returns a new Dataset partitioned by the given partitioning expressions into * `numPartitions`. The resulting Dataset is hash partitioned. * * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). * * @group typedrel * @since 2.0.0 */ @scala.annotation.varargs def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan { RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, Some(numPartitions)) }
The resulting Dataset is hash partitioned,說的很清楚,使用hash 分區,那看看hash 分區的源碼,
/** * Partitioner.scala * A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using * Java's `Object.hashCode`. * * Java arrays have hashCodes that are based on the arrays' identities rather than their contents, * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will * produce an unexpected or incorrect result. */ class HashPartitioner(partitions: Int) extends Partitioner { require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.") def numPartitions: Int = partitions def getPartition(key: Any): Int = key match { case null => 0 case _ => Utils.nonNegativeMod(key.hashCode, numPartitions) } override def equals(other: Any): Boolean = other match { case h: HashPartitioner => h.numPartitions == numPartitions case _ => false } override def hashCode: Int = numPartitions }
/* Calculates 'x' modulo 'mod', takes to consideration sign of x, * i.e. if 'x' is negative, than 'x' % 'mod' is negative too * so function return (x % mod) + mod in that case. */ def nonNegativeMod(x: Int, mod: Int): Int = { val rawMod = x % mod rawMod + (if (rawMod < 0) mod else 0) }
看到這裡,前面的相同分區存在不同的 name 的記錄就不難理解了,不同的name值hashCode%分區數後落到相同的分區... 簡單的調整方式,在遍歷分區裡面用hashMap兼容不同name值的記錄處理,那如果我們想自定義分區呢,自定義分組分區代碼寫起來就比較直觀容易理解,幸好spark提供了partitioner接口,可以自定義partitioner,支持這種自定義分組分區的方式,這裡我也有個簡單實現類,可以支持同個分區只有相同name的記錄
import org.apache.commons.collections.CollectionUtils; import org.apache.spark.Partitioner; import org.junit.Assert; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; /** * Created by lesly.lai on 2018/7/25. */ public class CuxGroupPartitioner extends Partitioner { private int partitions; /** * map * 主要為了區分不同分區 */ private Map
查看分區分佈情況工具類
import org.apache.spark.sql.{Dataset, Row} /** * Created by lesly.lai on 2017/12FeeTask/25. */ class SparkRddTaskInfo { def getTask(dataSet: Dataset[Row]) { val size = dataSet.rdd.partitions.length println(s"==> partition size: $size " ) import scala.collection.Iterator val showElements = (it: Iterator[Row]) => { val ns = it.toSeq import org.apache.spark.TaskContext val pid = TaskContext.get.partitionId println(s"[partition: $pid][size: ${ns.size}] ${ns.mkString(" ")}") } dataSet.foreachPartition(showElements) } }
調用方式
import com.vip.spark.db.ConnectionInfos; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import scala.Tuple2; import java.util.List; import java.util.stream.Collectors; /** * Created by lesly.lai on 2018/7/23. */ public class SparkSimpleTestPartition { public static void main(String[] args) throws InterruptedException {
SparkSession sparkSession = SparkSession.builder().appName("Java Spark SQL basic example").getOrCreate(); // 原始數據集 Dataset originSet = sparkSession.read().jdbc(ConnectionInfos.TEST_MYSQL_CONNECTION_URL, "people", ConnectionInfos.getTestUserAndPasswordProperties()); originSet.createOrReplaceTempView("people"); // 獲取分區分佈情況工具類 SparkRddTaskInfo taskInfo = new SparkRddTaskInfo(); Dataset groupSet = sparkSession.sql(" select name from people group by name"); List groupList = groupSet.javaRDD().collect().stream().map(row -> row.getAs("name")).collect(Collectors.toList()); // 創建pairRDD 目前只有pairRdd支持自定義partitioner,所以需要先轉成pairRdd JavaPairRDD pairRDD = originSet.javaRDD().mapToPair(row -> { return new Tuple2(row.getAs("name"), row); }); // 指定自定義partitioner JavaRDD javaRdd = pairRDD.partitionBy(new CuxGroupPartitioner(groupList)).map(new Function, Row>(){ @Override public Row call(Tuple2 v1) throws Exception { return v1._2; } }); Dataset result = sparkSession.createDataFrame(javaRdd, originSet.schema()); // 打印分區分佈情況 taskInfo.getTask(result); } }