了解使用数据集 api 从两个不同数据源获取数据时的张量流行为

Understanding tensorflow behavior when fetching Data from two different data sources with dataset api

我正在尝试使用 tensorflow 从两个不同的 dataset 来源获取数据。我写了下面的代码:

首先我尝试了以下方法:

import tensorflow as tf
import numpy as np

iters = []

def return_data1():
    d1 = tf.data.Dataset.range(1, 2000)
    iter1 = d1.make_initializable_iterator()
    iters.append(iter1)
    data1 = iter1.get_next()
    return data1

def return_data2():
    d2 = tf.data.Dataset.range(2000, 4000)
    iter2 = d2.make_initializable_iterator()
    iters.append(iter2)
    data2 = iter2.get_next()
    return data2

test = tf.placeholder(dtype=tf.bool)

data = tf.cond(test, lambda: return_data1(), lambda: return_data2())

iter1 = iters[0]
iter2 = iters[1]

init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    sess.run([iter1.initializer, iter2.initializer])

    for i in range(2000):
        if i < 1000:
            print(sess.run(data, feed_dict={test: True}), "..")
        else:
            print(sess.run(data, feed_dict={test: False}), "--")

我收到以下错误:

ValueError: Operation 'cond/MakeIterator' has been marked as not fetchable.

1- 我想知道为什么我会出现这种行为。

然后,我尝试修复我的代码,所以我写了以下内容:

d1 = tf.data.Dataset.range(1, 2000)
d2 = tf.data.Dataset.range(2000, 4000)

iter1 = d1.make_initializable_iterator()
iter2 = d2.make_initializable_iterator()

data1 = iter1.get_next()
data2 = iter2.get_next()

def return_data1():
    return data1

def return_data2():
    return data2

test = tf.placeholder(dtype=tf.bool)

data = tf.cond(test, lambda: return_data1(), lambda: return_data2())

init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    sess.run([iter1.initializer, iter2.initializer])

    for i in range(2000):
        if i < 1000:
            print(sess.run(data, feed_dict={test: True}), "..")
        else:
            print(sess.run(data, feed_dict={test: False}), "--")

并且从 print 中,我从第一个数据集中得到 1 -> 1000 的数字,但是当 i 变得大于 1000 时,它开始打印 3000 -> 4000。因此,我得出结论,因为第一个数据集已经 运行ning 或者像我一样从第二个数据集中获取了一个 1000 元素,但它们被忽略了。

后来,当我通过移动 data1 = iter1.get_next()data2 = iter2.get_next() 进入函数定义为:

def return_data1():
    data1 = iter1.get_next()
    return data1

def return_data2():
    data2 = iter2.get_next()
    return data2

代码有效,现在打印数字 1 -> 10002000 -> 3000

我想了解为什么会这样,以免以后再犯类似的错误。

我发现了 tf.control_dependency 的同类问题,它接受操作作为参数,并且不应在外部创建该操作。这种行为让我困惑了一段时间,但我想知道为什么张量流会发生这种情况。

其次,如果我想从两个以上的数据集来源中进行选择,然后 运行 分别选择它们,如何在 tensorflow 中做到这一点?

任何帮助将不胜感激!!!

下面是如何分别从多个数据集中获取数据。然而,我想知道关于 tensorflow 行为的其他问题的答案,以及为什么 data2 = iter2.get_next() 应该在方法中定义。

import tensorflow as tf
import numpy as np

d1 = tf.data.Dataset.range(1, 1000)
iter1 = d1.make_initializable_iterator()

d2 = tf.data.Dataset.range(1000, 2000)
iter2 = d2.make_initializable_iterator()

d3 = tf.data.Dataset.range(2000, 3000)
iter3 = d3.make_initializable_iterator()

d4 = tf.data.Dataset.range(3000, 4000)
iter4 = d4.make_initializable_iterator()

def return_data1_2():
    data1 = iter1.get_next()
    data2 = iter2.get_next()
    return data1, data2

def return_data2_3():
    data2 = iter2.get_next()
    data3 = iter3.get_next()
    return data2, data3

def return_data3_4():
    data3 = iter3.get_next()
    data4 = iter4.get_next()
    return data3, data4

def return_data4_1():
    data4 = iter4.get_next()
    data1 = iter1.get_next()
    return data4, data1

index1 = tf.placeholder(dtype=tf.int32)
index2 = tf.placeholder(dtype=tf.int32)

data = tf.case(pred_fn_pairs=[
    (tf.logical_and(tf.equal(index1, 1), tf.equal(index2, 2)), lambda: return_data1_2()), 
    (tf.logical_and(tf.equal(index1, 2), tf.equal(index2, 3)), lambda: return_data2_3()),
    (tf.logical_and(tf.equal(index1, 3), tf.equal(index2, 4)), lambda: return_data3_4()),
    (tf.logical_and(tf.equal(index1, 4), tf.equal(index2, 1)), lambda: return_data4_1())], exclusive=False)

init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    sess.run([iter1.initializer, iter2.initializer, iter3.initializer, iter4.initializer])


    for i in range(2000):
        try:
            if i < 500:
                print(sess.run(data, feed_dict={index1: 1, index2: 2}), "1-2")
            elif i < 1000:
                print(sess.run(data, feed_dict={index1: 2, index2: 3}), "2-3")
            elif i < 1500:
                print(sess.run(data, feed_dict={index1: 3, index2: 4}), "3-4")
            elif i < 2000:
                print(sess.run(data, feed_dict={index1: 4, index2: 1}), "4-1")
        except tf.errors.OutOfRangeError as error:
            print("error")