未验证 提交 7964119b 编写于 作者: C Charles-hit 提交者: GitHub

support pow_triple_grad op (#47799)

上级 658387b0
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h" #include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
...@@ -457,6 +458,26 @@ class PowDoubleGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -457,6 +458,26 @@ class PowDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
} }
}; };
template <typename T>
class PowTripleGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("pow_triple_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("DOut", this->Input("DOut"));
op->SetInput("DDX", this->Input("DDX"));
op->SetInput("D_DX", this->OutputGrad("DX"));
op->SetInput("D_DDOut", this->OutputGrad("DDOut"));
op->SetOutput("D_X", this->InputGrad("X"));
op->SetOutput("D_DOut", this->InputGrad("DOut"));
op->SetOutput("D_DDX", this->InputGrad("DDX"));
op->SetInput("FactorTensor", this->Input("FactorTensor"));
op->SetAttrMap(this->Attrs());
}
};
class PowOp : public framework::OperatorWithKernel { class PowOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -523,6 +544,16 @@ class PowOpDoubleGrad : public framework::OperatorWithKernel { ...@@ -523,6 +544,16 @@ class PowOpDoubleGrad : public framework::OperatorWithKernel {
} }
}; };
class PowOpTripleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X");
}
};
DECLARE_INPLACE_OP_INFERER(ActFwdInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(ActFwdInplaceInferer, {"X", "Out"});
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -575,6 +606,9 @@ REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor); ...@@ -575,6 +606,9 @@ REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor);
DECLARE_INFER_SHAPE_FUNCTOR(pow_double_grad, DECLARE_INFER_SHAPE_FUNCTOR(pow_double_grad,
PowDoubleGradInferShapeFunctor, PowDoubleGradInferShapeFunctor,
PD_INFER_META(phi::GeneralBinaryGradInferMeta)); PD_INFER_META(phi::GeneralBinaryGradInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(pow_triple_grad,
PowTripleGradInferShapeFunctor,
PD_INFER_META(phi::GeneralTernaryGradInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
pow, pow,
...@@ -594,7 +628,12 @@ REGISTER_OPERATOR(pow_grad, ...@@ -594,7 +628,12 @@ REGISTER_OPERATOR(pow_grad,
REGISTER_OPERATOR(pow_double_grad, REGISTER_OPERATOR(pow_double_grad,
ops::PowOpDoubleGrad, ops::PowOpDoubleGrad,
ops::ActivationDoubleGradOpInplaceInferer, ops::ActivationDoubleGradOpInplaceInferer,
ops::PowTripleGradOpMaker<paddle::framework::OpDesc>,
ops::PowTripleGradOpMaker<paddle::imperative::OpBase>,
PowDoubleGradInferShapeFunctor); PowDoubleGradInferShapeFunctor);
REGISTER_OPERATOR(pow_triple_grad,
ops::PowOpTripleGrad,
PowTripleGradInferShapeFunctor);
/* ========================================================================== */ /* ========================================================================== */
/* ========================== register checkpoint ===========================*/ /* ========================== register checkpoint ===========================*/
......
...@@ -27,7 +27,6 @@ limitations under the License. */ ...@@ -27,7 +27,6 @@ limitations under the License. */
#include <type_traits> #include <type_traits>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
......
...@@ -1330,9 +1330,10 @@ ...@@ -1330,9 +1330,10 @@
output : Tensor(x_grad), Tensor(grad_out_grad) output : Tensor(x_grad), Tensor(grad_out_grad)
infer_meta : infer_meta :
func : GeneralBinaryGradInferMeta func : GeneralBinaryGradInferMeta
param: [x, x] param: [x, grad_out]
kernel : kernel :
func : pow_double_grad func : pow_double_grad
backward : pow_triple_grad
inplace : (grad_x_grad -> x_grad) inplace : (grad_x_grad -> x_grad)
- backward_op : pow_grad - backward_op : pow_grad
...@@ -1347,6 +1348,16 @@ ...@@ -1347,6 +1348,16 @@
backward: pow_double_grad backward: pow_double_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : pow_triple_grad
forward : pow_double_grad(Tensor x, Tensor grad_out, Tensor grad_grad_x, Scalar y) -> Tensor(grad_x), Tensor(grad_grad_out)
args : (Tensor x, Tensor grad_out, Tensor grad_grad_x, Tensor grad_x_grad, Tensor grad_grad_out_grad, Scalar y)
output : Tensor(x_grad), Tensor(grad_out_grad), Tensor(grad_grad_x_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param: [x, grad_out, grad_grad_x]
kernel :
func : pow_triple_grad
- backward_op : prelu_grad - backward_op : prelu_grad
forward : prelu(Tensor x, Tensor alpha, str data_format, str mode) -> Tensor(out) forward : prelu(Tensor x, Tensor alpha, str data_format, str mode) -> Tensor(out)
args : (Tensor x, Tensor alpha, Tensor out_grad, str data_format, str mode) args : (Tensor x, Tensor alpha, Tensor out_grad, str data_format, str mode)
......
...@@ -226,6 +226,18 @@ void PowDoubleGradKernel(const Context& dev_ctx, ...@@ -226,6 +226,18 @@ void PowDoubleGradKernel(const Context& dev_ctx,
const Scalar& factor, const Scalar& factor,
DenseTensor* dx, DenseTensor* dx,
DenseTensor* ddout); DenseTensor* ddout);
template <typename T, typename Context>
void PowTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
const DenseTensor& d_dx,
const DenseTensor& d_ddout,
const Scalar& factor,
DenseTensor* out_d_x,
DenseTensor* out_d_dout,
DenseTensor* out_d_ddx);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Cos); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Cos);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Tan); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Tan);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Acos); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Acos);
......
...@@ -390,3 +390,11 @@ PD_REGISTER_KERNEL(pow_double_grad, ...@@ -390,3 +390,11 @@ PD_REGISTER_KERNEL(pow_double_grad,
double, double,
int, int,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(pow_triple_grad,
CPU,
ALL_LAYOUT,
phi::PowTripleGradKernel,
float,
double,
int,
int64_t) {}
...@@ -472,7 +472,6 @@ PD_REGISTER_KERNEL(pow_grad, ...@@ -472,7 +472,6 @@ PD_REGISTER_KERNEL(pow_grad,
int64_t, int64_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(pow_double_grad, PD_REGISTER_KERNEL(pow_double_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -483,3 +482,13 @@ PD_REGISTER_KERNEL(pow_double_grad, ...@@ -483,3 +482,13 @@ PD_REGISTER_KERNEL(pow_double_grad,
int64_t, int64_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(pow_triple_grad,
GPU,
ALL_LAYOUT,
phi::PowTripleGradKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/activation_kernel.h" #include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/elementwise_multiply_kernel.h" #include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/activation_functor.h" #include "paddle/phi/kernels/funcs/activation_functor.h"
...@@ -347,10 +348,10 @@ void PowDoubleGradKernel(const Context& dev_ctx, ...@@ -347,10 +348,10 @@ void PowDoubleGradKernel(const Context& dev_ctx,
DenseTensor* dx, DenseTensor* dx,
DenseTensor* ddout) { DenseTensor* ddout) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
dx, errors::NotFound("The output DenseTensor dx can not be nullptr")); dx, errors::NotFound("The output DenseTensor DX can not be nullptr"));
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
ddout, ddout,
errors::NotFound("The output DenseTensor ddout can not be nullptr")); errors::NotFound("The output DenseTensor DDOut can not be nullptr"));
float exponent = factor.to<float>(); float exponent = factor.to<float>();
if (exponent == 1) { if (exponent == 1) {
*dx = phi::FullLike<T, Context>(dev_ctx, x, static_cast<T>(0)); *dx = phi::FullLike<T, Context>(dev_ctx, x, static_cast<T>(0));
...@@ -366,6 +367,150 @@ void PowDoubleGradKernel(const Context& dev_ctx, ...@@ -366,6 +367,150 @@ void PowDoubleGradKernel(const Context& dev_ctx,
*ddout = phi::Scale<T, Context>(dev_ctx, ddout_tmp, exponent, 0.0, true); *ddout = phi::Scale<T, Context>(dev_ctx, ddout_tmp, exponent, 0.0, true);
} }
template <typename T, typename Context>
void PowTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
const DenseTensor& d_dx,
const DenseTensor& d_ddout,
const Scalar& factor,
DenseTensor* out_d_x,
DenseTensor* out_d_dout,
DenseTensor* out_d_ddx) {
PADDLE_ENFORCE_NOT_NULL(
out_d_x,
errors::NotFound("The output DenseTensor D_X can not be nullptr"));
PADDLE_ENFORCE_NOT_NULL(
out_d_dout,
errors::NotFound("The output DenseTensor D_DOut can not be nullptr"));
PADDLE_ENFORCE_NOT_NULL(
out_d_ddx,
errors::NotFound("The output DenseTensor D_DDX can not be nullptr"));
float exponent = factor.to<float>();
if (exponent != 2 && exponent != 1) {
// case1: b != 2 and b != 1
// D_X = D_DX * DDX * DOut * b * (b-1) * (b-2) * X^(b-3)
// + D_DDOut * DDX * b * (b-1) * X^(b-2)
DenseTensor out_d_x_tmp1 = phi::Multiply<T, Context>(dev_ctx, d_dx, ddx);
DenseTensor out_d_x_tmp2 =
phi::Scale<T, Context>(dev_ctx,
phi::Pow<T, Context>(dev_ctx, x, exponent - 3),
exponent * (exponent - 1) * (exponent - 2),
0.0,
true);
DenseTensor out_d_x_part1 = phi::Multiply<T, Context>(
dev_ctx,
phi::Multiply<T, Context>(dev_ctx, out_d_x_tmp1, dout),
out_d_x_tmp2);
DenseTensor out_d_x_tmp3 = phi::Multiply<T, Context>(dev_ctx, d_ddout, ddx);
DenseTensor out_d_x_tmp4 =
phi::Scale<T, Context>(dev_ctx,
phi::Pow<T, Context>(dev_ctx, x, exponent - 2),
exponent * (exponent - 1),
0.0,
true);
DenseTensor out_d_x_part2 =
phi::Multiply<T, Context>(dev_ctx, out_d_x_tmp3, out_d_x_tmp4);
*out_d_x = phi::Add<T, Context>(dev_ctx, out_d_x_part1, out_d_x_part2);
// D_DOut = D_DX * DDX * b * (b-1) * X^(b-2)
DenseTensor out_d_dout_tmp =
phi::Scale<T, Context>(dev_ctx,
phi::Pow<T, Context>(dev_ctx, x, exponent - 2),
exponent * (exponent - 1),
0.0,
true);
*out_d_dout =
phi::Multiply<T, Context>(dev_ctx, out_d_x_tmp1, out_d_dout_tmp);
// D_DDX = D_DX * DOut * b * (b-1) * X^(b-2) + D_DDOut * b * X^(b-1)
DenseTensor out_d_ddx_tmp1 = phi::Multiply<T, Context>(dev_ctx, d_dx, dout);
DenseTensor out_d_ddx_part1 =
phi::Multiply<T, Context>(dev_ctx, out_d_ddx_tmp1, out_d_dout_tmp);
DenseTensor out_d_ddx_tmp2 =
phi::Scale<T, Context>(dev_ctx,
phi::Pow<T, Context>(dev_ctx, x, exponent - 1),
exponent,
0.0,
true);
DenseTensor out_d_ddx_part2 =
phi::Multiply<T, Context>(dev_ctx, d_ddout, out_d_ddx_tmp2);
*out_d_ddx =
phi::Add<T, Context>(dev_ctx, out_d_ddx_part1, out_d_ddx_part2);
} else if (exponent == 2) {
// case2: b = 2
// D_X = D_DDOut * DDX * b * (b-1) * X^(b-2)
DenseTensor out_d_x_tmp1 = phi::Multiply<T, Context>(dev_ctx, d_ddout, ddx);
DenseTensor out_d_x_tmp2 =
phi::Scale<T, Context>(dev_ctx,
phi::Pow<T, Context>(dev_ctx, x, exponent - 2),
exponent * (exponent - 1),
0.0,
true);
*out_d_x = phi::Multiply<T, Context>(dev_ctx, out_d_x_tmp1, out_d_x_tmp2);
// D_DOut = D_DX * DDX * b * (b-1) * X^(b-2)
DenseTensor out_d_dout_tmp1 = phi::Multiply<T, Context>(dev_ctx, d_dx, ddx);
DenseTensor out_d_dout_tmp2 =
phi::Scale<T, Context>(dev_ctx,
phi::Pow<T, Context>(dev_ctx, x, exponent - 2),
exponent * (exponent - 1),
0.0,
true);
*out_d_dout =
phi::Multiply<T, Context>(dev_ctx, out_d_dout_tmp1, out_d_dout_tmp2);
// D_DDX = D_DX * DOut * b * (b-1) * X^(b-2) + D_DDOut * b * X^(b-1)
DenseTensor out_d_ddx_tmp1 = phi::Multiply<T, Context>(dev_ctx, d_dx, dout);
DenseTensor out_d_ddx_part1 =
phi::Multiply<T, Context>(dev_ctx, out_d_ddx_tmp1, out_d_dout_tmp2);
DenseTensor out_d_ddx_tmp2 =
phi::Scale<T, Context>(dev_ctx,
phi::Pow<T, Context>(dev_ctx, x, exponent - 1),
exponent,
0.0,
true);
DenseTensor out_d_ddx_part2 =
phi::Multiply<T, Context>(dev_ctx, d_ddout, out_d_ddx_tmp2);
*out_d_ddx =
phi::Add<T, Context>(dev_ctx, out_d_ddx_part1, out_d_ddx_part2);
} else {
// case3: b = 1
// D_X = D_DX * DDX * DOut * b * (b-1) * (b-2) * X^(b-3)
DenseTensor out_d_x_tmp1 = phi::Multiply<T, Context>(dev_ctx, d_dx, ddx);
DenseTensor out_d_x_tmp2 =
phi::Scale<T, Context>(dev_ctx,
phi::Pow<T, Context>(dev_ctx, x, exponent - 3),
exponent * (exponent - 1) * (exponent - 2),
0.0,
true);
*out_d_x = phi::Multiply<T, Context>(
dev_ctx,
phi::Multiply<T, Context>(dev_ctx, out_d_x_tmp1, dout),
out_d_x_tmp2);
// D_DOut = 0
*out_d_dout = phi::FullLike<T, Context>(dev_ctx, dout, static_cast<T>(0));
// D_DDX = D_DDOut * b * X^(b-1)
DenseTensor out_d_ddx_tmp =
phi::Scale<T, Context>(dev_ctx,
phi::Pow<T, Context>(dev_ctx, x, exponent - 1),
exponent,
0.0,
true);
*out_d_ddx = phi::Multiply<T, Context>(dev_ctx, d_ddout, out_d_ddx_tmp);
}
}
template <typename T, typename Context> template <typename T, typename Context>
void SqrtDoubleGradKernel(const Context& dev_ctx, void SqrtDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out, const DenseTensor& out,
......
...@@ -83,6 +83,21 @@ KernelSignature PowDoubleGradOpArgumentMapping( ...@@ -83,6 +83,21 @@ KernelSignature PowDoubleGradOpArgumentMapping(
"pow_double_grad", {"X", "DOut", "DDX"}, {"factor"}, {"DX", "DDOut"}); "pow_double_grad", {"X", "DOut", "DDX"}, {"factor"}, {"DX", "DDOut"});
} }
} }
KernelSignature PowTripleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.HasInput("FactorTensor")) {
return KernelSignature("pow_triple_grad",
{"X", "DOut", "DDX", "D_DX", "D_DDOut"},
{"FactorTensor"},
{"D_X", "D_DOut", "D_DDX"});
} else {
return KernelSignature("pow_triple_grad",
{"X", "DOut", "DDX", "D_DX", "D_DDOut"},
{"factor"},
{"D_X", "D_DOut", "D_DDX"});
}
}
} // namespace phi } // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(brelu, hard_tanh); PD_REGISTER_BASE_KERNEL_NAME(brelu, hard_tanh);
...@@ -100,4 +115,6 @@ PD_REGISTER_ARG_MAPPING_FN(swish_grad, phi::SwishGradOpArgumentMapping); ...@@ -100,4 +115,6 @@ PD_REGISTER_ARG_MAPPING_FN(swish_grad, phi::SwishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pow_grad, phi::PowGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(pow_grad, phi::PowGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pow_double_grad, PD_REGISTER_ARG_MAPPING_FN(pow_double_grad,
phi::PowDoubleGradOpArgumentMapping); phi::PowDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pow_triple_grad,
phi::PowTripleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pow, phi::PowOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(pow, phi::PowOpArgumentMapping);
...@@ -597,5 +597,98 @@ class TestSinTripleGradCheck(unittest.TestCase): ...@@ -597,5 +597,98 @@ class TestSinTripleGradCheck(unittest.TestCase):
self.func(p) self.func(p)
class TestPowTripleGradCheck1(unittest.TestCase):
def pow_wrapper(self, x):
return paddle.pow(x[0], 1)
@prog_scope()
def func(self, place):
shape = [2, 3, 7, 9]
eps = 1e-6
dtype = np.float64
x = layers.data('x', shape, False, dtype=dtype)
x.persistable = True
y = paddle.pow(x, 1)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.triple_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
gradient_checker.triple_grad_check_for_dygraph(
self.pow_wrapper, [x], y, x_init=x_arr, place=place
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestPowTripleGradCheck2(unittest.TestCase):
def pow_wrapper(self, x):
return paddle.pow(x[0], 2)
@prog_scope()
def func(self, place):
shape = [2, 3, 7, 9]
eps = 1e-6
dtype = np.float64
x = layers.data('x', shape, False, dtype=dtype)
x.persistable = True
y = paddle.pow(x, 2)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.triple_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
gradient_checker.triple_grad_check_for_dygraph(
self.pow_wrapper, [x], y, x_init=x_arr, place=place
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestPowTripleGradCheck3(unittest.TestCase):
def pow_wrapper(self, x):
return paddle.pow(x[0], 4)
@prog_scope()
def func(self, place):
shape = [2, 3, 7, 9]
eps = 1e-6
dtype = np.float64
x = layers.data('x', shape, False, dtype=dtype)
x.persistable = True
y = paddle.pow(x, 4)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.triple_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
gradient_checker.triple_grad_check_for_dygraph(
self.pow_wrapper, [x], y, x_init=x_arr, place=place
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册