diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index b875149ec63c88563eba3ffaed42143d864add4f..953c3a555fa4b7517bb909323082d1f64a1ae9e3 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -228,6 +228,59 @@ class MatMulV2GradOpMaker : public framework::SingleGradOpMaker { } }; +class MatMulV2OpDoubleGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* context) const override { + OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "matmul"); + OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul"); + OP_INOUT_CHECK(context->HasInput("DOut"), "Input", "DOut", "matmul"); + + if (context->HasOutput("DX") && context->HasInput("DDY")) { + context->ShareDim("X", "DX"); + } + + if (context->HasOutput("DY") && context->HasInput("DDX")) { + context->ShareDim("Y", "DY"); + } + + if (context->HasOutput("DDOut") && + (context->HasInput("DDY") || context->HasInput("DDX"))) { + context->ShareDim("DOut", "DDOut"); + } + } +}; + +template +class MatMulV2OpDoubleGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("matmul_v2_grad_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Y", this->Input("Y")); + op->SetInput("DOut", this->Input(framework::GradVarName("Out"))); + op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X"))); + op->SetInput("DDY", this->OutputGrad(framework::GradVarName("Y"))); + + auto ddx = this->OutputGrad(framework::GradVarName("X")); + auto ddy = this->OutputGrad(framework::GradVarName("Y")); + + if (!ddx.empty() || !ddy.empty()) { + op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out"))); + } + op->SetOutput("DX", + ddy.empty() ? this->EmptyInputGrad() : this->InputGrad("X")); + op->SetOutput("DY", + ddx.empty() ? this->EmptyInputGrad() : this->InputGrad("Y")); + + op->SetAttrMap(this->Attrs()); + } +}; } // namespace operators } // namespace paddle @@ -236,7 +289,11 @@ REGISTER_OPERATOR(matmul_v2, ops::MatMulV2Op, ops::MatMulV2OpMaker, ops::MatMulV2GradOpMaker, ops::MatMulV2GradOpMaker); -REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad); +REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad, + ops::MatMulV2OpDoubleGradMaker, + ops::MatMulV2OpDoubleGradMaker); + +REGISTER_OPERATOR(matmul_v2_grad_grad, ops::MatMulV2OpDoubleGrad); REGISTER_OP_CPU_KERNEL( matmul_v2, ops::MatMulV2Kernel, @@ -254,3 +311,11 @@ REGISTER_OP_CPU_KERNEL( paddle::platform::complex>, ops::MatMulV2GradKernel>); +REGISTER_OP_CPU_KERNEL( + matmul_v2_grad_grad, + ops::MatMulV2DoubleGradKernel, + ops::MatMulV2DoubleGradKernel, + ops::MatMulV2DoubleGradKernel>, + ops::MatMulV2DoubleGradKernel>); diff --git a/paddle/fluid/operators/matmul_v2_op.cu b/paddle/fluid/operators/matmul_v2_op.cu index 2176ab79dd919dec17ca15c0297c87bf2a47e85e..b258077456e1ed54483448bd395f6330447e7621 100644 --- a/paddle/fluid/operators/matmul_v2_op.cu +++ b/paddle/fluid/operators/matmul_v2_op.cu @@ -30,3 +30,13 @@ REGISTER_OP_CUDA_KERNEL( ops::MatMulV2GradKernel, ops::MatMulV2GradKernel>, ops::MatMulV2GradKernel>); + +REGISTER_OP_CUDA_KERNEL( + matmul_v2_grad_grad, + ops::MatMulV2DoubleGradKernel, + ops::MatMulV2DoubleGradKernel, + ops::MatMulV2DoubleGradKernel, + ops::MatMulV2DoubleGradKernel>, + ops::MatMulV2DoubleGradKernel>); diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index 5b114f381996e610f8d220e37661a3bfa059104d..58e57c3914f411c9dcc2d3e2e7d184c81ae2b58d 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -117,11 +117,12 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, const std::vector& x_dims, const std::vector& y_dims, Tensor* Out, bool trans_x, bool trans_y, - const paddle::framework::ExecutionContext& ctx) { + const paddle::framework::ExecutionContext& ctx, + bool flag = false) { const int x_ndim = x_dims.size(); const int y_ndim = y_dims.size(); - // get data ptr + // Get data ptr const T* x_data = X->data(); const T* y_data = Y->data(); @@ -141,7 +142,11 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, auto y_eigen = framework::EigenVector::Flatten(*Y); auto& dev = *ctx.template device_context().eigen_device(); - out_eigen.device(dev) = (x_eigen * y_eigen).sum(); + if (flag) { + out_eigen.device(dev) = (x_eigen * y_eigen).sum() + out_eigen; + } else { + out_eigen.device(dev) = (x_eigen * y_eigen).sum(); + } return; } @@ -178,18 +183,18 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, const int M = Y->numel() / N; VLOG(3) << "MatMul's case 2"; blas.GEMV(false, M, N, static_cast(1), y_data, x_data, - static_cast(0), Out->data()); + static_cast(flag), Out->data()); } else { const int M = y_dims[y_ndim - 1]; const int batch_size = Y->numel() / (M * N); if (batch_size == 1) { VLOG(3) << "MatMul's case 3"; blas.GEMV(true, N, M, static_cast(1), y_data, x_data, - static_cast(0), Out->data()); + static_cast(flag), Out->data()); } else { VLOG(3) << "MatMul's case 4"; blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, static_cast(1), - y_data, x_data, static_cast(0), Out->data(), + y_data, x_data, static_cast(flag), Out->data(), batch_size, M * N, 0); } } @@ -229,18 +234,18 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, if (batch_size == 1) { VLOG(3) << "MatMul's case 5"; blas.GEMV(true, N, M, static_cast(1), x_data, y_data, - static_cast(0), Out->data()); + static_cast(flag), Out->data()); } else { VLOG(3) << "MatMul's case 6"; blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, static_cast(1), - x_data, y_data, static_cast(0), Out->data(), + x_data, y_data, static_cast(flag), Out->data(), batch_size, M * N, 0); } } else { const int M = X->numel() / N; VLOG(3) << "MatMul's case 7"; blas.GEMV(false, M, N, static_cast(1), x_data, y_data, - static_cast(0), Out->data()); + static_cast(flag), Out->data()); } return; } @@ -298,17 +303,17 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, VLOG(3) << "MatMul's case 8"; blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K, static_cast(1), - x_data, y_data, static_cast(0), Out->data()); + x_data, y_data, static_cast(flag), Out->data()); } else if (x_batch_size == 1) { if (M == 1 && trans_y) { VLOG(3) << "MatMul's case 9"; blas.GEMV(false, y_batch_size * N, K, static_cast(1), y_data, x_data, - static_cast(0), Out->data()); + static_cast(flag), Out->data()); } else { VLOG(3) << "MatMul's case 10"; blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K, - static_cast(1), x_data, y_data, static_cast(0), + static_cast(1), x_data, y_data, static_cast(flag), Out->data(), out_batch_size, 0, K * N); } } else if (y_batch_size == 1) { @@ -316,18 +321,18 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, VLOG(3) << "MatMul's case 11"; blas.GEMM(CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans, x_batch_size * M, N, K, static_cast(1), x_data, y_data, - static_cast(0), Out->data()); + static_cast(flag), Out->data()); } else { VLOG(3) << "MatMul's case 12"; blas.BatchedGEMM(CblasTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K, - static_cast(1), x_data, y_data, static_cast(0), + static_cast(1), x_data, y_data, static_cast(flag), Out->data(), out_batch_size, M * K, 0); } } else if (!is_broadcast_dims) { VLOG(3) << "MatMul's case 13"; blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K, - static_cast(1), x_data, y_data, static_cast(0), + static_cast(1), x_data, y_data, static_cast(flag), Out->data(), out_batch_size, M * K, K * N); } else { // in the case, can't use stridedgemm @@ -351,18 +356,19 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K, static_cast(1), x_ptr.data(), y_ptr.data(), - static_cast(0), out_ptr.data(), out_batch_size); + static_cast(flag), out_ptr.data(), out_batch_size); } } template void MatMulFunction(const Tensor* X, const Tensor* Y, Tensor* Out, bool trans_x, bool trans_y, - const paddle::framework::ExecutionContext& ctx) { + const paddle::framework::ExecutionContext& ctx, + bool flag = false) { const std::vector x_dims = vectorize(X->dims()); const std::vector y_dims = vectorize(Y->dims()); MatMulFunction(X, Y, x_dims, y_dims, Out, trans_x, trans_y, - ctx); + ctx, flag); } template @@ -526,6 +532,245 @@ struct ConjHelper> { const framework::ExecutionContext& ctx_; }; +template +struct DotDoubleGradFunction { + void operator()(const Tensor* tensor_x, const Tensor* tensor_y, + Tensor* tensor_dx, Tensor* tensor_dy, + const Tensor* tensor_dout, const Tensor* tensor_ddx, + const Tensor* tensor_ddy, Tensor* tensor_ddout, + const paddle::framework::ExecutionContext& ctx); +}; + +template +struct DotDoubleGradFunction> { + void operator()(const Tensor* tensor_x, const Tensor* tensor_y, + Tensor* tensor_dx, Tensor* tensor_dy, + const Tensor* tensor_dout, const Tensor* tensor_ddx, + const Tensor* tensor_ddy, Tensor* tensor_ddout, + const paddle::framework::ExecutionContext& ctx) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == tensor_dout->dims().size()) { + framework::Tensor tensor_dout_help; + auto& dev_raw = ctx.template device_context(); + auto& dev = *dev_raw.eigen_device(); + if (tensor_dx || tensor_dy) { + tensor_dout_help.Resize(tensor_dout->dims()); + tensor_dout_help.mutable_data(ctx.GetPlace()); + paddle::platform::ForRange for_range( + dev_raw, tensor_dout->numel()); + math::ConjFunctor functor(tensor_dout->data(), + tensor_dout->numel(), + tensor_dout_help.data()); + for_range(functor); + } + if (tensor_dx) { + auto ddy = framework::EigenVector::Flatten(*tensor_ddy); + Eigen::DSizes size(tensor_ddy->numel()); + auto dx = framework::EigenVector::Flatten(*tensor_dx); + auto dout = framework::EigenVector::Flatten(tensor_dout_help); + dx.device(dev) = ddy * dout.broadcast(size); + } + + if (tensor_dy) { + auto ddx = framework::EigenVector::Flatten(*tensor_ddx); + Eigen::DSizes size(tensor_ddx->numel()); + auto dy = framework::EigenVector::Flatten(*tensor_dy); + auto dout = framework::EigenVector::Flatten(tensor_dout_help); + dy.device(dev) = ddx * dout.broadcast(size); + } + + if (tensor_ddout) { + framework::Tensor tensor_x_help, tensor_y_help; + tensor_x_help.Resize(tensor_x->dims()); + tensor_x_help.mutable_data(ctx.GetPlace()); + tensor_y_help.Resize(tensor_y->dims()); + tensor_y_help.mutable_data(ctx.GetPlace()); + + auto& dev_raw = ctx.template device_context(); + auto& dev = *dev_raw.eigen_device(); + paddle::platform::ForRange for_range(dev_raw, + tensor_x->numel()); + math::ConjFunctor functor_x(tensor_x->data(), tensor_x->numel(), + tensor_x_help.data()); + for_range(functor_x); + math::ConjFunctor functor_y(tensor_y->data(), tensor_y->numel(), + tensor_y_help.data()); + for_range(functor_y); + auto x = framework::EigenVector::Flatten(tensor_x_help); + auto y = framework::EigenVector::Flatten(tensor_y_help); + auto ddx = framework::EigenVector::Flatten(*tensor_ddx); + auto ddy = framework::EigenVector::Flatten(*tensor_ddy); + auto ddout = framework::EigenVector::Flatten(*tensor_ddout); + ddout.device(dev) = (x * ddy + y * ddx).sum(); + } + } +#else + const auto* data_dout = tensor_dout->data(); + + if (tensor_dx) { + auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); + const auto* data_ddy = tensor_ddy->data(); + const framework::DDim& dim = tensor_dx->dims(); + size_t N = static_cast(framework::product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dx[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddy[i]; + } + } + + if (tensor_dy) { + auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); + const auto* data_ddx = tensor_ddx->data(); + const framework::DDim& dim = tensor_dy->dims(); + size_t N = static_cast(framework::product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dy[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddx[i]; + } + } + + if (tensor_ddout) { + auto* data_ddout = tensor_ddout->mutable_data(ctx.GetPlace()); + auto* data_x = tensor_x->data(); + auto* data_y = tensor_y->data(); + auto* data_ddx = tensor_ddx->data(); + auto* data_ddy = tensor_ddy->data(); + + const framework::DDim& dim = tensor_dy->dims(); + size_t N = static_cast(framework::product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + bool new_s = false; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) { + ++s; + new_s = true; + } + if (new_s) { + data_ddout[s] = T(data_x[i].real, -data_x[i].imag) * data_ddy[i] + + T(data_y[i].real, -data_y[i].imag) * data_ddx[i]; + } else { + data_ddout[s] += T(data_x[i].real, -data_x[i].imag) * data_ddy[i] + + T(data_y[i].real, -data_y[i].imag) * data_ddx[i]; + } + new_s = false; + } + } +#endif + } +}; + +template +struct DotDoubleGradFunction> { + void operator()(const Tensor* tensor_x, const Tensor* tensor_y, + Tensor* tensor_dx, Tensor* tensor_dy, + const Tensor* tensor_dout, const Tensor* tensor_ddx, + const Tensor* tensor_ddy, Tensor* tensor_ddout, + const paddle::framework::ExecutionContext& ctx) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == tensor_dout->dims().size()) { + auto& dev_raw = ctx.template device_context(); + auto& dev = *dev_raw.eigen_device(); + auto dout = framework::EigenVector::Flatten(*tensor_dout); + if (tensor_dx) { + tensor_dx->mutable_data(ctx.GetPlace()); + auto ddy = framework::EigenVector::Flatten(*tensor_ddy); + Eigen::DSizes size(tensor_ddy->numel()); + auto dx = framework::EigenVector::Flatten(*tensor_dx); + dx.device(dev) = ddy * dout.broadcast(size); + } + + if (tensor_dy) { + tensor_dy->mutable_data(ctx.GetPlace()); + auto ddx = framework::EigenVector::Flatten(*tensor_ddx); + Eigen::DSizes size(tensor_ddx->numel()); + + auto dy = framework::EigenVector::Flatten(*tensor_dy); + dy.device(dev) = ddx * dout.broadcast(size); + } + + if (tensor_ddout) { + tensor_ddout->mutable_data(ctx.GetPlace()); + auto x = framework::EigenVector::Flatten(*tensor_x); + auto y = framework::EigenVector::Flatten(*tensor_y); + auto ddx = framework::EigenVector::Flatten(*tensor_ddx); + auto ddy = framework::EigenVector::Flatten(*tensor_ddy); + auto ddout = framework::EigenVector::Flatten(*tensor_ddout); + ddout.device(dev) = (x * ddy + y * ddx).sum(); + } + } +#else + const auto* data_dout = tensor_dout->data(); + + if (tensor_dx) { + auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); + const auto* data_ddy = tensor_ddy->data(); + const framework::DDim& dim = tensor_dx->dims(); + size_t N = static_cast(framework::product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dx[i] = data_dout[s] * data_ddy[i]; + } + } + + if (tensor_dy) { + auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); + const auto* data_ddx = tensor_ddx->data(); + const framework::DDim& dim = tensor_dy->dims(); + size_t N = static_cast(framework::product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dy[i] = data_dout[s] * data_ddx[i]; + } + } + + if (tensor_ddout) { + auto* data_ddout = tensor_ddout->mutable_data(ctx.GetPlace()); + auto* data_x = tensor_x->data(); + auto* data_y = tensor_y->data(); + auto* data_ddx = tensor_ddx->data(); + auto* data_ddy = tensor_ddy->data(); + + const framework::DDim& dim = tensor_dy->dims(); + size_t N = static_cast(framework::product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + bool new_s = false; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) { + ++s; + new_s = true; + } + if (new_s) { + data_ddout[s] = data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i]; + } else { + data_ddout[s] += data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i]; + } + new_s = false; + } + } +#endif + } +}; + template class MatMulV2GradKernel : public framework::OpKernel { public: @@ -573,10 +818,10 @@ class MatMulV2GradKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { bool transpose_x = ctx.Attr("trans_x"); bool transpose_y = ctx.Attr("trans_y"); - auto x = *ctx.Input("X"); auto y = *ctx.Input("Y"); auto dout = *ctx.Input(framework::GradVarName("Out")); + framework::Tensor y_conj(y.type()); framework::Tensor x_conj(y.type()); @@ -757,9 +1002,327 @@ class MatMulV2GradKernel : public framework::OpKernel { } dy->Resize(y.dims()); } + + // Get the OutputGrad(out) } } }; +template +class MatMulV2DoubleGradKernel : public framework::OpKernel { + public: + void MatMul(const framework::ExecutionContext& context, + const framework::Tensor& a, bool trans_a, + const framework::Tensor& b, bool trans_b, framework::Tensor* out, + bool flag) const { + out->mutable_data(context.GetPlace()); + auto blas = math::GetBlas(context); + auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); + auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); + if (a.dims().size() == 3 && b.dims().size() <= 2) { + // the transpose_X must be false, if is true, the transpose cost much time + if (!trans_a) { + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + } + } + blas.MatMul(a, mat_dim_a, b, mat_dim_b, static_cast(1), out, + static_cast(flag)); + } + + void CalcInputGrad(const framework::ExecutionContext& context, + const framework::Tensor& a, bool trans_a, + bool is_fold_init_dims_a, const framework::Tensor& b, + bool trans_b, bool is_fold_init_dims_b, + framework::Tensor* out, bool flag) const { + if (out == nullptr) return; + bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && + out->dims().size() == 2; + if (!need_combine) { + MatMul(context, a, trans_a, b, trans_b, out, flag); + } else { + auto& ctx = context.template device_context(); + MatMul(context, is_fold_init_dims_a + ? FoldInitDims(a) + : FoldHeadAndLastDims(ctx, a), + trans_a, is_fold_init_dims_b + ? FoldInitDims(b) + : FoldHeadAndLastDims(ctx, b), + trans_b, out, flag); + } + } + + void Compute(const framework::ExecutionContext& context) const override { + auto x = *context.Input("X"); + auto y = *context.Input("Y"); + auto dout = *context.Input("DOut"); + auto* ddx = context.Input("DDX"); + auto* ddy = context.Input("DDY"); + + auto* dx = context.Output("DX"); + auto* dy = context.Output("DY"); + auto* ddout = context.Output("DDOut"); + + bool transpose_x = context.Attr("trans_x"); + bool transpose_y = context.Attr("trans_y"); + + // Get dims from the input x, y, output_grad + std::vector x_dims = vectorize(x.dims()); + std::vector y_dims = vectorize(y.dims()); + std::vector dout_dims = vectorize(dout.dims()); + framework::Tensor x_conj(x.type()); + framework::Tensor y_conj(y.type()); + framework::Tensor dout_conj(dout.type()); + + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + int ndim = dout_dims.size(); + + // Case1 : x's or y's dim = 1 + if (x_ndim == 1 && y_ndim == 1) { + DotDoubleGradFunction()(&x, &y, dx, dy, &dout, ddx, ddy, + ddout, context); + return; + } + + bool is_broadcast = true; + if (x_ndim <= 2 || y_ndim <= 2) { + is_broadcast = false; + } else if (x_ndim != y_ndim) { + is_broadcast = true; + } else { + is_broadcast = !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, + y_dims.cbegin()); + } + + if (!is_broadcast) { + // Case2: no broadcast or no batch size + ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); + framework::DDim dx_dims; + + ConjHelper conj_helper(context); + if (dx) { + dx_dims = dx->dims(); + if (dx_dims != x.dims()) { + dx->Resize(x.dims()); + } + } + + framework::DDim dy_dims; + if (dy) { + dy_dims = dy->dims(); + if (dy_dims != y.dims()) { + dy->Resize(y.dims()); + } + } + + framework::DDim ddout_dims; + if (ddout) { + ddout_dims = ddout->dims(); + if (ddout_dims != dout.dims()) { + ddout->Resize(dout.dims()); + } + } + + if (ddx || ddy) { + ConjHelper conj_helper(context); + conj_helper(dout, dout_conj); + } + if (ddout) { + ConjHelper conj_helper(context); + conj_helper(x, x_conj); + conj_helper(y, y_conj); + } + bool ddout_flag = false; + if (ddx) { + auto ddx_mat = *ddx; + if (ddx_mat.dims() != x.dims()) { + ddx_mat.Resize(x.dims()); + } + if (dy) { + if (transpose_x && transpose_y) { + // dy = dout' * ddx' + CalcInputGrad(context, dout_conj, true, true, ddx_mat, true, false, + dy, false); + } else if (transpose_x) { + // dy = ddx * dout + CalcInputGrad(context, ddx_mat, false, false, dout_conj, false, + true, dy, false); + } else if (transpose_y) { + // dy = dout' * ddx + CalcInputGrad(context, dout_conj, true, true, ddx_mat, false, true, + dy, false); + } else { + // dy = ddx' * dout + CalcInputGrad(context, ddx_mat, true, true, dout_conj, false, true, + dy, false); + } + } + + if (ddout) { + CalcInputGrad(context, ddx_mat, transpose_x, true, y_conj, + transpose_y, false, ddout, ddout_flag); + ddout_flag = true; + } + } + + if (ddy) { + auto ddy_mat = *ddy; + if (ddy_mat.dims() != y.dims()) { + ddy_mat.Resize(y.dims()); + } + if (dx) { + if (transpose_x && transpose_y) { + // dx = ddy' * dout' + CalcInputGrad(context, ddy_mat, true, true, dout_conj, true, false, + dx, false); + } else if (transpose_x) { + // dx = ddy * dout' + CalcInputGrad(context, ddy_mat, false, false, dout_conj, true, + false, dx, false); + } else if (transpose_y) { + // dx = dout * ddy + CalcInputGrad(context, dout_conj, false, false, ddy_mat, false, + true, dx, false); + } else { + // dx = dout * ddy' + CalcInputGrad(context, dout_conj, false, false, ddy_mat, true, + false, dx, false); + } + } + + if (ddout) { + CalcInputGrad(context, x_conj, transpose_x, true, ddy_mat, + transpose_y, false, ddout, ddout_flag); + } + } + + if (dx) { + if (dx_dims != x.dims()) { + dx->Resize(dx_dims); + } + } + + if (dy) { + if (dy_dims != y.dims()) { + dy->Resize(dy_dims); + } + } + + if (ddout) { + if (ddout_dims != dout.dims()) { + ddout->Resize(ddout_dims); + } + } + } else { + // Case3: broadcast. It need cost much time to reduce sum for the + // broadcast and wastes the memory. + // So we should avoid the case in reality. + VLOG(3) << "It need cost much time to reduce sum for the broadcast and " + "wastes the memory. So we should avoid the case in reality"; + framework::Tensor ddy_conj(ddx->type()); + framework::Tensor ddx_conj(ddy->type()); + + Tensor dx_help, dy_help; + if (dx || dy) { + ConjHelper conj_helper(context); + conj_helper(dout, dout_conj); + } + if (ddout) { + ConjHelper conj_helper(context); + conj_helper(x, x_conj); + conj_helper(y, y_conj); + } + if (transpose_x) { + if (transpose_y) { + if (dx) + MatMulFunction(ddy, &dout_conj, y_dims, dout_dims, + &dx_help, true, true, context); + if (dy) + MatMulFunction(&dout_conj, ddx, dout_dims, x_dims, + &dy_help, true, true, context); + } else { + if (dx) + MatMulFunction(ddy, &dout_conj, y_dims, dout_dims, + &dx_help, false, true, context); + if (dy) + MatMulFunction(ddx, &dout_conj, x_dims, dout_dims, + &dy_help, false, false, context); + } + } else { + if (transpose_y) { + if (dx) + MatMulFunction(&dout_conj, ddy, dout_dims, y_dims, + &dx_help, false, false, context); + if (dy) + MatMulFunction(&dout_conj, ddx, dout_dims, x_dims, + &dy_help, true, false, context); + } else { + if (dx) + MatMulFunction(&dout_conj, ddy, dout_dims, y_dims, + &dx_help, false, true, context); + if (dy) + MatMulFunction(ddx, &dout_conj, x_dims, dout_dims, + &dy_help, true, false, context); + } + } + + // get help dims + const std::vector dx_help_dims = vectorize(dx_help.dims()); + const std::vector dy_help_dims = vectorize(dy_help.dims()); + + std::vector dx_broadcast_dims(ndim); + std::vector dy_broadcast_dims(ndim); + + std::fill(dx_broadcast_dims.data(), + dx_broadcast_dims.data() + ndim - x_ndim, 1); + std::fill(dy_broadcast_dims.data(), + dy_broadcast_dims.data() + ndim - y_ndim, 1); + std::copy(x_dims.data(), x_dims.data() + x_ndim, + dx_broadcast_dims.data() + ndim - x_ndim); + std::copy(y_dims.data(), y_dims.data() + y_ndim, + dy_broadcast_dims.data() + ndim - y_ndim); + + std::vector dx_reduce_dims; + std::vector dy_reduce_dims; + for (int idx = 0; idx <= ndim - 3; idx++) { + if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { + dx_reduce_dims.push_back(idx); + } + if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { + dy_reduce_dims.push_back(idx); + } + } + // Reduce sum to get grad by ReduceSum + if (dx) { + if (dx_reduce_dims.empty()) { + *dx = std::move(dx_help); + } else { + ReduceSumForMatmulGrad(&dx_help, dx, dx_reduce_dims, + context); + } + dx->Resize(x.dims()); + } + if (dy) { + if (dy_reduce_dims.empty()) { + *dy = std::move(dy_help); + } else { + ReduceSumForMatmulGrad(&dy_help, dy, dy_reduce_dims, + context); + } + dy->Resize(y.dims()); + } + + if (ddout) { + // Caluate the gradient of OutputGrad(Out) + MatMulFunction(ddx, &y_conj, x_dims, y_dims, ddout, + transpose_x, transpose_y, context); + MatMulFunction(&x_conj, ddy, x_dims, y_dims, ddout, + transpose_x, transpose_y, context, + true); + } + } + } +}; } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_nn_matmul_v2_grad.py b/python/paddle/fluid/tests/unittests/test_nn_matmul_v2_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..16e8a4a8b00fbe96ba419496e6c1fc5e72a71eb7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_nn_matmul_v2_grad.py @@ -0,0 +1,150 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.fluid.core as core +import gradient_checker +from decorator_helper import prog_scope +paddle.enable_static() + + +class TestMatmulDoubleGradCheck(unittest.TestCase): + def setUp(self): + self.init_test() + + def init_test(self): + self.x_shape = [2] + self.y_shape = [2] + self.transpose_x = False + self.transpose_y = False + + @prog_scope() + def func(self, place): + eps = 0.005 + dtype = np.float64 + typename = "float64" + x = paddle.static.create_parameter( + dtype=typename, shape=self.x_shape, name='x') + y = paddle.static.create_parameter( + dtype=typename, shape=self.y_shape, name='y') + out = paddle.matmul( + x, y, self.transpose_x, self.transpose_y, name='out') + + x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype) + gradient_checker.double_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestMatmulDoubleGradCheckCase1(TestMatmulDoubleGradCheck): + def init_test(self): + self.x_shape = [2, 3] + self.y_shape = [3, 2] + self.transpose_x = True + self.transpose_y = True + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestMatmulDoubleGradCheck2(unittest.TestCase): + def setUp(self): + self.init_test() + + def init_test(self): + self.x_shape = [2, 4, 3] + self.y_shape = [2, 4, 5] + self.transpose_x = True + self.transpose_y = False + + @prog_scope() + def func(self, place): + eps = 0.005 + dtype = np.float64 + typename = "float64" + x = paddle.static.create_parameter( + dtype=typename, shape=self.x_shape, name='x') + y = paddle.static.create_parameter( + dtype=typename, shape=self.y_shape, name='y') + out = paddle.matmul( + x, y, self.transpose_x, self.transpose_y, name='out') + + x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype) + gradient_checker.double_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestMatmulDoubleGradCheckCase3(unittest.TestCase): + def setUp(self): + self.init_test() + + def init_test(self): + self.x_shape = [1, 1, 4, 25] + self.y_shape = [1, 2, 25, 4] + self.transpose_x = False + self.transpose_y = False + + @prog_scope() + def func(self, place): + eps = 0.005 + dtype = np.float64 + typename = "float64" + x = paddle.static.create_parameter( + dtype=typename, shape=self.x_shape, name='x') + y = paddle.static.create_parameter( + dtype=typename, shape=self.y_shape, name='y') + out = paddle.matmul( + x, y, self.transpose_x, self.transpose_y, name='out') + + x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype) + gradient_checker.double_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +if __name__ == "__main__": + unittest.main()