Tensorflow 2.3:如何并行化从大文件中读取文本?

Tensorflow 2.3: How to parallelize reading text from big file?

我需要将大小为 4GB 的数据集文件分解成小块。作为优化时间消耗的一部分,我想最大化并行处理。目前,我可以观察到 CPU 和 GPU 的核心未得到充分利用。请参阅图像中的附加输出 here

我的代码片段如下所示

def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def serialize_row(text, rating):
    # Create a dictionary mapping the feature name to the tf.Example-compatible data type.
    feature = {
        'text': _bytes_feature(text),
        'rating': _float_feature(rating),
    }

    # Create a Features message using tf.train.Example.
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

def transform(example):
    str_example = example.decode("utf-8")
    json_example = json.loads(str_example)
    overall = json_example.get('overall', -99)
    text = json_example.get('reviewText', '')
    if type(text) is str:
        text = bytes(text, 'utf-8')
    tf_serialized_string = serialize_row(text, overall)
    return tf_serialized_string 


line_dataset = tf.data.TextLineDataset(filenames=[file_path])
line_dataset = line_dataset.map(lambda row: tf.numpy_function(transform, [row], tf.string))
line_dataset = line_dataset.shuffle(2)
line_dataset = line_dataset.batch(NUM_OF_RECORDS_PER_BATCH_FILE)
'''
    Perform batchwise transformation of the population.
'''
start = time.time()
for idx, line in line_dataset.enumerate():
    FILE_NAMES = 'test{0}.tfrecord'.format(idx)
    end = time.time()
    time_taken = end - start
    tf.print('Processing for file - {0}'.format(FILE_NAMES))
    DIRECTORY_URL = '/home/gaurav.gupta/projects/practice/'
    filepath = os.path.join(DIRECTORY_URL, 'data-set', 'electronics', FILE_NAMES)
    batch_ds = tf.data.Dataset.from_tensor_slices(line)
    writer = tf.data.experimental.TFRecordWriter(filepath)
    writer.write(batch_ds)
    tf.print('Processing for file - {0} took {1}'.format(FILE_NAMES, time_taken))
tf.print('Done')

展示执行流程的日志

Processing for file - test0.tfrecord took 14.350863218307495
Processing for file - test1.tfrecord took 12.695453882217407
Processing for file - test2.tfrecord took 12.904462575912476
Processing for file - test3.tfrecord took 12.344425439834595
Processing for file - test4.tfrecord took 11.188365697860718
Processing for file - test5.tfrecord took 11.319620609283447
Processing for file - test6.tfrecord took 11.285977840423584
Processing for file - test7.tfrecord took 11.169529438018799
Processing for file - test8.tfrecord took 11.289997816085815
Processing for file - test9.tfrecord took 11.431073188781738
Processing for file - test10.tfrecord took 11.428141593933105
Processing for file - test11.tfrecord took 3.223125457763672
Done

我试过 num_parallel_reads 参数,但看不出有什么不同。我相信它在读取多个文件而不是单个大文件时会很方便。

我正在征求您的建议,将此任务并行化以减少时间消耗。

我会尝试这样的事情(我喜欢使用 joblib 因为它很容易放入现有代码,你可能可以用许多其他框架做类似的事情,此外,joblib 不使用 GPU也不使用任何 JITting):

from joblib import Parallel, delayed
from tqdm import tqdm
...

def process_file(idx, line):
  FILE_NAMES = 'test{0}.tfrecord'.format(idx)
  end = time.time()
  time_taken = end - start
  tf.print('Processing for file - {0}'.format(FILE_NAMES))
  DIRECTORY_URL = '/home/gaurav.gupta/projects/practice/'
  filepath = os.path.join(DIRECTORY_URL, 'data-set', 'electronics', FILE_NAMES)
  batch_ds = tf.data.Dataset.from_tensor_slices(line)
  writer = tf.data.experimental.TFRecordWriter(filepath)
  writer.write(batch_ds)
  #tf.print('Processing for file - {0} took {1}'.format(FILE_NAMES, time_taken))
  return FILE_NAMES, time_taken


times = Parallel(n_jobs=12, prefer="processes")(delayed(process_file)(idx, line) for idx, line in tqdm(line_dataset.enumerate(), total=len(line_dataset)))
print('Done.')

这是未经测试的代码,我也不确定它如何与 tf 代码一起工作,但我会试一试。

tqdm 完全没有必要,它只是我更喜欢使用的东西,因为它提供了一个漂亮的进度条。