Keras model.summary() 对象到字符串
Keras model.summary() object to string
我想写一个包含神经网络超参数和模型架构的 *.txt 文件。是否可以将对象 model.summary() 写入我的输出文件?
(...)
summary = str(model.summary())
(...)
out = open(filename + 'report.txt','w')
out.write(summary)
out.close
碰巧我得到了 "None",如下所示。
Hyperparameters
=========================
learning_rate: 0.01
momentum: 0.8
decay: 0.0
batch size: 128
no. epochs: 3
dropout: 0.5
-------------------------
None
val_acc: 0.232323229313
val_loss: 3.88496732712
train_acc: 0.0965207634216
train_loss: 4.07161939425
train/val loss ratio: 1.04804469418
知道如何处理吗?
我也遇到了同样的问题!
有两种可能的解决方法:
使用模型的to_json()
方法
summary = str(model.to_json())
这是你上面的情况。
否则使用keras_diagram
中的ascii方法
from keras_diagram import ascii
summary = ascii(model)
这不是最好的方法,但您可以做的一件事是重定向标准输出:
orig_stdout = sys.stdout
f = open('out.txt', 'w')
sys.stdout = f
print(model.summary())
sys.stdout = orig_stdout
f.close()
见"How to redirect 'print' output to a file using python?"
虽然不能完全替代 model.summary,但一个选项是使用 model.get_config()
导出模型的配置。来自 the docs:
model.get_config()
: returns a dictionary containing the configuration of the model. The model can be reinstantiated from its config via:
config = model.get_config()
model = Model.from_config(config)
# or, for Sequential:
model = Sequential.from_config(config)
使用我的 Keras (2.0.6
) 和 Python (3.5.0
) 版本,这对我有用:
# Create an empty model
from keras.models import Sequential
model = Sequential()
# Open the file
with open(filename + 'report.txt','w') as fh:
# Pass the file handle in as a lambda function to make it callable
model.summary(print_fn=lambda x: fh.write(x + '\n'))
这会将以下行输出到文件中:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________
如果您想写入日志:
import logging
logger = logging.getLogger(__name__)
model.summary(print_fn=logger.info)
对我来说,这只是将模型摘要作为字符串获取:
stringlist = []
model.summary(print_fn=lambda x: stringlist.append(x))
short_model_summary = "\n".join(stringlist)
print(short_model_summary)
我知道 OP 已经接受了 winni2k 的回答,但是由于问题标题实际上意味着将 model.summary()
的输出保存到 string,而不是文件,所以以下代码可能会帮助访问此页面的其他人(就像我一样)。
下面的代码是 运行 使用 TensorFlow 1.12.0
,它在 Python 3.6.2
.
上随 Keras 2.1.6-tf
一起提供
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
import io
# Example model
model = Sequential([
Dense(32, input_shape=(784,)),
Activation('relu'),
Dense(10),
Activation('softmax'),
])
def get_model_summary(model):
stream = io.StringIO()
model.summary(print_fn=lambda x: stream.write(x + '\n'))
summary_string = stream.getvalue()
stream.close()
return summary_string
model_summary_string = get_model_summary(model)
print(model_summary_string)
产生(作为字符串):
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 32) 25120
_________________________________________________________________
activation (Activation) (None, 32) 0
_________________________________________________________________
dense_1 (Dense) (None, 10) 330
_________________________________________________________________
activation_1 (Activation) (None, 10) 0
=================================================================
Total params: 25,450
Trainable params: 25,450
Non-trainable params: 0
_________________________________________________________________
我遇到了同样的问题。 @Pasa 的回答非常有用,但我想我会 post 一个更小的例子:这是一个合理的假设,你此时已经有了一个 Keras 模型。
import io
s = io.StringIO()
model.summary(print_fn=lambda x: s.write(x + '\n'))
model_summary = s.getvalue()
s.close()
print("The model summary is:\n\n{}".format(model_summary))
这个字符串何时有用的例子:如果你有一个 matplotlib 图。然后你可以使用:
plt.text(0, 0.25, model_summary)
要将模型摘要写入性能图表,以供快速参考:
因为我来这里是想找到一种方法来记录摘要,所以我想与@ajb 的回答分享这个小技巧,以避免在每一行出现 INFO:
在日志文件中使用@FAnders 回答:
def get_model_summary(model: tf.keras.Model) -> str:
string_list = []
model.summary(line_length=80, print_fn=lambda x: string_list.append(x))
return "\n".join(string_list)
# some code
logging.info(get_model_summary(model)
生成的日志文件如下:
我想写一个包含神经网络超参数和模型架构的 *.txt 文件。是否可以将对象 model.summary() 写入我的输出文件?
(...)
summary = str(model.summary())
(...)
out = open(filename + 'report.txt','w')
out.write(summary)
out.close
碰巧我得到了 "None",如下所示。
Hyperparameters
=========================
learning_rate: 0.01
momentum: 0.8
decay: 0.0
batch size: 128
no. epochs: 3
dropout: 0.5
-------------------------
None
val_acc: 0.232323229313
val_loss: 3.88496732712
train_acc: 0.0965207634216
train_loss: 4.07161939425
train/val loss ratio: 1.04804469418
知道如何处理吗?
我也遇到了同样的问题! 有两种可能的解决方法:
使用模型的to_json()
方法
summary = str(model.to_json())
这是你上面的情况。
否则使用keras_diagram
中的ascii方法from keras_diagram import ascii
summary = ascii(model)
这不是最好的方法,但您可以做的一件事是重定向标准输出:
orig_stdout = sys.stdout
f = open('out.txt', 'w')
sys.stdout = f
print(model.summary())
sys.stdout = orig_stdout
f.close()
见"How to redirect 'print' output to a file using python?"
虽然不能完全替代 model.summary,但一个选项是使用 model.get_config()
导出模型的配置。来自 the docs:
model.get_config()
: returns a dictionary containing the configuration of the model. The model can be reinstantiated from its config via:config = model.get_config() model = Model.from_config(config) # or, for Sequential: model = Sequential.from_config(config)
使用我的 Keras (2.0.6
) 和 Python (3.5.0
) 版本,这对我有用:
# Create an empty model
from keras.models import Sequential
model = Sequential()
# Open the file
with open(filename + 'report.txt','w') as fh:
# Pass the file handle in as a lambda function to make it callable
model.summary(print_fn=lambda x: fh.write(x + '\n'))
这会将以下行输出到文件中:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________
如果您想写入日志:
import logging
logger = logging.getLogger(__name__)
model.summary(print_fn=logger.info)
对我来说,这只是将模型摘要作为字符串获取:
stringlist = []
model.summary(print_fn=lambda x: stringlist.append(x))
short_model_summary = "\n".join(stringlist)
print(short_model_summary)
我知道 OP 已经接受了 winni2k 的回答,但是由于问题标题实际上意味着将 model.summary()
的输出保存到 string,而不是文件,所以以下代码可能会帮助访问此页面的其他人(就像我一样)。
下面的代码是 运行 使用 TensorFlow 1.12.0
,它在 Python 3.6.2
.
2.1.6-tf
一起提供
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
import io
# Example model
model = Sequential([
Dense(32, input_shape=(784,)),
Activation('relu'),
Dense(10),
Activation('softmax'),
])
def get_model_summary(model):
stream = io.StringIO()
model.summary(print_fn=lambda x: stream.write(x + '\n'))
summary_string = stream.getvalue()
stream.close()
return summary_string
model_summary_string = get_model_summary(model)
print(model_summary_string)
产生(作为字符串):
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 32) 25120
_________________________________________________________________
activation (Activation) (None, 32) 0
_________________________________________________________________
dense_1 (Dense) (None, 10) 330
_________________________________________________________________
activation_1 (Activation) (None, 10) 0
=================================================================
Total params: 25,450
Trainable params: 25,450
Non-trainable params: 0
_________________________________________________________________
我遇到了同样的问题。 @Pasa 的回答非常有用,但我想我会 post 一个更小的例子:这是一个合理的假设,你此时已经有了一个 Keras 模型。
import io
s = io.StringIO()
model.summary(print_fn=lambda x: s.write(x + '\n'))
model_summary = s.getvalue()
s.close()
print("The model summary is:\n\n{}".format(model_summary))
这个字符串何时有用的例子:如果你有一个 matplotlib 图。然后你可以使用:
plt.text(0, 0.25, model_summary)
要将模型摘要写入性能图表,以供快速参考:
因为我来这里是想找到一种方法来记录摘要,所以我想与@ajb 的回答分享这个小技巧,以避免在每一行出现 INFO:
在日志文件中使用@FAnders 回答:
def get_model_summary(model: tf.keras.Model) -> str:
string_list = []
model.summary(line_length=80, print_fn=lambda x: string_list.append(x))
return "\n".join(string_list)
# some code
logging.info(get_model_summary(model)
生成的日志文件如下: