public abstract class PredictionModel<FeaturesType,M extends PredictionModel<FeaturesType,M>> extends Model<M>
| Constructor and Description |
|---|
PredictionModel() |
| Modifier and Type | Method and Description |
|---|---|
Param<java.lang.String> |
featuresCol()
Param for features column name.
|
protected DataType |
featuresDataType()
Returns the SQL DataType corresponding to the FeaturesType type parameter.
|
java.lang.String |
getFeaturesCol() |
java.lang.String |
getLabelCol() |
java.lang.String |
getPredictionCol() |
Param<java.lang.String> |
labelCol()
Param for label column name.
|
int |
numFeatures()
Returns the number of features the model was trained on.
|
protected abstract double |
predict(FeaturesType features)
Predict label for the given features.
|
Param<java.lang.String> |
predictionCol()
Param for prediction column name.
|
M |
setFeaturesCol(java.lang.String value) |
M |
setPredictionCol(java.lang.String value) |
Dataset<Row> |
transform(Dataset<?> dataset)
Transforms dataset by reading from
featuresCol, calling predict(), and storing
the predictions as a new column predictionCol. |
protected Dataset<Row> |
transformImpl(Dataset<?> dataset) |
StructType |
transformSchema(StructType schema)
:: DeveloperApi ::
|
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
transform, transform, transformtransformSchemaclone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitclear, copy, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn, validateParamstoString, uidpublic M setFeaturesCol(java.lang.String value)
public M setPredictionCol(java.lang.String value)
public int numFeatures()
protected DataType featuresDataType()
This is used by validateAndTransformSchema().
This workaround is needed since SQL has different APIs for Scala and Java.
The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
public StructType transformSchema(StructType schema)
PipelineStageDerives the output schema from the input schema.
transformSchema in class PipelineStageschema - (undocumented)public Dataset<Row> transform(Dataset<?> dataset)
featuresCol, calling predict(), and storing
the predictions as a new column predictionCol.
transform in class Transformerdataset - input datasetpredictionCol of type Doubleprotected abstract double predict(FeaturesType features)
transform() and output predictionCol.features - (undocumented)public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema - input schemafitting - whether this is in fittingfeaturesDataType - SQL DataType for FeaturesType.
E.g., VectorUDT for vector features.public Param<java.lang.String> labelCol()
public java.lang.String getLabelCol()
public Param<java.lang.String> featuresCol()
public java.lang.String getFeaturesCol()
public Param<java.lang.String> predictionCol()
public java.lang.String getPredictionCol()