TypechoJoeTheme

半醉残影

统计

基于Tensor2Tensor的自然语言翻译

2018-08-27
/
0 评论
/
1,143 阅读
/
正在检测是否收录...
08/27

Tensor2Tensor

前言

谷歌研究人员发布了 Tensor2Tensor(T2T),一个用于在 TensorFlow 中训练深度学习模型的开源系统。T2T 能够帮助人们为各种机器学习程序创建最先进的模型,可应用于多个领域,如翻译、语法分析、图像信息描述等,大大提高了研究和开发的速度。T2T 中也包含一个数据集和模型库,其中包括模型(Attention Is All You Need、Depthwise Separable Convolutions for Neural Machine Translation 和 One Model to Learn Them All)。

正文

安装

默认已经安装好了python环境

1.运行环境安装

安装Tensor2Tensor

$ pip install tensor2tensor

2.安装TensoTensor

pip install TensorFlow

测试

生成数据

编写脚本"general.sh"用于生成数据

# See what problems, models, and hyperparameter sets are available.
# You can easily swap between them (and add new ones).
t2t-trainer --registry_help

HOME=$PWD

PROBLEM=translate_ende_wmt32k
MODEL=transformer
HPARAMS=transformer_base

DATA_DIR=$HOME/t2t_data
TMP_DIR=$HOME/tmp/t2t_datagen
TRAIN_DIR=$HOME/t2t_train/$PROBLEM/$MODEL-$HPARAMS

mkdir -p $DATA_DIR $TMP_DIR $TRAIN_DIR

# Generate data
t2t-datagen \
  --data_dir=$DATA_DIR \
  --tmp_dir=$TMP_DIR \
  --problem=$PROBLEM

运行刚刚创建的脚本

$ sh general.sh

这个过程可能会持续一段时间,时间长短取决于你的网速。完成之后会生成如下文件:

$ ll t2t_data

-rw-rw-r--. 1 dyb dyb 7246355 Aug 27 13:04 translate_ende_wmt32k-train-00094-of-00100
-rw-rw-r--. 1 dyb dyb 7180757 Aug 27 13:04 translate_ende_wmt32k-train-00095-of-00100
-rw-rw-r--. 1 dyb dyb 7241253 Aug 27 13:04 translate_ende_wmt32k-train-00096-of-00100
-rw-rw-r--. 1 dyb dyb 7236535 Aug 27 13:04 translate_ende_wmt32k-train-00097-of-00100
-rw-rw-r--. 1 dyb dyb 7228109 Aug 27 13:04 translate_ende_wmt32k-train-00098-of-00100
-rw-rw-r--. 1 dyb dyb 7256118 Aug 27 13:04 translate_ende_wmt32k-train-00099-of-00100
-rw-rw-r--. 1 dyb dyb  321313 Aug 27 12:53 vocab.ende.32768

Training

编写脚本"training.sh"


HOME=$PWD
PROBLEM=translate_ende_wmt32k
MODEL=transformer
HPARAMS=transformer_base

DATA_DIR=$HOME/t2t_data
TMP_DIR=$HOME/tmp/t2t_datagen
TRAIN_DIR=$HOME/t2t_train/$PROBLEM/$MODEL-$HPARAMS

mkdir -p $DATA_DIR $TMP_DIR $TRAIN_DIR

# Train
# *  If you run out of memory, add --hparams='batch_size=1024'.
t2t-trainer \
  --data_dir=$DATA_DIR \
  --problem=$PROBLEM \
  --model=$MODEL \
  --hparams_set=$HPARAMS \
  --hparams='batch_size=1024' \
  --output_dir=$TRAIN_DIR

运行脚本

$ sh training.sh

这一步训练过程会持续很久,可以根据需要随时中断,TensorFlow会自动保存训练的中间过程。

下面是我训练的过程,为了节约时间,我这里训练了两轮,如果想要得到还比较满意的结果,那还是让它继续跑着吧。


INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into /home/dyb/work/Original/t2t_train/translate_ende_wmt32k/transformer-transformer_base/model.ckpt.
INFO:tensorflow:loss = 9.581421, step = 1
INFO:tensorflow:global_step/sec: 0.92796
INFO:tensorflow:loss = 8.066505, step = 101 (107.765 sec)

