Spark2.x精通:TaskRunner運行源碼深度剖析

如果您覺得“大數據開發運維架構”對你有幫助,歡迎轉發朋友圈


接上篇文章:Spark2.x精通:Executor進程源碼深度剖析,CoraseGrainedExecutorBackEnd接收LaunchTask消息之後,會立即調用executor.launchTask()函數執行task任務,裡面封裝了TaskRunner線程,這樣就會執行TaskRunner的run()方法,這裡我們就看一下他是如何執行Task任務的,這篇文章東西比較多也比較複雜,我們一點點進行剖析。


二、深入剖析


  1. 這裡我們直接看run方法,這個函數代碼比較多我這裡挑選比較重要的代碼講解,代碼如下:
<code>
override def run(): Unit = {
//這裡都是java的東西,獲取線程ID 設置線程名稱
threadId = Thread.currentThread.getId
Thread.currentThread.setName(threadName)
val threadMXBean = ManagementFactory.getThreadMXBean
//實例化一個內存管理器
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")

//向CoarseGrainedExecutorBackend端發送RUNNING消息,更新task執行狀態

//其實就是調用了Executor所在的CoarseGrainedExecutorBackend的statusUpdate()方法
//最後向driver發送StatusUpdate消息,進行狀態更新處理
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStart: Long = 0
var taskStartCpu: Long = 0
startGCTime = computeTotalGcTime()

try {
//下載jar 資源文件之前先設置下 task的相關屬性,以免權限不夠
Executor.taskDeserializationProps.set(taskDescription.properties)
//講task需要的文件、資源、jar拷貝通過rpc通信拷貝到本地
updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
//反序列化task信息,並設置對應的配置信息
task = ser.deserialize[Task[Any]](
taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
task.localProperties = taskDescription.properties
task.setTaskMemoryManager(taskMemoryManager)

// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
val killReason = reasonIfKilled
if (killReason.isDefined) {
// Throw an exception rather than returning, because returning within a try{} block
// causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
// exception will be caught by the catch block, leading to an incorrect ExceptionFailure
// for the task.
throw new TaskKilledException(killReason.get)
}

// The purpose of updating the epoch here is to invalidate executor map output status cache
// in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be
// MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so
// we don't need to make any special calls here.
if (!isLocal) {
logDebug("Task " + taskId + "'s epoch is " + task.epoch)
env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch)
}

// task啟動時間
taskStart = System.currentTimeMillis()
taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {

threadMXBean.getCurrentThreadCpuTime
} else 0L
var threwException = true
val value = try {
//最後啟動task.run()方法啟動,傳入三個參數taskid 最大嘗試次數,和監控信息
//再往下的代碼 就是一些task結束和異常處理了,這裡不再講解
//這裡的res其實就是Mapstatus信息,如果後面還是一個ShuffleMaptask就會去聯繫
//MapOuputTacker去拉取上一個Task輸出的數據
val res = task.run(
taskAttemptId = taskId,
attemptNumber = taskDescription.attemptNumber,
metricsSystem = env.metricsSystem)
threwException = false
res
} finally {
val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
..................................
}
task.context.fetchFailed.foreach { fetchFailure =>
// uh-oh. it appears the user code has caught the fetch-failure without throwing any
// other exceptions. Its *possible* this is what the user meant to do (though highly
// unlikely). So we will log an error and keep going.
logError(s"TID ${taskId} completed successfully though internally it encountered " +
s"unrecoverable fetch failures! Most likely this means user code is incorrectly " +
s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure)
}
val taskFinish = System.currentTimeMillis()
val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L

// If the task has been killed, let's fail it.
task.context.killTaskIfInterrupted()

val resultSer = env.serializer.newInstance()
val beforeSerialization = System.currentTimeMillis()
val valueBytes = resultSer.serialize(value)
val afterSerialization = System.currentTimeMillis()

// Deserialization happens in two parts: first, we deserialize a Task object, which
// includes the Partition. Second, Task.run() deserializes the RDD and function to be run.

task.metrics.setExecutorDeserializeTime(
(taskStart - deserializeStartTime) + task.executorDeserializeTime)
task.metrics.setExecutorDeserializeCpuTime(
(taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime)
// We need to subtract Task.run()'s deserialization time to avoid double-counting
task.metrics.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
task.metrics.setExecutorCpuTime(
(taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime)
task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
task.metrics.setResultSerializationTime(afterSerialization - beforeSerialization)

// Expose task metrics using the Dropwizard metrics system.
// Update task metrics counters
......................................
// Note: accumulator updates must be collected after TaskMetrics is updated
val accumUpdates = task.collectAccumulatorUpdates()
// TODO: do not serialize value twice
val directResult = new DirectTaskResult(valueBytes, accumUpdates)
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit()

// directSend = sending directly back to the driver
val serializedResult: ByteBuffer = {
if (maxResultSize > 0 && resultSize > maxResultSize) {
logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
s"dropping it.")
ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
} else if (resultSize > maxDirectResultSize) {
val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
blockId,
new ChunkedByteBuffer(serializedDirectResult.duplicate()),
StorageLevel.MEMORY_AND_DISK_SER)
logInfo(
s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
} else {
logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
serializedDirectResult
}
}
//向CoarseGrainedExecutorBackend端發送FINISHED消息,更新task執行狀態
//其實就是調用了Executor所在的CoarseGrainedExecutorBackend的statusUpdate()方法
//最後向driver發送StatusUpdate消息,進行狀態更新處理
setTaskFinishedAndClearInterruptStatus()

execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

} catch {
.................//異常處理
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
} else {
logInfo("Not reporting error to driver during JVM shutdown.")
}

// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (!t.isInstanceOf[SparkOutOfMemoryError] && Utils.isFatalError(t)) {
uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t)
}
} finally {
//執行完畢後從runningTasks內存緩存中刪除
runningTasks.remove(taskId)
}
}/<code>

