提交 f26a1c90 编写于 作者: S sneaxiy

test=develop

上级 43a67a26
......@@ -99,10 +99,10 @@ class NormGradKernel : public framework::OpKernel<T> {
auto dx_e = framework::EigenVector<T>::Flatten(*out_dx);
Eigen::DSizes<int, 3> shape(pre, n, post);
Eigen::DSizes<int, 2> norm_shape(pre, post);
Eigen::DSizes<int, 3> rshape(pre, 1, post);
auto x = x_e.reshape(shape);
auto dy = dy_e.reshape(shape);
auto norm = norm_e.reshape(norm_shape);
auto norm = norm_e.reshape(rshape);
auto dx = dx_e.reshape(shape);
framework::Tensor rsum;
......@@ -111,7 +111,6 @@ class NormGradKernel : public framework::OpKernel<T> {
Eigen::DSizes<int, 1> rdim(1);
Eigen::DSizes<int, 3> bcast(1, n, 1);
Eigen::DSizes<int, 3> rshape(pre, 1, post);
// dx = ( dy/sqrt(sum(x*x)) ) * [1 - x*sum(x) / (sum(x*x) + e)]
// = [dy - dy * x * sum(x) / (sum(x*x) + e)] / sqrt(sum(x*x))
......
......@@ -16,12 +16,10 @@ import os
import unittest
os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
from test_parallel_executor_transformer import TestTransformer
class EagerDeletionTestTransformer(TestTransformer):
pass
os.environ[
'RECORDIO_FILENAME'] = '/tmp/eager_deletion_transformer.wmt16.recordio'
from test_parallel_executor_transformer import TestTransformer
if __name__ == '__main__':
unittest.main()
......@@ -24,7 +24,7 @@ import paddle.fluid.core as core
import paddle.dataset.wmt16 as wmt16
import os
WMT16_RECORDIO_FILE = "/tmp/wmt16.recordio"
WMT16_RECORDIO_FILE = os.environ.get('RECORDIO_FILENAME', '/tmp/wmt16.recordio')
class ModelHyperParams(object):
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
from functools import partial
import numpy as np
import os
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.layers.io import open_recordio_file
......@@ -408,7 +409,7 @@ def transformer(
trg_pad_idx,
pos_pad_idx, ):
file_obj = open_recordio_file(
filename='/tmp/wmt16.recordio',
filename=os.environ.get('RECORDIO_FILENAME', '/tmp/wmt16.recordio'),
shapes=[
[batch_size * max_length, 1],
[batch_size * max_length, 1],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册