Extend Spark ML for your own model/transformer types
How to use the wordcount example as a starting point (and you thought you’d escape the wordcount example).
While Spark ML pipelines have a wide variety of algorithms, you may find yourself wanting additional functionality without having to leave the pipeline model. In Spark MLlib, this isn’t much of a problem—you can manually implement your algorithm with RDD transformations and keep going from there. For Spark ML pipelines, the same approach can work, but we lose some of the nicely integrated properties of the pipeline, including the ability to automatically run meta-algorithms, such as cross-validation parameter search. In this article, you will learn how to extend the Spark ML pipeline model using the standard wordcount example as a starting point (one can never really escape the intro to big data wordcount example).
To add your own algorithm to a Spark pipeline, you need to implement either Estimator
or Transformer
, which implements the PipelineStage
interface. For algorithms that don’t require training, you can implement the Transformer
interface, and for algorithms with training you can implement the Estimator
interface—both in org.apache.spark.ml
(both of which implement the base PipelineStage
). Note that training is not limited to complicated machine learning models; even the MinMaxScaler requires training to determine the range. If they need training, they must be constructed as Estimator
rather than Transformer
.
Note
Using PipelineStage
directly does not work, since inside of the pipeline fitting reflection is used, which assumes all stages are either an Estimator
or a Transformer
.
In addition to the obvious transform
or fit
function, all pipeline stages need to provide transformSchema
, and a copy
constructor or implement a class, which provides these for you—copy
is used to make a copy of the current stage, with any newly specified params merged in, and can simply be called defaultCopy
(unless your class has special constructor considerations).
The start of a pipeline stage, as well as the copy delegation, is shown—transformSchema
must produce what the expected output of your pipeline stage is based on any parameters set and an input schema. Most pipeline stages simply add new fields; very few drop previous fields in case they are needed, but this can sometimes result in records containing more data than is required downstream, negatively impacting performance. If you find this is a problem in your pipeline, you can create your own stage to drop unnecessary fields.
class
HardCodedWordCountStage
(
override
val
uid
:
String
)
extends
Transformer
{
def
this
()
=
this
(
Identifiable
.
randomUID
(
"hardcodedwordcount"
))
def
copy
(
extra
:
ParamMap
)
:
HardCodedWordCountStage
=
{
defaultCopy
(
extra
)
}
In addition to producing the output schema, the transformSchema
function should validate that the input schema is suitable for the stage (e.g., the input column is of the expected type).
This is also where you should perform validation on your stages parameters.
A simple transformSchema
for string inputs and a vector output, with hard coded input and output columns, is illustrated as follows.
override
def
transformSchema
(
schema
:
StructType
)
:
StructType
=
{
// Check that the input type is a string
val
idx
=
schema
.
fieldIndex
(
"happy_pandas"
)
val
field
=
schema
.
fields
(
idx
)
if
(
field
.
dataType
!=
StringType
)
{
throw
new
Exception
(
s"Input type
${
field
.
dataType
}
did not match input type StringType"
)
}
// Add the return field
schema
.
add
(
StructField
(
"happy_panda_counts"
,
IntegerType
,
false
))
}
Algorithms that do not require training can be implemented very simply using the Transformer
interface. Since this is the simplest pipeline stage, you can start with implementing a simple transformer, which counts the number of words on the input column.
def
transform
(
df
:
Dataset
[
_
])
:
DataFrame
=
{
val
wordcount
=
udf
{
in
:
String
=>
in
.
split
(
" "
).
size
}
df
.
select
(
col
(
"*"
),
wordcount
(
df
.
col
(
"happy_pandas"
)).
as
(
"happy_panda_counts"
))
}
To get the most of the pipeline interface, you will want to make your pipeline stage configurable using the params interface.
While the params interface is public, sadly the common default params that are commonly used inside of Spark are private, so you will end up with some amount of code duplication. In addition to allowing users to specify values, parameters can also contain some basic validation logic (e.g., the regularization parameter must be set to a non-negative value). The two most common parameters are input column and output column, which you can add to your model relatively simply.
In addition to string params, any other type can be used, including lists of strings for things like stop words, and doubles for things like stop words.
class
ConfigurableWordCount
(
override
val
uid
:
String
)
extends
Transformer
{
final
val
inputCol
=
new
Param
[
String
](
this
,
"inputCol"
,
"The input column"
)
final
val
outputCol
=
new
Param
[
String
](
this
,
"outputCol"
,
"The output column"
)
;
def
setInputCol
(
value
:
String
)
:
this.
type
=
set
(
inputCol
,
value
)
def
setOutputCol
(
value
:
String
)
:
this.
type
=
set
(
outputCol
,
value
)
def
this
()
=
this
(
Identifiable
.
randomUID
(
"configurablewordcount"
))
def
copy
(
extra
:
ParamMap
)
:
HardCodedWordCountStage
=
{
defaultCopy
(
extra
)
}
override
def
transformSchema
(
schema
:
StructType
)
:
StructType
=
{
// Check that the input type is a string
val
idx
=
schema
.
fieldIndex
(
$
(
inputCol
))
val
field
=
schema
.
fields
(
idx
)
if
(
field
.
dataType
!=
StringType
)
{
throw
new
Exception
(
s"Input type
${
field
.
dataType
}
did not match input type StringType"
)
}
// Add the return field
schema
.
add
(
StructField
(
$
(
outputCol
),
IntegerType
,
false
))
}
def
transform
(
df
:
Dataset
[
_
])
:
DataFrame
=
{
val
wordcount
=
udf
{
in
:
String
=>
in
.
split
(
" "
).
size
}
df
.
select
(
col
(
"*"
),
wordcount
(
df
.
col
(
$
(
inputCol
))).
as
(
$
(
outputCol
)))
}
}
Algorithms that do require training can be implemented using the Estimator
interface—although, for many algorithms, the org.apache.spark.ml.Predictor
or org.apache.spark.ml.classificationClassifier
helper classes are easier to implement. The primary difference between the Estimator
and Transformer
interfaces is that rather than directly expressing your transformation on the input, you will first have a training step in the form of a train
function. A string indexer is one of the simplest estimators you can implement, and while it’s already available in Spark, is still a good illustration of how to use the estimator interface.
trait
SimpleIndexerParams
extends
Params
{
final
val
inputCol
=
new
Param
[
String
](
this
,
"inputCol"
,
"The input column"
)
final
val
outputCol
=
new
Param
[
String
](
this
,
"outputCol"
,
"The output column"
)
}
class
SimpleIndexer
(
override
val
uid
:
String
)
extends
Estimator
[
SimpleIndexerModel
]
with
SimpleIndexerParams
{
def
setInputCol
(
value
:
String
)
=
set
(
inputCol
,
value
)
def
setOutputCol
(
value
:
String
)
=
set
(
outputCol
,
value
)
def
this
()
=
this
(
Identifiable
.
randomUID
(
"simpleindexer"
))
override
def
copy
(
extra
:
ParamMap
)
:
SimpleIndexer
=
{
defaultCopy
(
extra
)
}
override
def
transformSchema
(
schema
:
StructType
)
:
StructType
=
{
// Check that the input type is a string
val
idx
=
schema
.
fieldIndex
(
$
(
inputCol
))
val
field
=
schema
.
fields
(
idx
)
if
(
field
.
dataType
!=
StringType
)
{
throw
new
Exception
(
s"Input type
${
field
.
dataType
}
did not match input type StringType"
)
}
// Add the return field
schema
.
add
(
StructField
(
$
(
outputCol
),
IntegerType
,
false
))
}
override
def
fit
(
dataset
:
Dataset
[
_
])
:
SimpleIndexerModel
=
{
import
dataset.sparkSession.implicits._
val
words
=
dataset
.
select
(
dataset
(
$
(
inputCol
)).
as
[
String
]).
distinct
.
collect
()
new
SimpleIndexerModel
(
uid
,
words
)
;
}
}
class
SimpleIndexerModel
(
override
val
uid
:
String
,
words
:
Array
[
String
])
extends
Model
[
SimpleIndexerModel
]
with
SimpleIndexerParams
{
override
def
copy
(
extra
:
ParamMap
)
:
SimpleIndexerModel
=
{
defaultCopy
(
extra
)
}
private
val
labelToIndex
:
Map
[
String
,Double
]
=
words
.
zipWithIndex
.
map
{
case
(
x
,
y
)
=>
(
x
,
y
.
toDouble
)}.
toMap
override
def
transformSchema
(
schema
:
StructType
)
:
StructType
=
{
// Check that the input type is a string
val
idx
=
schema
.
fieldIndex
(
$
(
inputCol
))
val
field
=
schema
.
fields
(
idx
)
if
(
field
.
dataType
!=
StringType
)
{
throw
new
Exception
(
s"Input type
${
field
.
dataType
}
did not match input type StringType"
)
}
// Add the return field
schema
.
add
(
StructField
(
$
(
outputCol
),
IntegerType
,
false
))
}
override
def
transform
(
dataset
:
Dataset
[
_
])
:
DataFrame
=
{
val
indexer
=
udf
{
label
:
String
=>
labelToIndex
(
label
)
}
dataset
.
select
(
col
(
"*"
),
indexer
(
dataset
(
$
(
inputCol
)).
cast
(
StringType
)).
as
(
$
(
outputCol
)))
}
}
If you are implementing an iterative algorithm, you may wish to consider caching the input data automatically if it’s not already cached, or allow the user to specify a persistence level.
The Predictor
interface adds the two most common parameters (input and output columns) as labels column, features column, and prediction column—and automatically handles the schema transformation for us.
The Classifier
interface does much the same, except it also adds a rawPredictionColumn
and provides tools to detect the number of classes (getNumClasses
) as well as convert the input DataFrame
to an RDD of LabeledPoints
(making it easier to wrap legacy MLlib classification algorithms).
If you are implementing a regression or clustering interface, there is no public base set of interfaces to use, so you will need to use the generic Estimator
interface.
// Simple Bernouli Naive Bayes classifier - no sanity checks for brevity
// Example only - not for production use.
class
SimpleNaiveBayes
(
val
uid
:
String
)
extends
Classifier
[
Vector
,SimpleNaiveBayes
,SimpleNaiveBayesModel
]
{
def
this
()
=
this
(
Identifiable
.
randomUID
(
"simple-naive-bayes"
))
override
def
train
(
ds
:
Dataset
[
_
])
:
SimpleNaiveBayesModel
=
{
import
ds.sparkSession.implicits._
ds
.
cache
()
// Note: you can use getNumClasses and extractLabeledPoints to get an RDD instead
// Using the RDD approach is common when integrating with legacy machine learning code
// or iterative algorithms which can create large query plans.
// Here we use Datasets since neither of those apply.
// Compute the number of documents
val
numDocs
=
ds
.
count
// Get the number of classes.
// Note this estimator assumes they start at 0 and go to numClasses
val
numClasses
=
getNumClasses
(
ds
)
// Get the number of features by peaking at the first row
val
numFeatures
:
Integer
=
ds
.
select
(
col
(
$
(
featuresCol
))).
head
.
get
(
0
).
asInstanceOf
[
Vector
].
size
// Determine the number of records for each class
val
groupedByLabel
=
ds
.
select
(
col
(
$
(
labelCol
)).
as
[
Double
]).
groupByKey
(
x
=>
x
)
val
classCounts
=
groupedByLabel
.
agg
(
count
(
"*"
).
as
[
Long
])
.
sort
(
col
(
"value"
)).
collect
().
toMap
// Select the labels and features so we can more easily map over them.
// Note: we do this as a DataFrame using the untyped API because the Vector
// UDT is no longer public.
val
df
=
ds
.
select
(
col
(
$
(
labelCol
)).
cast
(
DoubleType
),
col
(
$
(
featuresCol
)))
// Figure out the non-zero frequency of each feature for each label and
// output label index pairs using a case clas to make it easier to work with.
val
labelCounts
:
Dataset
[
LabeledToken
]
=
df
.
flatMap
{
case
Row
(
label
:
Double
,
features
:
Vector
)
=>
features
.
toArray
.
zip
(
Stream
from
1
)
.
filter
{
vIdx
=>
vIdx
.
_2
==
1.0
}
.
map
{
case
(
v
,
idx
)
=>
LabeledToken
(
label
,
idx
)}
}
// Use the typed Dataset aggregation API to count the number of non-zero
// features for each label-feature index.
val
aggregatedCounts
:
Array
[((
Double
,Integer
)
,Long
)]
=
labelCounts
.
groupByKey
(
x
=>
(
x
.
label
,
x
.
index
))
.
agg
(
count
(
"*"
).
as
[
Long
]).
collect
()
val
theta
=
Array
.
fill
(
numClasses
)(
new
Array
[
Double
](
numFeatures
))
// Compute the denominator for the general prioirs
val
piLogDenom
=
math
.
log
(
numDocs
+
numClasses
)
// Compute the priors for each class
val
pi
=
classCounts
.
map
{
case
(
_
,
cc
)
=>
math
.
log
(
cc
.
toDouble
)
-
piLogDenom
}.
toArray
// For each label/feature update the probabilities
aggregatedCounts
.
foreach
{
case
((
label
,
featureIndex
),
count
)
=>
// log of number of documents for this label + 2.0 (smoothing)
val
thetaLogDenom
=
math
.
log
(
classCounts
.
get
(
label
).
map
(
_
.
toDouble
).
getOrElse
(
0.0
)
+
2.0
)
theta
(
label
.
toInt
)(
featureIndex
)
=
math
.
log
(
count
+
1.0
)
-
thetaLogDenom
}
// Unpersist now that we are done computing everything
ds
.
unpersist
()
// Construct a model
new
SimpleNaiveBayesModel
(
uid
,
numClasses
,
numFeatures
,
Vectors
.
dense
(
pi
),
new
DenseMatrix
(
numClasses
,
theta
(
0
).
length
,
theta
.
flatten
,
true
))
}
override
def
copy
(
extra
:
ParamMap
)
=
{
defaultCopy
(
extra
)
}
}
// Simplified Naive Bayes Model
case
class
SimpleNaiveBayesModel
(
override
val
uid
:
String
,
override
val
numClasses
:
Int
,
override
val
numFeatures
:
Int
,
val
pi
:
Vector
,
val
theta
:
DenseMatrix
)
extends
ClassificationModel
[
Vector
,SimpleNaiveBayesModel
]
{
override
def
copy
(
extra
:
ParamMap
)
=
{
defaultCopy
(
extra
)
}
// We have to do some tricks here because we are using Spark's
// Vector/DenseMatrix calculations - but for your own model don't feel
// limited to Spark's native ones.
val
negThetaArray
=
theta
.
values
.
map
(
v
=>
math
.
log
(
1.0
-
math
.
exp
(
v
)))
val
negTheta
=
new
DenseMatrix
(
numClasses
,
numFeatures
,
negThetaArray
,
true
)
val
thetaMinusNegThetaArray
=
theta
.
values
.
zip
(
negThetaArray
)
.
map
{
case
(
v
,
nv
)
=>
v
-
nv
}
val
thetaMinusNegTheta
=
new
DenseMatrix
(
numClasses
,
numFeatures
,
thetaMinusNegThetaArray
,
true
)
val
onesVec
=
Vectors
.
dense
(
Array
.
fill
(
theta
.
numCols
)(
1.0
))
val
negThetaSum
:
Array
[
Double
]
=
negTheta
.
multiply
(
onesVec
).
toArray
// Here is the prediciton functionality you need to implement - for ClassificationModels
// transform automatically wraps this - but if you might benefit from broadcasting your model or
// other optimizations you can also override transform.
def
predictRaw
(
features
:
Vector
)
:
Vector
=
{
// Toy implementation - use BLAS or similar instead
// the summing of the three vectors but the functionality isn't exposed.
Vectors
.
dense
(
thetaMinusNegTheta
.
multiply
(
features
).
toArray
.
zip
(
pi
.
toArray
)
.
map
{
case
(
x
,
y
)
=>
x
+
y
}.
zip
(
negThetaSum
).
map
{
case
(
x
,
y
)
=>
x
+
y
}
)
}
}
Note
If you simply need to modify an existing algorithm, you can (by pretending to be in the org.apache.spark
project) extend it.
Now you know how to extend Spark’s ML Pipeline API with your own stages. If you get lost, a good reference is the algorithms inside of Spark’s itself—while they do sometimes use internal APIs, for the most part they implement public interfaces in the same way that you will want to.