2.下面我們來看下updateDependencies()函數,如何進行資源的下載的,代碼如下:

<code> private def updateDependencies(newFiles: Map[String, Long], newJars: Map[String, Long]) {
lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
//由於多個task線程是運行在一個CoarseGrainedExecutorBackend進程中的,存在多線程的併發
//這裡用了synchronized關鍵字,來控制線程對資源的併發訪問
synchronized {
// Fetch missing dependencies
//這裡循環需要拉取的資源文件,通過 Utils.fetchFile函數從遠程拉取
for ((name, timestamp) logInfo("Fetching " + name + " with timestamp " + timestamp)
// Fetch file with useCache mode, close cache for local mode.
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentFiles(name) = timestamp
}
//這裡循環需要拉取的jar文件,通過 Utils.fetchFile函數從遠程拉取

for ((name, timestamp) val localName = new URI(name).getPath.split("/").last
val currentTimeStamp = currentJars.get(name)
.orElse(currentJars.get(localName))
.getOrElse(-1L)
if (currentTimeStamp < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
// Fetch file with useCache mode, close cache for local mode.
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentJars(name) = timestamp
// Add it to our class loader
val url = new File(SparkFiles.getRootDirectory(), localName).toURI.toURL
if (!urlClassLoader.getURLs().contains(url)) {
logInfo("Adding " + url + " to class loader")
urlClassLoader.addURL(url)
}
}
}
}
}/<code>

3.回過頭來繼續看task.scala中的run()方法,代碼如下:

<code>
final def run(
taskAttemptId: Long,
attemptNumber: Int,
metricsSystem: MetricsSystem
//將taskID註冊到blockManager 這個後面再講解blockManager
SparkEnv.get.blockManager.registerTask(taskAttemptId)
//實例化task執行的上下文TaskContext,包含屬於哪個stageid,task需要處理哪一部分數據partitionId
//以及task內存 配置等信息
context = new TaskContextImpl(
stageId,
stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal
partitionId,
taskAttemptId,
attemptNumber,
taskMemoryManager,
localProperties,
metricsSystem,
metrics)
//指定上下文
TaskContext.setTaskContext(context)

taskThread = Thread.currentThread()

if (_reasonIfKilled != null) {
kill(interruptThread = false, _reasonIfKilled)
}

new CallerContext(
"TASK",
SparkEnv.get.conf.get(APP_CALLER_CONTEXT),
appId,
appAttemptId,
jobId,
Option(stageId),
Option(stageAttemptId),
Option(taskAttemptId),
Option(attemptNumber)).setCurrentContext()

try {
//這裡修比較重要了 runTask是一個抽象方法,他有兩種實現,一個是ShuffleMapTask
//另外一個是ResultTask,這裡我們在task啟動的時候講到過,如果stage是finalStage這裡面的
//task就是ResultTask,否則是ShuffleMapTask 一定要注意
runTask(context)
} catch {
case e: Throwable =>
// Catch all errors; run task failure callbacks, and rethrow the exception.
try {
context.markTaskFailed(e)
} catch {
case t: Throwable =>
e.addSuppressed(t)
}
context.markTaskCompleted(Some(e))
throw e
} finally {
try {
// Call the task completion callbacks. If "markTaskCompleted" is called twice, the second
// one is no-op.
context.markTaskCompleted(None)
} finally {
try {
Utils.tryLogNonFatalError {
//獲取BlockManager的內存管理器,釋放內存
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)

SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(
MemoryMode.OFF_HEAP)
// Notify any tasks waiting for execution memory to be freed to wake up and try to
// acquire memory again. This makes impossible the scenario where a task sleeps forever
// because there are no other tasks left to notify it. Since this is safe to do but may
// not be strictly necessary, we should revisit whether we can remove this in the
// future.
val memoryManager = SparkEnv.get.memoryManager
memoryManager.synchronized { memoryManager.notifyAll() }
}
} finally {
// Though we unset the ThreadLocal here, the context member variable itself is still
// queried directly in the TaskRunner to check for FetchFailedExceptions.
//釋放上下文
TaskContext.unset()
}
}
}
}/<code>

