“cf6d9e90cb42b28086c1a14b37eccca62ae0a95a”上不存在“doc/howto/optimization/cpu_profiling.html”
未验证 提交 6fe9dfb2 编写于 作者: C Charles-hit 提交者: GitHub

support pow double grad op (#47691)

* support pow_double_grad op

* add unit test for pow double grad

* fix pow double grad

* optimize pow double grad kernel

* fix pow double grad kernel
上级 18adbbd0
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h" #include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
#include "paddle/phi/backends/dynload/port.h" #include "paddle/phi/backends/dynload/port.h"
#include "paddle/phi/infermeta/backward.h"
DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn);
...@@ -434,7 +435,24 @@ class PowGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -434,7 +435,24 @@ class PowGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetType("pow_grad"); op->SetType("pow_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework ::GradVarName("X"), this->InputGrad("X"));
op->SetInput("FactorTensor", this->Input("FactorTensor"));
op->SetAttrMap(this->Attrs());
}
};
template <typename T>
class PowDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("pow_double_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
op->SetInput("DDX", this->OutputGrad(framework ::GradVarName("X")));
op->SetOutput("DX", this->InputGrad("X"));
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
op->SetInput("FactorTensor", this->Input("FactorTensor")); op->SetInput("FactorTensor", this->Input("FactorTensor"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
} }
...@@ -493,6 +511,18 @@ class PowOpGrad : public framework::OperatorWithKernel { ...@@ -493,6 +511,18 @@ class PowOpGrad : public framework::OperatorWithKernel {
expected_kernel_type.data_type_, tensor.place(), tensor.layout()); expected_kernel_type.data_type_, tensor.place(), tensor.layout());
} }
}; };
class PowOpDoubleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X");
}
};
DECLARE_INPLACE_OP_INFERER(ActFwdInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(ActFwdInplaceInferer, {"X", "Out"});
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -542,6 +572,9 @@ REGISTER_ACTIVATION_OP(hard_swish, ...@@ -542,6 +572,9 @@ REGISTER_ACTIVATION_OP(hard_swish,
REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor); REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor);
/* ========================== pow register ============================ */ /* ========================== pow register ============================ */
DECLARE_INFER_SHAPE_FUNCTOR(pow_double_grad,
PowDoubleGradInferShapeFunctor,
PD_INFER_META(phi::GeneralBinaryGradInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
pow, pow,
...@@ -555,7 +588,13 @@ REGISTER_OPERATOR( ...@@ -555,7 +588,13 @@ REGISTER_OPERATOR(
void>::type); void>::type);
REGISTER_OPERATOR(pow_grad, REGISTER_OPERATOR(pow_grad,
ops::PowOpGrad, ops::PowOpGrad,
ops::ActivationGradOpInplaceInferer); ops::ActivationGradOpInplaceInferer,
ops::PowDoubleGradOpMaker<paddle::framework::OpDesc>,
ops::PowDoubleGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(pow_double_grad,
ops::PowOpDoubleGrad,
ops::ActivationDoubleGradOpInplaceInferer,
PowDoubleGradInferShapeFunctor);
/* ========================================================================== */ /* ========================================================================== */
/* ========================== register checkpoint ===========================*/ /* ========================== register checkpoint ===========================*/
......
...@@ -27,6 +27,7 @@ limitations under the License. */ ...@@ -27,6 +27,7 @@ limitations under the License. */
#include <type_traits> #include <type_traits>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
......
...@@ -1364,6 +1364,17 @@ ...@@ -1364,6 +1364,17 @@
param : [x, out, out_grad, kernel_size, strides, paddings, ceil_mode, exclusive, data_format, pooling_type, global_pooling, adaptive, padding_algorithm] param : [x, out, out_grad, kernel_size, strides, paddings, ceil_mode, exclusive, data_format, pooling_type, global_pooling, adaptive, padding_algorithm]
use_gpudnn : use_gpudnn use_gpudnn : use_gpudnn
- backward_op : pow_double_grad
forward : pow_grad(Tensor x, Tensor grad_out, Scalar y) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_out, Tensor grad_x_grad, Scalar y)
output : Tensor(x_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param: [x, x]
kernel :
func : pow_double_grad
inplace : (grad_x_grad -> x_grad)
- backward_op : pow_grad - backward_op : pow_grad
forward : pow(Tensor x, Scalar y) -> Tensor(out) forward : pow(Tensor x, Scalar y) -> Tensor(out)
args : (Tensor x, Tensor out_grad, Scalar y=-1) args : (Tensor x, Tensor out_grad, Scalar y=-1)
...@@ -1373,6 +1384,7 @@ ...@@ -1373,6 +1384,7 @@
param: [x] param: [x]
kernel : kernel :
func : pow_grad func : pow_grad
backward: pow_double_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : prelu_grad - backward_op : prelu_grad
......
...@@ -207,6 +207,14 @@ void PowGradKernel(const Context& dev_ctx, ...@@ -207,6 +207,14 @@ void PowGradKernel(const Context& dev_ctx,
const Scalar& factor, const Scalar& factor,
DenseTensor* dx); DenseTensor* dx);
template <typename T, typename Context>
void PowDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
const Scalar& factor,
DenseTensor* dx,
DenseTensor* ddout);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Cos); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Cos);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Tan); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Tan);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Acos); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Acos);
......
...@@ -373,3 +373,11 @@ PD_REGISTER_KERNEL(pow_grad, ...@@ -373,3 +373,11 @@ PD_REGISTER_KERNEL(pow_grad,
double, double,
int, int,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(pow_double_grad,
CPU,
ALL_LAYOUT,
phi::PowDoubleGradKernel,
float,
double,
int,
int64_t) {}
...@@ -462,3 +462,14 @@ PD_REGISTER_KERNEL(pow_grad, ...@@ -462,3 +462,14 @@ PD_REGISTER_KERNEL(pow_grad,
int64_t, int64_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(pow_double_grad,
GPU,
ALL_LAYOUT,
phi::PowDoubleGradKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -16,7 +16,11 @@ ...@@ -16,7 +16,11 @@
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/activation_functor.h" #include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/scale_kernel.h"
namespace phi { namespace phi {
...@@ -334,6 +338,34 @@ void PowGradKernel(const Context& dev_ctx, ...@@ -334,6 +338,34 @@ void PowGradKernel(const Context& dev_ctx,
functor(*place, x_flatten, nullptr, dout_flatten, dx_flatten); functor(*place, x_flatten, nullptr, dout_flatten, dx_flatten);
} }
template <typename T, typename Context>
void PowDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
const Scalar& factor,
DenseTensor* dx,
DenseTensor* ddout) {
PADDLE_ENFORCE_NOT_NULL(
dx, errors::NotFound("The output DenseTensor dx can not be nullptr"));
PADDLE_ENFORCE_NOT_NULL(
ddout,
errors::NotFound("The output DenseTensor ddout can not be nullptr"));
float exponent = factor.to<float>();
if (exponent == 1) {
*dx = phi::FullLike<T, Context>(dev_ctx, x, static_cast<T>(0));
} else {
DenseTensor dx_tmp1 = phi::Multiply<T, Context>(dev_ctx, dout, ddx);
DenseTensor dx_tmp2 = phi::Multiply<T, Context>(
dev_ctx, dx_tmp1, phi::Pow<T, Context>(dev_ctx, x, exponent - 2));
*dx = phi::Scale<T, Context>(
dev_ctx, dx_tmp2, exponent * (exponent - 1), 0.0, true);
}
DenseTensor ddout_tmp = phi::Multiply<T, Context>(
dev_ctx, ddx, phi::Pow<T, Context>(dev_ctx, x, exponent - 1));
*ddout = phi::Scale<T, Context>(dev_ctx, ddout_tmp, exponent, 0.0, true);
}
template <typename T, typename Context> template <typename T, typename Context>
void SqrtDoubleGradKernel(const Context& dev_ctx, void SqrtDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out, const DenseTensor& out,
......
...@@ -71,6 +71,18 @@ KernelSignature PowGradOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -71,6 +71,18 @@ KernelSignature PowGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
} }
} }
KernelSignature PowDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.HasInput("FactorTensor")) {
return KernelSignature("pow_double_grad",
{"X", "DOut", "DDX"},
{"FactorTensor"},
{"DX", "DDOut"});
} else {
return KernelSignature(
"pow_double_grad", {"X", "DOut", "DDX"}, {"factor"}, {"DX", "DDOut"});
}
}
} // namespace phi } // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(brelu, hard_tanh); PD_REGISTER_BASE_KERNEL_NAME(brelu, hard_tanh);
...@@ -86,4 +98,6 @@ PD_REGISTER_ARG_MAPPING_FN(hard_swish_grad, ...@@ -86,4 +98,6 @@ PD_REGISTER_ARG_MAPPING_FN(hard_swish_grad,
phi::HardSwishGradOpArgumentMapping); phi::HardSwishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(swish_grad, phi::SwishGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(swish_grad, phi::SwishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pow_grad, phi::PowGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(pow_grad, phi::PowGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pow_double_grad,
phi::PowDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pow, phi::PowOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(pow, phi::PowOpArgumentMapping);
...@@ -503,5 +503,67 @@ class TestSinDoubleGradCheck(unittest.TestCase): ...@@ -503,5 +503,67 @@ class TestSinDoubleGradCheck(unittest.TestCase):
self.func(p) self.func(p)
class TestPowDoubleGradCheck1(unittest.TestCase):
def pow_wrapper(self, x):
return paddle.pow(x[0], 2)
@prog_scope()
def func(self, place):
shape = [2, 3, 7, 9]
eps = 1e-6
dtype = np.float64
x = layers.data('x', shape, False, dtype=dtype)
x.persistable = True
y = paddle.pow(x, 2)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
gradient_checker.double_grad_check_for_dygraph(
self.pow_wrapper, [x], y, x_init=x_arr, place=place
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestPowDoubleGradCheck2(unittest.TestCase):
def pow_wrapper(self, x):
return paddle.pow(x[0], 1)
@prog_scope()
def func(self, place):
shape = [2, 3, 7, 9]
eps = 1e-6
dtype = np.float64
x = layers.data('x', shape, False, dtype=dtype)
x.persistable = True
y = paddle.pow(x, 1)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
gradient_checker.double_grad_check_for_dygraph(
self.pow_wrapper, [x], y, x_init=x_arr, place=place
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册