我正在尝试重构经过训练的基于火花树的模型(RandomForest或GBT分类器),使其可以在没有火花的环境中导出。该toDebugString
方法是一个很好的起点。但是,对于RandomForestClassifier
,字符串仅显示每棵树的预测类,而没有相对概率。因此,如果对所有树木的预测取平均值,则会得到错误的结果。
一个例子。我们DecisionTree
以这种方式代表:
DecisionTreeClassificationModel (uid=dtc_884dc2111789) of depth 2 with 5 nodes If (feature 21 in {1.0}) Predict: 0.0 Else (feature 21 not in {1.0}) If (feature 10 in {0.0}) Predict: 0.0 Else (feature 10 not in {0.0}) Predict: 1.0
如我们所见,跟随这些节点,看起来预测总是为0或1。但是,如果将这棵单树应用于特征向量,则得到的概率像[0.1007, 0.8993]
,并且它们在训练中非常有意义,因为在训练中设置负数/正数的比例,该比例最终与示例矢量与输出概率匹配的位置相同。
我的问题:这些概率存储在哪里?有没有办法提取它们?如果是这样,怎么办?一个pyspark
解决方案是更好的。