import org.apache.spark.TaskContext
val ctx = TaskContext.get
TaskContext
TaskContext
is the contract for contextual information about a Task in Spark that allows for registering task listeners.
You can access the active TaskContext
instance using TaskContext.get method.
Using TaskContext
you can access local properties that were set by the driver.
Note
|
TaskContext is serializable.
|
TaskContext Contract
trait TaskContext {
def taskSucceeded(index: Int, result: Any)
def jobFailed(exception: Exception)
}
Method | Description |
---|---|
Id of the Stage the task belongs to. Used when… |
|
Id of the Partition computed by the task. Used when… |
|
Specifies how many times the task has been attempted (starting from 0). Used when… |
|
Id of the attempt of the task. Used when… |
|
Gives all the metrics sources by |
|
Used when… Accesses local properties set by the driver using SparkContext.setLocalProperty. |
|
TaskMetrics of the active Task. Used when… |
|
Used when… |
|
Used when… |
|
Used when… |
|
A flag that is enabled when a task was killed. Used when… |
|
Registers a Used when… |
|
Registers a Used when… |
unset
Method
Caution
|
FIXME |
setTaskContext
Method
Caution
|
FIXME |
Accessing Active TaskContext — get
Method
get(): TaskContext
get
method returns the TaskContext
instance for an active task (as a TaskContextImpl). There can only be one instance and tasks can use the object to access contextual information about themselves.
val rdd = sc.range(0, 3, numSlices = 3)
scala> rdd.partitions.size
res0: Int = 3
rdd.foreach { n =>
import org.apache.spark.TaskContext
val tc = TaskContext.get
val msg = s"""|-------------------
|partitionId: ${tc.partitionId}
|stageId: ${tc.stageId}
|attemptNum: ${tc.attemptNumber}
|taskAttemptId: ${tc.taskAttemptId}
|-------------------""".stripMargin
println(msg)
}
Note
|
TaskContext object uses ThreadLocal to keep it thread-local, i.e. to associate state with the thread of a task.
|
Registering Task Listeners
Using TaskContext
object you can register task listeners for task completion regardless of the final state and task failures only.
addTaskCompletionListener
Method
addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext
addTaskCompletionListener
methods register a TaskCompletionListener
listener to be executed on task completion.
Note
|
It will be executed regardless of the final state of a task - success, failure, or cancellation. |
val rdd = sc.range(0, 5, numSlices = 1)
import org.apache.spark.TaskContext
val printTaskInfo = (tc: TaskContext) => {
val msg = s"""|-------------------
|partitionId: ${tc.partitionId}
|stageId: ${tc.stageId}
|attemptNum: ${tc.attemptNumber}
|taskAttemptId: ${tc.taskAttemptId}
|-------------------""".stripMargin
println(msg)
}
rdd.foreachPartition { _ =>
val tc = TaskContext.get
tc.addTaskCompletionListener(printTaskInfo)
}
addTaskFailureListener
Method
addTaskFailureListener(listener: TaskFailureListener): TaskContext
addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext
addTaskFailureListener
methods register a TaskFailureListener
listener to be executed on task failure only. It can be executed multiple times since a task can be re-attempted when it fails.
val rdd = sc.range(0, 2, numSlices = 2)
import org.apache.spark.TaskContext
val printTaskErrorInfo = (tc: TaskContext, error: Throwable) => {
val msg = s"""|-------------------
|partitionId: ${tc.partitionId}
|stageId: ${tc.stageId}
|attemptNum: ${tc.attemptNumber}
|taskAttemptId: ${tc.taskAttemptId}
|error: ${error.toString}
|-------------------""".stripMargin
println(msg)
}
val throwExceptionForOddNumber = (n: Long) => {
if (n % 2 == 1) {
throw new Exception(s"No way it will pass for odd number: $n")
}
}
// FIXME It won't work.
rdd.map(throwExceptionForOddNumber).foreachPartition { _ =>
val tc = TaskContext.get
tc.addTaskFailureListener(printTaskErrorInfo)
}
// Listener registration matters.
rdd.mapPartitions { (it: Iterator[Long]) =>
val tc = TaskContext.get
tc.addTaskFailureListener(printTaskErrorInfo)
it
}.map(throwExceptionForOddNumber).count
(Unused) Accessing Partition Id — getPartitionId
Method
getPartitionId(): Int
getPartitionId
gets the active TaskContext
and returns partitionId or 0
(if TaskContext
not available).
Note
|
getPartitionId is not used.
|