我最近用几个自定义转换器创建了pyspark PipelineModel,以生成不适用于本机Spark转换器的功能。这是我的一个变压器的示例。它接受字符串标签输入,并返回输入的超类标签:
class newLabelMap(
Transformer,HasInputCol,HasOutputCol,DefaultParamsReadable,DefaultParamsWritable,):
inputCol = Param(Params._dummy(),"inputCol","The input column",TypeConverters.toString)
outputCol = Param(Params._dummy(),"outputCol","The output column",TypeConverters.toString)
def __init__(self,inputCol = "",outputCol=""):
super(newLabelMap,self).__init__()
self._setDefault(inputCol="")
self._setDefault(outputCol="")
self._set(inputCol=inputCol)
self._set(outputCol=outputCol)
def getInputCol(self):
return self.getOrDefault(self.inputCol)
def setInputCol(self,inputCol):
self._set(inputCol=inputCol)
def getOutputCol(self):
return self.getOrDefault(self.outputCol)
def setOutputCol(self,outputCol):
self._set(outputCol=outputCol)
def _transform(self,dataset):
@udf("string")
def findLabel(labelVal):
new_label_dict = {'oldLabel0' : 'newLabel0','oldLabel1' : 'newLabel1','oldLabel2' : 'newLabel1','oldLabel3' : 'newLabel1','oldLabel4' : 'newLabel2','oldLabel5' : 'newLabel2','oldLabel6' : 'newLabel2','oldLabel7' : 'newLabel3','oldLabel8' : 'newLabel3','oldLabel9' : 'newLabel4','oldLabel10' : 'newLabel4'}
try:
labelKey = new_label_dict[labelVal]
return labelKey
except:
return 'other'
out_col = self.getOutputCol()
in_col = dataset[self.getInputCol()]
return dataset.withColumn(out_col,findLabel(in_col))
转换器在管道中可以正常工作,我可以保存它,将其加载回pyspark会话中,然后进行转换而不会出现任何问题。当我尝试将其导入scala环境时,就会出现问题。当我尝试加载模型时,收到以下错误输出:
Name: java.lang.IllegalArgumentException
Message: requirement failed: Error loading metadata: Expected class name org.apache.spark.ml.PipelineModel but found class name pyspark.ml.pipeline.PipelineModel
StackTrace: at scala.Predef$.require(Predef.scala:224)
at org.apache.spark.ml.util.DefaultParamsReader$.parseMetadata(ReadWrite.scala:638)
at org.apache.spark.ml.util.DefaultParamsReader$.loadMetadata(ReadWrite.scala:616)
at org.apache.spark.ml.Pipeline$SharedReadWrite$.load(Pipeline.scala:267)
at org.apache.spark.ml.PipelineModel$PipelineModelReader.load(Pipeline.scala:348)
at org.apache.spark.ml.PipelineModel$PipelineModelReader.load(Pipeline.scala:342)
如果我删除自定义转换器,它将在Scala中正常加载,因此我很好奇如何使用pyspark编写的自定义转换器,该自定义转换器可以在PipelineModel中移植到Scala环境中?我需要以任何方式附加我的代码吗?任何帮助,我们将不胜感激:)