tf2.0中如何获取tensor的值?
How to get the value of tensor in tf2.0?
IMAGE_FEATURE_MAP = {
'image/filename': tf.io.FixedLenFeature([], tf.string),
'image/encoded': tf.io.FixedLenFeature([], tf.string),
'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
'image/object/class/text': tf.io.VarLenFeature(tf.string),
}
yolo_max_boxes = 100
def parse_tfrecord(tfrecord, class_table, size):
x = tf.io.parse_single_example(tfrecord, IMAGE_FEATURE_MAP)
x_train = tf.image.decode_jpeg(x['image/encoded'], channels=3)
x_train = tf.image.resize(x_train, (size, size))
class_text = tf.sparse.to_dense(
x['image/object/class/text'], default_value='')
labels = tf.cast(class_table.lookup(class_text), tf.float32)
y_train = tf.stack([tf.sparse.to_dense(x['image/object/bbox/xmin']),
tf.sparse.to_dense(x['image/object/bbox/ymin']),
tf.sparse.to_dense(x['image/object/bbox/xmax']),
tf.sparse.to_dense(x['image/object/bbox/ymax']),
labels], axis=1)
paddings = [[0, yolo_max_boxes - tf.shape(y_train)[0]], [0, 0]]
y_train = tf.pad(y_train, paddings)
return x_train, y_train
def load_tfrecord_dataset(file_pattern, class_file, size=416):
LINE_NUMBER = -1
class_table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer(
class_file, tf.string, 0, tf.int64, LINE_NUMBER, delimiter="\n"), -1)
files = tf.data.Dataset.list_files(file_pattern)
dataset = files.flat_map(tf.data.TFRecordDataset)
result = dataset.map(lambda x: parse_tfrecord(x, class_table, size))
return result
def main():
load_tfrecord_dataset('../data/facemask2020_train.tfrecord', '../data/mask2020.names', size=416)
if __name__ == '__main__':
main()
当我打印x ['image/filename']时,我只能得到张量的形状和数据类型。我想获取tf.record文件中文件名的详细记录,但是看不到这个tensor.What的具体值,我该怎么做才能查看他的具体值呢?我是新手,请帮助我。
您发布的代码中没有 print
。但我相信遍历数据集应该可行:
def main():
ds = load_tfrecord_dataset('../data/facemask2020_train.tfrecord',
'../data/mask2020.names', size=416)
for r in ds:
print(r['image/filename'])
IMAGE_FEATURE_MAP = {
'image/filename': tf.io.FixedLenFeature([], tf.string),
'image/encoded': tf.io.FixedLenFeature([], tf.string),
'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
'image/object/class/text': tf.io.VarLenFeature(tf.string),
}
yolo_max_boxes = 100
def parse_tfrecord(tfrecord, class_table, size):
x = tf.io.parse_single_example(tfrecord, IMAGE_FEATURE_MAP)
x_train = tf.image.decode_jpeg(x['image/encoded'], channels=3)
x_train = tf.image.resize(x_train, (size, size))
class_text = tf.sparse.to_dense(
x['image/object/class/text'], default_value='')
labels = tf.cast(class_table.lookup(class_text), tf.float32)
y_train = tf.stack([tf.sparse.to_dense(x['image/object/bbox/xmin']),
tf.sparse.to_dense(x['image/object/bbox/ymin']),
tf.sparse.to_dense(x['image/object/bbox/xmax']),
tf.sparse.to_dense(x['image/object/bbox/ymax']),
labels], axis=1)
paddings = [[0, yolo_max_boxes - tf.shape(y_train)[0]], [0, 0]]
y_train = tf.pad(y_train, paddings)
return x_train, y_train
def load_tfrecord_dataset(file_pattern, class_file, size=416):
LINE_NUMBER = -1
class_table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer(
class_file, tf.string, 0, tf.int64, LINE_NUMBER, delimiter="\n"), -1)
files = tf.data.Dataset.list_files(file_pattern)
dataset = files.flat_map(tf.data.TFRecordDataset)
result = dataset.map(lambda x: parse_tfrecord(x, class_table, size))
return result
def main():
load_tfrecord_dataset('../data/facemask2020_train.tfrecord', '../data/mask2020.names', size=416)
if __name__ == '__main__':
main()
当我打印x ['image/filename']时,我只能得到张量的形状和数据类型。我想获取tf.record文件中文件名的详细记录,但是看不到这个tensor.What的具体值,我该怎么做才能查看他的具体值呢?我是新手,请帮助我。
您发布的代码中没有 print
。但我相信遍历数据集应该可行:
def main():
ds = load_tfrecord_dataset('../data/facemask2020_train.tfrecord',
'../data/mask2020.names', size=416)
for r in ds:
print(r['image/filename'])