From d01f6269445aa81971d175fa1931e11601e67360 Mon Sep 17 00:00:00 2001 From: furnace <34057289+windstamp@users.noreply.github.com> Date: Mon, 28 Sep 2020 00:06:46 +0800 Subject: [PATCH] update mv op according PR#27024 (#27474) --- paddle/fluid/operators/mv_op.cc | 16 +++---- paddle/fluid/operators/mv_op.cu | 31 ++++++------- paddle/fluid/operators/mv_op.h | 30 ++++++------ .../fluid/tests/unittests/test_mv_op.py | 46 +++++++++++++------ 4 files changed, 69 insertions(+), 54 deletions(-) diff --git a/paddle/fluid/operators/mv_op.cc b/paddle/fluid/operators/mv_op.cc index 1339982ada..cce066ec40 100644 --- a/paddle/fluid/operators/mv_op.cc +++ b/paddle/fluid/operators/mv_op.cc @@ -42,21 +42,21 @@ class MVOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "mv"); auto dim_x = context->GetInputDim("X"); - auto dim_y = context->GetInputDim("Vec"); + auto dim_vec = context->GetInputDim("Vec"); PADDLE_ENFORCE_EQ( dim_x.size(), 2, platform::errors::InvalidArgument( "The rank of input X should be 2, but is %d", dim_x.size())); PADDLE_ENFORCE_EQ( - dim_y.size(), 1, + dim_vec.size(), 1, platform::errors::InvalidArgument( - "The rank of input Vec should be 1, but is %d", dim_y.size())); - PADDLE_ENFORCE_EQ(dim_x[1] == dim_y[0], true, + "The rank of input Vec should be 1, but is %d", dim_vec.size())); + PADDLE_ENFORCE_EQ(dim_x[1], dim_vec[0], platform::errors::InvalidArgument( - "The length of input X' second dim should equal the " - "length of input Vec," - " but X[%d, %d], Vec[%d]", - dim_x[0], dim_x[1], dim_y[0])); + "X's second dimension is expected to be equal to " + "Vec's first dimension" + "but recieved X'shape = [%s], Vec's shape = [%s]", + dim_x, dim_vec)); framework::DDim dim_out = framework::make_ddim({dim_x[0]}); diff --git a/paddle/fluid/operators/mv_op.cu b/paddle/fluid/operators/mv_op.cu index 9a16fe025c..b6d829392e 100644 --- a/paddle/fluid/operators/mv_op.cu +++ b/paddle/fluid/operators/mv_op.cu @@ -19,8 +19,8 @@ namespace paddle { namespace operators { template -__global__ void MVGradCUDAKernel(const int m, const int n, const T *dout, - const T *vec, T *dx) { +__global__ void MVGradDxCUDAKernel(const int m, const int n, const T *dout, + const T *vec, T *dx) { int idx = blockDim.x * blockIdx.x + threadIdx.x; for (; idx < m * n; idx += blockDim.x * gridDim.x) { int i = idx / n; @@ -52,32 +52,31 @@ class MVGradKernel int m = dim_x[0]; int n = dim_x[1]; - dx->Resize(framework::make_ddim({m * n})); - // get data ptr const T *x_data = x->data(); const T *vec_data = vec->data(); const T *dout_data = dout->data(); - T *dx_data = dx->mutable_data(context.GetPlace()); - T *dvec_data = dvec->mutable_data(context.GetPlace()); - auto &dev_ctx = context.template device_context(); auto blas = math::GetBlas(dev_ctx); - - // calculate dx auto stream = context.cuda_device_context().stream(); auto config = GetGpuLaunchConfig1D(dev_ctx, m * n); - MVGradCUDAKernel< - T><<>>( - m, n, dout_data, vec_data, dx_data); - dx->Resize(framework::make_ddim({m, n})); + if (dx) { + T *dx_data = dx->mutable_data(context.GetPlace()); + + MVGradDxCUDAKernel< + T><<>>( + m, n, dout_data, vec_data, dx_data); + } + + if (dvec) { + T *dvec_data = dvec->mutable_data(context.GetPlace()); - // calculate dvec - blas.GEMV(true, dim_x[0], dim_x[1], static_cast(1), x_data, dout_data, - static_cast(0), dvec_data); + blas.GEMV(true, dim_x[0], dim_x[1], static_cast(1), x_data, dout_data, + static_cast(0), dvec_data); + } } }; diff --git a/paddle/fluid/operators/mv_op.h b/paddle/fluid/operators/mv_op.h index 3c63f3640f..e294499629 100644 --- a/paddle/fluid/operators/mv_op.h +++ b/paddle/fluid/operators/mv_op.h @@ -74,30 +74,30 @@ class MVGradKernel : public framework::OpKernel { int m = dim_x[0]; int n = dim_x[1]; - dx->Resize(framework::make_ddim({m * n})); - // get data ptr const T *x_data = x->data(); const T *vec_data = vec->data(); const T *dout_data = dout->data(); - T *dx_data = dx->mutable_data(context.GetPlace()); - T *dvec_data = dvec->mutable_data(context.GetPlace()); - - auto &dev_ctx = context.template device_context(); - auto blas = math::GetBlas(dev_ctx); + if (dx) { + T *dx_data = dx->mutable_data(context.GetPlace()); - // calculate dx - for (int i = 0; i < m; ++i) { - for (int j = 0; j < n; ++j) - dx_data[i * n + j] = dout_data[i] * vec_data[j]; + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + dx_data[i * n + j] = dout_data[i] * vec_data[j]; + } + } } - dx->Resize(framework::make_ddim({m, n})); + if (dvec) { + T *dvec_data = dvec->mutable_data(context.GetPlace()); + + auto &dev_ctx = context.template device_context(); + auto blas = math::GetBlas(dev_ctx); - // calculate dvec - blas.GEMV(true, dim_x[0], dim_x[1], static_cast(1), x_data, dout_data, - static_cast(0), dvec_data); + blas.GEMV(true, dim_x[0], dim_x[1], static_cast(1), x_data, dout_data, + static_cast(0), dvec_data); + } } }; diff --git a/python/paddle/fluid/tests/unittests/test_mv_op.py b/python/paddle/fluid/tests/unittests/test_mv_op.py index 6b930e59aa..e0d23e7871 100644 --- a/python/paddle/fluid/tests/unittests/test_mv_op.py +++ b/python/paddle/fluid/tests/unittests/test_mv_op.py @@ -20,6 +20,7 @@ import paddle import paddle.fluid as fluid import paddle.fluid.layers as layers import paddle.fluid.core as core +from paddle.static import program_guard, Program from op_test import OpTest @@ -37,7 +38,7 @@ class TestMVOp(OpTest): self.check_grad(['X', 'Vec'], 'Out') def init_config(self): - self.x = np.random.random((5, 100)).astype("float64") + self.x = np.random.random((2, 100)).astype("float64") self.vec = np.random.random((100)).astype("float64") @@ -57,21 +58,36 @@ class TestMVAPI(unittest.TestCase): paddle.enable_static() def test_static_graph(self): - paddle.enable_static() + for x_stop_gradient in [False, True]: + for vec_stop_gradient in [False, True]: + + paddle.enable_static() + + train_program = Program() + startup_program = Program() + + self.input_x = np.random.rand(5, 100).astype("float64") + self.input_vec = np.random.rand(100).astype("float64") + + with program_guard(train_program, startup_program): + data_x = paddle.static.data( + "x", shape=[5, 100], dtype="float64") + data_vec = paddle.static.data( + "vec", shape=[100], dtype="float64") + + data_x.stop_gradient = x_stop_gradient + data_vec.stop_gradient = vec_stop_gradient + + result_vec = paddle.mv(data_x, data_vec) - self.input_x = np.random.rand(5, 100).astype("float64") - self.input_vec = np.random.rand(100).astype("float64") - - data_x = paddle.static.data("x", shape=[5, 100], dtype="float64") - data_vec = paddle.static.data("vec", shape=[100], dtype="float64") - result_vec = paddle.mv(data_x, data_vec) - self.place = paddle.CPUPlace() - exe = paddle.static.Executor(self.place) - res, = exe.run(feed={"x": self.input_x, - "vec": self.input_vec}, - fetch_list=[result_vec]) - z_expected = np.array(np.dot(self.input_x, self.input_vec)) - self.assertTrue(np.allclose(res, z_expected)) + self.place = paddle.CPUPlace() + exe = paddle.static.Executor(self.place) + res, = exe.run( + feed={"x": self.input_x, + "vec": self.input_vec}, + fetch_list=[result_vec]) + z_expected = np.array(np.dot(self.input_x, self.input_vec)) + self.assertTrue(np.allclose(res, z_expected)) class TestMVError(unittest.TestCase): -- GitLab