那棵树看起来生气了
基于Tensor2Tensor的自然语言翻译
Tensor2Tensor
前言
谷歌研究人员发布了 Tensor2Tensor(T2T),一个用于在 TensorFlow 中训练深度学习模型的开源系统。T2T 能够帮助人们为各种机器学习程序创建最先进的模型,可应用于多个领域,如翻译、语法分析、图像信息描述等,大大提高了研究和开发的速度。T2T 中也包含一个数据集和模型库,其中包括模型(Attention Is All You Need、Depthwise Separable Convolutions for Neural Machine Translation 和 One Model to Learn Them All)。
正文
安装
默认已经安装好了python环境
1.运行环境安装
安装Tensor2Tensor
$ pip install tensor2tensor
2.安装TensoTensor
pip install TensorFlow
测试
生成数据
编写脚本"general.sh"用于生成数据
# See what problems, models, and hyperparameter sets are available.
# You can easily swap between them (and add new ones).
t2t-trainer --registry_help
HOME=$PWD
PROBLEM=translate_ende_wmt32k
MODEL=transformer
HPARAMS=transformer_base
DATA_DIR=$HOME/t2t_data
TMP_DIR=$HOME/tmp/t2t_datagen
TRAIN_DIR=$HOME/t2t_train/$PROBLEM/$MODEL-$HPARAMS
mkdir -p $DATA_DIR $TMP_DIR $TRAIN_DIR
# Generate data
t2t-datagen \
--data_dir=$DATA_DIR \
--tmp_dir=$TMP_DIR \
--problem=$PROBLEM
运行刚刚创建的脚本
$ sh general.sh
这个过程可能会持续一段时间,时间长短取决于你的网速。完成之后会生成如下文件:
$ ll t2t_data
-rw-rw-r--. 1 dyb dyb 7246355 Aug 27 13:04 translate_ende_wmt32k-train-00094-of-00100
-rw-rw-r--. 1 dyb dyb 7180757 Aug 27 13:04 translate_ende_wmt32k-train-00095-of-00100
-rw-rw-r--. 1 dyb dyb 7241253 Aug 27 13:04 translate_ende_wmt32k-train-00096-of-00100
-rw-rw-r--. 1 dyb dyb 7236535 Aug 27 13:04 translate_ende_wmt32k-train-00097-of-00100
-rw-rw-r--. 1 dyb dyb 7228109 Aug 27 13:04 translate_ende_wmt32k-train-00098-of-00100
-rw-rw-r--. 1 dyb dyb 7256118 Aug 27 13:04 translate_ende_wmt32k-train-00099-of-00100
-rw-rw-r--. 1 dyb dyb 321313 Aug 27 12:53 vocab.ende.32768
Training
编写脚本"training.sh"
HOME=$PWD
PROBLEM=translate_ende_wmt32k
MODEL=transformer
HPARAMS=transformer_base
DATA_DIR=$HOME/t2t_data
TMP_DIR=$HOME/tmp/t2t_datagen
TRAIN_DIR=$HOME/t2t_train/$PROBLEM/$MODEL-$HPARAMS
mkdir -p $DATA_DIR $TMP_DIR $TRAIN_DIR
# Train
# * If you run out of memory, add --hparams='batch_size=1024'.
t2t-trainer \
--data_dir=$DATA_DIR \
--problem=$PROBLEM \
--model=$MODEL \
--hparams_set=$HPARAMS \
--hparams='batch_size=1024' \
--output_dir=$TRAIN_DIR
运行脚本
$ sh training.sh
这一步训练过程会持续很久,可以根据需要随时中断,TensorFlow会自动保存训练的中间过程。
下面是我训练的过程,为了节约时间,我这里训练了两轮,如果想要得到还比较满意的结果,那还是让它继续跑着吧。
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into /home/dyb/work/Original/t2t_train/translate_ende_wmt32k/transformer-transformer_base/model.ckpt.
INFO:tensorflow:loss = 9.581421, step = 1
INFO:tensorflow:global_step/sec: 0.92796
INFO:tensorflow:loss = 8.066505, step = 101 (107.765 sec)
保存模型如下:
-rw-rw-r--. 1 dyb dyb 81 Aug 27 13:23 checkpoint
-rw-rw-r--. 1 dyb dyb 46911252 Aug 27 13:26 events.out.tfevents.1535390613.SKL43
-rw-rw-r--. 1 dyb dyb 938 Aug 27 13:23 flags_t2t.txt
-rw-rw-r--. 1 dyb dyb 1508 Aug 27 13:23 flags.txt
-rw-rw-r--. 1 dyb dyb 20638948 Aug 27 13:23 graph.pbtxt
-rw-rw-r--. 1 dyb dyb 3823 Aug 27 13:23 hparams.json
-rw-rw-r--. 1 dyb dyb 24 Aug 27 13:23 model.ckpt-0.data-00000-of-00002
-rw-rw-r--. 1 dyb dyb 733962248 Aug 27 13:23 model.ckpt-0.data-00001-of-00002
-rw-rw-r--. 1 dyb dyb 29961 Aug 27 13:23 model.ckpt-0.index
-rw-rw-r--. 1 dyb dyb 12296179 Aug 27 13:23 model.ckpt-0.meta
预测
编写脚本decoder.sh
HOME=$PWD
PROBLEM=translate_ende_wmt32k
MODEL=transformer
HPARAMS=transformer_base
DATA_DIR=$HOME/t2t_data
TMP_DIR=$HOME/tmp/t2t_datagen
TRAIN_DIR=$HOME/t2t_train/$PROBLEM/$MODEL-$HPARAMS
DECODE_FILE=./decode_this.txt
echo "Hello world" >> $DECODE_FILE
echo "Goodbye world" >> $DECODE_FILE
echo -e 'Hallo Welt\nAuf Wiedersehen Welt' > ref-translation.de
BEAM_SIZE=4
ALPHA=0.6
t2t-decoder \
--data_dir=$DATA_DIR \
--problem=$PROBLEM \
--model=$MODEL \
--hparams_set=$HPARAMS \
--output_dir=$TRAIN_DIR \
--decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \
--decode_from_file=$DECODE_FILE \
--decode_to_file=translation.en
运行脚本
$ sh decoder.sh
这个时候Tensor2Tensor会自动加载你之前训练时候保存的模型。
翻译结果如下:
INFO:tensorflow:Inference results INPUT: Goodbye world
INFO:tensorflow:Inference results OUTPUT: lichten lichten lichten lichten lichten lichten lichten lichten lichten lichten lichten lichten lichten lichten lichten agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda
INFO:tensorflow:Inference results INPUT: Hello world
INFO:tensorflow:Inference results OUTPUT: Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog
INFO:tensorflow:Elapsed Time: 12.43447
INFO:tensorflow:Averaged Single Token Generation Time: 0.1035313
INFO:tensorflow:Writing decodes into translation.en
这里会打印翻译时间,翻译结果,由于我这里只是做了演示,所以结果不理想,但是达到了效果。
结尾
如果想用自己的数据集,自己的模型,可以编辑自己的probelm。
详情参见Tensor2Tensor官网
除非注明,原创文章欢迎转载。
转载请注明本文地址。
三合一收款
下面三种方式都支持哦