保存模型如下:

-rw-rw-r--. 1 dyb dyb        81 Aug 27 13:23 checkpoint
-rw-rw-r--. 1 dyb dyb  46911252 Aug 27 13:26 events.out.tfevents.1535390613.SKL43
-rw-rw-r--. 1 dyb dyb       938 Aug 27 13:23 flags_t2t.txt
-rw-rw-r--. 1 dyb dyb      1508 Aug 27 13:23 flags.txt
-rw-rw-r--. 1 dyb dyb  20638948 Aug 27 13:23 graph.pbtxt
-rw-rw-r--. 1 dyb dyb      3823 Aug 27 13:23 hparams.json
-rw-rw-r--. 1 dyb dyb        24 Aug 27 13:23 model.ckpt-0.data-00000-of-00002
-rw-rw-r--. 1 dyb dyb 733962248 Aug 27 13:23 model.ckpt-0.data-00001-of-00002
-rw-rw-r--. 1 dyb dyb     29961 Aug 27 13:23 model.ckpt-0.index
-rw-rw-r--. 1 dyb dyb  12296179 Aug 27 13:23 model.ckpt-0.meta

预测

编写脚本decoder.sh


HOME=$PWD
PROBLEM=translate_ende_wmt32k
MODEL=transformer
HPARAMS=transformer_base

DATA_DIR=$HOME/t2t_data
TMP_DIR=$HOME/tmp/t2t_datagen
TRAIN_DIR=$HOME/t2t_train/$PROBLEM/$MODEL-$HPARAMS

DECODE_FILE=./decode_this.txt
echo "Hello world" >> $DECODE_FILE
echo "Goodbye world" >> $DECODE_FILE
echo -e 'Hallo Welt\nAuf Wiedersehen Welt' > ref-translation.de

BEAM_SIZE=4
ALPHA=0.6

t2t-decoder \
  --data_dir=$DATA_DIR \
  --problem=$PROBLEM \
  --model=$MODEL \
  --hparams_set=$HPARAMS \
  --output_dir=$TRAIN_DIR \
  --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \
  --decode_from_file=$DECODE_FILE \
  --decode_to_file=translation.en

运行脚本

$ sh decoder.sh

这个时候Tensor2Tensor会自动加载你之前训练时候保存的模型。

翻译结果如下:

INFO:tensorflow:Inference results INPUT: Goodbye world
INFO:tensorflow:Inference results OUTPUT: lichten lichten lichten lichten lichten lichten lichten lichten lichten lichten lichten lichten lichten lichten lichten agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda agenda
INFO:tensorflow:Inference results INPUT: Hello world
INFO:tensorflow:Inference results OUTPUT: Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog Blog
INFO:tensorflow:Elapsed Time: 12.43447
INFO:tensorflow:Averaged Single Token Generation Time: 0.1035313
INFO:tensorflow:Writing decodes into translation.en

这里会打印翻译时间,翻译结果,由于我这里只是做了演示,所以结果不理想,但是达到了效果。

结尾

如果想用自己的数据集,自己的模型,可以编辑自己的probelm。

详情参见Tensor2Tensor官网


除非注明,原创文章欢迎转载。
转载请注明本文地址。

TensorFlow深度学习Tensor2Tensor
朗读
赞(0)
赞赏
感谢您的支持,我会继续努力哒!

三合一收款

下面三种方式都支持哦

微信
QQ
支付宝
打开支付宝/微信/QQ扫一扫,即可进行扫码打赏哦
版权属于:

半醉残影

本文链接:

https://blog.dengyb.com/archives/18/(转载时请注明本文出处及文章链接)

评论 (0)

人生倒计时

今日已经过去小时
这周已经过去
本月已经过去
今年已经过去个月

最新回复

  1. 搭建自用导航网站 R11; JKblog
    2022-03-10
  2. JK
    2022-01-13

    {!{}!}

  3. 搭建自用导航网站 R11; JKblog
    2022-01-12
  4. 搭建自用导航网站 R11; JKblog
    2022-01-12
  5. MrGao
    2019-09-03

标签云