Tensorflow:再训练期间预训练嵌入初始化问题

Tensorflow: pre-trained embeddings initialization issue during retraining

我的目标是(1)从文件中加载一个预训练的词嵌入矩阵作为初始值; (2) Fine tune the word embedding 而不是保持不变; (3) 每次我恢复模型时,加载微调的词嵌入而不是预训练的词嵌入。

我已经尝试过:

class model():
    def __init__(self):
    # ...
    def _add_word_embed(self):
        W = tf.get_variable('W', [self._vsize, self._emb_size], 
                 initializer=tf.truncated_normal_initializer(stddev=1e-4))
        W.assign(load_and_read_w2v())
        # ...
    def _add_seq2seq(self):
        # ...
    def build_graph(self):
        self._add_word_embed()
        self._add_seq2seq()

但是每当我停止并重新启动训练时,这种方法都会覆盖经过微调的词嵌入。我在调用 model.build_graph 后也尝试了 sess.run(W.assign())。但是它抛出了一个错误,图表已经完成,我不能再改变它了。你能告诉我实现它的正确方法吗?提前致谢!

编辑:

这个问题没有重复,因为它有一个新的要求:在训练开始时使用预训练的单词嵌入,然后在训练之后进行调整。我还询问如何有效地做到这一点。该问题中接受的答案是不满足此要求的 FXXKING。在将任何问题标记为重复之前,您能否三思??????????

这是一个关于如何操作的玩具示例:

# The graph

# Inputs
vocab_size = 2
embed_dim = 2
embedding_matrix = np.ones((vocab_size, embed_dim))

#The weight matrix to initialize with embeddings
W = tf.get_variable(initializer=tf.zeros([vocab_size, embed_dim]), name='embed', trainable=True)

# global step used to take care of the weight initialization 
# for the first time will be loaded from numpy array and not during retraining.
global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')

# Initialiazation of weights based on global_step
initW = tf.cond(tf.equal(global_step, 0), lambda:W.assign(embedding_matrix), lambda: W)
inc = tf.assign_add(W,[[1, 1],[1, 1]])

# Update global step
update = tf.assign_add(global_step, 1)
op = tf.group(inc, update)

# init_fn 
def init_embed(sess):
  sess.run(initW)

现在,如果我们 运行 在会话中执行上述操作:

sv = tf.train.Supervisor(logdir='tmp',init_fn=init_embed)
with sv.managed_session() as sess:
   print('global step:', sess.run(global_step))
   print('Initial weight:')
   print(sess.run(W))
   for i in range(2):  
      sess.run([op])
    _ W, g_step= sess.run([W, global_step])
   print('Final weight:')        
   print(_W)
   sv.saver.save(sess,sv.save_path, global_step=g_step)

# Output at first run
   Initial weight:
   [[ 1.  1.]
   [ 1.  1.]]

   Final weight:
   [[ 3.  3.]
   [ 3.  3.]]

#Output at second run
   Initial weight:
   [[ 3.  3.]
   [ 3.  3.]]
   Final weight:
   [[ 5.  5.]
   [ 5.  5.]]