diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 1b609b15d6e56934a460b6d2ec249f7dc6a916d6..bd32af1c8f623c0938804debe42d2d3d9ac9cdb4 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -347,6 +347,76 @@ class MatMulV2OpDoubleGradMaker : public framework::SingleGradOpMaker { op->SetAttrMap(this->Attrs()); } }; +class MatMulV2OpTripleGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* context) const override { + OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", + "matmul_v2_triple_grad"); + OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", + "matmul_v2_triple_grad"); + OP_INOUT_CHECK(context->HasInput("DOut"), "Input", "DOut", + "matmul_v2_triple_grad"); + OP_INOUT_CHECK(context->HasInput("DDX"), "Input", "DDX", + "matmul_v2_triple_grad"); + OP_INOUT_CHECK(context->HasInput("DDY"), "Input", "DDY", + "matmul_v2_triple_grad"); + OP_INOUT_CHECK(context->HasInput("D_DX"), "Input", "D_DX", + "matmul_v2_triple_grad"); + OP_INOUT_CHECK(context->HasInput("D_DY"), "Input", "D_DY", + "matmul_v2_triple_grad"); + OP_INOUT_CHECK(context->HasInput("D_DDOut"), "Input", "D_DDOut", + "matmul_v2_triple_grad"); + + if (context->HasOutput("D_X_out")) { + context->ShareDim("X", "D_X_out"); + } + if (context->HasOutput("D_Y_out")) { + context->ShareDim("Y", "D_Y_out"); + } + if (context->HasOutput("D_DOut_out")) { + context->ShareDim("DOut", "D_DOut_out"); + } + if (context->HasOutput("D_DDX_out")) { + context->ShareDim("X", "D_DDX_out"); + } + if (context->HasOutput("D_DDY_out")) { + context->ShareDim("Y", "D_DDY_out"); + } + } +}; + +template +class MatMulV2OpTripleGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("matmul_v2_triple_grad"); + + // get input from double grad + op->SetInput("X", this->Input("X")); + op->SetInput("Y", this->Input("Y")); + op->SetInput("DOut", this->Input("DOut")); + op->SetInput("DDX", this->Input("DDX")); + op->SetInput("DDY", this->Input("DDY")); + op->SetInput("D_DX", this->OutputGrad("DX")); + op->SetInput("D_DY", this->OutputGrad("DY")); + op->SetInput("D_DDOut", this->OutputGrad("DDOut")); + + // set outputs + op->SetOutput("D_X_out", this->InputGrad("X")); + op->SetOutput("D_Y_out", this->InputGrad("Y")); + op->SetOutput("D_DOut_out", this->InputGrad("DOut")); + op->SetOutput("D_DDX_out", this->InputGrad("DDX")); + op->SetOutput("D_DDY_out", this->InputGrad("DDY")); + + op->SetAttrMap(this->Attrs()); + } +}; } // namespace operators } // namespace paddle @@ -359,7 +429,11 @@ REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad, ops::MatMulV2OpDoubleGradMaker, ops::MatMulV2OpDoubleGradMaker); -REGISTER_OPERATOR(matmul_v2_grad_grad, ops::MatMulV2OpDoubleGrad); +REGISTER_OPERATOR(matmul_v2_grad_grad, ops::MatMulV2OpDoubleGrad, + ops::MatMulV2OpTripleGradMaker, + ops::MatMulV2OpTripleGradMaker); + +REGISTER_OPERATOR(matmul_v2_triple_grad, ops::MatMulV2OpTripleGrad); REGISTER_OP_CPU_KERNEL( matmul_v2, ops::MatMulV2Kernel, @@ -385,3 +459,12 @@ REGISTER_OP_CPU_KERNEL( paddle::platform::complex>, ops::MatMulV2DoubleGradKernel>); + +REGISTER_OP_CPU_KERNEL( + matmul_v2_triple_grad, + ops::MatMulV2TripleGradKernel, + ops::MatMulV2TripleGradKernel, + ops::MatMulV2TripleGradKernel>, + ops::MatMulV2TripleGradKernel>); diff --git a/paddle/fluid/operators/matmul_v2_op.cu b/paddle/fluid/operators/matmul_v2_op.cu index b258077456e1ed54483448bd395f6330447e7621..c9602a1eab93197d14cb186c150e82b2e04e3e2d 100644 --- a/paddle/fluid/operators/matmul_v2_op.cu +++ b/paddle/fluid/operators/matmul_v2_op.cu @@ -40,3 +40,13 @@ REGISTER_OP_CUDA_KERNEL( paddle::platform::complex>, ops::MatMulV2DoubleGradKernel>); + +REGISTER_OP_CUDA_KERNEL( + matmul_v2_triple_grad, + ops::MatMulV2TripleGradKernel, + ops::MatMulV2TripleGradKernel, + ops::MatMulV2TripleGradKernel, + ops::MatMulV2TripleGradKernel>, + ops::MatMulV2TripleGradKernel>); diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index dd9940db29f7739b54a5fe26d89746f0eceb2b2c..2ba82243fe6b5b5bf8597ad2230fcce595174681 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -779,6 +779,421 @@ struct DotDoubleGradFunction> { } }; +template +struct DotTripleGradFunction { + void operator()(const Tensor* in_tensor_x, const Tensor* in_tensor_y, + const Tensor* in_tensor_ddx, const Tensor* in_tensor_ddy, + const Tensor* in_tensor_d_dx, const Tensor* in_tensor_d_dy, + const Tensor* in_tensor_dout, const Tensor* in_tensor_d_ddout, + Tensor* out_tensor_d_x, Tensor* out_tensor_d_y, + Tensor* out_tensor_d_dout, Tensor* out_tensor_d_ddx, + Tensor* out_tensor_d_ddy, + const paddle::framework::ExecutionContext& ctx); +}; + +// TODO(wuweilong): enable this function when the unittests framewark for multi +// grad is ok (dtype: complex64 or complex128). +template +struct DotTripleGradFunction> { + void operator()(const Tensor* in_tensor_x, const Tensor* in_tensor_y, + const Tensor* in_tensor_ddx, const Tensor* in_tensor_ddy, + const Tensor* in_tensor_d_dx, const Tensor* in_tensor_d_dy, + const Tensor* in_tensor_dout, const Tensor* in_tensor_d_ddout, + Tensor* out_tensor_d_x, Tensor* out_tensor_d_y, + Tensor* out_tensor_d_dout, Tensor* out_tensor_d_ddx, + Tensor* out_tensor_d_ddy, + const paddle::framework::ExecutionContext& ctx) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == in_tensor_d_ddout->dims().size()) { + framework::Tensor in_tensor_d_ddout_help; + auto& dev_raw = ctx.template device_context(); + auto& dev = *dev_raw.eigen_device(); + if (out_tensor_d_x || out_tensor_d_y) { + in_tensor_d_ddout_help.Resize(in_tensor_d_ddout->dims()); + in_tensor_d_ddout_help.mutable_data(ctx.GetPlace()); + paddle::platform::ForRange for_range( + dev_raw, in_tensor_d_ddout->numel()); + math::ConjFunctor functor(in_tensor_d_ddout->data(), + in_tensor_d_ddout->numel(), + in_tensor_d_ddout_help.data()); + for_range(functor); + } + if (out_tensor_d_x) { + auto ddy = framework::EigenVector::Flatten(*in_tensor_ddy); + Eigen::DSizes size(in_tensor_ddy->numel()); + auto d_x = framework::EigenVector::Flatten(*out_tensor_d_x); + auto d_ddout = + framework::EigenVector::Flatten(in_tensor_d_ddout_help); + d_x.device(dev) = ddy * d_ddout.broadcast(size); + } + + if (out_tensor_d_y) { + auto ddx = framework::EigenVector::Flatten(*in_tensor_ddx); + Eigen::DSizes size(in_tensor_ddx->numel()); + auto d_y = framework::EigenVector::Flatten(*out_tensor_d_y); + auto d_ddout = + framework::EigenVector::Flatten(in_tensor_d_ddout_help); + d_y.device(dev) = ddx * d_ddout.broadcast(size); + } + + if (out_tensor_d_dout) { + framework::Tensor in_tensor_ddx_help, in_tensor_ddy_help; + in_tensor_ddx_help.Resize(in_tensor_ddx->dims()); + in_tensor_ddx_help.mutable_data(ctx.GetPlace()); + in_tensor_ddy_help.Resize(in_tensor_ddy->dims()); + in_tensor_ddy_help.mutable_data(ctx.GetPlace()); + + auto& dev_raw = ctx.template device_context(); + auto& dev = *dev_raw.eigen_device(); + paddle::platform::ForRange for_range( + dev_raw, in_tensor_ddx->numel()); + math::ConjFunctor functor_ddx(in_tensor_ddx->data(), + in_tensor_ddx->numel(), + in_tensor_ddx_help.data()); + for_range(functor_ddx); + math::ConjFunctor functor_ddy(in_tensor_ddy->data(), + in_tensor_ddy->numel(), + in_tensor_ddy_help.data()); + for_range(functor_ddy); + auto ddx = framework::EigenVector::Flatten(in_tensor_ddx_help); + auto ddy = framework::EigenVector::Flatten(in_tensor_ddy_help); + auto d_dx = framework::EigenVector::Flatten(*in_tensor_d_dx); + auto d_dy = framework::EigenVector::Flatten(*in_tensor_d_dy); + auto d_dout = framework::EigenVector::Flatten(*out_tensor_d_dout); + d_dout.device(dev) = (ddx * d_dy + ddy * d_dx).sum(); + } + if (out_tensor_d_ddx) { + framework::Tensor in_tensor_dout_help, in_tensor_y_help; + in_tensor_dout_help.Resize(in_tensor_dout->dims()); + in_tensor_dout_help.mutable_data(ctx.GetPlace()); + in_tensor_y_help.Resize(in_tensor_y->dims()); + in_tensor_y_help.mutable_data(ctx.GetPlace()); + + auto& dev_raw = ctx.template device_context(); + auto& dev = *dev_raw.eigen_device(); + paddle::platform::ForRange for_range( + dev_raw, in_tensor_dout->numel()); + math::ConjFunctor functor_dout(in_tensor_dout->data(), + in_tensor_dout->numel(), + in_tensor_dout_help.data()); + for_range(functor_dout); + math::ConjFunctor functor_y(in_tensor_y->data(), + in_tensor_y->numel(), + in_tensor_y_help.data()); + for_range(functor_y); + auto dout = framework::EigenVector::Flatten(in_tensor_dout_help); + auto y = framework::EigenVector::Flatten(in_tensor_y_help); + auto d_ddout = framework::EigenVector::Flatten(*in_tensor_d_ddout); + auto d_dy = framework::EigenVector::Flatten(*in_tensor_d_dy); + auto d_ddx = framework::EigenVector::Flatten(*out_tensor_d_ddx); + Eigen::DSizes size(in_tensor_y->numel()); + d_ddx.device(dev) = + (dout.broadcast(size) * d_dy + y * d_ddout.broadcast(size)); + } + if (out_tensor_d_ddy) { + framework::Tensor in_tensor_dout_help, in_tensor_x_help; + in_tensor_dout_help.Resize(in_tensor_dout->dims()); + in_tensor_dout_help.mutable_data(ctx.GetPlace()); + in_tensor_x_help.Resize(in_tensor_x->dims()); + in_tensor_x_help.mutable_data(ctx.GetPlace()); + + auto& dev_raw = ctx.template device_context(); + auto& dev = *dev_raw.eigen_device(); + paddle::platform::ForRange for_range( + dev_raw, in_tensor_dout->numel()); + math::ConjFunctor functor_dout(in_tensor_dout->data(), + in_tensor_dout->numel(), + in_tensor_dout_help.data()); + for_range(functor_dout); + math::ConjFunctor functor_x(in_tensor_x->data(), + in_tensor_x->numel(), + in_tensor_x_help.data()); + for_range(functor_x); + auto dout = framework::EigenVector::Flatten(in_tensor_dout_help); + auto x = framework::EigenVector::Flatten(in_tensor_x_help); + auto d_ddout = framework::EigenVector::Flatten(*in_tensor_d_ddout); + auto d_dx = framework::EigenVector::Flatten(*in_tensor_d_dx); + auto d_ddy = framework::EigenVector::Flatten(*out_tensor_d_ddy); + Eigen::DSizes size(in_tensor_x->numel()); + d_ddy.device(dev) = + (dout.broadcast(size) * d_dx + x * d_ddout.broadcast(size)); + } + } +#else + const auto* data_d_ddout = in_tensor_d_ddout->data(); + + if (out_tensor_d_x) { + auto* data_d_x = out_tensor_d_x->mutable_data(ctx.GetPlace()); + const auto* data_ddy = in_tensor_ddy->data(); + + const framework::DDim& dim = out_tensor_d_x->dims(); + size_t N = static_cast(framework::product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_x[i] = T(data_ddy[i].real, -data_ddy[i].imag) * data_d_ddout[s]; + } + } + + if (out_tensor_d_y) { + auto* data_d_y = out_tensor_d_y->mutable_data(ctx.GetPlace()); + const auto* data_ddx = in_tensor_ddx->data(); + + const framework::DDim& dim = out_tensor_d_y->dims(); + size_t N = static_cast(framework::product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_y[i] = T(data_ddx[i].real, -data_ddx[i].imag) * data_d_ddout[s]; + } + } + + if (out_tensor_d_dout) { + auto* data_d_dout = out_tensor_d_dout->mutable_data(ctx.GetPlace()); + auto* data_ddx = in_tensor_ddx->data(); + auto* data_ddy = in_tensor_ddy->data(); + auto* data_d_dx = in_tensor_d_dx->data(); + auto* data_d_dy = in_tensor_d_dy->data(); + + const framework::DDim& dim = out_tensor_d_dout->dims(); + size_t N = static_cast(framework::product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + bool new_s = false; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) { + ++s; + new_s = true; + } + if (new_s) { + data_d_dout[s] = + T(data_ddy[i].real, -data_ddy[i].imag) * data_d_dx[i] + + T(data_ddx[i].real, -data_ddx[i].imag) * data_d_dy[i]; + } else { + data_d_dout[s] += + T(data_ddy[i].real, -data_ddy[i].imag) * data_d_dx[i] + + T(data_ddx[i].real, -data_ddx[i].imag) * data_d_dy[i]; + } + new_s = false; + } + } + + if (out_tensor_d_ddx) { + auto* data_d_ddx = out_tensor_d_ddx->mutable_data(ctx.GetPlace()); + auto* data_dout = in_tensor_dout->data(); + auto* data_d_dy = in_tensor_d_dy->data(); + auto* data_y = in_tensor_y->data(); + auto* data_d_ddout = in_tensor_d_ddout->data(); + + const framework::DDim& dim = out_tensor_d_ddx->dims(); + size_t N = static_cast(framework::product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_ddx[i] = + T(data_dout[s].real, -data_dout[s].imag) * data_d_dy[i] + + T(data_y[i].real, -data_y[i].imag) * data_d_ddout[s]; + } + } + + if (out_tensor_d_ddy) { + auto* data_d_ddy = out_tensor_d_ddy->mutable_data(ctx.GetPlace()); + auto* data_dout = in_tensor_dout->data(); + auto* data_d_dx = in_tensor_d_dx->data(); + auto* data_x = in_tensor_x->data(); + auto* data_d_ddout = in_tensor_d_ddout->data(); + + const framework::DDim& dim = out_tensor_d_ddy->dims(); + size_t N = static_cast(framework::product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_ddy[i] = + T(data_dout[s].real, -data_dout[s].imag) * data_d_dx[i] + + T(data_x[i].real, -data_x[i].imag) * data_d_ddout[s]; + } + } +#endif + } +}; + +template +struct DotTripleGradFunction> { + void operator()(const Tensor* in_tensor_x, const Tensor* in_tensor_y, + const Tensor* in_tensor_ddx, const Tensor* in_tensor_ddy, + const Tensor* in_tensor_d_dx, const Tensor* in_tensor_d_dy, + const Tensor* in_tensor_dout, const Tensor* in_tensor_d_ddout, + Tensor* out_tensor_d_x, Tensor* out_tensor_d_y, + Tensor* out_tensor_d_dout, Tensor* out_tensor_d_ddx, + Tensor* out_tensor_d_ddy, + const paddle::framework::ExecutionContext& ctx) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == in_tensor_d_ddout->dims().size()) { + auto& dev_raw = ctx.template device_context(); + auto& dev = *dev_raw.eigen_device(); + auto d_ddout = framework::EigenVector::Flatten(*in_tensor_d_ddout); + if (out_tensor_d_x) { + out_tensor_d_x->mutable_data(ctx.GetPlace()); + auto ddy = framework::EigenVector::Flatten(*in_tensor_ddy); + Eigen::DSizes size(in_tensor_ddy->numel()); + auto d_x = framework::EigenVector::Flatten(*out_tensor_d_x); + d_x.device(dev) = ddy * d_ddout.broadcast(size); + } + + if (out_tensor_d_y) { + out_tensor_d_y->mutable_data(ctx.GetPlace()); + auto ddx = framework::EigenVector::Flatten(*in_tensor_ddx); + Eigen::DSizes size(in_tensor_ddx->numel()); + + auto d_y = framework::EigenVector::Flatten(*out_tensor_d_y); + d_y.device(dev) = ddx * d_ddout.broadcast(size); + } + + if (out_tensor_d_dout) { + out_tensor_d_dout->mutable_data(ctx.GetPlace()); + auto ddx = framework::EigenVector::Flatten(*in_tensor_ddx); + auto ddy = framework::EigenVector::Flatten(*in_tensor_ddy); + auto d_dx = framework::EigenVector::Flatten(*in_tensor_d_dx); + auto d_dy = framework::EigenVector::Flatten(*in_tensor_d_dy); + auto d_dout = framework::EigenVector::Flatten(*out_tensor_d_dout); + d_dout.device(dev) = (ddx * d_dy + ddy * d_dx).sum(); + } + + if (out_tensor_d_ddx) { + out_tensor_d_ddx->mutable_data(ctx.GetPlace()); + auto dout = framework::EigenVector::Flatten(*in_tensor_dout); + auto y = framework::EigenVector::Flatten(*in_tensor_y); + auto d_ddout = framework::EigenVector::Flatten(*in_tensor_d_ddout); + auto d_dy = framework::EigenVector::Flatten(*in_tensor_d_dy); + auto d_ddx = framework::EigenVector::Flatten(*out_tensor_d_ddx); + Eigen::DSizes size(in_tensor_y->numel()); + d_ddx.device(dev) = + (dout.broadcast(size) * d_dy + y * d_ddout.broadcast(size)); + } + + if (out_tensor_d_ddy) { + out_tensor_d_ddy->mutable_data(ctx.GetPlace()); + auto dout = framework::EigenVector::Flatten(*in_tensor_dout); + auto x = framework::EigenVector::Flatten(*in_tensor_x); + auto d_ddout = framework::EigenVector::Flatten(*in_tensor_d_ddout); + auto d_dx = framework::EigenVector::Flatten(*in_tensor_d_dx); + auto d_ddy = framework::EigenVector::Flatten(*out_tensor_d_ddy); + Eigen::DSizes size(in_tensor_x->numel()); + d_ddy.device(dev) = + (dout.broadcast(size) * d_dx + x * d_ddout.broadcast(size)); + } + } +#else + const auto* data_d_ddout = in_tensor_d_ddout->data(); + + if (out_tensor_d_x) { + auto* data_d_x = out_tensor_d_x->mutable_data(ctx.GetPlace()); + const auto* data_ddy = in_tensor_ddy->data(); + + const framework::DDim& dim = out_tensor_d_x->dims(); + size_t N = static_cast(framework::product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_x[i] = data_ddy[i] * data_d_ddout[s]; + } + } + + if (out_tensor_d_y) { + auto* data_d_y = out_tensor_d_y->mutable_data(ctx.GetPlace()); + const auto* data_ddx = in_tensor_ddx->data(); + + const framework::DDim& dim = out_tensor_d_y->dims(); + size_t N = static_cast(framework::product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_y[i] = data_ddx[i] * data_d_ddout[s]; + } + } + + if (out_tensor_d_dout) { + auto* data_d_dout = out_tensor_d_dout->mutable_data(ctx.GetPlace()); + auto* data_ddx = in_tensor_ddx->data(); + auto* data_ddy = in_tensor_ddy->data(); + auto* data_d_dx = in_tensor_d_dx->data(); + auto* data_d_dy = in_tensor_d_dy->data(); + + const framework::DDim& dim = in_tensor_ddx->dims(); + size_t N = static_cast(framework::product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + bool new_s = false; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) { + ++s; + new_s = true; + } + if (new_s) { + data_d_dout[s] = + data_ddy[i] * data_d_dx[i] + data_ddx[i] * data_d_dy[i]; + } else { + data_d_dout[s] += + data_ddy[i] * data_d_dx[i] + data_ddx[i] * data_d_dy[i]; + } + new_s = false; + } + } + + if (out_tensor_d_ddx) { + auto* data_d_ddx = out_tensor_d_ddx->mutable_data(ctx.GetPlace()); + auto* data_dout = in_tensor_dout->data(); + auto* data_d_dy = in_tensor_d_dy->data(); + auto* data_y = in_tensor_y->data(); + auto* data_d_ddout = in_tensor_d_ddout->data(); + + const framework::DDim& dim = out_tensor_d_ddx->dims(); + size_t N = static_cast(framework::product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_ddx[i] = + data_dout[s] * data_d_dy[i] + data_y[i] * data_d_ddout[s]; + } + } + + if (out_tensor_d_ddy) { + auto* data_d_ddy = out_tensor_d_ddy->mutable_data(ctx.GetPlace()); + auto* data_dout = in_tensor_dout->data(); + auto* data_d_dx = in_tensor_d_dx->data(); + auto* data_x = in_tensor_x->data(); + auto* data_d_ddout = in_tensor_d_ddout->data(); + + const framework::DDim& dim = out_tensor_d_ddy->dims(); + size_t N = static_cast(framework::product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_ddy[i] = + data_dout[s] * data_d_dx[i] + data_x[i] * data_d_ddout[s]; + } + } +#endif + } +}; + template class MatMulV2GradKernel : public framework::OpKernel { public: @@ -1322,7 +1737,7 @@ class MatMulV2DoubleGradKernel : public framework::OpKernel { } if (ddout) { - // Caluate the gradient of OutputGrad(Out) + // Calculate the gradient of OutputGrad(Out) MatMulFunction(ddx, &y_conj, x_dims, y_dims, ddout, transpose_x, transpose_y, context); MatMulFunction(&x_conj, ddy, x_dims, y_dims, ddout, @@ -1332,5 +1747,609 @@ class MatMulV2DoubleGradKernel : public framework::OpKernel { } } }; + +template +class MatMulV2TripleGradKernel : public framework::OpKernel { + public: + void MatMul(const framework::ExecutionContext& context, + const framework::Tensor& a, bool trans_a, + const framework::Tensor& b, bool trans_b, framework::Tensor* out, + bool flag) const { + out->mutable_data(context.GetPlace()); + auto blas = math::GetBlas(context); + auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); + auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); + if (a.dims().size() == 3 && b.dims().size() <= 2) { + // the transpose_X must be false, if is true, the transpose cost much time + if (!trans_a) { + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + } + } + blas.MatMul(a, mat_dim_a, b, mat_dim_b, static_cast(1), out, + static_cast(flag)); + } + + void CalcInputGrad(const framework::ExecutionContext& context, + const framework::Tensor& a, bool trans_a, + bool is_fold_init_dims_a, const framework::Tensor& b, + bool trans_b, bool is_fold_init_dims_b, + framework::Tensor* out, bool flag) const { + if (out == nullptr) return; + bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && + out->dims().size() == 2; + if (!need_combine) { + MatMul(context, a, trans_a, b, trans_b, out, flag); + } else { + auto& ctx = context.template device_context(); + MatMul(context, is_fold_init_dims_a + ? FoldInitDims(a) + : FoldHeadAndLastDims(ctx, a), + trans_a, is_fold_init_dims_b + ? FoldInitDims(b) + : FoldHeadAndLastDims(ctx, b), + trans_b, out, flag); + } + } + + void Compute(const framework::ExecutionContext& context) const override { + // get input + auto x = *context.Input("X"); + auto y = *context.Input("Y"); + auto dout = *context.Input("DOut"); + auto ddx = *context.Input("DDX"); + auto ddy = *context.Input("DDY"); + + auto* d_dx = context.Input("D_DX"); + auto* d_dy = context.Input("D_DY"); + auto* d_ddout = context.Input("D_DDOut"); + + // get output + auto* out_d_x = context.Output("D_X_out"); + auto* out_d_y = context.Output("D_Y_out"); + auto* out_d_dout = context.Output("D_DOut_out"); + + auto* out_d_ddx = context.Output("D_DDX_out"); + auto* out_d_ddy = context.Output("D_DDY_out"); + + bool transpose_x = context.Attr("trans_x"); + bool transpose_y = context.Attr("trans_y"); + + // Get dims from the input x, y, output_grad + std::vector x_dims = vectorize(x.dims()); + std::vector y_dims = vectorize(y.dims()); + std::vector dout_dims = vectorize(dout.dims()); + framework::Tensor x_conj(x.type()); + framework::Tensor y_conj(y.type()); + framework::Tensor dout_conj(dout.type()); + framework::Tensor ddx_conj(ddx.type()); + framework::Tensor ddy_conj(ddy.type()); + + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + int ndim = dout_dims.size(); + + // Case1 : x's and y's dim = 1 + if (x_ndim == 1 && y_ndim == 1) { + VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 1"; + + DotTripleGradFunction()( + &x, &y, &ddx, &ddy, d_dx, d_dy, &dout, d_ddout, out_d_x, out_d_y, + out_d_dout, out_d_ddx, out_d_ddy, context); + return; + } + + bool is_broadcast = true; + if (x_ndim <= 2 || y_ndim <= 2) { + is_broadcast = false; + } else if (x_ndim != y_ndim) { + is_broadcast = true; + } else { + is_broadcast = !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, + y_dims.cbegin()); + } + + if (!is_broadcast) { + // Case2: no broadcast or no batch size + VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 2"; + ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); + + if (ddx.dims() != x.dims()) { + ddx.Resize(x.dims()); + } + + if (ddy.dims() != y.dims()) { + ddy.Resize(y.dims()); + } + + ConjHelper conj_helper(context); + + framework::DDim out_dx_dims; + if (out_d_x) { + out_dx_dims = out_d_x->dims(); + if (out_dx_dims != x.dims()) { + out_d_x->Resize(x.dims()); + } + } + + framework::DDim out_dy_dims; + if (out_d_y) { + out_dy_dims = out_d_y->dims(); + if (out_dy_dims != y.dims()) { + out_d_y->Resize(y.dims()); + } + } + + framework::DDim out_d_dout_dims; + if (out_d_dout) { + out_d_dout_dims = out_d_dout->dims(); + if (out_d_dout_dims != dout.dims()) { + out_d_dout->Resize(dout.dims()); + } + } + + framework::DDim out_d_ddx_dims; + if (out_d_ddx) { + out_d_ddx_dims = out_d_ddx->dims(); + if (out_d_ddx_dims != x.dims()) { + out_d_ddx->Resize(x.dims()); + } + } + + framework::DDim out_d_ddy_dims; + if (out_d_ddy) { + out_d_ddy_dims = out_d_ddy->dims(); + if (out_d_ddy_dims != y.dims()) { + out_d_ddy->Resize(y.dims()); + } + } + + if (out_d_dout) { + ConjHelper conj_helper(context); + conj_helper(ddx, ddx_conj); + conj_helper(ddy, ddy_conj); + } + + if (out_d_ddx || out_d_ddy) { + ConjHelper conj_helper(context); + conj_helper(x, x_conj); + conj_helper(y, y_conj); + conj_helper(dout, dout_conj); + } + + bool d_dout_flag = false; + bool d_ddx_flag = false; + bool d_ddy_flag = false; + + if (d_ddout) { + auto d_ddout_mat = *d_ddout; + if (d_ddout_mat.dims() != dout.dims()) { + d_ddout_mat.Resize(dout.dims()); + } + + if (out_d_y) { + if (transpose_x && transpose_y) { + // out_d_y = d_ddout' * ddx' + CalcInputGrad(context, d_ddout_mat, true, true, ddx_conj, true, + false, out_d_y, false); + } else if (transpose_x) { + // out_d_y = ddx * d_ddout + CalcInputGrad(context, ddx_conj, false, false, d_ddout_mat, false, + true, out_d_y, false); + } else if (transpose_y) { + // out_d_y = d_ddout' * ddx + CalcInputGrad(context, d_ddout_mat, true, true, ddx_conj, false, + true, out_d_y, false); + } else { + // out_d_y = ddx' * d_ddout + CalcInputGrad(context, ddx_conj, true, true, d_ddout_mat, false, + true, out_d_y, false); + } + } + + if (out_d_x) { + if (transpose_x && transpose_y) { + // out_d_x = ddy' * d_ddout' + CalcInputGrad(context, ddy_conj, true, true, d_ddout_mat, true, + false, out_d_x, false); + } else if (transpose_x) { + // out_d_x = ddy * d_ddout' + CalcInputGrad(context, ddy_conj, false, false, d_ddout_mat, true, + false, out_d_x, false); + } else if (transpose_y) { + // out_d_x = d_ddout * ddy + CalcInputGrad(context, d_ddout_mat, false, false, ddy_conj, false, + true, out_d_x, false); + } else { + // out_d_x = d_ddout * ddy' + CalcInputGrad(context, d_ddout_mat, false, false, ddy_conj, true, + false, out_d_x, false); + } + } + + // equations: + // d_ddx = DOut * D_DY + Y * D_DDOut + // Let: d_ddx1 = Y * D_DDOut + // Let: d_ddx2 = DOut * D_DY + + // d_ddy = DOut * D_DX + X * D_DDOut + // Let: d_ddy1 = X * D_DDOut + // Let: d_ddy2 = DOut * D_DX + + // d_dout = DDY * D_DX + DDX * D_DY + // Let: d_dout1 = DDX * D_DY + // Let: d_dout2 = DDY * D_DX + + // compute d_ddx1 + if (out_d_ddx) { + if (transpose_x && transpose_y) { + // out_d_ddx1 = y' * d_ddout' + CalcInputGrad(context, y_conj, true, true, d_ddout_mat, true, false, + out_d_ddx, d_ddx_flag); + } else if (transpose_x) { + // out_d_ddx1 = y * d_ddout' + CalcInputGrad(context, y_conj, false, false, d_ddout_mat, true, + false, out_d_ddx, d_ddx_flag); + } else if (transpose_y) { + // out_d_ddx1 = d_ddout * y + CalcInputGrad(context, d_ddout_mat, false, false, y_conj, false, + true, out_d_ddx, d_ddx_flag); + } else { + // out_d_ddx1 = d_ddout * y' + CalcInputGrad(context, d_ddout_mat, false, false, y_conj, true, + false, out_d_ddx, d_ddx_flag); + } + d_ddx_flag = true; + } + + // compute d_ddy1 + if (out_d_ddy) { + if (transpose_x && transpose_y) { + // out_d_ddy1 = d_ddout' * x' + CalcInputGrad(context, d_ddout_mat, true, true, x_conj, true, false, + out_d_ddy, false); + } else if (transpose_x) { + // out_d_ddy1 = x * d_ddout + CalcInputGrad(context, x_conj, false, false, d_ddout_mat, false, + true, out_d_ddy, false); + } else if (transpose_y) { + // out_d_ddy1 = d_ddout' * x + CalcInputGrad(context, d_ddout_mat, true, true, x_conj, false, true, + out_d_ddy, false); + } else { + // out_d_ddy1 = x' * d_ddout + CalcInputGrad(context, x_conj, true, true, d_ddout_mat, false, true, + out_d_ddy, false); + } + d_ddy_flag = true; + } + } + + if (d_dy) { + auto d_dy_mat = *d_dy; + if (d_dy_mat.dims() != y.dims()) { + d_dy_mat.Resize(y.dims()); + } + + // compute d_dout1 + if (out_d_dout) { + CalcInputGrad(context, ddx_conj, transpose_x, true, d_dy_mat, + transpose_y, false, out_d_dout, d_dout_flag); + d_dout_flag = true; + } + + // compute d_ddx2 + if (out_d_ddx) { + if (transpose_x && transpose_y) { + // out_d_ddx2 = D_DY' * DOut' + CalcInputGrad(context, d_dy_mat, true, true, dout_conj, true, false, + out_d_ddx, d_ddx_flag); + } else if (transpose_x) { + // out_d_ddx2 = D_DY * Dout' + CalcInputGrad(context, d_dy_mat, false, false, dout_conj, true, + false, out_d_ddx, d_ddx_flag); + } else if (transpose_y) { + // out_d_ddx2 = Dout * D_DY + CalcInputGrad(context, dout_conj, false, false, d_dy_mat, false, + true, out_d_ddx, d_ddx_flag); + } else { + // out_d_ddx2 = Dout * D_DY' + CalcInputGrad(context, dout_conj, false, false, d_dy_mat, true, + false, out_d_ddx, d_ddx_flag); + } + } + } + + if (d_dx) { + auto d_dx_mat = *d_dx; + if (d_dx_mat.dims() != x.dims()) { + d_dx_mat.Resize(x.dims()); + } + + // compute d_dout2 + if (out_d_dout) { + CalcInputGrad(context, d_dx_mat, transpose_x, true, ddy_conj, + transpose_y, false, out_d_dout, d_dout_flag); + } + + // compute d_ddy2 + if (out_d_ddy) { + if (transpose_x && transpose_y) { + // out_d_ddy2 = dout' * d_dx' + CalcInputGrad(context, dout_conj, true, true, d_dx_mat, true, false, + out_d_ddy, d_ddy_flag); + } else if (transpose_x) { + // out_d_ddy2 = d_dx * dout + CalcInputGrad(context, d_dx_mat, false, false, dout_conj, false, + true, out_d_ddy, d_ddy_flag); + } else if (transpose_y) { + // out_d_ddy2 = dout' * d_dx + CalcInputGrad(context, dout_conj, true, true, d_dx_mat, false, true, + out_d_ddy, d_ddy_flag); + } else { + // out_d_ddy2 = d_dx' * dout + CalcInputGrad(context, d_dx_mat, true, true, dout_conj, false, true, + out_d_ddy, d_ddy_flag); + } + } + } + + if (out_d_x) { + if (out_dx_dims != x.dims()) { + out_d_x->Resize(out_dx_dims); + } + } + + if (out_d_y) { + if (out_dy_dims != y.dims()) { + out_d_y->Resize(out_dy_dims); + } + } + + if (out_d_dout) { + if (out_d_dout_dims != dout.dims()) { + out_d_dout->Resize(out_d_dout_dims); + } + } + + if (out_d_ddx) { + if (out_d_ddx_dims != x.dims()) { + out_d_ddx->Resize(out_d_ddx_dims); + } + } + + if (out_d_ddy) { + if (out_d_ddy_dims != x.dims()) { + out_d_ddy->Resize(out_d_ddy_dims); + } + } + + } else { + // Case3: broadcast. It need cost much time to reduce sum for the + // broadcast and wastes the memory. + // So we should avoid the case in reality. + VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 3"; + VLOG(3) << "It need cost much time to reduce sum for the broadcast and " + "wastes the memory. So we should avoid the case in reality"; + + Tensor out_dx_help, out_dy_help; + Tensor out_d_ddx_help, out_d_ddy_help; + if (out_d_dout) { + ConjHelper conj_helper(context); + conj_helper(ddx, ddx_conj); + conj_helper(ddy, ddy_conj); + } + if (out_d_ddx || out_d_ddy) { + ConjHelper conj_helper(context); + conj_helper(x, x_conj); + conj_helper(y, y_conj); + conj_helper(dout, dout_conj); + } + + if (transpose_x) { + if (transpose_y) { + // dX = ddY' d_ddout’, dY = d_ddout’ ddX' + if (out_d_x) + MatMulFunction(&ddy_conj, d_ddout, y_dims, + dout_dims, &out_dx_help, true, + true, context); + if (out_d_y) + MatMulFunction(d_ddout, &ddx_conj, dout_dims, + x_dims, &out_dy_help, true, true, + context); + } else { + // dX = ddY d_ddout', dY = ddX d_ddout + if (out_d_x) + MatMulFunction(&ddy_conj, d_ddout, y_dims, + dout_dims, &out_dx_help, false, + true, context); + if (out_d_y) + MatMulFunction(&ddx_conj, d_ddout, x_dims, + dout_dims, &out_dy_help, false, + false, context); + } + } else { + if (transpose_y) { + // dX = d_ddout ddY, dY = d_ddout’ ddX + if (out_d_x) + MatMulFunction(d_ddout, &ddy_conj, dout_dims, + y_dims, &out_dx_help, false, false, + context); + if (out_d_y) + MatMulFunction(d_ddout, &ddx_conj, dout_dims, + x_dims, &out_dy_help, true, false, + context); + } else { + // dX = d_ddout ddY', dY = ddX' d_ddout + if (out_d_x) + MatMulFunction(d_ddout, &ddy_conj, dout_dims, + y_dims, &out_dx_help, false, true, + context); + if (out_d_y) + MatMulFunction(&ddx_conj, d_ddout, x_dims, + dout_dims, &out_dy_help, true, + false, context); + } + } + + // get help dims + const std::vector dx_help_dims = + vectorize(out_dx_help.dims()); + const std::vector dy_help_dims = + vectorize(out_dx_help.dims()); + + std::vector dx_broadcast_dims(ndim); + std::vector dy_broadcast_dims(ndim); + + std::fill(dx_broadcast_dims.data(), + dx_broadcast_dims.data() + ndim - x_ndim, 1); + std::fill(dy_broadcast_dims.data(), + dy_broadcast_dims.data() + ndim - y_ndim, 1); + std::copy(x_dims.data(), x_dims.data() + x_ndim, + dx_broadcast_dims.data() + ndim - x_ndim); + std::copy(y_dims.data(), y_dims.data() + y_ndim, + dy_broadcast_dims.data() + ndim - y_ndim); + + std::vector dx_reduce_dims; + std::vector dy_reduce_dims; + for (int idx = 0; idx <= ndim - 3; idx++) { + if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { + dx_reduce_dims.push_back(idx); + } + if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { + dy_reduce_dims.push_back(idx); + } + } + // Reduce sum to get grad by ReduceSum + if (out_d_x) { + if (dx_reduce_dims.empty()) { + *out_d_x = std::move(out_dx_help); + } else { + ReduceSumForMatmulGrad(&out_dx_help, out_d_x, + dx_reduce_dims, context); + } + out_d_x->Resize(x.dims()); + } + + if (out_d_y) { + if (dy_reduce_dims.empty()) { + *out_d_y = std::move(out_dy_help); + } else { + ReduceSumForMatmulGrad(&out_dy_help, out_d_y, + dy_reduce_dims, context); + } + out_d_y->Resize(y.dims()); + } + + // compute d_dout + if (out_d_dout) { + MatMulFunction(d_dx, &ddy_conj, x_dims, y_dims, + out_d_dout, transpose_x, transpose_y, + context); + MatMulFunction(&ddx_conj, d_dy, x_dims, y_dims, + out_d_dout, transpose_x, transpose_y, + context, true); + } + + // compute d_ddx + if (out_d_ddx) { + if (transpose_x && transpose_y) { + // out_d_ddx1 = y' * d_ddout' + MatMulFunction(&y_conj, d_ddout, y_dims, dout_dims, + &out_d_ddx_help, true, true, + context); + // out_d_ddx2 = D_DY' * DOut' + MatMulFunction(d_dy, &dout_conj, y_dims, dout_dims, + &out_d_ddx_help, true, true, context, + true); + } else if (transpose_x) { + // out_d_ddx1 = y * d_ddout' + MatMulFunction(&y_conj, d_ddout, y_dims, dout_dims, + &out_d_ddx_help, false, true, + context); + // out_d_ddx2 = D_DY * Dout' + MatMulFunction(d_dy, &dout_conj, y_dims, dout_dims, + &out_d_ddx_help, false, true, + context, true); + } else if (transpose_y) { + // out_d_ddx1 = d_ddout * y + MatMulFunction(d_ddout, &y_conj, dout_dims, y_dims, + &out_d_ddx_help, false, false, + context); + // out_d_ddx2 = Dout * D_DY + MatMulFunction(&dout_conj, d_dy, dout_dims, y_dims, + &out_d_ddx_help, false, false, + context, true); + } else { + // out_d_ddx1 = d_ddout * y' + MatMulFunction(d_ddout, &y_conj, dout_dims, y_dims, + &out_d_ddx_help, false, true, + context); + // out_d_ddx2 = Dout * D_DY' + MatMulFunction(&dout_conj, d_dy, dout_dims, y_dims, + &out_d_ddx_help, false, true, + context, true); + } + if (dx_reduce_dims.empty()) { + *out_d_ddx = std::move(out_d_ddx_help); + } else { + ReduceSumForMatmulGrad(&out_d_ddx_help, out_d_ddx, + dx_reduce_dims, context); + } + out_d_ddx->Resize(x.dims()); + } + + // compute d_ddy + if (out_d_ddy) { + if (transpose_x && transpose_y) { + // out_d_ddy1 = d_ddout' * x' + MatMulFunction(d_ddout, &x_conj, dout_dims, x_dims, + &out_d_ddy_help, true, true, + context); + // out_d_ddy2 = dout' * d_dx' + MatMulFunction(&dout_conj, d_dx, dout_dims, x_dims, + &out_d_ddy_help, true, true, context, + true); + } else if (transpose_x) { + // out_d_ddy1 = x * d_ddout + MatMulFunction(&x_conj, d_ddout, x_dims, dout_dims, + &out_d_ddy_help, false, false, + context); + // out_d_ddy2 = d_dx * dout + MatMulFunction(d_dx, &dout_conj, x_dims, dout_dims, + &out_d_ddy_help, false, false, + context, true); + } else if (transpose_y) { + // out_d_ddy1 = d_ddout' * x + MatMulFunction(d_ddout, &x_conj, dout_dims, x_dims, + &out_d_ddy_help, true, false, + context); + // out_d_ddy2 = dout' * d_dx + MatMulFunction(&dout_conj, d_dx, dout_dims, x_dims, + &out_d_ddy_help, true, false, + context, true); + } else { + // out_d_ddy1 = x' * d_ddout + MatMulFunction(&x_conj, d_ddout, x_dims, dout_dims, + &out_d_ddy_help, true, false, + context); + // out_d_ddy2 = d_dx' * dout + MatMulFunction(d_dx, &dout_conj, x_dims, dout_dims, + &out_d_ddy_help, true, false, + context, true); + } + + if (dy_reduce_dims.empty()) { + *out_d_ddy = std::move(out_d_ddy_help); + } else { + ReduceSumForMatmulGrad(&out_d_ddy_help, out_d_ddy, + dy_reduce_dims, context); + } + out_d_ddy->Resize(y.dims()); + } + } + } +}; + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/gradient_checker.py b/python/paddle/fluid/tests/unittests/gradient_checker.py index b56bbc07a7f44f9cc465a316559d40b14ae54e93..dff2b7aa8d8d600f92f83480790712842c12b61c 100644 --- a/python/paddle/fluid/tests/unittests/gradient_checker.py +++ b/python/paddle/fluid/tests/unittests/gradient_checker.py @@ -304,7 +304,6 @@ def grad_check(x, if b.has_var(xi.name): clone_x.append(b.var(xi.name)) break - analytical.append( _compute_analytical_jacobian(prog, clone_x, clone_y, place, scope)) @@ -486,7 +485,6 @@ def triple_grad_check(x, var_to_np_array_in_scope(scope, place, v.name) for v in x_grads_grads ] - x += y_grads x_init = _as_list(x_init) x_init += y_grads_init diff --git a/python/paddle/fluid/tests/unittests/test_nn_matmul_v2_grad.py b/python/paddle/fluid/tests/unittests/test_nn_matmul_v2_grad.py index 16e8a4a8b00fbe96ba419496e6c1fc5e72a71eb7..6dbabda1f4c342cb2cdf4d15c3d97235fc3e7ca3 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_matmul_v2_grad.py +++ b/python/paddle/fluid/tests/unittests/test_nn_matmul_v2_grad.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -146,5 +146,427 @@ class TestMatmulDoubleGradCheckCase3(unittest.TestCase): self.func(p) +class TestMatmulTripleGradCheckDotCase(unittest.TestCase): + def setUp(self): + self.init_test() + + +def init_test(self): + self.x_shape = [2] + self.y_shape = [2] + self.transpose_x = False + self.transpose_y = False + + +@prog_scope() +def func(self, place): + eps = 0.005 + dtype = np.float64 + typename = "float64" + x = paddle.static.create_parameter( + dtype=typename, shape=self.x_shape, name='x') + y = paddle.static.create_parameter( + dtype=typename, shape=self.y_shape, name='y') + out = paddle.matmul(x, y, self.transpose_x, self.transpose_y, name='out') + np.random.seed(2021) + x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype) + gradient_checker.triple_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + +def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestMatmulTripleGradCheckNormalCase1(unittest.TestCase): + def setUp(self): + self.init_test() + + def init_test(self): + self.x_shape = [2, 2] + self.y_shape = [2, 2] + self.transpose_x = False + self.transpose_y = False + + @prog_scope() + def func(self, place): + eps = 0.005 + dtype = np.float64 + typename = "float64" + x = paddle.static.create_parameter( + dtype=typename, shape=self.x_shape, name='x') + y = paddle.static.create_parameter( + dtype=typename, shape=self.y_shape, name='y') + out = paddle.matmul( + x, y, self.transpose_x, self.transpose_y, name='out') + np.random.seed(2021) + x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype) + gradient_checker.triple_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestMatmulTripleGradCheckNormalCase2(unittest.TestCase): + def setUp(self): + self.init_test() + + def init_test(self): + self.x_shape = [2, 2] + self.y_shape = [2, 2] + self.transpose_x = True + self.transpose_y = False + + @prog_scope() + def func(self, place): + eps = 0.005 + dtype = np.float64 + typename = "float64" + x = paddle.static.create_parameter( + dtype=typename, shape=self.x_shape, name='x') + y = paddle.static.create_parameter( + dtype=typename, shape=self.y_shape, name='y') + out = paddle.matmul( + x, y, self.transpose_x, self.transpose_y, name='out') + np.random.seed(2021) + x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype) + gradient_checker.triple_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestMatmulTripleGradCheckNormalCase3(unittest.TestCase): + def setUp(self): + self.init_test() + + def init_test(self): + self.x_shape = [2, 2] + self.y_shape = [2, 2] + self.transpose_x = False + self.transpose_y = True + + @prog_scope() + def func(self, place): + eps = 0.005 + dtype = np.float64 + typename = "float64" + x = paddle.static.create_parameter( + dtype=typename, shape=self.x_shape, name='x') + y = paddle.static.create_parameter( + dtype=typename, shape=self.y_shape, name='y') + out = paddle.matmul( + x, y, self.transpose_x, self.transpose_y, name='out') + np.random.seed(2021) + x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype) + gradient_checker.triple_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestMatmulTripleGradCheckNormalCase4(unittest.TestCase): + def setUp(self): + self.init_test() + + def init_test(self): + self.x_shape = [2, 2] + self.y_shape = [2, 2] + self.transpose_x = True + self.transpose_y = True + + @prog_scope() + def func(self, place): + eps = 0.005 + dtype = np.float64 + typename = "float64" + x = paddle.static.create_parameter( + dtype=typename, shape=self.x_shape, name='x') + y = paddle.static.create_parameter( + dtype=typename, shape=self.y_shape, name='y') + out = paddle.matmul( + x, y, self.transpose_x, self.transpose_y, name='out') + np.random.seed(2021) + x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype) + gradient_checker.triple_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestMatmulTripleGradCheckBroadcastCase1(unittest.TestCase): + def setUp(self): + self.init_test() + + def init_test(self): + self.x_shape = [3, 2, 2] + self.y_shape = [1, 2, 2] + self.transpose_x = False + self.transpose_y = False + + @prog_scope() + def func(self, place): + eps = 0.005 + dtype = np.float64 + typename = "float64" + x = paddle.static.create_parameter( + dtype=typename, shape=self.x_shape, name='x') + y = paddle.static.create_parameter( + dtype=typename, shape=self.y_shape, name='y') + out = paddle.matmul( + x, y, self.transpose_x, self.transpose_y, name='out') + np.random.seed(2021) + x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype) + gradient_checker.triple_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestMatmulTripleGradCheckBroadcastCase2(unittest.TestCase): + def setUp(self): + self.init_test() + + def init_test(self): + self.x_shape = [1, 2, 2] + self.y_shape = [3, 2, 2] + self.transpose_x = False + self.transpose_y = False + + @prog_scope() + def func(self, place): + eps = 0.005 + dtype = np.float64 + typename = "float64" + x = paddle.static.create_parameter( + dtype=typename, shape=self.x_shape, name='x') + y = paddle.static.create_parameter( + dtype=typename, shape=self.y_shape, name='y') + out = paddle.matmul( + x, y, self.transpose_x, self.transpose_y, name='out') + np.random.seed(2021) + x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype) + gradient_checker.triple_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestMatmulTripleGradCheckBroadcastCase3(unittest.TestCase): + def setUp(self): + self.init_test() + + def init_test(self): + self.x_shape = [1, 2, 2] + self.y_shape = [3, 2, 2] + self.transpose_x = True + self.transpose_y = False + + @prog_scope() + def func(self, place): + eps = 0.005 + dtype = np.float64 + typename = "float64" + x = paddle.static.create_parameter( + dtype=typename, shape=self.x_shape, name='x') + y = paddle.static.create_parameter( + dtype=typename, shape=self.y_shape, name='y') + out = paddle.matmul( + x, y, self.transpose_x, self.transpose_y, name='out') + np.random.seed(2021) + x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype) + gradient_checker.triple_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestMatmulTripleGradCheckBroadcastCase4(unittest.TestCase): + def setUp(self): + self.init_test() + + def init_test(self): + self.x_shape = [1, 2, 2] + self.y_shape = [3, 2, 2] + self.transpose_x = False + self.transpose_y = True + + @prog_scope() + def func(self, place): + eps = 0.005 + dtype = np.float64 + typename = "float64" + x = paddle.static.create_parameter( + dtype=typename, shape=self.x_shape, name='x') + y = paddle.static.create_parameter( + dtype=typename, shape=self.y_shape, name='y') + out = paddle.matmul( + x, y, self.transpose_x, self.transpose_y, name='out') + np.random.seed(2021) + x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype) + gradient_checker.triple_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestMatmulTripleGradCheckBroadcastCase5(unittest.TestCase): + def setUp(self): + self.init_test() + + def init_test(self): + self.x_shape = [1, 2, 2] + self.y_shape = [3, 2, 2] + self.transpose_x = True + self.transpose_y = True + + @prog_scope() + def func(self, place): + eps = 0.005 + dtype = np.float64 + typename = "float64" + x = paddle.static.create_parameter( + dtype=typename, shape=self.x_shape, name='x') + y = paddle.static.create_parameter( + dtype=typename, shape=self.y_shape, name='y') + out = paddle.matmul( + x, y, self.transpose_x, self.transpose_y, name='out') + np.random.seed(2021) + x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype) + gradient_checker.triple_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestMatmulTripleGradCheckSpecialCase1(unittest.TestCase): + def setUp(self): + self.init_test() + + def init_test(self): + self.x_shape = [3, 4, 5] + self.y_shape = [5] + self.transpose_x = False + self.transpose_y = False + + @prog_scope() + def func(self, place): + eps = 0.005 + dtype = np.float64 + typename = "float64" + x = paddle.static.create_parameter( + dtype=typename, shape=self.x_shape, name='x') + y = paddle.static.create_parameter( + dtype=typename, shape=self.y_shape, name='y') + out = paddle.matmul( + x, y, self.transpose_x, self.transpose_y, name='out') + np.random.seed(2021) + x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype) + gradient_checker.triple_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestMatmulTripleGradCheckSpecialCase2(unittest.TestCase): + def setUp(self): + self.init_test() + + def init_test(self): + self.x_shape = [4, 5, 5] + self.y_shape = [5] + self.transpose_x = True + self.transpose_y = False + + @prog_scope() + def func(self, place): + eps = 0.005 + dtype = np.float64 + typename = "float64" + x = paddle.static.create_parameter( + dtype=typename, shape=self.x_shape, name='x') + y = paddle.static.create_parameter( + dtype=typename, shape=self.y_shape, name='y') + out = paddle.matmul( + x, y, self.transpose_x, self.transpose_y, name='out') + np.random.seed(2021) + x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype) + gradient_checker.triple_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + if __name__ == "__main__": unittest.main()