未验证 提交 a3cd9cb9 编写于 作者: Z zhangyuqin1998 提交者: GitHub

delete axis from elementwise_grad (#53202)

* remove axis from elementwise_grad

* Update elementwise_sig.cc
上级 e123b98e
...@@ -86,15 +86,8 @@ class ElementwiseMaxCompositeGradOpMaker ...@@ -86,15 +86,8 @@ class ElementwiseMaxCompositeGradOpMaker
paddle::Tensor dy = this->GetSingleInputGrad("Y"); paddle::Tensor dy = this->GetSingleInputGrad("Y");
auto* dy_ptr = this->GetOutputPtr(&dy); auto* dy_ptr = this->GetOutputPtr(&dy);
std::string dy_name = this->GetOutputName(dy); std::string dy_name = this->GetOutputName(dy);
int axis = static_cast<int>(this->Attr<int>("axis"));
PADDLE_ENFORCE_EQ(
axis,
-1,
phi::errors::InvalidArgument(
"We only support axis = -1 in composite maximum_grad but we got: ",
axis));
VLOG(6) << "Runing maximum_grad composite func"; VLOG(6) << "Runing maximum_grad composite func";
prim::maximum_grad<prim::DescTensor>(x, y, out_grad, axis, dx_ptr, dy_ptr); prim::maximum_grad<prim::DescTensor>(x, y, out_grad, dx_ptr, dy_ptr);
this->RecoverOutputName(dx, dx_name); this->RecoverOutputName(dx, dx_name);
this->RecoverOutputName(dy, dy_name); this->RecoverOutputName(dy, dy_name);
} }
......
...@@ -60,15 +60,8 @@ class ElementwisePowCompositeGradOpMaker ...@@ -60,15 +60,8 @@ class ElementwisePowCompositeGradOpMaker
paddle::Tensor dy = this->GetSingleInputGrad("Y"); paddle::Tensor dy = this->GetSingleInputGrad("Y");
auto dy_ptr = this->GetOutputPtr(&dy); auto dy_ptr = this->GetOutputPtr(&dy);
std::string dy_name = this->GetOutputName(dy); std::string dy_name = this->GetOutputName(dy);
int axis = static_cast<int>(this->Attr<int>("axis"));
PADDLE_ENFORCE_EQ(
axis,
-1,
phi::errors::InvalidArgument(
"We only support axis = -1 in composite pow but we got: ", axis));
VLOG(6) << "Runing pow_grad composite func";
prim::elementwise_pow_grad<prim::DescTensor>( prim::elementwise_pow_grad<prim::DescTensor>(
x, y, out_grad, axis, dx_ptr, dy_ptr); x, y, out_grad, dx_ptr, dy_ptr);
this->RecoverOutputName(dx, dx_name); this->RecoverOutputName(dx, dx_name);
this->RecoverOutputName(dy, dy_name); this->RecoverOutputName(dy, dy_name);
} }
......
...@@ -391,7 +391,6 @@ template <typename T> ...@@ -391,7 +391,6 @@ template <typename T>
void elementwise_pow_grad(const Tensor& x, void elementwise_pow_grad(const Tensor& x,
const Tensor& y, const Tensor& y,
const Tensor& out_grad, const Tensor& out_grad,
int axis,
Tensor* dx, Tensor* dx,
Tensor* dy) { Tensor* dy) {
if (dy) { if (dy) {
...@@ -1380,7 +1379,6 @@ template <typename T> ...@@ -1380,7 +1379,6 @@ template <typename T>
void maximum_grad(const Tensor& x, void maximum_grad(const Tensor& x,
const Tensor& y, const Tensor& y,
const Tensor& out_grad, const Tensor& out_grad,
int axis,
Tensor* x_grad, Tensor* x_grad,
Tensor* y_grad) { Tensor* y_grad) {
if (x_grad) { if (x_grad) {
......
...@@ -332,7 +332,7 @@ ...@@ -332,7 +332,7 @@
- backward_op : elementwise_pow_grad - backward_op : elementwise_pow_grad
forward : elementwise_pow(Tensor x, Tensor y) -> Tensor(out) forward : elementwise_pow(Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, int axis=-1) args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad) output : Tensor(x_grad), Tensor(y_grad)
infer_meta : infer_meta :
func : GeneralBinaryGradInferMeta func : GeneralBinaryGradInferMeta
...@@ -577,7 +577,7 @@ ...@@ -577,7 +577,7 @@
- backward_op : maximum_grad - backward_op : maximum_grad
forward : maximum(Tensor x, Tensor y) -> Tensor(out) forward : maximum(Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, int axis=-1) args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad) output : Tensor(x_grad), Tensor(y_grad)
infer_meta : infer_meta :
func : GeneralBinaryGradInferMeta func : GeneralBinaryGradInferMeta
...@@ -616,7 +616,7 @@ ...@@ -616,7 +616,7 @@
- backward_op : minimum_grad - backward_op : minimum_grad
forward : minimum(Tensor x, Tensor y) -> Tensor(out) forward : minimum(Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, int axis=-1) args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad) output : Tensor(x_grad), Tensor(y_grad)
infer_meta : infer_meta :
func : GeneralBinaryGradInferMeta func : GeneralBinaryGradInferMeta
......
...@@ -28,10 +28,10 @@ void MaximumGradKernel(const Context& dev_ctx, ...@@ -28,10 +28,10 @@ void MaximumGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
const DenseTensor& dout, const DenseTensor& dout,
int axis,
DenseTensor* dx, DenseTensor* dx,
DenseTensor* dy) { DenseTensor* dy) {
funcs::ElementwiseGradPreProcess(dout, dx); funcs::ElementwiseGradPreProcess(dout, dx);
int axis = -1;
phi::funcs::ElemwiseGradCompute<Context, T, MaxGradDx<T>, MaxGradDy<T>>( phi::funcs::ElemwiseGradCompute<Context, T, MaxGradDx<T>, MaxGradDy<T>>(
dev_ctx, x, y, dout, dout, axis, dx, dy, MaxGradDx<T>(), MaxGradDy<T>()); dev_ctx, x, y, dout, dout, axis, dx, dy, MaxGradDx<T>(), MaxGradDy<T>());
} }
...@@ -41,10 +41,10 @@ void MinimumGradKernel(const Context& dev_ctx, ...@@ -41,10 +41,10 @@ void MinimumGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
const DenseTensor& dout, const DenseTensor& dout,
int axis,
DenseTensor* dx, DenseTensor* dx,
DenseTensor* dy) { DenseTensor* dy) {
funcs::ElementwiseGradPreProcess(dout, dx); funcs::ElementwiseGradPreProcess(dout, dx);
int axis = -1;
phi::funcs::ElemwiseGradCompute<Context, T, MinGradDx<T>, MinGradDy<T>>( phi::funcs::ElemwiseGradCompute<Context, T, MinGradDx<T>, MinGradDy<T>>(
dev_ctx, x, y, dout, dout, axis, dx, dy, MinGradDx<T>(), MinGradDy<T>()); dev_ctx, x, y, dout, dout, axis, dx, dy, MinGradDx<T>(), MinGradDy<T>());
} }
......
...@@ -40,7 +40,6 @@ void MaximumGradKernel(const Context& dev_ctx, ...@@ -40,7 +40,6 @@ void MaximumGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
const DenseTensor& dout, const DenseTensor& dout,
int axis,
DenseTensor* dx, DenseTensor* dx,
DenseTensor* dy); DenseTensor* dy);
...@@ -49,7 +48,6 @@ void MinimumGradKernel(const Context& dev_ctx, ...@@ -49,7 +48,6 @@ void MinimumGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
const DenseTensor& dout, const DenseTensor& dout,
int axis,
DenseTensor* dx, DenseTensor* dx,
DenseTensor* dy); DenseTensor* dy);
...@@ -66,7 +64,6 @@ void ElementwisePowGradKernel(const Context& dev_ctx, ...@@ -66,7 +64,6 @@ void ElementwisePowGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
const DenseTensor& dout, const DenseTensor& dout,
int axis,
DenseTensor* dx, DenseTensor* dx,
DenseTensor* dy); DenseTensor* dy);
} // namespace phi } // namespace phi
...@@ -31,11 +31,10 @@ void MaximumGradKernel(const Context& dev_ctx, ...@@ -31,11 +31,10 @@ void MaximumGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
const DenseTensor& dout, const DenseTensor& dout,
int axis,
DenseTensor* dx, DenseTensor* dx,
DenseTensor* dy) { DenseTensor* dy) {
const auto place = dev_ctx.GetPlace(); const auto place = dev_ctx.GetPlace();
int axis = -1;
if (dx != nullptr && dy != nullptr) { if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout}; std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXAndYOut<ElementwiseType::kTernary, T>( GetGradXAndYOut<ElementwiseType::kTernary, T>(
...@@ -63,10 +62,10 @@ void MinimumGradKernel(const Context& dev_ctx, ...@@ -63,10 +62,10 @@ void MinimumGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
const DenseTensor& dout, const DenseTensor& dout,
int axis,
DenseTensor* dx, DenseTensor* dx,
DenseTensor* dy) { DenseTensor* dy) {
const auto place = dev_ctx.GetPlace(); const auto place = dev_ctx.GetPlace();
int axis = -1;
if (dx != nullptr && dy != nullptr) { if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout}; std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXAndYOut<ElementwiseType::kTernary, T>( GetGradXAndYOut<ElementwiseType::kTernary, T>(
......
...@@ -958,10 +958,10 @@ void ElementwisePowGradKernel(const Context& dev_ctx, ...@@ -958,10 +958,10 @@ void ElementwisePowGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
const DenseTensor& dout, const DenseTensor& dout,
int axis,
DenseTensor* dx, DenseTensor* dx,
DenseTensor* dy) { DenseTensor* dy) {
funcs::ElementwiseGradPreProcess(dout, dx); funcs::ElementwiseGradPreProcess(dout, dx);
int axis = -1;
phi::funcs::ElemwiseGradCompute<Context, T, PowGradDX<T>, PowGradDY<T>>( phi::funcs::ElemwiseGradCompute<Context, T, PowGradDX<T>, PowGradDY<T>>(
dev_ctx, x, y, dout, dout, axis, dx, dy, PowGradDX<T>(), PowGradDY<T>()); dev_ctx, x, y, dout, dout, axis, dx, dy, PowGradDX<T>(), PowGradDY<T>());
} }
......
...@@ -25,11 +25,10 @@ void MaximumGradKernel(const Context& dev_ctx, ...@@ -25,11 +25,10 @@ void MaximumGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
const DenseTensor& dout, const DenseTensor& dout,
int axis,
DenseTensor* dx, DenseTensor* dx,
DenseTensor* dy) { DenseTensor* dy) {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
int axis = -1;
auto f = [](xpu::Context* ctx, auto f = [](xpu::Context* ctx,
const XPUType* x, const XPUType* x,
const XPUType* y, const XPUType* y,
...@@ -51,11 +50,10 @@ void MinimumGradKernel(const Context& dev_ctx, ...@@ -51,11 +50,10 @@ void MinimumGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
const DenseTensor& dout, const DenseTensor& dout,
int axis,
DenseTensor* dx, DenseTensor* dx,
DenseTensor* dy) { DenseTensor* dy) {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
int axis = -1;
auto f = [](xpu::Context* ctx, auto f = [](xpu::Context* ctx,
const XPUType* x, const XPUType* x,
const XPUType* y, const XPUType* y,
......
...@@ -210,13 +210,13 @@ KernelSignature ElementwiseMulTripleGradOpArgumentMapping( ...@@ -210,13 +210,13 @@ KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
KernelSignature ElementwiseMaxGradOpArgumentMapping( KernelSignature ElementwiseMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"maximum_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"}); "maximum_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
} }
KernelSignature ElementwiseMinGradOpArgumentMapping( KernelSignature ElementwiseMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"minimum_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"}); "minimum_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
} }
KernelSignature ElementwiseHeavisideGradOpArgumentMapping( KernelSignature ElementwiseHeavisideGradOpArgumentMapping(
...@@ -227,10 +227,8 @@ KernelSignature ElementwiseHeavisideGradOpArgumentMapping( ...@@ -227,10 +227,8 @@ KernelSignature ElementwiseHeavisideGradOpArgumentMapping(
KernelSignature ElementwisePowGradOpArgumentMapping( KernelSignature ElementwisePowGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("elementwise_pow_grad", return KernelSignature(
{"X", "Y", "Out@GRAD"}, "elementwise_pow_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
{"axis"},
{"X@GRAD", "Y@GRAD"});
} }
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册