TaskContext

TaskContext is the contract for contextual information about a Task in Spark that allows for registering task listeners.

You can access the active TaskContext instance using TaskContext.get method.

import org.apache.spark.TaskContext
val ctx = TaskContext.get

Using TaskContext you can access local properties that were set by the driver.

Note
TaskContext is serializable.

TaskContext Contract

trait TaskContext {
  def taskSucceeded(index: Int, result: Any)
  def jobFailed(exception: Exception)
}
Table 1. TaskContext Contract
Method Description

stageId

Id of the Stage the task belongs to.

Used when…​

partitionId

Id of the Partition computed by the task.

Used when…​

attemptNumber

Specifies how many times the task has been attempted (starting from 0).

Used when…​

taskAttemptId

Id of the attempt of the task.

Used when…​

getMetricsSources

Gives all the metrics sources by sourceName which are associated with the instance that runs the task.

getLocalProperty

Used when…​

Accesses local properties set by the driver using SparkContext.setLocalProperty.

taskMetrics

TaskMetrics of the active Task.

Used when…​

taskMemoryManager

Used when…​

registerAccumulator

Used when…​

isCompleted

Used when…​

isInterrupted

A flag that is enabled when a task was killed.

Used when…​

addTaskCompletionListener

Registers a TaskCompletionListener

Used when…​

addTaskFailureListener

Registers a TaskFailureListener

Used when…​

unset Method

Caution
FIXME

setTaskContext Method

Caution
FIXME

Accessing Active TaskContext — get Method

get(): TaskContext

get method returns the TaskContext instance for an active task (as a TaskContextImpl). There can only be one instance and tasks can use the object to access contextual information about themselves.

val rdd = sc.range(0, 3, numSlices = 3)

scala> rdd.partitions.size
res0: Int = 3

rdd.foreach { n =>
  import org.apache.spark.TaskContext
  val tc = TaskContext.get
  val msg = s"""|-------------------
                |partitionId:   ${tc.partitionId}
                |stageId:       ${tc.stageId}
                |attemptNum:    ${tc.attemptNumber}
                |taskAttemptId: ${tc.taskAttemptId}
                |-------------------""".stripMargin
  println(msg)
}
Note
TaskContext object uses ThreadLocal to keep it thread-local, i.e. to associate state with the thread of a task.

Registering Task Listeners

Using TaskContext object you can register task listeners for task completion regardless of the final state and task failures only.

addTaskCompletionListener Method

addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext

addTaskCompletionListener methods register a TaskCompletionListener listener to be executed on task completion.

Note
It will be executed regardless of the final state of a task - success, failure, or cancellation.
val rdd = sc.range(0, 5, numSlices = 1)

import org.apache.spark.TaskContext
val printTaskInfo = (tc: TaskContext) => {
  val msg = s"""|-------------------
                |partitionId:   ${tc.partitionId}
                |stageId:       ${tc.stageId}
                |attemptNum:    ${tc.attemptNumber}
                |taskAttemptId: ${tc.taskAttemptId}
                |-------------------""".stripMargin
  println(msg)
}

rdd.foreachPartition { _ =>
  val tc = TaskContext.get
  tc.addTaskCompletionListener(printTaskInfo)
}

addTaskFailureListener Method

addTaskFailureListener(listener: TaskFailureListener): TaskContext
addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext

addTaskFailureListener methods register a TaskFailureListener listener to be executed on task failure only. It can be executed multiple times since a task can be re-attempted when it fails.

val rdd = sc.range(0, 2, numSlices = 2)

import org.apache.spark.TaskContext
val printTaskErrorInfo = (tc: TaskContext, error: Throwable) => {
  val msg = s"""|-------------------
                |partitionId:   ${tc.partitionId}
                |stageId:       ${tc.stageId}
                |attemptNum:    ${tc.attemptNumber}
                |taskAttemptId: ${tc.taskAttemptId}
                |error:         ${error.toString}
                |-------------------""".stripMargin
  println(msg)
}

val throwExceptionForOddNumber = (n: Long) => {
  if (n % 2 == 1) {
    throw new Exception(s"No way it will pass for odd number: $n")
  }
}

// FIXME It won't work.
rdd.map(throwExceptionForOddNumber).foreachPartition { _ =>
  val tc = TaskContext.get
  tc.addTaskFailureListener(printTaskErrorInfo)
}

// Listener registration matters.
rdd.mapPartitions { (it: Iterator[Long]) =>
  val tc = TaskContext.get
  tc.addTaskFailureListener(printTaskErrorInfo)
  it
}.map(throwExceptionForOddNumber).count

(Unused) Accessing Partition Id — getPartitionId Method

getPartitionId(): Int

getPartitionId gets the active TaskContext and returns partitionId or 0 (if TaskContext not available).

Note
getPartitionId is not used.

results matching ""

    No results matching ""