拟合多标签文本分类模型时的错误

Posted

技术标签:

【中文标题】拟合多标签文本分类模型时的错误【英文标题】:Bugs when fitting Multi label text classification models 【发布时间】:2019-12-17 23:09:57 【问题描述】:

我现在正在尝试为多标签文本分类问题拟合分类模型。

我有一个训练集 X_train,其中包含已清理文本的列表,例如

["I am constructing Markov chains with  to  states and inferring     
transition probabilities empirically by simply counting how many 
times I saw each transition in my raw data",
"I know the chips only of the  players of my table and mine obviously I 
also know the total number of chips the max and min amount chips the 
players have and the average stackIs it possible to make an 
approximation of my probability of winningI have,
...]

和一个训练多个标签集y对应X_train中的每个文本,比如

[['hypothesis-testing', 'statistical-significance', 'markov-process'],
['probability', 'normal-distribution', 'games'],
...]

现在我想拟合一个模型,该模型可以预测文本集 X_test 中的标签,其格式与 X_train 相同。

我已经使用MultiLabelBinarizer 转换标签并使用TfidfVectorizer 转换train set 中的清理文本。

multilabel_binarizer = MultiLabelBinarizer()
multilabel_binarizer.fit(y)
Y = multilabel_binarizer.transform(y)

vectorizer = TfidfVectorizer(stop_words = stopWordList)
vectorizer.fit(X_train)
x_train = vectorizer.transform(X_train)

但是当我尝试拟合模型时,我总是会遇到错误。我尝试过 OneVsRestClassifierLogisticRegression

当我安装 OneVsRestClassifier 模型时,我遇到了类似

的错误
Traceback (most recent call last):
  File "/opt/conda/envs/data3/lib/python3.6/socketserver.py", line 317, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/opt/conda/envs/data3/lib/python3.6/socketserver.py", line 348, in process_request
    self.finish_request(request, client_address)
  File "/opt/conda/envs/data3/lib/python3.6/socketserver.py", line 361, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/opt/conda/envs/data3/lib/python3.6/socketserver.py", line 696, in __init__
    self.handle()
  File "/usr/local/spark/python/pyspark/accumulators.py", line 268, in handle
    poll(accum_updates)
  File "/usr/local/spark/python/pyspark/accumulators.py", line 241, in poll
    if func():
  File "/usr/local/spark/python/pyspark/accumulators.py", line 245, in accum_updates
    num_updates = read_int(self.rfile)
  File "/usr/local/spark/python/pyspark/serializers.py", line 714, in read_int
    raise EOFError
EOFError

当我安装 LogisticRegression 模型时,我遇到了类似

的错误
/opt/conda/envs/data3/lib/python3.6/site-packages/sklearn/linear_model/sag.py:326: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge
  "the coef_ did not converge", ConvergenceWarning)

有人知道问题出在哪里以及如何解决吗?非常感谢。

【问题讨论】:

【参考方案1】:

OneVsRestClassifier 适合每个类一个分类器。您需要告诉它您想要哪种类型的分类器(例如 Losgistic 回归)。

以下代码适用于我:

from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression

classifier = OneVsRestClassifier(LogisticRegression())
classifier.fit(x_train, Y)

X_test= ["I play with Markov chains"]
x_test = vectorizer.transform(X_test)

classifier.predict(x_test)

输出:数组([[0, 1, 1, 0, 0, 1]])

【讨论】:

以上是关于拟合多标签文本分类模型时的错误的主要内容,如果未能解决你的问题,请参考以下文章

用于多标签问题的 keras 模型的 scikit 学习链分类器的拟合方法错误

Bert模型做多标签文本分类

BERT模型在多类别文本分类时的precision, recall, f1值的计算

多标签文本分类MSML-BERT模型的层级多标签文本分类方法研究

多标签文本分类融合CNN-SAM与GAT的多标签文本分类模型

使用深度学习改进多标签文本分类问题的结果 [关闭]