未验证 提交 203a0e3e 编写于 作者: W Weilong Wu 提交者: GitHub

Support matmul_v2 triple grad Kernel (#36459)

* native commit for triple grad of sigmod

* Updated unittests files

* init functional jacobian api

* Updated trible_test func

* Updated gradient_checker & test_script

* finish test with dtype float32

* add float64 test case

* polish code

* use atol=1e-5 with dtype float64

* fix for ci

* set timeout for test_jacobian

* fix dygraph grad to support high differential

* polish API docstring

* Updated gradient checker and some related files

* fix double grad strip error for high differential

* fix double grad strip error for high differential

* Add Sigmoid triple grad tests

* fix dygraph double grad dtype error when calling for high differential senario

* Updated triple grad teses func

* Use np.random to initialize ddx

* Updated triple_grad_check func

* add todo for gradient checker and refine some comments

* remove additional code

* add test for warnging in backward.py

* format python code

* support multi input in triple gradient checker

* Add matmul triple grad kernel

* Updated comments of TODO

* Supported some special tests

* Change code-format to follow CI std

* Updated gradient_checker.py

* Fix conflicts

* Removed unnecessary printing log

* Change code style to follow CI std
Co-authored-by: Nlevi131 <limaolin01@baidu.com>
Co-authored-by: NJiabin Yang <360788950@qq.com>
上级 b9fdd3bc
......@@ -347,6 +347,76 @@ class MatMulV2OpDoubleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs());
}
};
class MatMulV2OpTripleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X",
"matmul_v2_triple_grad");
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y",
"matmul_v2_triple_grad");
OP_INOUT_CHECK(context->HasInput("DOut"), "Input", "DOut",
"matmul_v2_triple_grad");
OP_INOUT_CHECK(context->HasInput("DDX"), "Input", "DDX",
"matmul_v2_triple_grad");
OP_INOUT_CHECK(context->HasInput("DDY"), "Input", "DDY",
"matmul_v2_triple_grad");
OP_INOUT_CHECK(context->HasInput("D_DX"), "Input", "D_DX",
"matmul_v2_triple_grad");
OP_INOUT_CHECK(context->HasInput("D_DY"), "Input", "D_DY",
"matmul_v2_triple_grad");
OP_INOUT_CHECK(context->HasInput("D_DDOut"), "Input", "D_DDOut",
"matmul_v2_triple_grad");
if (context->HasOutput("D_X_out")) {
context->ShareDim("X", "D_X_out");
}
if (context->HasOutput("D_Y_out")) {
context->ShareDim("Y", "D_Y_out");
}
if (context->HasOutput("D_DOut_out")) {
context->ShareDim("DOut", "D_DOut_out");
}
if (context->HasOutput("D_DDX_out")) {
context->ShareDim("X", "D_DDX_out");
}
if (context->HasOutput("D_DDY_out")) {
context->ShareDim("Y", "D_DDY_out");
}
}
};
template <typename T>
class MatMulV2OpTripleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("matmul_v2_triple_grad");
// get input from double grad
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput("DOut", this->Input("DOut"));
op->SetInput("DDX", this->Input("DDX"));
op->SetInput("DDY", this->Input("DDY"));
op->SetInput("D_DX", this->OutputGrad("DX"));
op->SetInput("D_DY", this->OutputGrad("DY"));
op->SetInput("D_DDOut", this->OutputGrad("DDOut"));
// set outputs
op->SetOutput("D_X_out", this->InputGrad("X"));
op->SetOutput("D_Y_out", this->InputGrad("Y"));
op->SetOutput("D_DOut_out", this->InputGrad("DOut"));
op->SetOutput("D_DDX_out", this->InputGrad("DDX"));
op->SetOutput("D_DDY_out", this->InputGrad("DDY"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
......@@ -359,7 +429,11 @@ REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad,
ops::MatMulV2OpDoubleGradMaker<paddle::framework::OpDesc>,
ops::MatMulV2OpDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(matmul_v2_grad_grad, ops::MatMulV2OpDoubleGrad);
REGISTER_OPERATOR(matmul_v2_grad_grad, ops::MatMulV2OpDoubleGrad,
ops::MatMulV2OpTripleGradMaker<paddle::framework::OpDesc>,
ops::MatMulV2OpTripleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(matmul_v2_triple_grad, ops::MatMulV2OpTripleGrad);
REGISTER_OP_CPU_KERNEL(
matmul_v2, ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, float>,
......@@ -385,3 +459,12 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex<float>>,
ops::MatMulV2DoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
matmul_v2_triple_grad,
ops::MatMulV2TripleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulV2TripleGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::MatMulV2TripleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::MatMulV2TripleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
......@@ -40,3 +40,13 @@ REGISTER_OP_CUDA_KERNEL(
paddle::platform::complex<float>>,
ops::MatMulV2DoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
matmul_v2_triple_grad,
ops::MatMulV2TripleGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MatMulV2TripleGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::MatMulV2TripleGradKernel<plf::CUDADeviceContext, plf::float16>,
ops::MatMulV2TripleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::MatMulV2TripleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
......@@ -779,6 +779,421 @@ struct DotDoubleGradFunction<DeviceContext, T, math::DisableComplex<T>> {
}
};
template <typename DeviceContext, typename T, typename Enabel = void>
struct DotTripleGradFunction {
void operator()(const Tensor* in_tensor_x, const Tensor* in_tensor_y,
const Tensor* in_tensor_ddx, const Tensor* in_tensor_ddy,
const Tensor* in_tensor_d_dx, const Tensor* in_tensor_d_dy,
const Tensor* in_tensor_dout, const Tensor* in_tensor_d_ddout,
Tensor* out_tensor_d_x, Tensor* out_tensor_d_y,
Tensor* out_tensor_d_dout, Tensor* out_tensor_d_ddx,
Tensor* out_tensor_d_ddy,
const paddle::framework::ExecutionContext& ctx);
};
// TODO(wuweilong): enable this function when the unittests framewark for multi
// grad is ok (dtype: complex64 or complex128).
template <typename DeviceContext, typename T>
struct DotTripleGradFunction<DeviceContext, T, math::EnableComplex<T>> {
void operator()(const Tensor* in_tensor_x, const Tensor* in_tensor_y,
const Tensor* in_tensor_ddx, const Tensor* in_tensor_ddy,
const Tensor* in_tensor_d_dx, const Tensor* in_tensor_d_dy,
const Tensor* in_tensor_dout, const Tensor* in_tensor_d_ddout,
Tensor* out_tensor_d_x, Tensor* out_tensor_d_y,
Tensor* out_tensor_d_dout, Tensor* out_tensor_d_ddx,
Tensor* out_tensor_d_ddy,
const paddle::framework::ExecutionContext& ctx) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == in_tensor_d_ddout->dims().size()) {
framework::Tensor in_tensor_d_ddout_help;
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
if (out_tensor_d_x || out_tensor_d_y) {
in_tensor_d_ddout_help.Resize(in_tensor_d_ddout->dims());
in_tensor_d_ddout_help.mutable_data<T>(ctx.GetPlace());
paddle::platform::ForRange<DeviceContext> for_range(
dev_raw, in_tensor_d_ddout->numel());
math::ConjFunctor<T> functor(in_tensor_d_ddout->data<T>(),
in_tensor_d_ddout->numel(),
in_tensor_d_ddout_help.data<T>());
for_range(functor);
}
if (out_tensor_d_x) {
auto ddy = framework::EigenVector<T>::Flatten(*in_tensor_ddy);
Eigen::DSizes<int, 1> size(in_tensor_ddy->numel());
auto d_x = framework::EigenVector<T>::Flatten(*out_tensor_d_x);
auto d_ddout =
framework::EigenVector<T>::Flatten(in_tensor_d_ddout_help);
d_x.device(dev) = ddy * d_ddout.broadcast(size);
}
if (out_tensor_d_y) {
auto ddx = framework::EigenVector<T>::Flatten(*in_tensor_ddx);
Eigen::DSizes<int, 1> size(in_tensor_ddx->numel());
auto d_y = framework::EigenVector<T>::Flatten(*out_tensor_d_y);
auto d_ddout =
framework::EigenVector<T>::Flatten(in_tensor_d_ddout_help);
d_y.device(dev) = ddx * d_ddout.broadcast(size);
}
if (out_tensor_d_dout) {
framework::Tensor in_tensor_ddx_help, in_tensor_ddy_help;
in_tensor_ddx_help.Resize(in_tensor_ddx->dims());
in_tensor_ddx_help.mutable_data<T>(ctx.GetPlace());
in_tensor_ddy_help.Resize(in_tensor_ddy->dims());
in_tensor_ddy_help.mutable_data<T>(ctx.GetPlace());
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
paddle::platform::ForRange<DeviceContext> for_range(
dev_raw, in_tensor_ddx->numel());
math::ConjFunctor<T> functor_ddx(in_tensor_ddx->data<T>(),
in_tensor_ddx->numel(),
in_tensor_ddx_help.data<T>());
for_range(functor_ddx);
math::ConjFunctor<T> functor_ddy(in_tensor_ddy->data<T>(),
in_tensor_ddy->numel(),
in_tensor_ddy_help.data<T>());
for_range(functor_ddy);
auto ddx = framework::EigenVector<T>::Flatten(in_tensor_ddx_help);
auto ddy = framework::EigenVector<T>::Flatten(in_tensor_ddy_help);
auto d_dx = framework::EigenVector<T>::Flatten(*in_tensor_d_dx);
auto d_dy = framework::EigenVector<T>::Flatten(*in_tensor_d_dy);
auto d_dout = framework::EigenVector<T>::Flatten(*out_tensor_d_dout);
d_dout.device(dev) = (ddx * d_dy + ddy * d_dx).sum();
}
if (out_tensor_d_ddx) {
framework::Tensor in_tensor_dout_help, in_tensor_y_help;
in_tensor_dout_help.Resize(in_tensor_dout->dims());
in_tensor_dout_help.mutable_data<T>(ctx.GetPlace());
in_tensor_y_help.Resize(in_tensor_y->dims());
in_tensor_y_help.mutable_data<T>(ctx.GetPlace());
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
paddle::platform::ForRange<DeviceContext> for_range(
dev_raw, in_tensor_dout->numel());
math::ConjFunctor<T> functor_dout(in_tensor_dout->data<T>(),
in_tensor_dout->numel(),
in_tensor_dout_help.data<T>());
for_range(functor_dout);
math::ConjFunctor<T> functor_y(in_tensor_y->data<T>(),
in_tensor_y->numel(),
in_tensor_y_help.data<T>());
for_range(functor_y);
auto dout = framework::EigenVector<T>::Flatten(in_tensor_dout_help);
auto y = framework::EigenVector<T>::Flatten(in_tensor_y_help);
auto d_ddout = framework::EigenVector<T>::Flatten(*in_tensor_d_ddout);
auto d_dy = framework::EigenVector<T>::Flatten(*in_tensor_d_dy);
auto d_ddx = framework::EigenVector<T>::Flatten(*out_tensor_d_ddx);
Eigen::DSizes<int, 1> size(in_tensor_y->numel());
d_ddx.device(dev) =
(dout.broadcast(size) * d_dy + y * d_ddout.broadcast(size));
}
if (out_tensor_d_ddy) {
framework::Tensor in_tensor_dout_help, in_tensor_x_help;
in_tensor_dout_help.Resize(in_tensor_dout->dims());
in_tensor_dout_help.mutable_data<T>(ctx.GetPlace());
in_tensor_x_help.Resize(in_tensor_x->dims());
in_tensor_x_help.mutable_data<T>(ctx.GetPlace());
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
paddle::platform::ForRange<DeviceContext> for_range(
dev_raw, in_tensor_dout->numel());
math::ConjFunctor<T> functor_dout(in_tensor_dout->data<T>(),
in_tensor_dout->numel(),
in_tensor_dout_help.data<T>());
for_range(functor_dout);
math::ConjFunctor<T> functor_x(in_tensor_x->data<T>(),
in_tensor_x->numel(),
in_tensor_x_help.data<T>());
for_range(functor_x);
auto dout = framework::EigenVector<T>::Flatten(in_tensor_dout_help);
auto x = framework::EigenVector<T>::Flatten(in_tensor_x_help);
auto d_ddout = framework::EigenVector<T>::Flatten(*in_tensor_d_ddout);
auto d_dx = framework::EigenVector<T>::Flatten(*in_tensor_d_dx);
auto d_ddy = framework::EigenVector<T>::Flatten(*out_tensor_d_ddy);
Eigen::DSizes<int, 1> size(in_tensor_x->numel());
d_ddy.device(dev) =
(dout.broadcast(size) * d_dx + x * d_ddout.broadcast(size));
}
}
#else
const auto* data_d_ddout = in_tensor_d_ddout->data<T>();
if (out_tensor_d_x) {
auto* data_d_x = out_tensor_d_x->mutable_data<T>(ctx.GetPlace());
const auto* data_ddy = in_tensor_ddy->data<T>();
const framework::DDim& dim = out_tensor_d_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_x[i] = T(data_ddy[i].real, -data_ddy[i].imag) * data_d_ddout[s];
}
}
if (out_tensor_d_y) {
auto* data_d_y = out_tensor_d_y->mutable_data<T>(ctx.GetPlace());
const auto* data_ddx = in_tensor_ddx->data<T>();
const framework::DDim& dim = out_tensor_d_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_y[i] = T(data_ddx[i].real, -data_ddx[i].imag) * data_d_ddout[s];
}
}
if (out_tensor_d_dout) {
auto* data_d_dout = out_tensor_d_dout->mutable_data<T>(ctx.GetPlace());
auto* data_ddx = in_tensor_ddx->data<T>();
auto* data_ddy = in_tensor_ddy->data<T>();
auto* data_d_dx = in_tensor_d_dx->data<T>();
auto* data_d_dy = in_tensor_d_dy->data<T>();
const framework::DDim& dim = out_tensor_d_dout->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
bool new_s = false;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) {
++s;
new_s = true;
}
if (new_s) {
data_d_dout[s] =
T(data_ddy[i].real, -data_ddy[i].imag) * data_d_dx[i] +
T(data_ddx[i].real, -data_ddx[i].imag) * data_d_dy[i];
} else {
data_d_dout[s] +=
T(data_ddy[i].real, -data_ddy[i].imag) * data_d_dx[i] +
T(data_ddx[i].real, -data_ddx[i].imag) * data_d_dy[i];
}
new_s = false;
}
}
if (out_tensor_d_ddx) {
auto* data_d_ddx = out_tensor_d_ddx->mutable_data<T>(ctx.GetPlace());
auto* data_dout = in_tensor_dout->data<T>();
auto* data_d_dy = in_tensor_d_dy->data<T>();
auto* data_y = in_tensor_y->data<T>();
auto* data_d_ddout = in_tensor_d_ddout->data<T>();
const framework::DDim& dim = out_tensor_d_ddx->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_ddx[i] =
T(data_dout[s].real, -data_dout[s].imag) * data_d_dy[i] +
T(data_y[i].real, -data_y[i].imag) * data_d_ddout[s];
}
}
if (out_tensor_d_ddy) {
auto* data_d_ddy = out_tensor_d_ddy->mutable_data<T>(ctx.GetPlace());
auto* data_dout = in_tensor_dout->data<T>();
auto* data_d_dx = in_tensor_d_dx->data<T>();
auto* data_x = in_tensor_x->data<T>();
auto* data_d_ddout = in_tensor_d_ddout->data<T>();
const framework::DDim& dim = out_tensor_d_ddy->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_ddy[i] =
T(data_dout[s].real, -data_dout[s].imag) * data_d_dx[i] +
T(data_x[i].real, -data_x[i].imag) * data_d_ddout[s];
}
}
#endif
}
};
template <typename DeviceContext, typename T>
struct DotTripleGradFunction<DeviceContext, T, math::DisableComplex<T>> {
void operator()(const Tensor* in_tensor_x, const Tensor* in_tensor_y,
const Tensor* in_tensor_ddx, const Tensor* in_tensor_ddy,
const Tensor* in_tensor_d_dx, const Tensor* in_tensor_d_dy,
const Tensor* in_tensor_dout, const Tensor* in_tensor_d_ddout,
Tensor* out_tensor_d_x, Tensor* out_tensor_d_y,
Tensor* out_tensor_d_dout, Tensor* out_tensor_d_ddx,
Tensor* out_tensor_d_ddy,
const paddle::framework::ExecutionContext& ctx) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == in_tensor_d_ddout->dims().size()) {
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
auto d_ddout = framework::EigenVector<T>::Flatten(*in_tensor_d_ddout);
if (out_tensor_d_x) {
out_tensor_d_x->mutable_data<T>(ctx.GetPlace());
auto ddy = framework::EigenVector<T>::Flatten(*in_tensor_ddy);
Eigen::DSizes<int, 1> size(in_tensor_ddy->numel());
auto d_x = framework::EigenVector<T>::Flatten(*out_tensor_d_x);
d_x.device(dev) = ddy * d_ddout.broadcast(size);
}
if (out_tensor_d_y) {
out_tensor_d_y->mutable_data<T>(ctx.GetPlace());
auto ddx = framework::EigenVector<T>::Flatten(*in_tensor_ddx);
Eigen::DSizes<int, 1> size(in_tensor_ddx->numel());
auto d_y = framework::EigenVector<T>::Flatten(*out_tensor_d_y);
d_y.device(dev) = ddx * d_ddout.broadcast(size);
}
if (out_tensor_d_dout) {
out_tensor_d_dout->mutable_data<T>(ctx.GetPlace());
auto ddx = framework::EigenVector<T>::Flatten(*in_tensor_ddx);
auto ddy = framework::EigenVector<T>::Flatten(*in_tensor_ddy);
auto d_dx = framework::EigenVector<T>::Flatten(*in_tensor_d_dx);
auto d_dy = framework::EigenVector<T>::Flatten(*in_tensor_d_dy);
auto d_dout = framework::EigenVector<T>::Flatten(*out_tensor_d_dout);
d_dout.device(dev) = (ddx * d_dy + ddy * d_dx).sum();
}
if (out_tensor_d_ddx) {
out_tensor_d_ddx->mutable_data<T>(ctx.GetPlace());
auto dout = framework::EigenVector<T>::Flatten(*in_tensor_dout);
auto y = framework::EigenVector<T>::Flatten(*in_tensor_y);
auto d_ddout = framework::EigenVector<T>::Flatten(*in_tensor_d_ddout);
auto d_dy = framework::EigenVector<T>::Flatten(*in_tensor_d_dy);
auto d_ddx = framework::EigenVector<T>::Flatten(*out_tensor_d_ddx);
Eigen::DSizes<int, 1> size(in_tensor_y->numel());
d_ddx.device(dev) =
(dout.broadcast(size) * d_dy + y * d_ddout.broadcast(size));
}
if (out_tensor_d_ddy) {
out_tensor_d_ddy->mutable_data<T>(ctx.GetPlace());
auto dout = framework::EigenVector<T>::Flatten(*in_tensor_dout);
auto x = framework::EigenVector<T>::Flatten(*in_tensor_x);
auto d_ddout = framework::EigenVector<T>::Flatten(*in_tensor_d_ddout);
auto d_dx = framework::EigenVector<T>::Flatten(*in_tensor_d_dx);
auto d_ddy = framework::EigenVector<T>::Flatten(*out_tensor_d_ddy);
Eigen::DSizes<int, 1> size(in_tensor_x->numel());
d_ddy.device(dev) =
(dout.broadcast(size) * d_dx + x * d_ddout.broadcast(size));
}
}
#else
const auto* data_d_ddout = in_tensor_d_ddout->data<T>();
if (out_tensor_d_x) {
auto* data_d_x = out_tensor_d_x->mutable_data<T>(ctx.GetPlace());
const auto* data_ddy = in_tensor_ddy->data<T>();
const framework::DDim& dim = out_tensor_d_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_x[i] = data_ddy[i] * data_d_ddout[s];
}
}
if (out_tensor_d_y) {
auto* data_d_y = out_tensor_d_y->mutable_data<T>(ctx.GetPlace());
const auto* data_ddx = in_tensor_ddx->data<T>();
const framework::DDim& dim = out_tensor_d_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_y[i] = data_ddx[i] * data_d_ddout[s];
}
}
if (out_tensor_d_dout) {
auto* data_d_dout = out_tensor_d_dout->mutable_data<T>(ctx.GetPlace());
auto* data_ddx = in_tensor_ddx->data<T>();
auto* data_ddy = in_tensor_ddy->data<T>();
auto* data_d_dx = in_tensor_d_dx->data<T>();
auto* data_d_dy = in_tensor_d_dy->data<T>();
const framework::DDim& dim = in_tensor_ddx->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
bool new_s = false;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) {
++s;
new_s = true;
}
if (new_s) {
data_d_dout[s] =
data_ddy[i] * data_d_dx[i] + data_ddx[i] * data_d_dy[i];
} else {
data_d_dout[s] +=
data_ddy[i] * data_d_dx[i] + data_ddx[i] * data_d_dy[i];
}
new_s = false;
}
}
if (out_tensor_d_ddx) {
auto* data_d_ddx = out_tensor_d_ddx->mutable_data<T>(ctx.GetPlace());
auto* data_dout = in_tensor_dout->data<T>();
auto* data_d_dy = in_tensor_d_dy->data<T>();
auto* data_y = in_tensor_y->data<T>();
auto* data_d_ddout = in_tensor_d_ddout->data<T>();
const framework::DDim& dim = out_tensor_d_ddx->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_ddx[i] =
data_dout[s] * data_d_dy[i] + data_y[i] * data_d_ddout[s];
}
}
if (out_tensor_d_ddy) {
auto* data_d_ddy = out_tensor_d_ddy->mutable_data<T>(ctx.GetPlace());
auto* data_dout = in_tensor_dout->data<T>();
auto* data_d_dx = in_tensor_d_dx->data<T>();
auto* data_x = in_tensor_x->data<T>();
auto* data_d_ddout = in_tensor_d_ddout->data<T>();
const framework::DDim& dim = out_tensor_d_ddy->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_ddy[i] =
data_dout[s] * data_d_dx[i] + data_x[i] * data_d_ddout[s];
}
}
#endif
}
};
template <typename DeviceContext, typename T>
class MatMulV2GradKernel : public framework::OpKernel<T> {
public:
......@@ -1322,7 +1737,7 @@ class MatMulV2DoubleGradKernel : public framework::OpKernel<T> {
}
if (ddout) {
// Caluate the gradient of OutputGrad(Out)
// Calculate the gradient of OutputGrad(Out)
MatMulFunction<DeviceContext, T>(ddx, &y_conj, x_dims, y_dims, ddout,
transpose_x, transpose_y, context);
MatMulFunction<DeviceContext, T>(&x_conj, ddy, x_dims, y_dims, ddout,
......@@ -1332,5 +1747,609 @@ class MatMulV2DoubleGradKernel : public framework::OpKernel<T> {
}
}
};
template <typename DeviceContext, typename T>
class MatMulV2TripleGradKernel : public framework::OpKernel<T> {
public:
void MatMul(const framework::ExecutionContext& context,
const framework::Tensor& a, bool trans_a,
const framework::Tensor& b, bool trans_b, framework::Tensor* out,
bool flag) const {
out->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b);
if (a.dims().size() == 3 && b.dims().size() <= 2) {
// the transpose_X must be false, if is true, the transpose cost much time
if (!trans_a) {
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
}
}
blas.MatMul(a, mat_dim_a, b, mat_dim_b, static_cast<T>(1), out,
static_cast<T>(flag));
}
void CalcInputGrad(const framework::ExecutionContext& context,
const framework::Tensor& a, bool trans_a,
bool is_fold_init_dims_a, const framework::Tensor& b,
bool trans_b, bool is_fold_init_dims_b,
framework::Tensor* out, bool flag) const {
if (out == nullptr) return;
bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) &&
out->dims().size() == 2;
if (!need_combine) {
MatMul(context, a, trans_a, b, trans_b, out, flag);
} else {
auto& ctx = context.template device_context<DeviceContext>();
MatMul(context, is_fold_init_dims_a
? FoldInitDims(a)
: FoldHeadAndLastDims<DeviceContext, T>(ctx, a),
trans_a, is_fold_init_dims_b
? FoldInitDims(b)
: FoldHeadAndLastDims<DeviceContext, T>(ctx, b),
trans_b, out, flag);
}
}
void Compute(const framework::ExecutionContext& context) const override {
// get input
auto x = *context.Input<framework::Tensor>("X");
auto y = *context.Input<framework::Tensor>("Y");
auto dout = *context.Input<framework::Tensor>("DOut");
auto ddx = *context.Input<framework::Tensor>("DDX");
auto ddy = *context.Input<framework::Tensor>("DDY");
auto* d_dx = context.Input<framework::Tensor>("D_DX");
auto* d_dy = context.Input<framework::Tensor>("D_DY");
auto* d_ddout = context.Input<framework::Tensor>("D_DDOut");
// get output
auto* out_d_x = context.Output<framework::Tensor>("D_X_out");
auto* out_d_y = context.Output<framework::Tensor>("D_Y_out");
auto* out_d_dout = context.Output<framework::Tensor>("D_DOut_out");
auto* out_d_ddx = context.Output<framework::Tensor>("D_DDX_out");
auto* out_d_ddy = context.Output<framework::Tensor>("D_DDY_out");
bool transpose_x = context.Attr<bool>("trans_x");
bool transpose_y = context.Attr<bool>("trans_y");
// Get dims from the input x, y, output_grad
std::vector<std::int64_t> x_dims = vectorize(x.dims());
std::vector<std::int64_t> y_dims = vectorize(y.dims());
std::vector<std::int64_t> dout_dims = vectorize(dout.dims());
framework::Tensor x_conj(x.type());
framework::Tensor y_conj(y.type());
framework::Tensor dout_conj(dout.type());
framework::Tensor ddx_conj(ddx.type());
framework::Tensor ddy_conj(ddy.type());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
int ndim = dout_dims.size();
// Case1 : x's and y's dim = 1
if (x_ndim == 1 && y_ndim == 1) {
VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 1";
DotTripleGradFunction<DeviceContext, T>()(
&x, &y, &ddx, &ddy, d_dx, d_dy, &dout, d_ddout, out_d_x, out_d_y,
out_d_dout, out_d_ddx, out_d_ddy, context);
return;
}
bool is_broadcast = true;
if (x_ndim <= 2 || y_ndim <= 2) {
is_broadcast = false;
} else if (x_ndim != y_ndim) {
is_broadcast = true;
} else {
is_broadcast = !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2,
y_dims.cbegin());
}
if (!is_broadcast) {
// Case2: no broadcast or no batch size
VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 2";
ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
if (ddx.dims() != x.dims()) {
ddx.Resize(x.dims());
}
if (ddy.dims() != y.dims()) {
ddy.Resize(y.dims());
}
ConjHelper<DeviceContext, T> conj_helper(context);
framework::DDim out_dx_dims;
if (out_d_x) {
out_dx_dims = out_d_x->dims();
if (out_dx_dims != x.dims()) {
out_d_x->Resize(x.dims());
}
}
framework::DDim out_dy_dims;
if (out_d_y) {
out_dy_dims = out_d_y->dims();
if (out_dy_dims != y.dims()) {
out_d_y->Resize(y.dims());
}
}
framework::DDim out_d_dout_dims;
if (out_d_dout) {
out_d_dout_dims = out_d_dout->dims();
if (out_d_dout_dims != dout.dims()) {
out_d_dout->Resize(dout.dims());
}
}
framework::DDim out_d_ddx_dims;
if (out_d_ddx) {
out_d_ddx_dims = out_d_ddx->dims();
if (out_d_ddx_dims != x.dims()) {
out_d_ddx->Resize(x.dims());
}
}
framework::DDim out_d_ddy_dims;
if (out_d_ddy) {
out_d_ddy_dims = out_d_ddy->dims();
if (out_d_ddy_dims != y.dims()) {
out_d_ddy->Resize(y.dims());
}
}
if (out_d_dout) {
ConjHelper<DeviceContext, T> conj_helper(context);
conj_helper(ddx, ddx_conj);
conj_helper(ddy, ddy_conj);
}
if (out_d_ddx || out_d_ddy) {
ConjHelper<DeviceContext, T> conj_helper(context);
conj_helper(x, x_conj);
conj_helper(y, y_conj);
conj_helper(dout, dout_conj);
}
bool d_dout_flag = false;
bool d_ddx_flag = false;
bool d_ddy_flag = false;
if (d_ddout) {
auto d_ddout_mat = *d_ddout;
if (d_ddout_mat.dims() != dout.dims()) {
d_ddout_mat.Resize(dout.dims());
}
if (out_d_y) {
if (transpose_x && transpose_y) {
// out_d_y = d_ddout' * ddx'
CalcInputGrad(context, d_ddout_mat, true, true, ddx_conj, true,
false, out_d_y, false);
} else if (transpose_x) {
// out_d_y = ddx * d_ddout
CalcInputGrad(context, ddx_conj, false, false, d_ddout_mat, false,
true, out_d_y, false);
} else if (transpose_y) {
// out_d_y = d_ddout' * ddx
CalcInputGrad(context, d_ddout_mat, true, true, ddx_conj, false,
true, out_d_y, false);
} else {
// out_d_y = ddx' * d_ddout
CalcInputGrad(context, ddx_conj, true, true, d_ddout_mat, false,
true, out_d_y, false);
}
}
if (out_d_x) {
if (transpose_x && transpose_y) {
// out_d_x = ddy' * d_ddout'
CalcInputGrad(context, ddy_conj, true, true, d_ddout_mat, true,
false, out_d_x, false);
} else if (transpose_x) {
// out_d_x = ddy * d_ddout'
CalcInputGrad(context, ddy_conj, false, false, d_ddout_mat, true,
false, out_d_x, false);
} else if (transpose_y) {
// out_d_x = d_ddout * ddy
CalcInputGrad(context, d_ddout_mat, false, false, ddy_conj, false,
true, out_d_x, false);
} else {
// out_d_x = d_ddout * ddy'
CalcInputGrad(context, d_ddout_mat, false, false, ddy_conj, true,
false, out_d_x, false);
}
}
// equations:
// d_ddx = DOut * D_DY + Y * D_DDOut
// Let: d_ddx1 = Y * D_DDOut
// Let: d_ddx2 = DOut * D_DY
// d_ddy = DOut * D_DX + X * D_DDOut
// Let: d_ddy1 = X * D_DDOut
// Let: d_ddy2 = DOut * D_DX
// d_dout = DDY * D_DX + DDX * D_DY
// Let: d_dout1 = DDX * D_DY
// Let: d_dout2 = DDY * D_DX
// compute d_ddx1
if (out_d_ddx) {
if (transpose_x && transpose_y) {
// out_d_ddx1 = y' * d_ddout'
CalcInputGrad(context, y_conj, true, true, d_ddout_mat, true, false,
out_d_ddx, d_ddx_flag);
} else if (transpose_x) {
// out_d_ddx1 = y * d_ddout'
CalcInputGrad(context, y_conj, false, false, d_ddout_mat, true,
false, out_d_ddx, d_ddx_flag);
} else if (transpose_y) {
// out_d_ddx1 = d_ddout * y
CalcInputGrad(context, d_ddout_mat, false, false, y_conj, false,
true, out_d_ddx, d_ddx_flag);
} else {
// out_d_ddx1 = d_ddout * y'
CalcInputGrad(context, d_ddout_mat, false, false, y_conj, true,
false, out_d_ddx, d_ddx_flag);
}
d_ddx_flag = true;
}
// compute d_ddy1
if (out_d_ddy) {
if (transpose_x && transpose_y) {
// out_d_ddy1 = d_ddout' * x'
CalcInputGrad(context, d_ddout_mat, true, true, x_conj, true, false,
out_d_ddy, false);
} else if (transpose_x) {
// out_d_ddy1 = x * d_ddout
CalcInputGrad(context, x_conj, false, false, d_ddout_mat, false,
true, out_d_ddy, false);
} else if (transpose_y) {
// out_d_ddy1 = d_ddout' * x
CalcInputGrad(context, d_ddout_mat, true, true, x_conj, false, true,
out_d_ddy, false);
} else {
// out_d_ddy1 = x' * d_ddout
CalcInputGrad(context, x_conj, true, true, d_ddout_mat, false, true,
out_d_ddy, false);
}
d_ddy_flag = true;
}
}
if (d_dy) {
auto d_dy_mat = *d_dy;
if (d_dy_mat.dims() != y.dims()) {
d_dy_mat.Resize(y.dims());
}
// compute d_dout1
if (out_d_dout) {
CalcInputGrad(context, ddx_conj, transpose_x, true, d_dy_mat,
transpose_y, false, out_d_dout, d_dout_flag);
d_dout_flag = true;
}
// compute d_ddx2
if (out_d_ddx) {
if (transpose_x && transpose_y) {
// out_d_ddx2 = D_DY' * DOut'
CalcInputGrad(context, d_dy_mat, true, true, dout_conj, true, false,
out_d_ddx, d_ddx_flag);
} else if (transpose_x) {
// out_d_ddx2 = D_DY * Dout'
CalcInputGrad(context, d_dy_mat, false, false, dout_conj, true,
false, out_d_ddx, d_ddx_flag);
} else if (transpose_y) {
// out_d_ddx2 = Dout * D_DY
CalcInputGrad(context, dout_conj, false, false, d_dy_mat, false,
true, out_d_ddx, d_ddx_flag);
} else {
// out_d_ddx2 = Dout * D_DY'
CalcInputGrad(context, dout_conj, false, false, d_dy_mat, true,
false, out_d_ddx, d_ddx_flag);
}
}
}
if (d_dx) {
auto d_dx_mat = *d_dx;
if (d_dx_mat.dims() != x.dims()) {
d_dx_mat.Resize(x.dims());
}
// compute d_dout2
if (out_d_dout) {
CalcInputGrad(context, d_dx_mat, transpose_x, true, ddy_conj,
transpose_y, false, out_d_dout, d_dout_flag);
}
// compute d_ddy2
if (out_d_ddy) {
if (transpose_x && transpose_y) {
// out_d_ddy2 = dout' * d_dx'
CalcInputGrad(context, dout_conj, true, true, d_dx_mat, true, false,
out_d_ddy, d_ddy_flag);
} else if (transpose_x) {
// out_d_ddy2 = d_dx * dout
CalcInputGrad(context, d_dx_mat, false, false, dout_conj, false,
true, out_d_ddy, d_ddy_flag);
} else if (transpose_y) {
// out_d_ddy2 = dout' * d_dx
CalcInputGrad(context, dout_conj, true, true, d_dx_mat, false, true,
out_d_ddy, d_ddy_flag);
} else {
// out_d_ddy2 = d_dx' * dout
CalcInputGrad(context, d_dx_mat, true, true, dout_conj, false, true,
out_d_ddy, d_ddy_flag);
}
}
}
if (out_d_x) {
if (out_dx_dims != x.dims()) {
out_d_x->Resize(out_dx_dims);
}
}
if (out_d_y) {
if (out_dy_dims != y.dims()) {
out_d_y->Resize(out_dy_dims);
}
}
if (out_d_dout) {
if (out_d_dout_dims != dout.dims()) {
out_d_dout->Resize(out_d_dout_dims);
}
}
if (out_d_ddx) {
if (out_d_ddx_dims != x.dims()) {
out_d_ddx->Resize(out_d_ddx_dims);
}
}
if (out_d_ddy) {
if (out_d_ddy_dims != x.dims()) {
out_d_ddy->Resize(out_d_ddy_dims);
}
}
} else {
// Case3: broadcast. It need cost much time to reduce sum for the
// broadcast and wastes the memory.
// So we should avoid the case in reality.
VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 3";
VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality";
Tensor out_dx_help, out_dy_help;
Tensor out_d_ddx_help, out_d_ddy_help;
if (out_d_dout) {
ConjHelper<DeviceContext, T> conj_helper(context);
conj_helper(ddx, ddx_conj);
conj_helper(ddy, ddy_conj);
}
if (out_d_ddx || out_d_ddy) {
ConjHelper<DeviceContext, T> conj_helper(context);
conj_helper(x, x_conj);
conj_helper(y, y_conj);
conj_helper(dout, dout_conj);
}
if (transpose_x) {
if (transpose_y) {
// dX = ddY' d_ddout’, dY = d_ddout’ ddX'
if (out_d_x)
MatMulFunction<DeviceContext, T>(&ddy_conj, d_ddout, y_dims,
dout_dims, &out_dx_help, true,
true, context);
if (out_d_y)
MatMulFunction<DeviceContext, T>(d_ddout, &ddx_conj, dout_dims,
x_dims, &out_dy_help, true, true,
context);
} else {
// dX = ddY d_ddout', dY = ddX d_ddout
if (out_d_x)
MatMulFunction<DeviceContext, T>(&ddy_conj, d_ddout, y_dims,
dout_dims, &out_dx_help, false,
true, context);
if (out_d_y)
MatMulFunction<DeviceContext, T>(&ddx_conj, d_ddout, x_dims,
dout_dims, &out_dy_help, false,
false, context);
}
} else {
if (transpose_y) {
// dX = d_ddout ddY, dY = d_ddout’ ddX
if (out_d_x)
MatMulFunction<DeviceContext, T>(d_ddout, &ddy_conj, dout_dims,
y_dims, &out_dx_help, false, false,
context);
if (out_d_y)
MatMulFunction<DeviceContext, T>(d_ddout, &ddx_conj, dout_dims,
x_dims, &out_dy_help, true, false,
context);
} else {
// dX = d_ddout ddY', dY = ddX' d_ddout
if (out_d_x)
MatMulFunction<DeviceContext, T>(d_ddout, &ddy_conj, dout_dims,
y_dims, &out_dx_help, false, true,
context);
if (out_d_y)
MatMulFunction<DeviceContext, T>(&ddx_conj, d_ddout, x_dims,
dout_dims, &out_dy_help, true,
false, context);
}
}
// get help dims
const std::vector<std::int64_t> dx_help_dims =
vectorize(out_dx_help.dims());
const std::vector<std::int64_t> dy_help_dims =
vectorize(out_dx_help.dims());
std::vector<std::int64_t> dx_broadcast_dims(ndim);
std::vector<std::int64_t> dy_broadcast_dims(ndim);
std::fill(dx_broadcast_dims.data(),
dx_broadcast_dims.data() + ndim - x_ndim, 1);
std::fill(dy_broadcast_dims.data(),
dy_broadcast_dims.data() + ndim - y_ndim, 1);
std::copy(x_dims.data(), x_dims.data() + x_ndim,
dx_broadcast_dims.data() + ndim - x_ndim);
std::copy(y_dims.data(), y_dims.data() + y_ndim,
dy_broadcast_dims.data() + ndim - y_ndim);
std::vector<int> dx_reduce_dims;
std::vector<int> dy_reduce_dims;
for (int idx = 0; idx <= ndim - 3; idx++) {
if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) {
dx_reduce_dims.push_back(idx);
}
if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) {
dy_reduce_dims.push_back(idx);
}
}
// Reduce sum to get grad by ReduceSum
if (out_d_x) {
if (dx_reduce_dims.empty()) {
*out_d_x = std::move(out_dx_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&out_dx_help, out_d_x,
dx_reduce_dims, context);
}
out_d_x->Resize(x.dims());
}
if (out_d_y) {
if (dy_reduce_dims.empty()) {
*out_d_y = std::move(out_dy_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&out_dy_help, out_d_y,
dy_reduce_dims, context);
}
out_d_y->Resize(y.dims());
}
// compute d_dout
if (out_d_dout) {
MatMulFunction<DeviceContext, T>(d_dx, &ddy_conj, x_dims, y_dims,
out_d_dout, transpose_x, transpose_y,
context);
MatMulFunction<DeviceContext, T>(&ddx_conj, d_dy, x_dims, y_dims,
out_d_dout, transpose_x, transpose_y,
context, true);
}
// compute d_ddx
if (out_d_ddx) {
if (transpose_x && transpose_y) {
// out_d_ddx1 = y' * d_ddout'
MatMulFunction<DeviceContext, T>(&y_conj, d_ddout, y_dims, dout_dims,
&out_d_ddx_help, true, true,
context);
// out_d_ddx2 = D_DY' * DOut'
MatMulFunction<DeviceContext, T>(d_dy, &dout_conj, y_dims, dout_dims,
&out_d_ddx_help, true, true, context,
true);
} else if (transpose_x) {
// out_d_ddx1 = y * d_ddout'
MatMulFunction<DeviceContext, T>(&y_conj, d_ddout, y_dims, dout_dims,
&out_d_ddx_help, false, true,
context);
// out_d_ddx2 = D_DY * Dout'
MatMulFunction<DeviceContext, T>(d_dy, &dout_conj, y_dims, dout_dims,
&out_d_ddx_help, false, true,
context, true);
} else if (transpose_y) {
// out_d_ddx1 = d_ddout * y
MatMulFunction<DeviceContext, T>(d_ddout, &y_conj, dout_dims, y_dims,
&out_d_ddx_help, false, false,
context);
// out_d_ddx2 = Dout * D_DY
MatMulFunction<DeviceContext, T>(&dout_conj, d_dy, dout_dims, y_dims,
&out_d_ddx_help, false, false,
context, true);
} else {
// out_d_ddx1 = d_ddout * y'
MatMulFunction<DeviceContext, T>(d_ddout, &y_conj, dout_dims, y_dims,
&out_d_ddx_help, false, true,
context);
// out_d_ddx2 = Dout * D_DY'
MatMulFunction<DeviceContext, T>(&dout_conj, d_dy, dout_dims, y_dims,
&out_d_ddx_help, false, true,
context, true);
}
if (dx_reduce_dims.empty()) {
*out_d_ddx = std::move(out_d_ddx_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&out_d_ddx_help, out_d_ddx,
dx_reduce_dims, context);
}
out_d_ddx->Resize(x.dims());
}
// compute d_ddy
if (out_d_ddy) {
if (transpose_x && transpose_y) {
// out_d_ddy1 = d_ddout' * x'
MatMulFunction<DeviceContext, T>(d_ddout, &x_conj, dout_dims, x_dims,
&out_d_ddy_help, true, true,
context);
// out_d_ddy2 = dout' * d_dx'
MatMulFunction<DeviceContext, T>(&dout_conj, d_dx, dout_dims, x_dims,
&out_d_ddy_help, true, true, context,
true);
} else if (transpose_x) {
// out_d_ddy1 = x * d_ddout
MatMulFunction<DeviceContext, T>(&x_conj, d_ddout, x_dims, dout_dims,
&out_d_ddy_help, false, false,
context);
// out_d_ddy2 = d_dx * dout
MatMulFunction<DeviceContext, T>(d_dx, &dout_conj, x_dims, dout_dims,
&out_d_ddy_help, false, false,
context, true);
} else if (transpose_y) {
// out_d_ddy1 = d_ddout' * x
MatMulFunction<DeviceContext, T>(d_ddout, &x_conj, dout_dims, x_dims,
&out_d_ddy_help, true, false,
context);
// out_d_ddy2 = dout' * d_dx
MatMulFunction<DeviceContext, T>(&dout_conj, d_dx, dout_dims, x_dims,
&out_d_ddy_help, true, false,
context, true);
} else {
// out_d_ddy1 = x' * d_ddout
MatMulFunction<DeviceContext, T>(&x_conj, d_ddout, x_dims, dout_dims,
&out_d_ddy_help, true, false,
context);
// out_d_ddy2 = d_dx' * dout
MatMulFunction<DeviceContext, T>(d_dx, &dout_conj, x_dims, dout_dims,
&out_d_ddy_help, true, false,
context, true);
}
if (dy_reduce_dims.empty()) {
*out_d_ddy = std::move(out_d_ddy_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&out_d_ddy_help, out_d_ddy,
dy_reduce_dims, context);
}
out_d_ddy->Resize(y.dims());
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -304,7 +304,6 @@ def grad_check(x,
if b.has_var(xi.name):
clone_x.append(b.var(xi.name))
break
analytical.append(
_compute_analytical_jacobian(prog, clone_x, clone_y, place, scope))
......@@ -486,7 +485,6 @@ def triple_grad_check(x,
var_to_np_array_in_scope(scope, place, v.name)
for v in x_grads_grads
]
x += y_grads
x_init = _as_list(x_init)
x_init += y_grads_init
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -146,5 +146,427 @@ class TestMatmulDoubleGradCheckCase3(unittest.TestCase):
self.func(p)
class TestMatmulTripleGradCheckDotCase(unittest.TestCase):
def setUp(self):
self.init_test()
def init_test(self):
self.x_shape = [2]
self.y_shape = [2]
self.transpose_x = False
self.transpose_y = False
@prog_scope()
def func(self, place):
eps = 0.005
dtype = np.float64
typename = "float64"
x = paddle.static.create_parameter(
dtype=typename, shape=self.x_shape, name='x')
y = paddle.static.create_parameter(
dtype=typename, shape=self.y_shape, name='y')
out = paddle.matmul(x, y, self.transpose_x, self.transpose_y, name='out')
np.random.seed(2021)
x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype)
gradient_checker.triple_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestMatmulTripleGradCheckNormalCase1(unittest.TestCase):
def setUp(self):
self.init_test()
def init_test(self):
self.x_shape = [2, 2]
self.y_shape = [2, 2]
self.transpose_x = False
self.transpose_y = False
@prog_scope()
def func(self, place):
eps = 0.005
dtype = np.float64
typename = "float64"
x = paddle.static.create_parameter(
dtype=typename, shape=self.x_shape, name='x')
y = paddle.static.create_parameter(
dtype=typename, shape=self.y_shape, name='y')
out = paddle.matmul(
x, y, self.transpose_x, self.transpose_y, name='out')
np.random.seed(2021)
x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype)
gradient_checker.triple_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestMatmulTripleGradCheckNormalCase2(unittest.TestCase):
def setUp(self):
self.init_test()
def init_test(self):
self.x_shape = [2, 2]
self.y_shape = [2, 2]
self.transpose_x = True
self.transpose_y = False
@prog_scope()
def func(self, place):
eps = 0.005
dtype = np.float64
typename = "float64"
x = paddle.static.create_parameter(
dtype=typename, shape=self.x_shape, name='x')
y = paddle.static.create_parameter(
dtype=typename, shape=self.y_shape, name='y')
out = paddle.matmul(
x, y, self.transpose_x, self.transpose_y, name='out')
np.random.seed(2021)
x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype)
gradient_checker.triple_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestMatmulTripleGradCheckNormalCase3(unittest.TestCase):
def setUp(self):
self.init_test()
def init_test(self):
self.x_shape = [2, 2]
self.y_shape = [2, 2]
self.transpose_x = False
self.transpose_y = True
@prog_scope()
def func(self, place):
eps = 0.005
dtype = np.float64
typename = "float64"
x = paddle.static.create_parameter(
dtype=typename, shape=self.x_shape, name='x')
y = paddle.static.create_parameter(
dtype=typename, shape=self.y_shape, name='y')
out = paddle.matmul(
x, y, self.transpose_x, self.transpose_y, name='out')
np.random.seed(2021)
x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype)
gradient_checker.triple_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestMatmulTripleGradCheckNormalCase4(unittest.TestCase):
def setUp(self):
self.init_test()
def init_test(self):
self.x_shape = [2, 2]
self.y_shape = [2, 2]
self.transpose_x = True
self.transpose_y = True
@prog_scope()
def func(self, place):
eps = 0.005
dtype = np.float64
typename = "float64"
x = paddle.static.create_parameter(
dtype=typename, shape=self.x_shape, name='x')
y = paddle.static.create_parameter(
dtype=typename, shape=self.y_shape, name='y')
out = paddle.matmul(
x, y, self.transpose_x, self.transpose_y, name='out')
np.random.seed(2021)
x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype)
gradient_checker.triple_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestMatmulTripleGradCheckBroadcastCase1(unittest.TestCase):
def setUp(self):
self.init_test()
def init_test(self):
self.x_shape = [3, 2, 2]
self.y_shape = [1, 2, 2]
self.transpose_x = False
self.transpose_y = False
@prog_scope()
def func(self, place):
eps = 0.005
dtype = np.float64
typename = "float64"
x = paddle.static.create_parameter(
dtype=typename, shape=self.x_shape, name='x')
y = paddle.static.create_parameter(
dtype=typename, shape=self.y_shape, name='y')
out = paddle.matmul(
x, y, self.transpose_x, self.transpose_y, name='out')
np.random.seed(2021)
x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype)
gradient_checker.triple_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestMatmulTripleGradCheckBroadcastCase2(unittest.TestCase):
def setUp(self):
self.init_test()
def init_test(self):
self.x_shape = [1, 2, 2]
self.y_shape = [3, 2, 2]
self.transpose_x = False
self.transpose_y = False
@prog_scope()
def func(self, place):
eps = 0.005
dtype = np.float64
typename = "float64"
x = paddle.static.create_parameter(
dtype=typename, shape=self.x_shape, name='x')
y = paddle.static.create_parameter(
dtype=typename, shape=self.y_shape, name='y')
out = paddle.matmul(
x, y, self.transpose_x, self.transpose_y, name='out')
np.random.seed(2021)
x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype)
gradient_checker.triple_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestMatmulTripleGradCheckBroadcastCase3(unittest.TestCase):
def setUp(self):
self.init_test()
def init_test(self):
self.x_shape = [1, 2, 2]
self.y_shape = [3, 2, 2]
self.transpose_x = True
self.transpose_y = False
@prog_scope()
def func(self, place):
eps = 0.005
dtype = np.float64
typename = "float64"
x = paddle.static.create_parameter(
dtype=typename, shape=self.x_shape, name='x')
y = paddle.static.create_parameter(
dtype=typename, shape=self.y_shape, name='y')
out = paddle.matmul(
x, y, self.transpose_x, self.transpose_y, name='out')
np.random.seed(2021)
x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype)
gradient_checker.triple_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestMatmulTripleGradCheckBroadcastCase4(unittest.TestCase):
def setUp(self):
self.init_test()
def init_test(self):
self.x_shape = [1, 2, 2]
self.y_shape = [3, 2, 2]
self.transpose_x = False
self.transpose_y = True
@prog_scope()
def func(self, place):
eps = 0.005
dtype = np.float64
typename = "float64"
x = paddle.static.create_parameter(
dtype=typename, shape=self.x_shape, name='x')
y = paddle.static.create_parameter(
dtype=typename, shape=self.y_shape, name='y')
out = paddle.matmul(
x, y, self.transpose_x, self.transpose_y, name='out')
np.random.seed(2021)
x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype)
gradient_checker.triple_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestMatmulTripleGradCheckBroadcastCase5(unittest.TestCase):
def setUp(self):
self.init_test()
def init_test(self):
self.x_shape = [1, 2, 2]
self.y_shape = [3, 2, 2]
self.transpose_x = True
self.transpose_y = True
@prog_scope()
def func(self, place):
eps = 0.005
dtype = np.float64
typename = "float64"
x = paddle.static.create_parameter(
dtype=typename, shape=self.x_shape, name='x')
y = paddle.static.create_parameter(
dtype=typename, shape=self.y_shape, name='y')
out = paddle.matmul(
x, y, self.transpose_x, self.transpose_y, name='out')
np.random.seed(2021)
x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype)
gradient_checker.triple_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestMatmulTripleGradCheckSpecialCase1(unittest.TestCase):
def setUp(self):
self.init_test()
def init_test(self):
self.x_shape = [3, 4, 5]
self.y_shape = [5]
self.transpose_x = False
self.transpose_y = False
@prog_scope()
def func(self, place):
eps = 0.005
dtype = np.float64
typename = "float64"
x = paddle.static.create_parameter(
dtype=typename, shape=self.x_shape, name='x')
y = paddle.static.create_parameter(
dtype=typename, shape=self.y_shape, name='y')
out = paddle.matmul(
x, y, self.transpose_x, self.transpose_y, name='out')
np.random.seed(2021)
x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype)
gradient_checker.triple_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestMatmulTripleGradCheckSpecialCase2(unittest.TestCase):
def setUp(self):
self.init_test()
def init_test(self):
self.x_shape = [4, 5, 5]
self.y_shape = [5]
self.transpose_x = True
self.transpose_y = False
@prog_scope()
def func(self, place):
eps = 0.005
dtype = np.float64
typename = "float64"
x = paddle.static.create_parameter(
dtype=typename, shape=self.x_shape, name='x')
y = paddle.static.create_parameter(
dtype=typename, shape=self.y_shape, name='y')
out = paddle.matmul(
x, y, self.transpose_x, self.transpose_y, name='out')
np.random.seed(2021)
x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype)
gradient_checker.triple_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册