如何针对 TF.Dataset 的 class 标签绘制直方图

How to plot histogram against class label for TF.Dataset

我正在使用 TF.CsvDataset 从磁盘加载数据。并将数据绘制为

#This is the transformation function applied on loaded data before displaying histogram.
def preprocess(*fields):
    print(len(fields))
    features=tf.stack(fields[:-1])
    labels=tf.stack([int(x) for x in fields[-1:]])
    return features,labels  # x, y

for features,label in train_ds.take(1000):
#  print(features[0])
plt.hist(features.numpy().flatten(), bins = 101)

我得到了这个直方图

但我想根据二进制 class 标签绘制 712 个特征值的分布。即class标签为0时,特征1,2或3的值是多少。

如何使用 pyplot 做到这一点?

我已阅读以下主题,但没有任何帮助。

Plotting histograms against classes in pandas / matplotlib

How to draw an histogram with multiple categories in python

您可以使用 np.fromiter 并获取所有标签。然后你只需将标签列表传递给 plt.hist:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

train, test = tf.keras.datasets.mnist.load_data()

ds = tf.data.Dataset.from_tensor_slices(train)

vals = np.fromiter(ds.map(lambda x, y: y), float)

plt.hist(vals)
plt.xticks(range(10))
plt.title('Label Frequency')
plt.show()