未验证 提交 a3cd9cb9 编写于 作者: Z zhangyuqin1998 提交者: GitHub

delete axis from elementwise_grad (#53202)

* remove axis from elementwise_grad

* Update elementwise_sig.cc
上级 e123b98e
......@@ -86,15 +86,8 @@ class ElementwiseMaxCompositeGradOpMaker
paddle::Tensor dy = this->GetSingleInputGrad("Y");
auto* dy_ptr = this->GetOutputPtr(&dy);
std::string dy_name = this->GetOutputName(dy);
int axis = static_cast<int>(this->Attr<int>("axis"));
PADDLE_ENFORCE_EQ(
axis,
-1,
phi::errors::InvalidArgument(
"We only support axis = -1 in composite maximum_grad but we got: ",
axis));
VLOG(6) << "Runing maximum_grad composite func";
prim::maximum_grad<prim::DescTensor>(x, y, out_grad, axis, dx_ptr, dy_ptr);
prim::maximum_grad<prim::DescTensor>(x, y, out_grad, dx_ptr, dy_ptr);
this->RecoverOutputName(dx, dx_name);
this->RecoverOutputName(dy, dy_name);
}
......
......@@ -60,15 +60,8 @@ class ElementwisePowCompositeGradOpMaker
paddle::Tensor dy = this->GetSingleInputGrad("Y");
auto dy_ptr = this->GetOutputPtr(&dy);
std::string dy_name = this->GetOutputName(dy);
int axis = static_cast<int>(this->Attr<int>("axis"));
PADDLE_ENFORCE_EQ(
axis,
-1,
phi::errors::InvalidArgument(
"We only support axis = -1 in composite pow but we got: ", axis));
VLOG(6) << "Runing pow_grad composite func";
prim::elementwise_pow_grad<prim::DescTensor>(
x, y, out_grad, axis, dx_ptr, dy_ptr);
x, y, out_grad, dx_ptr, dy_ptr);
this->RecoverOutputName(dx, dx_name);
this->RecoverOutputName(dy, dy_name);
}
......
......@@ -391,7 +391,6 @@ template <typename T>
void elementwise_pow_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis,
Tensor* dx,
Tensor* dy) {
if (dy) {
......@@ -1380,7 +1379,6 @@ template <typename T>
void maximum_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis,
Tensor* x_grad,
Tensor* y_grad) {
if (x_grad) {
......
......@@ -332,7 +332,7 @@
- backward_op : elementwise_pow_grad
forward : elementwise_pow(Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, int axis=-1)
args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
......@@ -577,7 +577,7 @@
- backward_op : maximum_grad
forward : maximum(Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, int axis=-1)
args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
......@@ -616,7 +616,7 @@
- backward_op : minimum_grad
forward : minimum(Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, int axis=-1)
args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
......
......@@ -28,10 +28,10 @@ void MaximumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy) {
funcs::ElementwiseGradPreProcess(dout, dx);
int axis = -1;
phi::funcs::ElemwiseGradCompute<Context, T, MaxGradDx<T>, MaxGradDy<T>>(
dev_ctx, x, y, dout, dout, axis, dx, dy, MaxGradDx<T>(), MaxGradDy<T>());
}
......@@ -41,10 +41,10 @@ void MinimumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy) {
funcs::ElementwiseGradPreProcess(dout, dx);
int axis = -1;
phi::funcs::ElemwiseGradCompute<Context, T, MinGradDx<T>, MinGradDy<T>>(
dev_ctx, x, y, dout, dout, axis, dx, dy, MinGradDx<T>(), MinGradDy<T>());
}
......
......@@ -40,7 +40,6 @@ void MaximumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy);
......@@ -49,7 +48,6 @@ void MinimumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy);
......@@ -66,7 +64,6 @@ void ElementwisePowGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy);
} // namespace phi
......@@ -31,11 +31,10 @@ void MaximumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy) {
const auto place = dev_ctx.GetPlace();
int axis = -1;
if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
......@@ -63,10 +62,10 @@ void MinimumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy) {
const auto place = dev_ctx.GetPlace();
int axis = -1;
if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
......
......@@ -958,10 +958,10 @@ void ElementwisePowGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy) {
funcs::ElementwiseGradPreProcess(dout, dx);
int axis = -1;
phi::funcs::ElemwiseGradCompute<Context, T, PowGradDX<T>, PowGradDY<T>>(
dev_ctx, x, y, dout, dout, axis, dx, dy, PowGradDX<T>(), PowGradDY<T>());
}
......
......@@ -25,11 +25,10 @@ void MaximumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy) {
using XPUType = typename XPUTypeTrait<T>::Type;
int axis = -1;
auto f = [](xpu::Context* ctx,
const XPUType* x,
const XPUType* y,
......@@ -51,11 +50,10 @@ void MinimumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy) {
using XPUType = typename XPUTypeTrait<T>::Type;
int axis = -1;
auto f = [](xpu::Context* ctx,
const XPUType* x,
const XPUType* y,
......
......@@ -210,13 +210,13 @@ KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
KernelSignature ElementwiseMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"maximum_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
"maximum_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"minimum_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
"minimum_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseHeavisideGradOpArgumentMapping(
......@@ -227,10 +227,8 @@ KernelSignature ElementwiseHeavisideGradOpArgumentMapping(
KernelSignature ElementwisePowGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("elementwise_pow_grad",
{"X", "Y", "Out@GRAD"},
{"axis"},
{"X@GRAD", "Y@GRAD"});
return KernelSignature(
"elementwise_pow_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
}
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册