如何获取火炬中心模型生成的翻译的对齐或注意信息?
how to get alignment or attention information for translations produced by a torch hub model?
火炬中心提供预训练模型,例如:https://pytorch.org/hub/pytorch_fairseq_translation/
这些模型可以在 python 中使用,或者与 CLI 交互使用。
使用 CLI 可以使用 --print-alignment
标志进行对齐。 The following code works in a terminal, after installing fairseq(和火炬)
curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf -
MODEL_DIR=wmt14.en-fr.fconv-py
fairseq-interactive \
--path $MODEL_DIR/model.pt $MODEL_DIR \
--beam 5 --source-lang en --target-lang fr \
--tokenizer moses \
--bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes \
--print-alignment
在 python 中可以指定关键字参数 verbose
和 print_alignment
:
import torch
en2fr = torch.hub.load('pytorch/fairseq', 'transformer.wmt14.en-fr', tokenizer='moses', bpe='subword_nmt')
fr = en2fr.translate('Hello world!', beam=5, verbose=True, print_alignment=True)
但是,这只会将对齐输出为日志消息。对于 fairseq 0.9,它似乎已损坏并导致错误消息 (issue)。
有没有办法从 python 代码访问对齐信息(甚至可能是完整的注意力矩阵)?
我浏览了 fairseq 代码库,发现了一种输出对齐信息的怪异方法。
因为这需要编辑 fairseq 源代码本身,所以我认为这不是一个可以接受的解决方案。但也许它对某人有帮助(我仍然对如何正确执行此操作的答案很感兴趣)。
编辑 sample() function 并重写 return 语句。
这是整个函数(为了帮助您更好地在代码中找到它),但只应更改最后一行:
def sample(self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs) -> List[str]:
if isinstance(sentences, str):
return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
return list(zip([self.decode(hypos[0]['tokens']) for hypos in batched_hypos], [hypos[0]['alignment'] for hypos in batched_hypos]))
火炬中心提供预训练模型,例如:https://pytorch.org/hub/pytorch_fairseq_translation/
这些模型可以在 python 中使用,或者与 CLI 交互使用。
使用 CLI 可以使用 --print-alignment
标志进行对齐。 The following code works in a terminal, after installing fairseq(和火炬)
curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf -
MODEL_DIR=wmt14.en-fr.fconv-py
fairseq-interactive \
--path $MODEL_DIR/model.pt $MODEL_DIR \
--beam 5 --source-lang en --target-lang fr \
--tokenizer moses \
--bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes \
--print-alignment
在 python 中可以指定关键字参数 verbose
和 print_alignment
:
import torch
en2fr = torch.hub.load('pytorch/fairseq', 'transformer.wmt14.en-fr', tokenizer='moses', bpe='subword_nmt')
fr = en2fr.translate('Hello world!', beam=5, verbose=True, print_alignment=True)
但是,这只会将对齐输出为日志消息。对于 fairseq 0.9,它似乎已损坏并导致错误消息 (issue)。
有没有办法从 python 代码访问对齐信息(甚至可能是完整的注意力矩阵)?
我浏览了 fairseq 代码库,发现了一种输出对齐信息的怪异方法。 因为这需要编辑 fairseq 源代码本身,所以我认为这不是一个可以接受的解决方案。但也许它对某人有帮助(我仍然对如何正确执行此操作的答案很感兴趣)。
编辑 sample() function 并重写 return 语句。 这是整个函数(为了帮助您更好地在代码中找到它),但只应更改最后一行:
def sample(self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs) -> List[str]:
if isinstance(sentences, str):
return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
return list(zip([self.decode(hypos[0]['tokens']) for hypos in batched_hypos], [hypos[0]['alignment'] for hypos in batched_hypos]))