Pretrained Bidirectional Distillation for Machine Translation

By Yimeng Zhuang Samsung R&D Institute China-Beijing
By Mei Tu Samsung R&D Institute China-Beijing


Initializing parameters by a pretrained masked language model (LM) [1] is a knowledge transfer method widely applied to natural language processing tasks. Following its success, pretrained neural machine translation (NMT) models have attracted more and more research interest [2,3,4,5].

However, the pretrain-finetune paradigm may suffer from potential issues. As is pointed out in [6], the finetuned model may forget some critical language generation skills learned from the pretraining phase. The catastrophic forgetting problem [7,8] commonly exists in transfer learning, leading to overfitting to target domains. [9,10] also observe similar forgetting problems in pretrained NMT tasks. Besides, in the pretrain-finetune paradigm, model parameters are initialized by a pretrained model; this requires structure consistency (e.g., exact dimensions, layers, attention heads, etc.) between the pretrained LM and the NMT models to some extent. However, a powerful but structurally inconsistent pretrained LM may incorporate more language knowledge.

In this paper, we propose Pretrained Bidirectional Distillation (PBD) for NMT, which can alleviate the difference caused by pretraining (mask language modeling, perturbed sentences) and MT fine-tuning (full sentences) in the pretrain-finetune paradigm and boost large-scale translation training. In pretrained bidirectional distillation, language knowledge acquired from pretraining is continuously transferred to the NMT model. Knowledge transfer runs through the training process to address the forgetting problem. We deal with the pretrained language knowledge by pretrained bidirectional distillation objectives, which are the token probabilities generated by the pretrained LM about potential tokens matching a global context. The pretrained bidirectional distillation objectives are distilled to the encoder and decoder of an NMT model. Therefore, there is no need to require structure consistency between pretrained LMs and NMT models, and bidirectional distillation enriches the NMT decoder with bidirectional semantic information.

Pretrained Bidirectional Distillation

Figure 1.  Overall training flow of pretrained bidirectional distillation for machine translation.

Algorithm 1.  Pretrained Bidirectional Distillation for NMT

Figure 1 and Algorithm 1 illustrate the overall flow of the proposed Pretrained Bidirectional Distillation (PBD) for machine translation. It consists of two processes: (1) Self-distilled masked language pretraining takes unlabeled LM training data as input and optimizes a token reconstruction loss and a self-distillation loss. The produced self-distilled LM has the advantage of generating the full probability prediction of all input tokens in one pass rather than only the masked tokens as in previous masked LMs. This ensures the efficiency of pretrained bidirectional distillation in the second process. (2) Translation training with PBD losses trains a standard Encoder-Decoder NMT model using parallel data but enhances it with extra PBD losses. The PBD losses are jointly optimized with the standard translation loss, and pretrained language knowledge in the form of full token probabilities generated by the pretrained LM is distilled to the encoder and decoder of the NMT model.

Self-distilled Masked Language Pretraining

Figure 2.  Overall architecture of the self-distilled masked language model.

This paper proposes self-distilled masked language pretraining to obtain the pretrained bidirectional distillation objective for NMT models. Pretrained masked language models predict a token probability distribution over the vocabulary for each masked position, and these token probabilities indicate the potential tokens matching the context. Our assumption is that these token probabilities contain specific language knowledge and can be transferred to NMT models. Thus, we consider these token probabilities as the distillation objective.

However, in our preliminary experiments, we discovered that the token probabilities predicted in non-masked positions often tend to focus too much on real tokens, which fails to accurately reflect the long-tailed distribution of potential tokens. In standard masked language pretraining, only a small percentage (typically 15%) of tokens can be masked. This limitation prevents us from efficiently achieving the full distillation objective that reflects the long-tailed distribution for each position of an input sequence in a single forward pass. To obtain a globally defined distillation objective, we adopt self-distillation, in which the token probabilities in non-masked positions are learned from the corresponding masked positions. Figure 2 illustrates the overall architecture of the proposed self-distilled masked language model, which follows the widely used masked language model framework with some modifications to its architecture: (1) The target tokens to be predicted have two types: masked tokens and real tokens. (2) The input sequence is partitioned into three parts to avoid exposing information between masked tokens and real tokens. (3) Masked and real tokens have different prediction heads and loss functions.

Pretrained Bidirectional Distillation Loss

In this paper, the knowledge learned from the aforementioned self-distilled mask language model is transferred to an NMT model using the pretrained bidirectional distillation loss. Specifically, we concatenate the source and target sentence without masking to form an input sequence to the self-distilled LM, and obtain the full probability prediction from the LM as the pretrained bidirectional distillation objective, which is distilled to a NMT model by optimizing the KL divergence between the pretrained bidirectional distillation objective and its corresponding predictions from an intermediate layer of the encoder or decoder. The distillation loss of the encoder is as follows.

Here, we use X and Y to denote the sentence in source and target language, respectively, and denotes the t-th position of X. w is a word in the vocabulary V. represents the hidden states of an intermediate layer l of the encoder. is the token embedding matrix. We reuse the token embedding matrix, therefore, the pretrained bidirectional distillation won't add any extra parameters. The t-th row and w-th column of the probability matrix is the value of .

