通过 pyspark.ml.tuning.TrainValidationSplit 调整后如何获得最佳参数?
Posted
技术标签:
【中文标题】通过 pyspark.ml.tuning.TrainValidationSplit 调整后如何获得最佳参数?【英文标题】:How to get best params after tuning by pyspark.ml.tuning.TrainValidationSplit? 【发布时间】:2017-01-28 09:50:09 【问题描述】:我正在尝试通过TrainValidationSplit
调整 Spark (PySpark) ALS
模型的超参数。
效果很好,但我想知道哪种超参数组合最好。评估后如何获得最佳参数?
from pyspark.ml.recommendation import ALS
from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder
from pyspark.ml.evaluation import RegressionEvaluator
df = sqlCtx.createDataFrame(
[(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
["user", "item", "rating"],
)
df_test = sqlCtx.createDataFrame(
[(0, 0), (0, 1), (1, 1), (1, 2), (2, 1), (2, 2)],
["user", "item"],
)
als = ALS()
param_grid = ParamGridBuilder().addGrid(
als.rank,
[10, 15],
).addGrid(
als.maxIter,
[10, 15],
).build()
evaluator = RegressionEvaluator(
metricName="rmse",
labelCol="rating",
)
tvs = TrainValidationSplit(
estimator=als,
estimatorParamMaps=param_grid,
evaluator=evaluator,
)
model = tvs.fit(df)
问题:如何获得最好的rank和maxIter?
【问题讨论】:
【参考方案1】:您可以使用TrainValidationSplitModel
的bestModel
属性访问最佳模型:
best_model = model.bestModel
可以使用ALSModel
的rank
属性直接访问排名:
best_model.rank
10
获得最大迭代次数需要更多技巧:
(best_model
._java_obj # Get Java object
.parent() # Get parent (ALS estimator)
.getMaxIter()) # Get maxIter
10
【讨论】:
以上是关于通过 pyspark.ml.tuning.TrainValidationSplit 调整后如何获得最佳参数?的主要内容,如果未能解决你的问题,请参考以下文章
java是通过值传递,也就是通过拷贝传递——通过方法操作不同类型的变量加深理解