未验证 提交 57564bdf 编写于 作者: L lzydev 提交者: GitHub

Auto generate code for elementwise_max (#54412)

* auto generate code for elementwise_max

* auto generate code for elementwise_max

* fix composite ops

* fix bug of fmax
上级 f3d4c78f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
namespace paddle {
namespace framework {
class OpDesc;
} // namespace framework
namespace imperative {
class OpBase;
} // namespace imperative
} // namespace paddle
namespace paddle {
namespace operators {
class ElementwiseMaxOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "Max"; }
std::string GetEquation() const override { return "Out = max(X, Y)"; }
void AddInputX() override {
AddInput("X", "The first tensor holding the elements to be compared.");
}
void AddInputY() override {
AddInput("Y", "The second tensor holding the elements to be compared.");
}
std::string GetOpFunctionality() const override {
return "Compare two tensors and returns a new tensor containing the "
"element-wise maxima.";
}
};
class ElementwiseFMaxOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "FMax"; }
std::string GetEquation() const override { return "Out = fmax(X, Y)"; }
void AddInputX() override {
AddInput("X", "The first tensor holding the elements to be compared.");
}
void AddInputY() override {
AddInput("Y", "The second tensor holding the elements to be compared.");
}
std::string GetOpFunctionality() const override {
return "Compare two tensors and returns a new tensor containing the "
"element-wise maxima. If the element of one tensor is nan, "
"return the element value of the other tensor, if both are nan, "
"return the first nan";
}
};
class ElementwiseMaxCompositeGradOpMaker
: public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor y = this->GetSingleForwardInput("Y");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::Tensor dx = this->GetSingleInputGrad("X");
auto* dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx);
paddle::Tensor dy = this->GetSingleInputGrad("Y");
auto* dy_ptr = this->GetOutputPtr(&dy);
std::string dy_name = this->GetOutputName(dy);
VLOG(6) << "Runing maximum_grad composite func";
int axis = static_cast<int>(this->Attr<int>("axis"));
PADDLE_ENFORCE_EQ(
axis,
-1,
phi::errors::InvalidArgument(
"We only support axis = -1 in composite maximum_grad but we got: ",
axis));
prim::maximum_grad<prim::DescTensor>(x, y, out_grad, dx_ptr, dy_ptr);
this->RecoverOutputName(dx, dx_name);
this->RecoverOutputName(dy, dy_name);
}
};
template <typename T>
class ElementwiseMaxGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("elementwise_max_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
}
};
template <typename T>
class ElementwiseFMaxGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("elementwise_fmax_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_max,
ops::ElementwiseOp,
ops::ElementwiseMaxOpMaker,
ops::ElementwiseOpInferVarType,
ops::ElementwiseMaxGradOpMaker<paddle::framework::OpDesc>,
ops::ElementwiseMaxGradOpMaker<paddle::imperative::OpBase>,
ops::ElementwiseMaxCompositeGradOpMaker);
REGISTER_OPERATOR(elementwise_max_grad, ops::ElementwiseOpGrad);
REGISTER_OP_VERSION(elementwise_max)
.AddCheckpoint(
R"ROC(Register elementwise_max for adding the attribute of Scale_y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_y",
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_max.",
1.0f));
REGISTER_OPERATOR(elementwise_fmax,
ops::ElementwiseOp,
ops::ElementwiseFMaxOpMaker,
ops::ElementwiseOpInferVarType,
ops::ElementwiseFMaxGradOpMaker<paddle::framework::OpDesc>,
ops::ElementwiseFMaxGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_fmax_grad, ops::ElementwiseOpGrad);
...@@ -9,7 +9,6 @@ register_unity_group( ...@@ -9,7 +9,6 @@ register_unity_group(
elementwise_add_op.cc elementwise_add_op.cc
elementwise_div_op.cc elementwise_div_op.cc
elementwise_floordiv_op.cc elementwise_floordiv_op.cc
elementwise_max_op.cc
elementwise_min_op.cc elementwise_min_op.cc
elementwise_mod_op.cc elementwise_mod_op.cc
elementwise_mul_op.cc elementwise_mul_op.cc
......
...@@ -776,6 +776,17 @@ ...@@ -776,6 +776,17 @@
composite : floor_grad(out_grad, x_grad) composite : floor_grad(out_grad, x_grad)
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : fmax_grad
forward : fmax(Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param: [x, y]
kernel :
func : fmax_grad
data_type : out_grad
- backward_op : fold_grad - backward_op : fold_grad
forward: fold (Tensor x, int[] output_sizes, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations) -> Tensor(out) forward: fold (Tensor x, int[] output_sizes, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations) -> Tensor(out)
args: (Tensor x, Tensor out_grad, int[] output_sizes, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations) args: (Tensor x, Tensor out_grad, int[] output_sizes, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations)
......
...@@ -293,16 +293,6 @@ ...@@ -293,16 +293,6 @@
func : fill_grad func : fill_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : fmax_grad
forward : fmax(Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param: [x, y]
kernel :
func : fmax_grad
- backward_op : fmin_grad - backward_op : fmin_grad
forward : fmin(Tensor x, Tensor y) -> Tensor(out) forward : fmin(Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad) args : (Tensor x, Tensor y, Tensor out_grad)
......
...@@ -361,16 +361,6 @@ ...@@ -361,16 +361,6 @@
kernel : kernel :
func : floor_divide func : floor_divide
- op : fmax
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
param: [x, y]
func : ElementwiseInferMeta
kernel :
func : fmax
backward : fmax_grad
- op : fmin - op : fmin
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
output : Tensor(out) output : Tensor(out)
......
...@@ -1028,9 +1028,15 @@ ...@@ -1028,9 +1028,15 @@
- op : fmax (elementwise_fmax) - op : fmax (elementwise_fmax)
backward : fmax_grad (elementwise_fmax_grad) backward : fmax_grad (elementwise_fmax_grad)
inputs :
{x : X, y : Y}
outputs :
{out : Out}
extra : extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32",
bool use_quantizer = false, float Scale_x = 1.0f, float Scale_y = 1.0f, float Scale_out = 1.0f] bool use_quantizer = false, float Scale_x = 1.0f, float Scale_y = 1.0f, float Scale_out = 1.0f]
complex_promote : [X, Y]
manual_signature : [fmax]
- op : fmin (elementwise_fmin) - op : fmin (elementwise_fmin)
backward : fmin_grad (elementwise_fmin_grad) backward : fmin_grad (elementwise_fmin_grad)
...@@ -1628,9 +1634,15 @@ ...@@ -1628,9 +1634,15 @@
- op : maximum (elementwise_max) - op : maximum (elementwise_max)
backward : maximum_grad (elementwise_max_grad) backward : maximum_grad (elementwise_max_grad)
inputs :
{x : X, y : Y}
outputs :
{out : Out}
extra : extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32",
bool use_quantizer = false, float Scale_x = 1.0f, float Scale_y = 1.0f, float Scale_out = 1.0f] bool use_quantizer = false, float Scale_x = 1.0f, float Scale_y = 1.0f, float Scale_out = 1.0f]
complex_promote : [X, Y]
manual_signature : [maximum]
- op : maxout - op : maxout
inputs : inputs :
......
...@@ -157,6 +157,14 @@ ...@@ -157,6 +157,14 @@
comment : In order to add additional size to one side of each dimension in the output. comment : In order to add additional size to one side of each dimension in the output.
default : "std::vector<int>{}" default : "std::vector<int>{}"
- op : elementwise_max
version :
- checkpoint : Register elementwise_max for adding the attribute of Scale_y.
action :
- add_attr : Scale_y
comment : In order to support the function of scaling the input Y when using the operator of elementwise_max.
default : 1.0
- op : embedding - op : embedding
version : version :
- checkpoint : Upgrade flip, add new attr [axis] and delete attr [dims] - checkpoint : Upgrade flip, add new attr [axis] and delete attr [dims]
......
...@@ -862,6 +862,16 @@ ...@@ -862,6 +862,16 @@
inplace : (x -> out) inplace : (x -> out)
backward : floor_grad backward : floor_grad
- op : fmax
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
param: [x, y]
func : ElementwiseInferMeta
kernel :
func : fmax
backward : fmax_grad
- op : fold - op : fold
args: (Tensor x, int[] output_sizes, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations) args: (Tensor x, int[] output_sizes, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations)
output: Tensor(out) output: Tensor(out)
......
...@@ -123,6 +123,18 @@ ...@@ -123,6 +123,18 @@
func : max_grad func : max_grad
composite: max_grad(x, out, out_grad, axis, keepdim, reduce_all, x_grad) composite: max_grad(x, out, out_grad, axis, keepdim, reduce_all, x_grad)
- backward_op : maximum_grad
forward : maximum(Tensor x, Tensor y, int axis = -1) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param: [x, y]
kernel :
func : maximum_grad
data_type : out_grad
composite : maximum_grad(x, y, out_grad, x_grad, y_grad)
- backward_op : min_grad - backward_op : min_grad
forward: min (Tensor x, IntArray axis={0}, bool keepdim=false, bool reduce_all=false, int in_dtype=-1, int out_dtype=-1) -> Tensor(out) forward: min (Tensor x, IntArray axis={0}, bool keepdim=false, bool reduce_all=false, int in_dtype=-1, int out_dtype=-1) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, IntArray axis={}, bool keepdim=false, bool reduce_all=false) args : (Tensor x, Tensor out, Tensor out_grad, IntArray axis={}, bool keepdim=false, bool reduce_all=false)
......
...@@ -309,6 +309,15 @@ ...@@ -309,6 +309,15 @@
param : [x, axis, keepdim, reduce_all] param : [x, axis, keepdim, reduce_all]
backward : max_grad backward : max_grad
- op : maximum
args : (Tensor x, Tensor y, int axis = -1)
output : Tensor(out)
infer_meta :
func : ElementwiseRawInferMeta
kernel :
func : maximum
backward : maximum_grad
- op : min - op : min
args : (Tensor x, IntArray axis={0}, bool keepdim=false, bool reduce_all=false, int in_dtype=-1, int out_dtype=-1) args : (Tensor x, IntArray axis={0}, bool keepdim=false, bool reduce_all=false, int in_dtype=-1, int out_dtype=-1)
output : Tensor(out) output : Tensor(out)
......
...@@ -29,6 +29,10 @@ limitations under the License. */ ...@@ -29,6 +29,10 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/common_shape.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace phi { namespace phi {
namespace detail { namespace detail {
...@@ -1170,7 +1174,8 @@ void ElementwiseInferMeta(const MetaTensor& x, ...@@ -1170,7 +1174,8 @@ void ElementwiseInferMeta(const MetaTensor& x,
void ElementwiseRawInferMeta(const MetaTensor& x, void ElementwiseRawInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
int axis, int axis,
MetaTensor* out) { MetaTensor* out,
MetaConfig config) {
if (x.dims() != y.dims()) { if (x.dims() != y.dims()) {
auto x_dims = x.dims(); auto x_dims = x.dims();
auto y_dims = y.dims(); auto y_dims = y.dims();
...@@ -1199,6 +1204,25 @@ void ElementwiseRawInferMeta(const MetaTensor& x, ...@@ -1199,6 +1204,25 @@ void ElementwiseRawInferMeta(const MetaTensor& x,
std::vector<int> x_dims_array(max_dim); std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim); std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim); std::vector<int> out_dims_array(max_dim);
#ifdef PADDLE_WITH_MKLDNN
bool should_rotate =
config.is_run_mkldnn_kernel &&
(phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
phi::DataLayout::kNHWC) &&
(x_dims.size() >= 3 || y_dims.size() >= 3);
if (should_rotate) {
// Pick bigger shape and rotate this one
bool x_over_y = (x_dims.size() > y_dims.size());
auto vdims =
x_over_y ? phi::vectorize<int>(x_dims) : phi::vectorize<int>(y_dims);
std::rotate(vdims.begin() + 1, vdims.begin() + 2, vdims.end());
if (x_over_y) {
x_dims = phi::make_ddim(vdims);
} else {
y_dims = phi::make_ddim(vdims);
}
}
#endif
funcs::GetBroadcastDimsArrays(x_dims, funcs::GetBroadcastDimsArrays(x_dims,
y_dims, y_dims,
x_dims_array.data(), x_dims_array.data(),
...@@ -1206,6 +1230,13 @@ void ElementwiseRawInferMeta(const MetaTensor& x, ...@@ -1206,6 +1230,13 @@ void ElementwiseRawInferMeta(const MetaTensor& x,
out_dims_array.data(), out_dims_array.data(),
max_dim, max_dim,
axis); axis);
#ifdef PADDLE_WITH_MKLDNN
if (should_rotate) {
std::rotate(out_dims_array.begin() + 1,
out_dims_array.end() - 1,
out_dims_array.end());
}
#endif
auto out_dims = phi::make_ddim(out_dims_array); auto out_dims = phi::make_ddim(out_dims_array);
out->set_dims(out_dims); out->set_dims(out_dims);
} else { } else {
......
...@@ -203,7 +203,8 @@ void ElementwiseInferMeta(const MetaTensor& x, ...@@ -203,7 +203,8 @@ void ElementwiseInferMeta(const MetaTensor& x,
void ElementwiseRawInferMeta(const MetaTensor& x_meta, void ElementwiseRawInferMeta(const MetaTensor& x_meta,
const MetaTensor& y_meta, const MetaTensor& y_meta,
int axis, int axis,
MetaTensor* out); MetaTensor* out,
MetaConfig config = MetaConfig());
void EmbeddingInferMeta(const MetaTensor& x, void EmbeddingInferMeta(const MetaTensor& x,
const MetaTensor& weight, const MetaTensor& weight,
......
...@@ -22,6 +22,32 @@ ...@@ -22,6 +22,32 @@
namespace phi { namespace phi {
KernelKey ElementwiseGetKernelTypeForVar(
const GetKernelTypeForVarContext* ctx) {
const DenseTensor& tensor = ctx->GetTensor();
const KernelKey& expected_kernel_type = ctx->GetKernelKey();
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if (expected_kernel_type.dtype() == phi::DataType::COMPLEX64 ||
expected_kernel_type.dtype() == phi::DataType::COMPLEX128) {
// only promote inputs’s types when contains complex input
return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
} else {
// When elementwise is first oneDNN op (there was some non oneDNN op
// previously)
// then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN) &&
phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
phi::DataLayout::kNHWC) {
return phi::KernelKey(
tensor.place(), phi::DataLayout::kNHWC, expected_kernel_type.dtype());
}
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
}
template <typename T, dnnl::algorithm BINARY_OP> template <typename T, dnnl::algorithm BINARY_OP>
void ElementwiseKernel(const OneDNNContext& dev_ctx, void ElementwiseKernel(const OneDNNContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -135,7 +161,9 @@ PD_REGISTER_KERNEL(add_raw, ...@@ -135,7 +161,9 @@ PD_REGISTER_KERNEL(add_raw,
float, float,
phi::dtype::bfloat16, phi::dtype::bfloat16,
int8_t, int8_t,
uint8_t) {} uint8_t) {
kernel->get_kerneltype_forvar_fn_ = phi::ElementwiseGetKernelTypeForVar;
}
PD_REGISTER_KERNEL(add, PD_REGISTER_KERNEL(add,
OneDNN, OneDNN,
...@@ -144,7 +172,9 @@ PD_REGISTER_KERNEL(add, ...@@ -144,7 +172,9 @@ PD_REGISTER_KERNEL(add,
float, float,
phi::dtype::bfloat16, phi::dtype::bfloat16,
int8_t, int8_t,
uint8_t) {} uint8_t) {
kernel->get_kerneltype_forvar_fn_ = phi::ElementwiseGetKernelTypeForVar;
}
PD_REGISTER_KERNEL(subtract_raw, PD_REGISTER_KERNEL(subtract_raw,
OneDNN, OneDNN,
...@@ -153,7 +183,9 @@ PD_REGISTER_KERNEL(subtract_raw, ...@@ -153,7 +183,9 @@ PD_REGISTER_KERNEL(subtract_raw,
float, float,
phi::dtype::bfloat16, phi::dtype::bfloat16,
int8_t, int8_t,
uint8_t) {} uint8_t) {
kernel->get_kerneltype_forvar_fn_ = phi::ElementwiseGetKernelTypeForVar;
}
PD_REGISTER_KERNEL(subtract, PD_REGISTER_KERNEL(subtract,
OneDNN, OneDNN,
...@@ -162,7 +194,9 @@ PD_REGISTER_KERNEL(subtract, ...@@ -162,7 +194,9 @@ PD_REGISTER_KERNEL(subtract,
float, float,
phi::dtype::bfloat16, phi::dtype::bfloat16,
int8_t, int8_t,
uint8_t) {} uint8_t) {
kernel->get_kerneltype_forvar_fn_ = phi::ElementwiseGetKernelTypeForVar;
}
PD_REGISTER_KERNEL(multiply_raw, PD_REGISTER_KERNEL(multiply_raw,
OneDNN, OneDNN,
...@@ -171,7 +205,9 @@ PD_REGISTER_KERNEL(multiply_raw, ...@@ -171,7 +205,9 @@ PD_REGISTER_KERNEL(multiply_raw,
float, float,
phi::dtype::bfloat16, phi::dtype::bfloat16,
int8_t, int8_t,
uint8_t) {} uint8_t) {
kernel->get_kerneltype_forvar_fn_ = phi::ElementwiseGetKernelTypeForVar;
}
PD_REGISTER_KERNEL(multiply, PD_REGISTER_KERNEL(multiply,
OneDNN, OneDNN,
...@@ -180,7 +216,9 @@ PD_REGISTER_KERNEL(multiply, ...@@ -180,7 +216,9 @@ PD_REGISTER_KERNEL(multiply,
float, float,
phi::dtype::bfloat16, phi::dtype::bfloat16,
int8_t, int8_t,
uint8_t) {} uint8_t) {
kernel->get_kerneltype_forvar_fn_ = phi::ElementwiseGetKernelTypeForVar;
}
PD_REGISTER_KERNEL(divide_raw, PD_REGISTER_KERNEL(divide_raw,
OneDNN, OneDNN,
...@@ -190,4 +228,6 @@ PD_REGISTER_KERNEL(divide_raw, ...@@ -190,4 +228,6 @@ PD_REGISTER_KERNEL(divide_raw,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
divide, OneDNN, ONEDNN, phi::DivideKernel, float, phi::dtype::bfloat16) {} divide, OneDNN, ONEDNN, phi::DivideKernel, float, phi::dtype::bfloat16) {
kernel->get_kerneltype_forvar_fn_ = phi::ElementwiseGetKernelTypeForVar;
}
...@@ -66,6 +66,9 @@ KernelSignature ElementwiseDivOpArgumentMapping( ...@@ -66,6 +66,9 @@ KernelSignature ElementwiseDivOpArgumentMapping(
KernelSignature ElementwiseMaxOpArgumentMapping( KernelSignature ElementwiseMaxOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
if (ctx.IsForInferShape()) {
return KernelSignature("maximum_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
int axis = paddle::any_cast<int>(ctx.Attr("axis")); int axis = paddle::any_cast<int>(ctx.Attr("axis"));
if (axis == -1) { if (axis == -1) {
return KernelSignature("maximum", {"X", "Y"}, {}, {"Out"}); return KernelSignature("maximum", {"X", "Y"}, {}, {"Out"});
...@@ -184,12 +187,6 @@ KernelSignature ElementwiseFMinOpArgumentMapping( ...@@ -184,12 +187,6 @@ KernelSignature ElementwiseFMinOpArgumentMapping(
return KernelSignature("fmin", {"X", "Y"}, {}, {"Out"}); return KernelSignature("fmin", {"X", "Y"}, {}, {"Out"});
} }
KernelSignature ElementwiseFMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature(
"fmax_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseMulDoubleGradOpArgumentMapping( KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) { const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("multiply_double_grad", return KernelSignature("multiply_double_grad",
...@@ -207,12 +204,6 @@ KernelSignature ElementwiseMulTripleGradOpArgumentMapping( ...@@ -207,12 +204,6 @@ KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
{"D_X", "D_Y", "D_DOut", "D_DDX", "D_DDY"}); {"D_X", "D_Y", "D_DOut", "D_DDX", "D_DDY"});
} }
KernelSignature ElementwiseMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature(
"maximum_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseMinGradOpArgumentMapping( KernelSignature ElementwiseMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) { const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature( return KernelSignature(
...@@ -253,9 +244,7 @@ PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad_grad, multiply_double_grad); ...@@ -253,9 +244,7 @@ PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad_grad, multiply_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_triple_grad, multiply_triple_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_triple_grad, multiply_triple_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax, fmax); PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax, fmax);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin, fmin); PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin, fmin);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax_grad, fmax_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin_grad, fmin_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin_grad, fmin_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_max_grad, maximum_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_min_grad, minimum_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_min_grad, minimum_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_heaviside_grad, heaviside_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_heaviside_grad, heaviside_grad);
...@@ -303,12 +292,8 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax, ...@@ -303,12 +292,8 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax,
phi::ElementwiseFMaxOpArgumentMapping); phi::ElementwiseFMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin, PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin,
phi::ElementwiseFMinOpArgumentMapping); phi::ElementwiseFMinOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax_grad,
phi::ElementwiseFMaxGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin_grad, PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin_grad,
phi::ElementwiseFMinGradOpArgumentMapping); phi::ElementwiseFMinGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_max_grad,
phi::ElementwiseMaxGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_min_grad, PD_REGISTER_ARG_MAPPING_FN(elementwise_min_grad,
phi::ElementwiseMinGradOpArgumentMapping); phi::ElementwiseMinGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_heaviside_grad, PD_REGISTER_ARG_MAPPING_FN(elementwise_heaviside_grad,
......
...@@ -27,7 +27,6 @@ if(WITH_GPU ...@@ -27,7 +27,6 @@ if(WITH_GPU
reduce_mean_op reduce_mean_op
activation_op activation_op
sum_op sum_op
elementwise_max_op
elementwise_div_op elementwise_div_op
generated_op generated_op
generated_static_op generated_static_op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册