Similar distillation loss is applied to the decoder.

where denotes the t-th position of the target sentence, and we use to represent the hidden states of an intermediate layer l of the decoder. Note that these distillation losses are jointly optimized with the standard translation loss when the NMT training.

The pretrained bidirectional distillation objective is not only globally defined but also bidirectional context aware (i.e., bidirectional language knowledge of the complete source and target sentence). Therefore, it is a challenging task to approximate the pretrained bidirectional distillation objective for the encoder and decoder given only a source sentence or given the source and partial target sentence, but it is reasonable since the source sentence has complete semantics information. On the other hand, the challenging task may force the NMT model to learn global language knowledge from the self-distilled LM. It can enrich the NMT decoder with bidirectional semantic information, as using future information is important for machine translation.

Experimental Results

We primarily study the proposed pretrained bidirectional distillation by conducting experiments on supervised, unsupervised, and zero-shot multilingual machine translation scenarios.

Table 1.  Performance of our model and competing approaches in the surprised translation scenario.

We trained a unified multilingual NMT model with pretrained bidirectional distillation. As is shown in Table 1, our proposed PBD-MT clearly outperforms previously published approaches and achieves new state-of-the-art performances in most translation directions. It achieves +0.76 average BLEU improvement over mRASP2, which validates the effectiveness of the proposed pretrained bidirectional distillation.

Table 2.  Performance of unified multilingual MT models in zero-shot translation directions.

Table 3.  Performance of unified multilingual MT models in unsupervised translation scenario.

Table 2 summarizes the performance of unified multilingual models on a zero-shot translation scenario. Although the training data only consists of English-centric parallel sentences, multilingual NMT models show promising performance on zero-shot translation. Compared with mRASP2, PBD-MT further boosts the translation quality in most zero-shot directions, achieving a +1.24 average gain. Besides, we evaluate the unified multilingual models in unsupervised translation directions, and the results are shown in Table 3. For PBD-MT, positive results are observed in all translation directions but one direction, and the average BLEU score increases by a +0.73 point. These results validate the positive effects of the proposed pretrained bidirectional distillation not only on supervised scenario but also zero-shot and unsupervised scenarios.


In this paper, we proposed the pretrained bidirectional distillation to investigate language knowledge transfer from pretrained language models to NMT models by knowledge distillation. The proposed approach has the advantages of distillation effectiveness and efficiency, and achieves new state-of-the-art performance in supervised, unsupervised, and zero-shot multilingual translation experiments. The model analysis also shows that the proposed self-distilled language model is critical to generating globally defined distillation objectives. In the future, we will do more research on optimizing the self-distilled language model and pretrained bidirectional distillation losses.

Link to the paper


[1]. Jacob Devlin Ming-Wei Chang Kenton and Lee Kristina Toutanova. 2019. Bert: Pre-training of deep bidirectional transformers for language understanding. In Proceedings of NAACL-HLT, pages 4171–4186.

[2]. Alexis Conneau and Guillaume Lample. 2019. Crosslingual language model pretraining. NIPS, 32.

[3]. Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, and TieYan Liu. 2019. Mass: Masked sequence to sequence pre-training for language generation. arXiv preprint arXiv:1905.02450.

[4]. Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, and Luke Zettlemoyer. 2020. Multilingual denoising pretraining for neural machine translation. TACL, 8:726–742.

[5]. Pengfei Li, Liangyou Li, Meng Zhang, Minghao Wu, and Qun Liu. 2022. Universal conditional masked language pre-training for neural machine translation. arXiv preprint arXiv:2203.09210.

[6]. Tianxing He, Jun Liu, Kyunghyun Cho, Myle Ott, Bing Liu, James Glass, and Fuchun Peng. 2021. Analyzing the forgetting problem in pretrain-finetuning of opendomain dialogue response models. In Proceedings of the 16th Conference of the European Chapter of the Association for Computational Linguistics: Main Volume, pages 1121–1133.

[7]. James Kirkpatrick, Razvan Pascanu, Neil Rabinowitz, Joel Veness, Guillaume Desjardins, Andrei A Rusu, Kieran Milan, John Quan, Tiago Ramalho, Agnieszka Grabska-Barwinska, et al. 2017. Overcoming catastrophic forgetting in neural networks. Proceedings of the national academy of sciences, 114(13):3521–3526.

[8]. Michael McCloskey and Neal J Cohen. 1989. Catastrophic interference in connectionist networks: The sequential learning problem. In Psychology of learning and motivation, volume 24, pages 109–165. Elsevier.

[9]. Junjie Hu, Hiroaki Hayashi, Kyunghyun Cho, and Graham Neubig. 2022. Deep: Denoising entity pretraining for neural machine translation. In Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 1753–1766.

[10]. Qingkai Fang, Rong Ye, Lei Li, Yang Feng, and Mingxuan Wang. 2022. Stemm: Self-learning with speechtext manifold mixup for speech translation. ACL 2022 (Volume 1: Long Papers), pages 7050–7062.