BertForSequenceClassification 与 BertForMultipleChoice 用于句子多 class classification

BertForSequenceClassification vs. BertForMultipleChoice for sentence multi-class classification

我正在处理一个文本class化问题(例如情绪分析),我需要class将一个文本字符串class化为五个class之一。

我刚开始使用 Huggingface Transformer package and BERT with PyTorch. What I need is a classifier with a softmax layer on top so that I can do 5-way classification. Confusingly, there seem to be two relevant options in the Transformer package: BertForSequenceClassification and BertForMultipleChoice

我的 5 向 classification 任务应该使用哪一个?它们的合适用例是什么?

BertForSequenceClassification 的文档根本没有提到 softmax,尽管它确实提到了交叉熵。我不确定这个 class 是否仅适用于 2-class classification(即逻辑回归)。

Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks.

  • labels (torch.LongTensor of shape (batch_size,), optional, defaults to None) – Labels for computing the sequence classification/regression loss. Indices should be in [0, ..., config.num_labels - 1]. If config.num_labels == 1 a regression loss is computed (Mean-Square loss), If config.num_labels > 1 a classification loss is computed (Cross-Entropy).

BertForMultipleChoice 的文档提到了 softmax,但是描述标签的方式,听起来像这样 class 是针对多标签 classification (即,多个标签的二进制 class化)。

Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.

  • labels (torch.LongTensor of shape (batch_size,), optional, defaults to None) – Labels for computing the multiple choice classification loss. Indices should be in [0, ..., num_choices] where num_choices is the size of the second dimension of the input tensors.

感谢您的帮助。

答案在于对任务内容的(诚然非常简短)描述:

[BertForMultipleChoice] [...], e.g. for RocStories/SWAG tasks.

查看 paper for SWAG 时,任务似乎实际上是在学习从不同的选项中进行选择。这与您的 "classical" 分类任务形成对比,在该任务中 "choices"(即 类)不会在您的样本中变化,这是BertForSequenceClassification 的确切用途。

这两种变体实际上可以是任意数量的类(在BertForSequenceClassification的情况下),分别选择(对于BertForMultipleChoice),通过改变labels 配置中的参数。但是,由于您似乎正在处理 "classical classification" 的情况,我建议使用 BertForSequenceClassification 模型。

快速解决 BertForSequenceClassification 中缺失的 Softmax:由于分类任务可以计算 类 独立于样本的损失(与多项选择不同,您的分布正在发生变化),这使您可以使用交叉熵损失,它在 increased numerical stability.

的反向传播步骤中考虑了 Softmax