未验证 提交 b3787d1b 编写于 作者: W wawltor 提交者: GitHub

add the matmul v2 grad kernel

* add the matmul v2 grad kernel

* relief the test case time

* update the test case for the matmul double grad

* remove the unsed code for the matmul double grad

* update the test case for the double grad matmul

* remove the unused code in dot
上级 c727ec4a
......@@ -228,6 +228,59 @@ class MatMulV2GradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
class MatMulV2OpDoubleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "matmul");
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul");
OP_INOUT_CHECK(context->HasInput("DOut"), "Input", "DOut", "matmul");
if (context->HasOutput("DX") && context->HasInput("DDY")) {
context->ShareDim("X", "DX");
}
if (context->HasOutput("DY") && context->HasInput("DDX")) {
context->ShareDim("Y", "DY");
}
if (context->HasOutput("DDOut") &&
(context->HasInput("DDY") || context->HasInput("DDX"))) {
context->ShareDim("DOut", "DDOut");
}
}
};
template <typename T>
class MatMulV2OpDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("matmul_v2_grad_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetInput("DDY", this->OutputGrad(framework::GradVarName("Y")));
auto ddx = this->OutputGrad(framework::GradVarName("X"));
auto ddy = this->OutputGrad(framework::GradVarName("Y"));
if (!ddx.empty() || !ddy.empty()) {
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
}
op->SetOutput("DX",
ddy.empty() ? this->EmptyInputGrad() : this->InputGrad("X"));
op->SetOutput("DY",
ddx.empty() ? this->EmptyInputGrad() : this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
......@@ -236,7 +289,11 @@ REGISTER_OPERATOR(matmul_v2, ops::MatMulV2Op, ops::MatMulV2OpMaker,
ops::MatMulV2GradOpMaker<paddle::framework::OpDesc>,
ops::MatMulV2GradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad);
REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad,
ops::MatMulV2OpDoubleGradMaker<paddle::framework::OpDesc>,
ops::MatMulV2OpDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(matmul_v2_grad_grad, ops::MatMulV2OpDoubleGrad);
REGISTER_OP_CPU_KERNEL(
matmul_v2, ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, float>,
......@@ -254,3 +311,11 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex<float>>,
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
matmul_v2_grad_grad,
ops::MatMulV2DoubleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulV2DoubleGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::MatMulV2DoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::MatMulV2DoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
......@@ -30,3 +30,13 @@ REGISTER_OP_CUDA_KERNEL(
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::float16>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::complex<float>>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
matmul_v2_grad_grad,
ops::MatMulV2DoubleGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MatMulV2DoubleGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::MatMulV2DoubleGradKernel<plf::CUDADeviceContext, plf::float16>,
ops::MatMulV2DoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::MatMulV2DoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
......@@ -117,11 +117,12 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
const std::vector<std::int64_t>& x_dims,
const std::vector<std::int64_t>& y_dims, Tensor* Out,
bool trans_x, bool trans_y,
const paddle::framework::ExecutionContext& ctx) {
const paddle::framework::ExecutionContext& ctx,
bool flag = false) {
const int x_ndim = x_dims.size();
const int y_ndim = y_dims.size();
// get data ptr
// Get data ptr
const T* x_data = X->data<T>();
const T* y_data = Y->data<T>();
......@@ -141,7 +142,11 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
auto y_eigen = framework::EigenVector<T>::Flatten(*Y);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
out_eigen.device(dev) = (x_eigen * y_eigen).sum();
if (flag) {
out_eigen.device(dev) = (x_eigen * y_eigen).sum() + out_eigen;
} else {
out_eigen.device(dev) = (x_eigen * y_eigen).sum();
}
return;
}
......@@ -178,18 +183,18 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
const int M = Y->numel() / N;
VLOG(3) << "MatMul's case 2";
blas.GEMV(false, M, N, static_cast<T>(1), y_data, x_data,
static_cast<T>(0), Out->data<T>());
static_cast<T>(flag), Out->data<T>());
} else {
const int M = y_dims[y_ndim - 1];
const int batch_size = Y->numel() / (M * N);
if (batch_size == 1) {
VLOG(3) << "MatMul's case 3";
blas.GEMV(true, N, M, static_cast<T>(1), y_data, x_data,
static_cast<T>(0), Out->data<T>());
static_cast<T>(flag), Out->data<T>());
} else {
VLOG(3) << "MatMul's case 4";
blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, static_cast<T>(1),
y_data, x_data, static_cast<T>(0), Out->data<T>(),
y_data, x_data, static_cast<T>(flag), Out->data<T>(),
batch_size, M * N, 0);
}
}
......@@ -229,18 +234,18 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
if (batch_size == 1) {
VLOG(3) << "MatMul's case 5";
blas.GEMV(true, N, M, static_cast<T>(1), x_data, y_data,
static_cast<T>(0), Out->data<T>());
static_cast<T>(flag), Out->data<T>());
} else {
VLOG(3) << "MatMul's case 6";
blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, static_cast<T>(1),
x_data, y_data, static_cast<T>(0), Out->data<T>(),
x_data, y_data, static_cast<T>(flag), Out->data<T>(),
batch_size, M * N, 0);
}
} else {
const int M = X->numel() / N;
VLOG(3) << "MatMul's case 7";
blas.GEMV(false, M, N, static_cast<T>(1), x_data, y_data,
static_cast<T>(0), Out->data<T>());
static_cast<T>(flag), Out->data<T>());
}
return;
}
......@@ -298,17 +303,17 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
VLOG(3) << "MatMul's case 8";
blas.GEMM(trans_x ? CblasTrans : CblasNoTrans,
trans_y ? CblasTrans : CblasNoTrans, M, N, K, static_cast<T>(1),
x_data, y_data, static_cast<T>(0), Out->data<T>());
x_data, y_data, static_cast<T>(flag), Out->data<T>());
} else if (x_batch_size == 1) {
if (M == 1 && trans_y) {
VLOG(3) << "MatMul's case 9";
blas.GEMV(false, y_batch_size * N, K, static_cast<T>(1), y_data, x_data,
static_cast<T>(0), Out->data<T>());
static_cast<T>(flag), Out->data<T>());
} else {
VLOG(3) << "MatMul's case 10";
blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans,
trans_y ? CblasTrans : CblasNoTrans, M, N, K,
static_cast<T>(1), x_data, y_data, static_cast<T>(0),
static_cast<T>(1), x_data, y_data, static_cast<T>(flag),
Out->data<T>(), out_batch_size, 0, K * N);
}
} else if (y_batch_size == 1) {
......@@ -316,18 +321,18 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
VLOG(3) << "MatMul's case 11";
blas.GEMM(CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans,
x_batch_size * M, N, K, static_cast<T>(1), x_data, y_data,
static_cast<T>(0), Out->data<T>());
static_cast<T>(flag), Out->data<T>());
} else {
VLOG(3) << "MatMul's case 12";
blas.BatchedGEMM(CblasTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K,
static_cast<T>(1), x_data, y_data, static_cast<T>(0),
static_cast<T>(1), x_data, y_data, static_cast<T>(flag),
Out->data<T>(), out_batch_size, M * K, 0);
}
} else if (!is_broadcast_dims) {
VLOG(3) << "MatMul's case 13";
blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans,
trans_y ? CblasTrans : CblasNoTrans, M, N, K,
static_cast<T>(1), x_data, y_data, static_cast<T>(0),
static_cast<T>(1), x_data, y_data, static_cast<T>(flag),
Out->data<T>(), out_batch_size, M * K, K * N);
} else {
// in the case, can't use stridedgemm
......@@ -351,18 +356,19 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans,
trans_y ? CblasTrans : CblasNoTrans, M, N, K,
static_cast<T>(1), x_ptr.data(), y_ptr.data(),
static_cast<T>(0), out_ptr.data(), out_batch_size);
static_cast<T>(flag), out_ptr.data(), out_batch_size);
}
}
template <typename DeviceContext, typename T>
void MatMulFunction(const Tensor* X, const Tensor* Y, Tensor* Out, bool trans_x,
bool trans_y,
const paddle::framework::ExecutionContext& ctx) {
const paddle::framework::ExecutionContext& ctx,
bool flag = false) {
const std::vector<std::int64_t> x_dims = vectorize(X->dims());
const std::vector<std::int64_t> y_dims = vectorize(Y->dims());
MatMulFunction<DeviceContext, T>(X, Y, x_dims, y_dims, Out, trans_x, trans_y,
ctx);
ctx, flag);
}
template <typename DeviceContext, typename T>
......@@ -526,6 +532,245 @@ struct ConjHelper<DeviceContext, paddle::platform::complex<double>> {
const framework::ExecutionContext& ctx_;
};
template <typename DeviceContext, typename T, typename Enabel = void>
struct DotDoubleGradFunction {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
Tensor* tensor_dx, Tensor* tensor_dy,
const Tensor* tensor_dout, const Tensor* tensor_ddx,
const Tensor* tensor_ddy, Tensor* tensor_ddout,
const paddle::framework::ExecutionContext& ctx);
};
template <typename DeviceContext, typename T>
struct DotDoubleGradFunction<DeviceContext, T, math::EnableComplex<T>> {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
Tensor* tensor_dx, Tensor* tensor_dy,
const Tensor* tensor_dout, const Tensor* tensor_ddx,
const Tensor* tensor_ddy, Tensor* tensor_ddout,
const paddle::framework::ExecutionContext& ctx) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
framework::Tensor tensor_dout_help;
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
if (tensor_dx || tensor_dy) {
tensor_dout_help.Resize(tensor_dout->dims());
tensor_dout_help.mutable_data<T>(ctx.GetPlace());
paddle::platform::ForRange<DeviceContext> for_range(
dev_raw, tensor_dout->numel());
math::ConjFunctor<T> functor(tensor_dout->data<T>(),
tensor_dout->numel(),
tensor_dout_help.data<T>());
for_range(functor);
}
if (tensor_dx) {
auto ddy = framework::EigenVector<T>::Flatten(*tensor_ddy);
Eigen::DSizes<int, 1> size(tensor_ddy->numel());
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
auto dout = framework::EigenVector<T>::Flatten(tensor_dout_help);
dx.device(dev) = ddy * dout.broadcast(size);
}
if (tensor_dy) {
auto ddx = framework::EigenVector<T>::Flatten(*tensor_ddx);
Eigen::DSizes<int, 1> size(tensor_ddx->numel());
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
auto dout = framework::EigenVector<T>::Flatten(tensor_dout_help);
dy.device(dev) = ddx * dout.broadcast(size);
}
if (tensor_ddout) {
framework::Tensor tensor_x_help, tensor_y_help;
tensor_x_help.Resize(tensor_x->dims());
tensor_x_help.mutable_data<T>(ctx.GetPlace());
tensor_y_help.Resize(tensor_y->dims());
tensor_y_help.mutable_data<T>(ctx.GetPlace());
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_x->numel());
math::ConjFunctor<T> functor_x(tensor_x->data<T>(), tensor_x->numel(),
tensor_x_help.data<T>());
for_range(functor_x);
math::ConjFunctor<T> functor_y(tensor_y->data<T>(), tensor_y->numel(),
tensor_y_help.data<T>());
for_range(functor_y);
auto x = framework::EigenVector<T>::Flatten(tensor_x_help);
auto y = framework::EigenVector<T>::Flatten(tensor_y_help);
auto ddx = framework::EigenVector<T>::Flatten(*tensor_ddx);
auto ddy = framework::EigenVector<T>::Flatten(*tensor_ddy);
auto ddout = framework::EigenVector<T>::Flatten(*tensor_ddout);
ddout.device(dev) = (x * ddy + y * ddx).sum();
}
}
#else
const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_ddy = tensor_ddy->data<T>();
const framework::DDim& dim = tensor_dx->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddy[i];
}
}
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_ddx = tensor_ddx->data<T>();
const framework::DDim& dim = tensor_dy->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddx[i];
}
}
if (tensor_ddout) {
auto* data_ddout = tensor_ddout->mutable_data<T>(ctx.GetPlace());
auto* data_x = tensor_x->data<T>();
auto* data_y = tensor_y->data<T>();
auto* data_ddx = tensor_ddx->data<T>();
auto* data_ddy = tensor_ddy->data<T>();
const framework::DDim& dim = tensor_dy->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
bool new_s = false;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) {
++s;
new_s = true;
}
if (new_s) {
data_ddout[s] = T(data_x[i].real, -data_x[i].imag) * data_ddy[i] +
T(data_y[i].real, -data_y[i].imag) * data_ddx[i];
} else {
data_ddout[s] += T(data_x[i].real, -data_x[i].imag) * data_ddy[i] +
T(data_y[i].real, -data_y[i].imag) * data_ddx[i];
}
new_s = false;
}
}
#endif
}
};
template <typename DeviceContext, typename T>
struct DotDoubleGradFunction<DeviceContext, T, math::DisableComplex<T>> {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
Tensor* tensor_dx, Tensor* tensor_dy,
const Tensor* tensor_dout, const Tensor* tensor_ddx,
const Tensor* tensor_ddy, Tensor* tensor_ddout,
const paddle::framework::ExecutionContext& ctx) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto ddy = framework::EigenVector<T>::Flatten(*tensor_ddy);
Eigen::DSizes<int, 1> size(tensor_ddy->numel());
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
dx.device(dev) = ddy * dout.broadcast(size);
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto ddx = framework::EigenVector<T>::Flatten(*tensor_ddx);
Eigen::DSizes<int, 1> size(tensor_ddx->numel());
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
dy.device(dev) = ddx * dout.broadcast(size);
}
if (tensor_ddout) {
tensor_ddout->mutable_data<T>(ctx.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto ddx = framework::EigenVector<T>::Flatten(*tensor_ddx);
auto ddy = framework::EigenVector<T>::Flatten(*tensor_ddy);
auto ddout = framework::EigenVector<T>::Flatten(*tensor_ddout);
ddout.device(dev) = (x * ddy + y * ddx).sum();
}
}
#else
const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_ddy = tensor_ddy->data<T>();
const framework::DDim& dim = tensor_dx->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = data_dout[s] * data_ddy[i];
}
}
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_ddx = tensor_ddx->data<T>();
const framework::DDim& dim = tensor_dy->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = data_dout[s] * data_ddx[i];
}
}
if (tensor_ddout) {
auto* data_ddout = tensor_ddout->mutable_data<T>(ctx.GetPlace());
auto* data_x = tensor_x->data<T>();
auto* data_y = tensor_y->data<T>();
auto* data_ddx = tensor_ddx->data<T>();
auto* data_ddy = tensor_ddy->data<T>();
const framework::DDim& dim = tensor_dy->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
bool new_s = false;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) {
++s;
new_s = true;
}
if (new_s) {
data_ddout[s] = data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i];
} else {
data_ddout[s] += data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i];
}
new_s = false;
}
}
#endif
}
};
template <typename DeviceContext, typename T>
class MatMulV2GradKernel : public framework::OpKernel<T> {
public:
......@@ -573,10 +818,10 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
bool transpose_x = ctx.Attr<bool>("trans_x");
bool transpose_y = ctx.Attr<bool>("trans_y");
auto x = *ctx.Input<framework::Tensor>("X");
auto y = *ctx.Input<framework::Tensor>("Y");
auto dout = *ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor y_conj(y.type());
framework::Tensor x_conj(y.type());
......@@ -757,9 +1002,327 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
}
dy->Resize(y.dims());
}
// Get the OutputGrad(out)
}
}
};
template <typename DeviceContext, typename T>
class MatMulV2DoubleGradKernel : public framework::OpKernel<T> {
public:
void MatMul(const framework::ExecutionContext& context,
const framework::Tensor& a, bool trans_a,
const framework::Tensor& b, bool trans_b, framework::Tensor* out,
bool flag) const {
out->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b);
if (a.dims().size() == 3 && b.dims().size() <= 2) {
// the transpose_X must be false, if is true, the transpose cost much time
if (!trans_a) {
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
}
}
blas.MatMul(a, mat_dim_a, b, mat_dim_b, static_cast<T>(1), out,
static_cast<T>(flag));
}
void CalcInputGrad(const framework::ExecutionContext& context,
const framework::Tensor& a, bool trans_a,
bool is_fold_init_dims_a, const framework::Tensor& b,
bool trans_b, bool is_fold_init_dims_b,
framework::Tensor* out, bool flag) const {
if (out == nullptr) return;
bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) &&
out->dims().size() == 2;
if (!need_combine) {
MatMul(context, a, trans_a, b, trans_b, out, flag);
} else {
auto& ctx = context.template device_context<DeviceContext>();
MatMul(context, is_fold_init_dims_a
? FoldInitDims(a)
: FoldHeadAndLastDims<DeviceContext, T>(ctx, a),
trans_a, is_fold_init_dims_b
? FoldInitDims(b)
: FoldHeadAndLastDims<DeviceContext, T>(ctx, b),
trans_b, out, flag);
}
}
void Compute(const framework::ExecutionContext& context) const override {
auto x = *context.Input<framework::Tensor>("X");
auto y = *context.Input<framework::Tensor>("Y");
auto dout = *context.Input<framework::Tensor>("DOut");
auto* ddx = context.Input<framework::Tensor>("DDX");
auto* ddy = context.Input<framework::Tensor>("DDY");
auto* dx = context.Output<framework::Tensor>("DX");
auto* dy = context.Output<framework::Tensor>("DY");
auto* ddout = context.Output<framework::Tensor>("DDOut");
bool transpose_x = context.Attr<bool>("trans_x");
bool transpose_y = context.Attr<bool>("trans_y");
// Get dims from the input x, y, output_grad
std::vector<std::int64_t> x_dims = vectorize(x.dims());
std::vector<std::int64_t> y_dims = vectorize(y.dims());
std::vector<std::int64_t> dout_dims = vectorize(dout.dims());
framework::Tensor x_conj(x.type());
framework::Tensor y_conj(y.type());
framework::Tensor dout_conj(dout.type());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
int ndim = dout_dims.size();
// Case1 : x's or y's dim = 1
if (x_ndim == 1 && y_ndim == 1) {
DotDoubleGradFunction<DeviceContext, T>()(&x, &y, dx, dy, &dout, ddx, ddy,
ddout, context);
return;
}
bool is_broadcast = true;
if (x_ndim <= 2 || y_ndim <= 2) {
is_broadcast = false;
} else if (x_ndim != y_ndim) {
is_broadcast = true;
} else {
is_broadcast = !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2,
y_dims.cbegin());
}
if (!is_broadcast) {
// Case2: no broadcast or no batch size
ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
framework::DDim dx_dims;
ConjHelper<DeviceContext, T> conj_helper(context);
if (dx) {
dx_dims = dx->dims();
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
}
}
framework::DDim dy_dims;
if (dy) {
dy_dims = dy->dims();
if (dy_dims != y.dims()) {
dy->Resize(y.dims());
}
}
framework::DDim ddout_dims;
if (ddout) {
ddout_dims = ddout->dims();
if (ddout_dims != dout.dims()) {
ddout->Resize(dout.dims());
}
}
if (ddx || ddy) {
ConjHelper<DeviceContext, T> conj_helper(context);
conj_helper(dout, dout_conj);
}
if (ddout) {
ConjHelper<DeviceContext, T> conj_helper(context);
conj_helper(x, x_conj);
conj_helper(y, y_conj);
}
bool ddout_flag = false;
if (ddx) {
auto ddx_mat = *ddx;
if (ddx_mat.dims() != x.dims()) {
ddx_mat.Resize(x.dims());
}
if (dy) {
if (transpose_x && transpose_y) {
// dy = dout' * ddx'
CalcInputGrad(context, dout_conj, true, true, ddx_mat, true, false,
dy, false);
} else if (transpose_x) {
// dy = ddx * dout
CalcInputGrad(context, ddx_mat, false, false, dout_conj, false,
true, dy, false);
} else if (transpose_y) {
// dy = dout' * ddx
CalcInputGrad(context, dout_conj, true, true, ddx_mat, false, true,
dy, false);
} else {
// dy = ddx' * dout
CalcInputGrad(context, ddx_mat, true, true, dout_conj, false, true,
dy, false);
}
}
if (ddout) {
CalcInputGrad(context, ddx_mat, transpose_x, true, y_conj,
transpose_y, false, ddout, ddout_flag);
ddout_flag = true;
}
}
if (ddy) {
auto ddy_mat = *ddy;
if (ddy_mat.dims() != y.dims()) {
ddy_mat.Resize(y.dims());
}
if (dx) {
if (transpose_x && transpose_y) {
// dx = ddy' * dout'
CalcInputGrad(context, ddy_mat, true, true, dout_conj, true, false,
dx, false);
} else if (transpose_x) {
// dx = ddy * dout'
CalcInputGrad(context, ddy_mat, false, false, dout_conj, true,
false, dx, false);
} else if (transpose_y) {
// dx = dout * ddy
CalcInputGrad(context, dout_conj, false, false, ddy_mat, false,
true, dx, false);
} else {
// dx = dout * ddy'
CalcInputGrad(context, dout_conj, false, false, ddy_mat, true,
false, dx, false);
}
}
if (ddout) {
CalcInputGrad(context, x_conj, transpose_x, true, ddy_mat,
transpose_y, false, ddout, ddout_flag);
}
}
if (dx) {
if (dx_dims != x.dims()) {
dx->Resize(dx_dims);
}
}
if (dy) {
if (dy_dims != y.dims()) {
dy->Resize(dy_dims);
}
}
if (ddout) {
if (ddout_dims != dout.dims()) {
ddout->Resize(ddout_dims);
}
}
} else {
// Case3: broadcast. It need cost much time to reduce sum for the
// broadcast and wastes the memory.
// So we should avoid the case in reality.
VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality";
framework::Tensor ddy_conj(ddx->type());
framework::Tensor ddx_conj(ddy->type());
Tensor dx_help, dy_help;
if (dx || dy) {
ConjHelper<DeviceContext, T> conj_helper(context);
conj_helper(dout, dout_conj);
}
if (ddout) {
ConjHelper<DeviceContext, T> conj_helper(context);
conj_helper(x, x_conj);
conj_helper(y, y_conj);
}
if (transpose_x) {
if (transpose_y) {
if (dx)
MatMulFunction<DeviceContext, T>(ddy, &dout_conj, y_dims, dout_dims,
&dx_help, true, true, context);
if (dy)
MatMulFunction<DeviceContext, T>(&dout_conj, ddx, dout_dims, x_dims,
&dy_help, true, true, context);
} else {
if (dx)
MatMulFunction<DeviceContext, T>(ddy, &dout_conj, y_dims, dout_dims,
&dx_help, false, true, context);
if (dy)
MatMulFunction<DeviceContext, T>(ddx, &dout_conj, x_dims, dout_dims,
&dy_help, false, false, context);
}
} else {
if (transpose_y) {
if (dx)
MatMulFunction<DeviceContext, T>(&dout_conj, ddy, dout_dims, y_dims,
&dx_help, false, false, context);
if (dy)
MatMulFunction<DeviceContext, T>(&dout_conj, ddx, dout_dims, x_dims,
&dy_help, true, false, context);
} else {
if (dx)
MatMulFunction<DeviceContext, T>(&dout_conj, ddy, dout_dims, y_dims,
&dx_help, false, true, context);
if (dy)
MatMulFunction<DeviceContext, T>(ddx, &dout_conj, x_dims, dout_dims,
&dy_help, true, false, context);
}
}
// get help dims
const std::vector<std::int64_t> dx_help_dims = vectorize(dx_help.dims());
const std::vector<std::int64_t> dy_help_dims = vectorize(dy_help.dims());
std::vector<std::int64_t> dx_broadcast_dims(ndim);
std::vector<std::int64_t> dy_broadcast_dims(ndim);
std::fill(dx_broadcast_dims.data(),
dx_broadcast_dims.data() + ndim - x_ndim, 1);
std::fill(dy_broadcast_dims.data(),
dy_broadcast_dims.data() + ndim - y_ndim, 1);
std::copy(x_dims.data(), x_dims.data() + x_ndim,
dx_broadcast_dims.data() + ndim - x_ndim);
std::copy(y_dims.data(), y_dims.data() + y_ndim,
dy_broadcast_dims.data() + ndim - y_ndim);
std::vector<int> dx_reduce_dims;
std::vector<int> dy_reduce_dims;
for (int idx = 0; idx <= ndim - 3; idx++) {
if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) {
dx_reduce_dims.push_back(idx);
}
if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) {
dy_reduce_dims.push_back(idx);
}
}
// Reduce sum to get grad by ReduceSum
if (dx) {
if (dx_reduce_dims.empty()) {
*dx = std::move(dx_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&dx_help, dx, dx_reduce_dims,
context);
}
dx->Resize(x.dims());
}
if (dy) {
if (dy_reduce_dims.empty()) {
*dy = std::move(dy_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&dy_help, dy, dy_reduce_dims,
context);
}
dy->Resize(y.dims());
}
if (ddout) {
// Caluate the gradient of OutputGrad(Out)
MatMulFunction<DeviceContext, T>(ddx, &y_conj, x_dims, y_dims, ddout,
transpose_x, transpose_y, context);
MatMulFunction<DeviceContext, T>(&x_conj, ddy, x_dims, y_dims, ddout,
transpose_x, transpose_y, context,
true);
}
}
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
import gradient_checker
from decorator_helper import prog_scope
paddle.enable_static()
class TestMatmulDoubleGradCheck(unittest.TestCase):
def setUp(self):
self.init_test()
def init_test(self):
self.x_shape = [2]
self.y_shape = [2]
self.transpose_x = False
self.transpose_y = False
@prog_scope()
def func(self, place):
eps = 0.005
dtype = np.float64
typename = "float64"
x = paddle.static.create_parameter(
dtype=typename, shape=self.x_shape, name='x')
y = paddle.static.create_parameter(
dtype=typename, shape=self.y_shape, name='y')
out = paddle.matmul(
x, y, self.transpose_x, self.transpose_y, name='out')
x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype)
gradient_checker.double_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestMatmulDoubleGradCheckCase1(TestMatmulDoubleGradCheck):
def init_test(self):
self.x_shape = [2, 3]
self.y_shape = [3, 2]
self.transpose_x = True
self.transpose_y = True
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestMatmulDoubleGradCheck2(unittest.TestCase):
def setUp(self):
self.init_test()
def init_test(self):
self.x_shape = [2, 4, 3]
self.y_shape = [2, 4, 5]
self.transpose_x = True
self.transpose_y = False
@prog_scope()
def func(self, place):
eps = 0.005
dtype = np.float64
typename = "float64"
x = paddle.static.create_parameter(
dtype=typename, shape=self.x_shape, name='x')
y = paddle.static.create_parameter(
dtype=typename, shape=self.y_shape, name='y')
out = paddle.matmul(
x, y, self.transpose_x, self.transpose_y, name='out')
x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype)
gradient_checker.double_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestMatmulDoubleGradCheckCase3(unittest.TestCase):
def setUp(self):
self.init_test()
def init_test(self):
self.x_shape = [1, 1, 4, 25]
self.y_shape = [1, 2, 25, 4]
self.transpose_x = False
self.transpose_y = False
@prog_scope()
def func(self, place):
eps = 0.005
dtype = np.float64
typename = "float64"
x = paddle.static.create_parameter(
dtype=typename, shape=self.x_shape, name='x')
y = paddle.static.create_parameter(
dtype=typename, shape=self.y_shape, name='y')
out = paddle.matmul(
x, y, self.transpose_x, self.transpose_y, name='out')
x_arr = np.random.uniform(-1, 1, self.x_shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, self.y_shape).astype(dtype)
gradient_checker.double_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册