4.我們先看ShuffleMapTask的runtask()函數的實現,代碼如下:

<code> //這裡一定要注意有返回值MapStatus ,包含了ShuffleMapTask執行我們的代碼邏輯
// 講計算結果對應得BlockManager位置信息保存在MapStatus中返回
override def runTask(context: TaskContext): MapStatus = {

val threadMXBean = ManagementFactory.getThreadMXBean
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
val ser = SparkEnv.get.closureSerializer.newInstance()
// 多個task運行在多個Executor中,都是並行運行,或者併發運行的,可能都不在一個地方,但是一個stage的task,其實要處理的rdd是一樣,
// 所以task如何拿到自己要處理的rdd數據?
// 都是從廣播變量中獲取rdd數據,之前我們知道啟動task時候是需要將數據進行廣播.

val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
_executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
} else 0L

var writer: ShuffleWriter[Any, Any] = null
try {
val manager = SparkEnv.get.shuffleManager
//調用shuffleManager獲取ShuffleWriter
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
//這裡是重點,調用rdd.iterator方法,針對當前task需要處理的那個partition數據
//執行我們自己的代碼邏輯進行處理
// 返回MapStatus裡面含有Task計算完成後的輸出數據位置信息 ,其實就是BlockManager相關的信息
//最後講返回的MapStatus信息,通過write方法寫入到磁盤文件
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: product2="" any=""> //寫入成功
writer.stop(success = true).get
} catch {
case e: Exception =>
try {
if (writer != null) {
writer.stop(success = false)
}
} catch {
case e: Exception =>
log.debug("Could not stop writer", e)
}
throw e
}
}/<code>

這裡再詳細說一下,在:Spark2.x精通:Job觸發流程源碼深度剖析(一)文章中我們知道,我們的RDD實現類一般都是MapPartitionsRDD,上面的rdd.iterator()方法,最後都會去執行compute()方法,代碼如下:

<code>//這裡就比較簡單了,這裡的f函數可以理解成我們代碼中自己的算子函數,這裡就不在深入講解了
//能理解到這裡就可以了,知道是執行的我們自己的那個算子即可
override def compute(split: Partition, context: TaskContext): Iterator[U] =
f(context, split.index, firstParent[T].iterator(split, context))/<code>

5.回過頭我們再看一下ResultTasks的runTask()方法,代碼如下:

<code>    override def runTask(context: TaskContext): U = {
// Deserialize the RDD and the func using the broadcast variables.
val threadMXBean = ManagementFactory.getThreadMXBean
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
val ser = SparkEnv.get.closureSerializer.newInstance()
//原理跟上面一樣獲取數據
// 多個task運行在多個Executor中,都是並行運行,或者併發運行的,可能都不在一個地方,但是一個stage的task,其實要處理的rdd是一樣,
// 所以task如何拿到自己要處理的rdd數據?
// 都是從廣播變量中獲取rdd數據,之前我們知道啟動task時候是需要將數據進行廣播.
val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
_executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
} else 0L
//這個方法比較簡單,直接調用rdd.iterator()方法執行我們自己的算子操作
//他的計算結果一般就寫入磁盤 等待下一個stage調用,或者直接輸入到文件

//或者數據庫了
func(context, rdd.iterator(partition, context))
}/<code>

至此TaskRunner運行的源碼剖析完畢,感謝關注!!!


如果覺得我的文章能幫到您,請關注微信公眾號“大數據開發運維架構”,並轉發朋友圈,謝謝支持!!!


分享到:


相關文章: