使用输入 fn 在 Tensorflow 估计器中进行预测
Predict in Tensorflow estimator using input fn
我使用了 https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/learn/wide_n_deep_tutorial.py 中的教程代码并且代码工作正常,直到我尝试做出预测而不是仅仅评估它。我试图创建另一个预测函数,看起来像这样(通过删除参数 y):
def input_fn_predict(data_file, num_epochs, shuffle):
"""Input builder function."""
df_data = pd.read_csv(
tf.gfile.Open(data_file),
names=CSV_COLUMNS,
skipinitialspace=True,
engine="python",
skiprows=1)
# remove NaN elements
df_data = df_data.dropna(how="any", axis=0)
labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
return tf.estimator.inputs.pandas_input_fn( #removed paramter y
x=df_data,
batch_size=100,
num_epochs=num_epochs,
shuffle=shuffle,
num_threads=5)
并这样称呼它:
predictions = m.predict(
input_fn=input_fn_predict(test_file_name, num_epochs=1, shuffle=True)
)
for i, p in enumerate(predictions):
print(i, p)
- 我做的对吗?
- 为什么我得到的预测是 81404 而不是 16282(测试文件中的行数)?
- 每行包含如下内容:
{'probabilities': array([ 0.78595656, 0.21404342], dtype=float32),
'logits': array([-1.3007226], dtype=float32), 'classes': array(['0'],
dtype=object), 'class_ids': array([0]), 'logistic': array([
0.21404341], dtype=float32)}
我该如何阅读?
您需要设置 shuffle=False
因为要预测新标签,您需要保持数据顺序。
下面是我 运行 预测的代码(我已经测试过了)。输入文件类似于测试数据(csv),但没有标签列。
def predict_input_fn(data_file):
global CSV_COLUMNS
CSV_COLUMNS = CSV_COLUMNS[:-1]
df_data = pd.read_csv(
tf.gfile.Open(data_file),
names=CSV_COLUMNS,
skipinitialspace=True,
engine='python',
skiprows=1
)
# remove NaN elements
df_data = df_data.dropna(how='any', axis=0)
return tf.estimator.inputs.pandas_input_fn(
x=df_data,
num_epochs=1,
shuffle=False
)
调用它:
predict_file_name = 'tutorials/data/adult.predict'
results = m.predict(
input_fn=predict_input_fn(predict_file_name)
)
for result in results:
print 'result: {}'.format(result)
一个样本的预测结果如下:
{
'probabilities': array([0.78595656, 0.21404342], dtype = float32),
'logits': array([-1.3007226], dtype = float32),
'classes': array(['0'], dtype = object),
'class_ids': array([0]),
'logistic': array([0.21404341], dtype = float32)
}
每个字段的含义是
- 'probabilities': array([0.78595656, 0.21404342], dtype = float32).
它预测输出标签是 class -0(在本例中 <=50K)
置信度 0.78595656
- 'logits': 数组([-1.3007226], dtype = float32)
等式1/(1+e^(-z))中z的值为-1.3。
- 'classes': array(['0'], dtype = object)
class 标签为 0
我使用了 https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/learn/wide_n_deep_tutorial.py 中的教程代码并且代码工作正常,直到我尝试做出预测而不是仅仅评估它。我试图创建另一个预测函数,看起来像这样(通过删除参数 y):
def input_fn_predict(data_file, num_epochs, shuffle):
"""Input builder function."""
df_data = pd.read_csv(
tf.gfile.Open(data_file),
names=CSV_COLUMNS,
skipinitialspace=True,
engine="python",
skiprows=1)
# remove NaN elements
df_data = df_data.dropna(how="any", axis=0)
labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
return tf.estimator.inputs.pandas_input_fn( #removed paramter y
x=df_data,
batch_size=100,
num_epochs=num_epochs,
shuffle=shuffle,
num_threads=5)
并这样称呼它:
predictions = m.predict(
input_fn=input_fn_predict(test_file_name, num_epochs=1, shuffle=True)
)
for i, p in enumerate(predictions):
print(i, p)
- 我做的对吗?
- 为什么我得到的预测是 81404 而不是 16282(测试文件中的行数)?
- 每行包含如下内容:
{'probabilities': array([ 0.78595656, 0.21404342], dtype=float32), 'logits': array([-1.3007226], dtype=float32), 'classes': array(['0'], dtype=object), 'class_ids': array([0]), 'logistic': array([ 0.21404341], dtype=float32)}
我该如何阅读?
您需要设置 shuffle=False
因为要预测新标签,您需要保持数据顺序。
下面是我 运行 预测的代码(我已经测试过了)。输入文件类似于测试数据(csv),但没有标签列。
def predict_input_fn(data_file):
global CSV_COLUMNS
CSV_COLUMNS = CSV_COLUMNS[:-1]
df_data = pd.read_csv(
tf.gfile.Open(data_file),
names=CSV_COLUMNS,
skipinitialspace=True,
engine='python',
skiprows=1
)
# remove NaN elements
df_data = df_data.dropna(how='any', axis=0)
return tf.estimator.inputs.pandas_input_fn(
x=df_data,
num_epochs=1,
shuffle=False
)
调用它:
predict_file_name = 'tutorials/data/adult.predict'
results = m.predict(
input_fn=predict_input_fn(predict_file_name)
)
for result in results:
print 'result: {}'.format(result)
一个样本的预测结果如下:
{
'probabilities': array([0.78595656, 0.21404342], dtype = float32),
'logits': array([-1.3007226], dtype = float32),
'classes': array(['0'], dtype = object),
'class_ids': array([0]),
'logistic': array([0.21404341], dtype = float32)
}
每个字段的含义是
- 'probabilities': array([0.78595656, 0.21404342], dtype = float32).
它预测输出标签是 class -0(在本例中 <=50K) 置信度 0.78595656 - 'logits': 数组([-1.3007226], dtype = float32)
等式1/(1+e^(-z))中z的值为-1.3。 - 'classes': array(['0'], dtype = object)
class 标签为 0