提交 982e61f5 编写于 作者: L Leo Chen 提交者: Zeng Jinle

Update elementwise double grad to save gpu memory (#19509)

* update elementwise double grad to save gpu memory, test=develop

* update elementwise_mul/div_grad_grad to save memory, test=develop

* remove eval function in eigen statement to save memory, test=develop

* add unittest for elementwise_div_grad_grad without dout, test=develop

* add unittest for elementwise_add_grad_grad without ddx, test=develop

* add float16 cuda kernel for elementwise double grad op, test=develop
上级 db26de83
......@@ -782,6 +782,8 @@ class SquareDoubleGradMaker
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInference,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInference,
{"DDX", "DDOut"});
class PowGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
......@@ -896,7 +898,8 @@ REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
ops::ReluDoubleGradMaker);
REGISTER_OPERATOR(
relu_grad_grad,
ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>);
ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInference);
REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
......@@ -921,7 +924,8 @@ REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
ops::LeakyReluDoubleGradMaker);
REGISTER_OPERATOR(
leaky_relu_grad_grad,
ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>);
ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInference);
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
LeakyReluGradFunctor);
......@@ -945,7 +949,9 @@ REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
ops::SqrtDoubleGradMaker);
REGISTER_OPERATOR(
sqrt_grad_grad,
ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>);
ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInference);
REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
sqrt_grad_grad, ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
......@@ -967,7 +973,8 @@ REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
ops::SquareDoubleGradMaker);
REGISTER_OPERATOR(
square_grad_grad,
ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>);
ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInference);
REGISTER_ACTIVATION_CPU_KERNEL(square, Square, SquareFunctor,
SquareGradFunctor);
......
......@@ -1437,15 +1437,17 @@ struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
ddout.device(*d) = ddx * static_cast<T>(0.5) / out;
}
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
// calculate dy first, so ddy can inplace ddx
if (dOut) {
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX));
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
dout.device(*d) = dx * ddx * static_cast<T>(-1) / out;
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
ddout.device(*d) = ddx * static_cast<T>(0.5) / out;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
......@@ -1459,15 +1461,17 @@ struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X));
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
ddout.device(*d) = ddx * static_cast<T>(2) * x;
}
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
// calculate dx first, so ddy can inplace ddx
if (dX) {
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX));
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
dx.device(*d) = ddx * static_cast<T>(2) * dout;
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
ddout.device(*d) = ddx * static_cast<T>(2) * x;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
......
......@@ -2,3 +2,5 @@ include(operators)
register_operators()
cc_test(test_elementwise_add_op_inplace SRCS test_elementwise_add_op_inplace.cc DEPS op_registry elementwise_add_op scope device_context enforce executor)
cc_test(test_elementwise_div_grad_grad SRCS test_elementwise_div_grad_grad.cc DEPS op_registry elementwise_div_op scope device_context enforce executor)
cc_test(test_elementwise_add_grad_grad SRCS test_elementwise_add_grad_grad.cc DEPS op_registry elementwise_add_op scope device_context enforce executor)
......@@ -54,7 +54,9 @@ REGISTER_OPERATOR(elementwise_add_grad, ops::ElementwiseOpExplicitGrad,
ops::ElementwiseGradNoBufVarsInference,
ops::ElementwiseAddDoubleGradDescMaker);
REGISTER_OPERATOR(elementwise_add_grad_grad,
ops::ElementwiseOpDoubleGradWithoutDXDY);
ops::ElementwiseOpDoubleGradWithoutDXDY,
ops::ElementwiseDoubleGradOpInplace,
ops::ElementwiseDoubleGradNoBufVarsInference);
REGISTER_OP_CPU_KERNEL(
elementwise_add,
......
......@@ -36,4 +36,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int64_t>);
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
plat::float16>);
......@@ -80,7 +80,8 @@ REGISTER_OPERATOR(elementwise_div, ops::ElementwiseOp,
REGISTER_OPERATOR(elementwise_div_grad, ops::ElementwiseOpGrad,
ops::ElementwiseDivDoubleGradDescMaker);
REGISTER_OPERATOR(elementwise_div_grad_grad, ops::ElementwiseDivOpDoubleGrad);
REGISTER_OPERATOR(elementwise_div_grad_grad, ops::ElementwiseDivOpDoubleGrad,
ops::ElementwiseDivDoubleGradOpInplace);
REGISTER_OP_CPU_KERNEL(
elementwise_div,
......
......@@ -37,6 +37,8 @@ REGISTER_OP_CUDA_KERNEL(
elementwise_div_grad_grad,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
float>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
double>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
......
......@@ -148,20 +148,21 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel<T> {
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, Out, ddX, &ddX_safe);
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, Y, ddY, &ddY_safe);
// ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y
// dY = Out * dX * ddY / Y - dX * ddX / Y
// dOut = - dX * ddY
// To save memory, (1) dout can be used as 'tmp' tensor, (2) ddout can
// inplace ddx
Tensor tmp;
if (dOut) {
// dOut = - dX * ddY
default_elementwise_mul<DeviceContext, T>(ctx, dX, &ddY_safe, dOut);
auto& place =
*ctx.template device_context<DeviceContext>().eigen_device();
auto dout = framework::EigenVector<T>::Flatten(*dOut);
dout.device(place) = static_cast<T>(-1) * dout;
tmp = *dOut;
} else {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
tmp = ctx.AllocateTmpTensor<T, DeviceContext>(Out->dims(), dev_ctx);
}
if (dY) {
// dX_div_Y = dX / Y;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
Tensor dX_div_Y =
ctx.AllocateTmpTensor<T, DeviceContext>(Out->dims(), dev_ctx);
Tensor dX_div_Y = tmp;
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(
ctx, dX, Y, axis, DivFunctor<T>(), &dX_div_Y);
......@@ -179,14 +180,25 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel<T> {
if (ddOut) {
// ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y
default_elementwise_mul<DeviceContext, T>(ctx, Out, &ddY_safe, ddOut);
default_elementwise_mul<DeviceContext, T>(ctx, Out, &ddY_safe, &tmp);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &ddX_safe, ddOut, 0, SubFunctor<T>(), ddOut);
ctx, &ddX_safe, &tmp, 0, SubFunctor<T>(), &tmp);
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(
ctx, ddOut, Y, axis, DivFunctor<T>(), ddOut);
ctx, &tmp, Y, axis, DivFunctor<T>(), ddOut);
}
if (dOut) {
// dOut = - dX * ddY
default_elementwise_mul<DeviceContext, T>(ctx, dX, &ddY_safe, dOut);
auto& place =
*ctx.template device_context<DeviceContext>().eigen_device();
auto dout = framework::EigenVector<T>::Flatten(*dOut);
dout.device(place) = static_cast<T>(-1) * dout;
}
}
};
DECLARE_INPLACE_OP_INFERER(ElementwiseDivDoubleGradOpInplace, {"DDX", "DDOut"});
} // namespace operators
} // namespace paddle
......@@ -77,7 +77,8 @@ REGISTER_OPERATOR(elementwise_mul, ops::ElementwiseOp,
ops::ElementwiseMulOpGradDescMaker);
REGISTER_OPERATOR(elementwise_mul_grad, ops::ElementwiseOpGrad,
ops::ElementwiseMulDoubleGradDescMaker);
REGISTER_OPERATOR(elementwise_mul_grad_grad, ops::ElementwiseOpDoubleGrad);
REGISTER_OPERATOR(elementwise_mul_grad_grad, ops::ElementwiseOpDoubleGrad,
ops::ElementwiseMulDoubleGradOpInplace);
REGISTER_OP_CPU_KERNEL(
elementwise_mul,
......
......@@ -94,4 +94,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>);
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::float16>);
......@@ -146,37 +146,48 @@ class ElementwiseMulDoubleGradKernel : public framework::OpKernel<T> {
if (ddout) ddout->mutable_data<T>(ctx.GetPlace());
// dx = dout * ddy
// dy = dout * ddx
Tensor ddx_safe, ddy_safe;
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, x, ddx, &ddx_safe);
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>(
ctx, ddx_safe, ddy_safe, *dout, *dout, axis, dx, dy, MulGradDX<T>(),
MulGradDY<T>());
// dx = dout * ddy
// dy = dout * ddx
// ddout = ddx * y + x * ddy
// change computation sequence to save memory, so ddout can inplace ddx and
// dx can be used as 'tmp' tensor
// (1) dx = x * ddy
// (2) dy = dout * ddx
// (3) ddout = ddx * y
// (4) ddout = ddout + dx
// (5) dx = dout *ddy
if (ddout) {
if (ddx && ddy) {
Tensor ddout_tmp;
ddout_tmp.mutable_data<T>(ddout->dims(), ctx.GetPlace());
default_elementwise_mul<DeviceContext, T>(ctx, ddx, y, ddout);
default_elementwise_mul<DeviceContext, T>(ctx, x, ddy, &ddout_tmp);
auto& place =
*ctx.template device_context<DeviceContext>().eigen_device();
auto ddout_t = framework::EigenVector<T>::Flatten(*ddout);
auto ddout_tmp_t = framework::EigenVector<T>::Flatten(ddout_tmp);
ddout_t.device(place) = ddout_t + ddout_tmp_t;
} else {
if (ddx) default_elementwise_mul<DeviceContext, T>(ctx, ddx, y, ddout);
if (ddy) default_elementwise_mul<DeviceContext, T>(ctx, x, ddy, ddout);
}
// use dx to save memory, other than alloc tmp tensor
Tensor* ddout_tmp = dx;
default_elementwise_mul<DeviceContext, T>(ctx, x, &ddy_safe, ddout_tmp);
int axis = ctx.Attr<int>("axis");
// NOTE: in the following ElemwiseGradCompute, for the
// first output tensor is nullptr, the branch to calculate first
// output tensor will not be activated, DivGradDx function will not
// be called and can be ignored, the first branch has little effect
// on running speed.
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>(
ctx, ddx_safe, ddy_safe, *dout, *dout, axis, nullptr, dy,
MulGradDX<T>(), MulGradDY<T>());
default_elementwise_mul<DeviceContext, T>(ctx, &ddx_safe, y, ddout);
auto& place =
*ctx.template device_context<DeviceContext>().eigen_device();
auto ddout_t = framework::EigenVector<T>::Flatten(*ddout);
auto ddout_tmp_t = framework::EigenVector<T>::Flatten(*ddout_tmp);
ddout_t.device(place) = ddout_t + ddout_tmp_t;
default_elementwise_mul<DeviceContext, T>(ctx, dout, &ddy_safe, dx);
}
}
};
DECLARE_INPLACE_OP_INFERER(ElementwiseMulDoubleGradOpInplace, {"DDX", "DDOut"},
{"X", framework::GradVarName("X")},
{"Y", framework::GradVarName("Y")});
} // namespace operators
} // namespace paddle
......@@ -264,7 +264,18 @@ class ElementwiseOpDoubleGradWithoutDXDY
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type = ctx.Input<Tensor>("DOut")->type();
framework::proto::VarType::Type input_data_type;
if (ctx.HasInput("DDX") == false) {
PADDLE_ENFORCE_EQ(ctx.HasInput("DDY"), true,
"Input(DDY) should not be null");
input_data_type = ctx.Input<Tensor>("DDY")->type();
} else if (ctx.HasInput("DDY") == false) {
PADDLE_ENFORCE_EQ(ctx.HasInput("DDX"), true,
"Input(DDX) should not be null");
input_data_type = ctx.Input<Tensor>("DDX")->type();
} else {
input_data_type = ctx.Input<Tensor>("DDX")->type();
}
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
......@@ -321,8 +332,11 @@ DECLARE_INPLACE_OP_INFERER(ElementwiseOpInplace, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ElementwiseGradOpInplace,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
DECLARE_INPLACE_OP_INFERER(ElementwiseDoubleGradOpInplace, {"DDX", "DDOut"});
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseGradNoBufVarsInference, "Y");
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference,
"Y", "DOut");
} // namespace operators
} // namespace paddle
......
......@@ -54,7 +54,9 @@ REGISTER_OPERATOR(elementwise_sub_grad, ops::ElementwiseOpExplicitGrad,
ops::ElementwiseGradNoBufVarsInference,
ops::ElementwiseSubDoubleGradDescMaker);
REGISTER_OPERATOR(elementwise_sub_grad_grad,
ops::ElementwiseOpDoubleGradWithoutDXDY);
ops::ElementwiseOpDoubleGradWithoutDXDY,
ops::ElementwiseDoubleGradOpInplace,
ops::ElementwiseDoubleGradNoBufVarsInference);
REGISTER_OP_CPU_KERNEL(
elementwise_sub,
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cstdlib>
#include <memory>
#include <random>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/elementwise/test_elementwise_op_grad_grad.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
USE_OP(elementwise_add);
namespace paddle {
namespace operators {
template <typename T>
class TestElementwiseAddGradGradWithoutDout
: public TestElementwiseOpGradGrad<T> {
public:
TestElementwiseAddGradGradWithoutDout(const platform::Place &place,
const framework::DDim &dims)
: TestElementwiseOpGradGrad<T>("elementwise_add_grad_grad", place, dims,
{"Y", "DOut", "DDY"}, {"DDOut"}) {}
using TestElementwiseOpGradGrad<T>::feed_datas_;
using TestElementwiseOpGradGrad<T>::expected_outs_;
using TestElementwiseOpGradGrad<T>::dims_;
void ComputeExpectedOuts() override {
size_t numel = static_cast<size_t>(framework::product(dims_));
std::vector<T> dy(numel);
std::vector<T> ddout(numel);
for (size_t i = 0; i < numel; ++i) {
// ddOut = ddX + ddY = ddY if ddX empty
ddout[i] = feed_datas_["DDY"][i];
}
expected_outs_["DDOut"] = ddout;
}
std::unique_ptr<framework::OperatorBase> CreateTestOp() override {
auto op = framework::OpRegistry::CreateOp(
this->op_type_, {{"Y", {"Y"}}, {"DOut", {"DOut"}}, {"DDY", {"DDY"}}},
{{"DDOut", {"DDOut"}}}, {{"use_mkldnn", false}, {"axis", 0}});
return op;
}
};
TEST(test_elementwise_add_grad_grad_without_ddx, cpu_place) {
framework::DDim dims({32, 64});
platform::CPUPlace p;
TestElementwiseAddGradGradWithoutDout<float> test(p, dims);
ASSERT_TRUE(test.Check());
}
#ifdef PADDLE_WITH_CUDA
TEST(test_elementwise_add_grad_grad_without_ddx, gpu_place) {
framework::DDim dims({32, 64});
platform::CUDAPlace p(0);
TestElementwiseAddGradGradWithoutDout<float> test(p, dims);
ASSERT_TRUE(test.Check());
}
#endif
} // namespace operators
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cstdlib>
#include <memory>
#include <random>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/elementwise/test_elementwise_op_grad_grad.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
USE_OP(elementwise_div);
namespace paddle {
namespace operators {
template <typename T>
class TestElementwiseDivGradGradWithoutDout
: public TestElementwiseOpGradGrad<T> {
public:
TestElementwiseDivGradGradWithoutDout(const platform::Place &place,
const framework::DDim &dims)
: TestElementwiseOpGradGrad<T>("elementwise_div_grad_grad", place, dims,
{"Y", "Out", "DDX", "DDY", "DX"},
{"Y@GRAD", "DDOut"}) {}
using TestElementwiseOpGradGrad<T>::feed_datas_;
using TestElementwiseOpGradGrad<T>::expected_outs_;
using TestElementwiseOpGradGrad<T>::dims_;
void ComputeExpectedOuts() override {
size_t numel = static_cast<size_t>(framework::product(dims_));
std::vector<T> dy(numel);
std::vector<T> ddout(numel);
for (size_t i = 0; i < numel; ++i) {
// dY(Y@GRAD) = Out * dX * ddY / Y - dX * ddX / Y
dy[i] = (feed_datas_["DX"][i] / feed_datas_["Y"][i]) *
(feed_datas_["Out"][i] * feed_datas_["DDY"][i] -
feed_datas_["DDX"][i]);
// ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y
ddout[i] = (feed_datas_["DDX"][i] -
feed_datas_["Out"][i] * feed_datas_["DDY"][i]) /
(feed_datas_["Y"][i]);
}
expected_outs_["Y@GRAD"] = dy;
expected_outs_["DDOut"] = ddout;
}
std::unique_ptr<framework::OperatorBase> CreateTestOp() override {
auto op = framework::OpRegistry::CreateOp(
this->op_type_, {{"Y", {"Y"}},
{"Out", {"Out"}},
{"DDX", {"DDX"}},
{"DDY", {"DDY"}},
{"DX", {"DX"}}},
{{"Y@GRAD", {"Y@GRAD"}}, {"DDOut", {"DDOut"}}},
{{"use_mkldnn", false}, {"axis", 0}});
return op;
}
};
TEST(test_elementwise_div_grad_grad_without_dout, cpu_place) {
framework::DDim dims({32, 64});
platform::CPUPlace p;
TestElementwiseDivGradGradWithoutDout<float> test(p, dims);
ASSERT_TRUE(test.Check());
}
#ifdef PADDLE_WITH_CUDA
TEST(test_elementwise_div_grad_grad_without_dout, gpu_place) {
framework::DDim dims({32, 64});
platform::CUDAPlace p(0);
TestElementwiseDivGradGradWithoutDout<float> test(p, dims);
ASSERT_TRUE(test.Check());
}
#endif
} // namespace operators
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cstdlib>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
// currently, this test class only support same dims
template <typename T>
class TestElementwiseOpGradGrad {
public:
TestElementwiseOpGradGrad(const std::string &op_type,
const platform::Place &place,
const framework::DDim &dims,
const std::vector<std::string> &inputs,
const std::vector<std::string> &outputs)
: op_type_(op_type),
place_(place),
dims_(dims),
inputs_(inputs),
outputs_(outputs) {}
void InitVarInScope(std::string var_name) {
in_out_tensors_[var_name] =
scope_.Var(var_name)->template GetMutable<framework::LoDTensor>();
in_out_tensors_[var_name]->Resize(dims_);
in_out_tensors_[var_name]->template mutable_data<T>(place_);
}
void InitFeedData(std::string var_name, size_t size) {
// generate random data
std::uniform_real_distribution<T> dist(static_cast<T>(10.0),
static_cast<T>(20.0));
std::mt19937 engine;
std::vector<T> data(size);
for (size_t i = 0; i < size; ++i) {
data[i] = dist(engine);
}
feed_datas_[var_name] = data;
}
void Setup() {
size_t numel = static_cast<size_t>(framework::product(dims_));
// init vars in scope and feed inputs
for (auto in_name : inputs_) {
InitVarInScope(in_name);
InitFeedData(in_name, numel);
}
for (auto out_name : outputs_) {
InitVarInScope(out_name);
}
// feeding: copy data to tensor, out tensor don't need init
auto bytes = sizeof(T) * numel;
for (auto &in_name : inputs_) {
auto dst = in_out_tensors_[in_name]->template data<T>();
auto src = feed_datas_[in_name].data();
auto src_place = platform::CPUPlace();
if (platform::is_cpu_place(place_)) {
auto dst_place = boost::get<platform::CPUPlace>(place_);
memory::Copy(dst_place, dst, src_place, src, bytes);
} else if (platform::is_gpu_place(place_)) {
#ifdef PADDLE_WITH_CUDA
auto dst_place = boost::get<platform::CUDAPlace>(place_);
memory::Copy(dst_place, dst, src_place, src, bytes, nullptr);
#else
PADDLE_THROW("Not compiled with cuda");
#endif
}
}
// calculate expected outputs
ComputeExpectedOuts();
}
bool Check() {
Setup();
auto op = CreateTestOp();
op->Run(scope_, place_);
platform::DeviceContextPool::Instance().Get(place_)->Wait();
framework::LoDTensor cpu_out;
PADDLE_ENFORCE_EQ(scope_.kids().empty(), true, "scope has child scopes");
// get outputs from scope and compare them with expected_outs
bool all_equal = true;
for (auto &out_name : outputs_) {
auto &out_tensor =
scope_.FindVar(out_name)->template Get<framework::LoDTensor>();
if (platform::is_gpu_place(place_)) {
framework::TensorCopySync(out_tensor, platform::CPUPlace(), &cpu_out);
} else {
cpu_out = out_tensor;
}
auto *out_ptr = cpu_out.data<T>();
size_t numel = static_cast<size_t>(framework::product(dims_));
auto is_equal =
std::equal(out_ptr, out_ptr + numel, expected_outs_[out_name].data());
if (!is_equal) {
all_equal = false;
break;
}
}
return all_equal;
}
virtual std::unique_ptr<framework::OperatorBase> CreateTestOp() = 0;
virtual void ComputeExpectedOuts() = 0;
virtual ~TestElementwiseOpGradGrad() {}
protected:
std::string op_type_;
platform::Place place_;
framework::DDim dims_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
std::map<std::string, paddle::framework::LoDTensor *> in_out_tensors_;
std::map<std::string, std::vector<T>> feed_datas_;
std::map<std::string, std::vector<T>> expected_outs_;
framework::Scope scope_;
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册