ML Persistence — Saving and Loading Models and Pipelines

MLWriter and MLReader belong to org.apache.spark.ml.util package.

They allow you to save and load models despite the languages — Scala, Java, Python or R — they have been saved in and loaded later on.

MLWriter

MLWriter abstract class comes with save(path: String) method to save a ML component to a given path.

save(path: String): Unit

It comes with another (chainable) method overwrite to overwrite the output path if it already exists.

overwrite(): this.type

The component is saved into a JSON file (see MLWriter Example section below).

Tip

Enable INFO logging level for the MLWriter implementation logger to see what happens inside.

Add the following line to conf/log4j.properties:

log4j.logger.org.apache.spark.ml.Pipeline$.PipelineWriter=INFO

Refer to Logging.

Caution
FIXME The logging doesn’t work and overwriting does not print out INFO message to the logs :(

MLWriter Example

import org.apache.spark.ml._
val pipeline = new Pipeline().setStages(Array.empty[PipelineStage])
pipeline.write.overwrite.save("sample-pipeline")

The result of save for "unfitted" pipeline is a JSON file for metadata (as shown below).

$ cat sample-pipeline/metadata/part-00000 | jq
{
  "class": "org.apache.spark.ml.Pipeline",
  "timestamp": 1472747720477,
  "sparkVersion": "2.1.0-SNAPSHOT",
  "uid": "pipeline_181c90b15d65",
  "paramMap": {
    "stageUids": []
  }
}

The result of save for pipeline model is a JSON file for metadata while Parquet for model data, e.g. coefficients.

val model = pipeline.fit(training)
model.write.save("sample-model")
$ cat sample-model/metadata/part-00000 | jq
{
  "class": "org.apache.spark.ml.PipelineModel",
  "timestamp": 1472748168005,
  "sparkVersion": "2.1.0-SNAPSHOT",
  "uid": "pipeline_3ed598da1c4b",
  "paramMap": {
    "stageUids": [
      "regexTok_bf73e7c36e22",
      "hashingTF_ebece38da130",
      "logreg_819864aa7120"
    ]
  }
}

$ tree sample-model/stages/
sample-model/stages/
|-- 0_regexTok_bf73e7c36e22
|   `-- metadata
|       |-- _SUCCESS
|       `-- part-00000
|-- 1_hashingTF_ebece38da130
|   `-- metadata
|       |-- _SUCCESS
|       `-- part-00000
`-- 2_logreg_819864aa7120
    |-- data
    |   |-- _SUCCESS
    |   `-- part-r-00000-56423674-0208-4768-9d83-2e356ac6a8d2.snappy.parquet
    `-- metadata
        |-- _SUCCESS
        `-- part-00000

7 directories, 8 files

MLReader

MLReader abstract class comes with load(path: String) method to load a ML component from a given path.

import org.apache.spark.ml._
val pipeline = Pipeline.read.load("sample-pipeline")

scala> val stageCount = pipeline.getStages.size
stageCount: Int = 0

val pipelineModel = PipelineModel.read.load("sample-model")

scala> pipelineModel.stages
res1: Array[org.apache.spark.ml.Transformer] = Array(regexTok_bf73e7c36e22, hashingTF_ebece38da130, logreg_819864aa7120)

results matching ""

    No results matching ""