使用 PyTorch 进行二进制分类的目标和输出 shape/type
Target and output shape/type for binary classification using PyTorch
所以我有一些带注释的图像,我想用它们来训练二值图像 classifier,但我在创建数据集和实际获取测试模型进行训练时遇到了问题。每个图像要么是某个 class,要么不是,所以我想使用 PyTorch 设置二进制 classification dataset/model。我有一些问题:
- 标签应该是浮动的还是长的?
- 我的标签应该是什么形状?
- 我正在使用来自 torchvision 模型的 resnet18 class,我的最终 softmax 层应该有一个还是两个输出?
- 如果我的批量大小是 200,在训练期间我的目标应该是什么形状?
- 我的输出应该是什么形状?
提前致谢
引用
删除
二进制 classification 与多标签 classification 略有不同:而对于多标签,您的模型预测 vector of "logits",per样本,并使用 softmax 将 logits 转换为概率;在二进制情况下,模型预测每个样本的 标量 “logit”,并使用 sigmoid 函数将其转换为 class 概率。
在 pytorch the softmax and the sigmoind are "folded" into the loss layer (for numerical stability considerations) and therefore there are different Cross Entropy loss layers for the two cases: nn.BCEWithLogitsLoss
for the binary case (with sigmoid) and nn.CrossEntropyLoss
中用于多标签情况(使用 softmax)。
在你的情况下你想使用二进制版本(带 sigmoid):nn.BCEWithLogitsLoss
.
因此,您的标签应该是 torch.float32
类型(与网络输出相同的 float
类型)而不是整数。
每个样本应该有一个 单个 标签。因此,如果您的批量大小为 200,则目标的形状应为 (200,1)
.
我将把它留在这里作为练习,以表明训练具有两个输出和 CE+softmax 的模型等同于二进制输出+sigmoid ;)
所以我有一些带注释的图像,我想用它们来训练二值图像 classifier,但我在创建数据集和实际获取测试模型进行训练时遇到了问题。每个图像要么是某个 class,要么不是,所以我想使用 PyTorch 设置二进制 classification dataset/model。我有一些问题:
- 标签应该是浮动的还是长的?
- 我的标签应该是什么形状?
- 我正在使用来自 torchvision 模型的 resnet18 class,我的最终 softmax 层应该有一个还是两个输出?
- 如果我的批量大小是 200,在训练期间我的目标应该是什么形状?
- 我的输出应该是什么形状?
提前致谢
引用 删除
二进制 classification 与多标签 classification 略有不同:而对于多标签,您的模型预测 vector of "logits",per样本,并使用 softmax 将 logits 转换为概率;在二进制情况下,模型预测每个样本的 标量 “logit”,并使用 sigmoid 函数将其转换为 class 概率。
在 pytorch the softmax and the sigmoind are "folded" into the loss layer (for numerical stability considerations) and therefore there are different Cross Entropy loss layers for the two cases: nn.BCEWithLogitsLoss
for the binary case (with sigmoid) and nn.CrossEntropyLoss
中用于多标签情况(使用 softmax)。
在你的情况下你想使用二进制版本(带 sigmoid):nn.BCEWithLogitsLoss
.
因此,您的标签应该是 torch.float32
类型(与网络输出相同的 float
类型)而不是整数。
每个样本应该有一个 单个 标签。因此,如果您的批量大小为 200,则目标的形状应为 (200,1)
.
我将把它留在这里作为练习,以表明训练具有两个输出和 CE+softmax 的模型等同于二进制输出+sigmoid ;)