From ec1d2a16b66f3525dc775cb557fec1aea74534e8 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 21 Apr 2022 11:25:38 +0800 Subject: [PATCH] [Cherry-pick] Optimize dygraph scheduling performance (#42010) * [Phi] Support setting size of vector for out in yaml (#41576) * support setting vector out size in yaml * support setting size of vector for out in yaml * resolve conflict Co-authored-by: zyfncg --- .../final_state_generator/codegen_utils.py | 2 +- paddle/fluid/framework/infershape_utils.cc | 160 +++++-- paddle/fluid/framework/infershape_utils.h | 83 +++- .../new_executor/new_executor_defs.cc | 9 +- .../new_executor/new_executor_defs.h | 8 +- paddle/fluid/framework/op_desc.cc | 12 +- paddle/fluid/framework/operator.cc | 24 +- paddle/fluid/framework/operator.h | 1 + paddle/fluid/framework/phi_utils.cc | 73 +-- paddle/fluid/framework/phi_utils.h | 6 +- paddle/fluid/framework/shape_inference.h | 10 +- paddle/fluid/imperative/infer_shape_context.h | 15 +- paddle/fluid/imperative/prepared_operator.cc | 33 +- paddle/fluid/imperative/prepared_operator.h | 8 +- .../fluid/inference/api/analysis_predictor.cc | 1 + paddle/fluid/inference/api/api_impl.cc | 1 + .../fluid/operators/controlflow/while_op.cc | 6 +- .../detection/collect_fpn_proposals_op.cc | 6 +- paddle/fluid/pybind/imperative.cc | 50 +- paddle/fluid/pybind/pybind.cc | 2 + .../infrt/dialect/phi/pass/kernel_op_desc.cc | 9 +- paddle/infrt/host_context/value.h | 1 + paddle/phi/api/lib/api_custom_impl.cc | 440 +----------------- paddle/phi/api/lib/api_custom_impl.h | 51 +- paddle/phi/api/lib/api_gen_utils.cc | 10 + paddle/phi/api/lib/api_gen_utils.h | 3 + paddle/phi/common/int_array.h | 2 + paddle/phi/core/compat/arg_map_context.h | 40 +- paddle/phi/core/compat/convert_utils.cc | 2 +- paddle/phi/core/compat/convert_utils.h | 2 +- paddle/phi/core/compat/op_utils.h | 8 +- paddle/phi/core/infermeta_utils.cc | 40 +- paddle/phi/core/infermeta_utils.h | 57 +-- paddle/phi/core/kernel_context.cc | 4 +- paddle/phi/core/kernel_factory.cc | 24 +- paddle/phi/core/kernel_factory.h | 39 +- paddle/phi/core/meta_tensor.cc | 2 + paddle/phi/core/meta_tensor.h | 8 +- paddle/phi/core/type_defs.h | 7 +- paddle/phi/infermeta/backward.cc | 6 +- paddle/phi/infermeta/backward.h | 6 +- paddle/phi/infermeta/multiary.cc | 24 +- paddle/phi/infermeta/multiary.h | 24 +- paddle/phi/kernels/concat_kernel.h | 2 +- paddle/phi/ops/compat/abs_sig.cc | 3 +- paddle/phi/ops/compat/activation_sig.cc | 39 +- paddle/phi/ops/compat/addmm_sig.cc | 9 +- paddle/phi/ops/compat/argsort_sig.cc | 4 +- paddle/phi/ops/compat/atan2_sig.cc | 6 +- paddle/phi/ops/compat/batch_norm_sig.cc | 41 +- paddle/phi/ops/compat/bce_loss_sig.cc | 6 +- .../ops/compat/bilinear_tensor_product_sig.cc | 7 +- .../phi/ops/compat/broadcast_tensors_sig.cc | 2 +- paddle/phi/ops/compat/cholesky_sig.cc | 6 +- paddle/phi/ops/compat/cholesky_solve_sig.cc | 4 +- paddle/phi/ops/compat/clip_sig.cc | 26 +- paddle/phi/ops/compat/complex_sig.cc | 6 +- paddle/phi/ops/compat/concat_sig.cc | 12 +- paddle/phi/ops/compat/conv2d_sig.cc | 4 +- paddle/phi/ops/compat/conv3d_sig.cc | 4 +- paddle/phi/ops/compat/conv_transpose_sig.cc | 12 +- paddle/phi/ops/compat/cross_sig.cc | 6 +- paddle/phi/ops/compat/cumprod_sig.cc | 6 +- paddle/phi/ops/compat/deformable_conv_sig.cc | 7 +- paddle/phi/ops/compat/depthwise_conv2d_sig.cc | 4 +- paddle/phi/ops/compat/determinant_sig.cc | 6 +- paddle/phi/ops/compat/diag_sig.cc | 2 +- paddle/phi/ops/compat/diagonal_sig.cc | 4 +- paddle/phi/ops/compat/digamma_sig.cc | 3 +- paddle/phi/ops/compat/dist_sig.cc | 6 +- paddle/phi/ops/compat/dot_sig.cc | 6 +- paddle/phi/ops/compat/dropout_sig.cc | 4 +- paddle/phi/ops/compat/eigh_sig.cc | 12 +- paddle/phi/ops/compat/elementwise_sig.cc | 54 +-- paddle/phi/ops/compat/embedding_sig.cc | 16 +- paddle/phi/ops/compat/erf_sig.cc | 3 +- paddle/phi/ops/compat/erfinv_sig.cc | 3 +- paddle/phi/ops/compat/expand_as_sig.cc | 6 +- paddle/phi/ops/compat/expand_sig.cc | 18 +- paddle/phi/ops/compat/flatten_sig.cc | 2 +- paddle/phi/ops/compat/frobenius_norm_sig.cc | 4 +- paddle/phi/ops/compat/gather_scatter_sig.cc | 14 +- paddle/phi/ops/compat/gather_sig.cc | 8 +- paddle/phi/ops/compat/gelu_sig.cc | 6 +- paddle/phi/ops/compat/graph_send_recv_sig.cc | 4 +- paddle/phi/ops/compat/grid_sampler_sig.cc | 4 +- paddle/phi/ops/compat/gumbel_softmax_sig.cc | 6 +- .../ops/compat/hierarchical_sigmoid_sig.cc | 74 ++- paddle/phi/ops/compat/huber_loss_sig.cc | 4 +- paddle/phi/ops/compat/index_sample_sig.cc | 6 +- paddle/phi/ops/compat/index_select_sig.cc | 6 +- paddle/phi/ops/compat/interpolate_sig.cc | 115 +++-- paddle/phi/ops/compat/kldiv_loss_sig.cc | 4 +- paddle/phi/ops/compat/kron_sig.cc | 6 +- paddle/phi/ops/compat/kthvalue_sig.cc | 4 +- paddle/phi/ops/compat/label_smooth_sig.cc | 6 +- paddle/phi/ops/compat/layer_norm_sig.cc | 9 +- paddle/phi/ops/compat/lerp_sig.cc | 4 +- paddle/phi/ops/compat/lgamma_sig.cc | 3 +- paddle/phi/ops/compat/log_loss_sig.cc | 4 +- paddle/phi/ops/compat/log_softmax_sig.cc | 6 +- paddle/phi/ops/compat/logsumexp_sig.cc | 4 +- paddle/phi/ops/compat/masked_select_sig.cc | 6 +- paddle/phi/ops/compat/matmul_sig.cc | 8 +- paddle/phi/ops/compat/matrix_power_sig.cc | 6 +- paddle/phi/ops/compat/maxout_sig.cc | 6 +- paddle/phi/ops/compat/mean_sig.cc | 3 +- paddle/phi/ops/compat/meshgrid_sig.cc | 3 +- paddle/phi/ops/compat/mode_sig.cc | 4 +- paddle/phi/ops/compat/mul_sig.cc | 4 +- paddle/phi/ops/compat/multi_dot_sig.cc | 3 +- paddle/phi/ops/compat/multiplex_sig.cc | 3 +- paddle/phi/ops/compat/mv_sig.cc | 6 +- paddle/phi/ops/compat/nll_loss_sig.cc | 9 +- paddle/phi/ops/compat/norm_sig.cc | 4 +- paddle/phi/ops/compat/p_norm_sig.cc | 4 +- paddle/phi/ops/compat/pad3d_sig.cc | 8 +- paddle/phi/ops/compat/pad_sig.cc | 6 +- paddle/phi/ops/compat/pixel_shuffle_sig.cc | 4 +- paddle/phi/ops/compat/poisson_sig.cc | 3 +- paddle/phi/ops/compat/pool_sig.cc | 16 +- paddle/phi/ops/compat/prelu_sig.cc | 4 +- paddle/phi/ops/compat/psroi_pool_sig.cc | 4 +- paddle/phi/ops/compat/put_along_axis_sig.cc | 4 +- paddle/phi/ops/compat/reduce_sig.cc | 20 +- paddle/phi/ops/compat/reshape_sig.cc | 3 +- paddle/phi/ops/compat/rnn_sig.cc | 8 +- paddle/phi/ops/compat/roi_align_sig.cc | 4 +- paddle/phi/ops/compat/roi_pool_sig.cc | 4 +- paddle/phi/ops/compat/roll_sig.cc | 6 +- paddle/phi/ops/compat/segment_pool_sig.cc | 13 +- paddle/phi/ops/compat/selu_sig.cc | 6 +- paddle/phi/ops/compat/set_value_sig.cc | 137 +++--- .../sigmoid_cross_entropy_with_logits_sig.cc | 4 +- paddle/phi/ops/compat/slice_sig.cc | 36 +- paddle/phi/ops/compat/softmax_sig.cc | 6 +- .../compat/softmax_with_cross_entropy_sig.cc | 4 +- paddle/phi/ops/compat/squeeze_sig.cc | 6 +- paddle/phi/ops/compat/stack_sig.cc | 3 +- paddle/phi/ops/compat/strided_slice_sig.cc | 28 +- paddle/phi/ops/compat/take_along_axis_sig.cc | 4 +- paddle/phi/ops/compat/temporal_shift_sig.cc | 4 +- paddle/phi/ops/compat/tile_sig.cc | 18 +- paddle/phi/ops/compat/top_k_sig.cc | 4 +- paddle/phi/ops/compat/trace_sig.cc | 4 +- paddle/phi/ops/compat/transpose_sig.cc | 3 +- paddle/phi/ops/compat/triangular_solve_sig.cc | 4 +- paddle/phi/ops/compat/tril_triu_sig.cc | 6 +- paddle/phi/ops/compat/trunc_sig.cc | 3 +- paddle/phi/ops/compat/unfold_sig.cc | 4 +- paddle/phi/ops/compat/unsqueeze_sig.cc | 2 +- paddle/phi/ops/compat/unstack_sig.cc | 3 +- paddle/phi/ops/compat/warpctc_sig.cc | 9 +- paddle/phi/ops/compat/where_grad_sig.cc | 4 +- paddle/phi/ops/compat/yolov3_loss_sig.cc | 36 +- paddle/phi/tests/core/test_meta_fn_utils.cc | 16 +- paddle/testing/CMakeLists.txt | 2 +- paddle/testing/paddle_gtest_main.cc | 2 + python/paddle/fluid/__init__.py | 1 + python/paddle/utils/code_gen/api.yaml | 14 +- python/paddle/utils/code_gen/api_base.py | 74 +-- python/paddle/utils/code_gen/api_gen.py | 20 +- python/paddle/utils/code_gen/backward.yaml | 42 +- .../paddle/utils/code_gen/backward_api_gen.py | 23 +- 164 files changed, 1181 insertions(+), 1508 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py index 2330c84ea09..ab8c28c33e7 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py @@ -235,7 +235,7 @@ def ParseYamlReturns(string): returns = [x.strip() for x in string.strip().split(",")] for i in range(len(returns)): - ret = returns[i] + ret = returns[i].split("{")[0].strip() ret_name = "" if "(" in ret and ")" in ret: diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 17acbde2a09..bd71ade7e93 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -308,10 +308,100 @@ void CompatMetaTensor::share_meta(const MetaTensor& meta_tensor) { share_lod(meta_tensor); } -phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, - const std::string& op_type) { +void CompatInferMetaContext::EmplaceBackInput(CompatMetaTensor input) { + int index = compat_inputs_.size(); + compat_inputs_.emplace_back(std::move(input)); + input_range_.emplace_back(std::pair(index, index + 1)); +} +void CompatInferMetaContext::EmplaceBackOutput(CompatMetaTensor output) { + int index = compat_outputs_.size(); + compat_outputs_.emplace_back(std::move(output)); + output_range_.emplace_back(std::pair(index, index + 1)); +} + +void CompatInferMetaContext::EmplaceBackInputs( + paddle::SmallVector inputs) { + int index = compat_inputs_.size(); + input_range_.emplace_back(std::pair(index, index + inputs.size())); + compat_inputs_.insert(compat_inputs_.end(), + std::make_move_iterator(inputs.begin()), + std::make_move_iterator(inputs.end())); +} + +void CompatInferMetaContext::EmplaceBackOutputs( + paddle::SmallVector + outputs) { + int index = compat_outputs_.size(); + output_range_.emplace_back( + std::pair(index, index + outputs.size())); + compat_outputs_.insert(compat_outputs_.end(), + std::make_move_iterator(outputs.begin()), + std::make_move_iterator(outputs.end())); +} + +const phi::MetaTensor& CompatInferMetaContext::InputAt(size_t idx) const { + return compat_inputs_.at(idx); +} + +paddle::optional +CompatInferMetaContext::OptionalInputAt(size_t idx) const { + const auto& input = compat_inputs_.at(idx); + return input.initialized() + ? paddle::optional{input} + : paddle::optional{paddle::none}; +} + +std::vector CompatInferMetaContext::InputsBetween( + size_t start, size_t end) const { + std::vector result; + result.reserve(end - start); + + for (size_t i = start; i < end; ++i) { + auto& in = compat_inputs_.at(i); + result.emplace_back(in.initialized() ? &in : nullptr); + } + + return result; +} + +paddle::optional> +CompatInferMetaContext::OptionalInputsBetween(size_t start, size_t end) const { + const auto& first = compat_inputs_.at(start); + + if (first.initialized()) { + std::vector result; + result.reserve(end - start); + + for (size_t i = start; i < end; ++i) { + auto& in = compat_inputs_.at(i); + result.emplace_back(in.initialized() ? &in : nullptr); + } + + return paddle::optional>(result); + } + return paddle::optional>( + paddle::none); +} + +phi::MetaTensor* CompatInferMetaContext::MutableOutputAt(size_t idx) { + auto& out = compat_outputs_.at(idx); + return out.initialized() ? &out : nullptr; +} + +std::vector CompatInferMetaContext::MutableOutputBetween( + size_t start, size_t end) { + std::vector result; + result.reserve(end - start); + for (size_t i = start; i < end; ++i) { + auto& out = compat_outputs_.at(i); + result.emplace_back(out.initialized() ? &out : nullptr); + } + return result; +} + +CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, + const std::string& op_type) { // 1. get kernel args - InitDefaultKernelSignatureMap(); auto arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type); PADDLE_ENFORCE_NOT_NULL( arg_map_fn, platform::errors::NotFound( @@ -321,52 +411,47 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature; // 2. build infermeta context - phi::InferMetaContext infer_meta_context( + CompatInferMetaContext infer_meta_context( {ctx->IsRuntime(), ctx->IsRunMKLDNNKernel()}); auto& input_names = std::get<0>(signature.args); auto& attr_names = std::get<1>(signature.args); auto& output_names = std::get<2>(signature.args); - auto kernels_map = - phi::KernelFactory::Instance().SelectKernelMap(signature.name); - if (kernels_map.size() == 0) { - PADDLE_THROW( - platform::errors::Unimplemented("Not find `%s` kernels when construct " - "InferMetaContext.", - signature.name)); - } - auto attr_defs = kernels_map.cbegin()->second.args_def().attribute_defs(); + const auto& args_def = + phi::KernelFactory::Instance().GetFirstKernelArgsDef(signature.name); + const auto& attr_defs = args_def.attribute_defs(); - // TODO(chenweihang): support multiple inputs and outputs later - phi::InferMetaContext infer_mete_context; for (auto& in_name : input_names) { if (ctx->HasInputs(in_name)) { - auto input_var = ctx->GetInputVarPtrs(in_name); + auto input_var = std::move(ctx->GetInputVarPtrs(in_name)); if (input_var.size() == 1) { infer_meta_context.EmplaceBackInput( - std::make_shared(input_var[0], ctx->IsRuntime())); + std::move(CompatMetaTensor(input_var[0], ctx->IsRuntime()))); } else { - paddle::SmallVector> inputs; - inputs.reserve(input_var.size()); + paddle::SmallVector + inputs; for (const auto& in : input_var) { - inputs.push_back( - std::make_shared(in, ctx->IsRuntime())); + inputs.emplace_back( + std::move(CompatMetaTensor(in, ctx->IsRuntime()))); } infer_meta_context.EmplaceBackInputs(std::move(inputs)); } } else { - infer_meta_context.EmplaceBackInput({nullptr}); + infer_meta_context.EmplaceBackInput( + std::move(CompatMetaTensor(ctx->IsRuntime()))); } } + VLOG(6) << "BuildInferMetaContext: Done inputs"; + auto attr_reader = ctx->Attrs(); for (size_t i = 0; i < attr_names.size(); ++i) { - auto attr_name = attr_names[i]; + auto& attr_name = attr_names[i]; if (attr_defs[i].type_index == std::type_index(typeid(phi::IntArray))) { // When attr is a vector_tensor or tensor, transform it to IntArray if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) { - const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name); + auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name)); if (ctx->IsRuntime()) { // If is in runtime, we will get tensor's value for IntArray // and push it into attrs @@ -456,7 +541,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, attr_name)); } } else if (ctx->HasInput(attr_name)) { - const auto& infershape_input = ctx->GetInputVarPtrs(attr_name); + auto infershape_input = std::move(ctx->GetInputVarPtrs(attr_name)); if (infershape_input.size() == 1) { if (ctx->IsRuntime()) { Variable* var = BOOST_GET_CONST(Variable*, infershape_input[0]); @@ -581,7 +666,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, // convert from data if (attr_defs[i].type_index == std::type_index(typeid(int32_t))) { if (ctx->IsRuntime()) { - const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name); + auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name)); auto var_temp = BOOST_GET_CONST(Variable*, infershape_inputs[i]); auto val = experimental::MakePhiScalarFromVar(*var_temp); int32_t val_int = val.template to(); @@ -596,36 +681,41 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, } } + VLOG(6) << "BuildInferMetaContext: Done attrs"; + for (auto& out_name : output_names) { if (ctx->HasOutputs(out_name, true)) { - auto output_var = ctx->GetOutputVarPtrs(out_name); + auto output_var = std::move(ctx->GetOutputVarPtrs(out_name)); if (output_var.size() == 1) { - infer_meta_context.EmplaceBackOutput(std::make_shared( - output_var[0], ctx->IsRuntime())); + infer_meta_context.EmplaceBackOutput( + std::move(CompatMetaTensor(output_var[0], ctx->IsRuntime()))); } else { - paddle::SmallVector> outputs; - outputs.reserve(output_var.size()); + paddle::SmallVector + outputs; for (const auto& out : output_var) { if (ctx->IsRuntime()) { if (BOOST_GET_CONST(Variable*, out)) { outputs.emplace_back( - std::make_shared(out, ctx->IsRuntime())); + std::move(CompatMetaTensor(out, ctx->IsRuntime()))); continue; } } else if (BOOST_GET_CONST(VarDesc*, out)) { outputs.emplace_back( - std::make_shared(out, ctx->IsRuntime())); + std::move(CompatMetaTensor(out, ctx->IsRuntime()))); continue; } - outputs.emplace_back(nullptr); + outputs.emplace_back(std::move(CompatMetaTensor(ctx->IsRuntime()))); } infer_meta_context.EmplaceBackOutputs(std::move(outputs)); } } else { - infer_meta_context.EmplaceBackOutput({nullptr}); + infer_meta_context.EmplaceBackOutput( + std::move(CompatMetaTensor(ctx->IsRuntime()))); } } + VLOG(6) << "BuildInferMetaContext: Done outputs"; + return infer_meta_context; } diff --git a/paddle/fluid/framework/infershape_utils.h b/paddle/fluid/framework/infershape_utils.h index 022f194b667..e54f2e81e7e 100644 --- a/paddle/fluid/framework/infershape_utils.h +++ b/paddle/fluid/framework/infershape_utils.h @@ -18,38 +18,24 @@ limitations under the License. */ #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/shape_inference.h" +#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/meta_tensor.h" -namespace phi { -class InferMetaContext; -} // namespace phi namespace paddle { namespace framework { -phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, - const std::string& op_type); - -#define DECLARE_INFER_SHAPE_FUNCTOR(op_type, functor_name, fn) \ - struct functor_name : public paddle::framework::InferShapeBase { \ - void operator()( \ - paddle::framework::InferShapeContext* ctx) const override { \ - auto infer_meta_context = \ - paddle::framework::BuildInferMetaContext(ctx, #op_type); \ - fn(&infer_meta_context); \ - } \ - } - // TODO(chenweihang): Support TensorArray later class CompatMetaTensor : public phi::MetaTensor { public: + explicit CompatMetaTensor(bool is_runtime) + : is_runtime_(is_runtime), initialized_(false) {} CompatMetaTensor(InferShapeVarPtr var, bool is_runtime) : var_(std::move(var)), is_runtime_(is_runtime) {} - CompatMetaTensor() = default; - CompatMetaTensor(const CompatMetaTensor&) = default; CompatMetaTensor(CompatMetaTensor&&) = default; - CompatMetaTensor& operator=(const CompatMetaTensor&) = delete; - CompatMetaTensor& operator=(CompatMetaTensor&&) = delete; + CompatMetaTensor& operator=(CompatMetaTensor&&) = default; + CompatMetaTensor(const CompatMetaTensor&) = default; + CompatMetaTensor& operator=(const CompatMetaTensor&) = default; int64_t numel() const override; @@ -71,6 +57,8 @@ class CompatMetaTensor : public phi::MetaTensor { void share_meta(const MetaTensor& meta_tensor) override; + bool initialized() const override { return initialized_; }; + private: const LoD& GetRuntimeLoD() const { auto* var = BOOST_GET_CONST(Variable*, var_); @@ -95,7 +83,62 @@ class CompatMetaTensor : public phi::MetaTensor { InferShapeVarPtr var_; bool is_runtime_; + bool initialized_{true}; +}; + +// Note: In order to avoid using shared_ptr to manage MetaTensor in +// InferMetaContext, inherit and implement InferMetaContext separately +// for compatibility with fluid, shared_ptr will cause significant decrease +// in scheduling performance +class CompatInferMetaContext : public phi::InferMetaContext { + public: + CompatInferMetaContext() = default; + explicit CompatInferMetaContext(phi::MetaConfig config) + : phi::InferMetaContext(config) {} + + void EmplaceBackInput(CompatMetaTensor input); + void EmplaceBackOutput(CompatMetaTensor output); + + void EmplaceBackInputs( + paddle::SmallVector inputs); + void EmplaceBackOutputs( + paddle::SmallVector + outputs); + + const phi::MetaTensor& InputAt(size_t idx) const override; + paddle::optional OptionalInputAt( + size_t idx) const override; + + std::vector InputsBetween(size_t start, + size_t end) const override; + paddle::optional> + OptionalInputsBetween(size_t start, size_t end) const override; + + phi::MetaTensor* MutableOutputAt(size_t idx) override; + std::vector MutableOutputBetween(size_t start, + size_t end) override; + + virtual ~CompatInferMetaContext() = default; + + private: + paddle::SmallVector + compat_inputs_; + paddle::SmallVector + compat_outputs_; }; +CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, + const std::string& op_type); + +#define DECLARE_INFER_SHAPE_FUNCTOR(op_type, functor_name, fn) \ + struct functor_name : public paddle::framework::InferShapeBase { \ + void operator()( \ + paddle::framework::InferShapeContext* ctx) const override { \ + auto infer_meta_context = \ + paddle::framework::BuildInferMetaContext(ctx, #op_type); \ + fn(&infer_meta_context); \ + } \ + } + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index ac1a654df47..86d534b0b4e 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -328,20 +328,21 @@ bool InterpretercoreInferShapeContext::IsRunMKLDNNKernel() const { } // TODO(paddle-dev): Can this be template? -std::vector InterpretercoreInferShapeContext::GetInputVarPtrs( +paddle::SmallVector +InterpretercoreInferShapeContext::GetInputVarPtrs( const std::string& name) const { const std::vector& vars = InputVars(name); - std::vector res; + paddle::SmallVector res; res.reserve(vars.size()); res.insert(res.begin(), vars.begin(), vars.end()); return res; } -std::vector +paddle::SmallVector InterpretercoreInferShapeContext::GetOutputVarPtrs( const std::string& name) const { const std::vector& vars = OutputVars(name); - std::vector res; + paddle::SmallVector res; res.reserve(vars.size()); res.insert(res.begin(), vars.begin(), vars.end()); return res; diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index b223a2ad769..6a1e46e3592 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -90,11 +90,11 @@ class InterpretercoreInferShapeContext : public InferShapeContext { bool IsRunMKLDNNKernel() const override; // TODO(paddle-dev): Can this be template? - std::vector GetInputVarPtrs( - const std::string& name) const override; + paddle::SmallVector + GetInputVarPtrs(const std::string& name) const override; - std::vector GetOutputVarPtrs( - const std::string& name) const override; + paddle::SmallVector + GetOutputVarPtrs(const std::string& name) const override; DDim GetInputDim(const std::string& name) const override; diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 15b979086d1..d27bf0e150f 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -202,10 +202,10 @@ class CompileTimeInferShapeContext : public InferShapeContext { } } - std::vector GetInputVarPtrs( - const std::string &name) const override { + paddle::SmallVector + GetInputVarPtrs(const std::string &name) const override { const std::vector arg_names = Inputs(name); - std::vector res; + paddle::SmallVector res; res.reserve(arg_names.size()); std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res), [this](const std::string &name) { @@ -214,10 +214,10 @@ class CompileTimeInferShapeContext : public InferShapeContext { return res; } - std::vector GetOutputVarPtrs( - const std::string &name) const override { + paddle::SmallVector + GetOutputVarPtrs(const std::string &name) const override { const std::vector arg_names = Outputs(name); - std::vector res; + paddle::SmallVector res; res.reserve(arg_names.size()); std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res), [this](const std::string &name) { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index e3bc22ae88b..0291309aa0d 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -947,19 +947,19 @@ class RuntimeInferShapeContext : public InferShapeContext { } // TODO(paddle-dev): Can this be template? - std::vector GetInputVarPtrs( - const std::string& name) const override { + paddle::SmallVector + GetInputVarPtrs(const std::string& name) const override { const std::vector& vars = InputVars(name); - std::vector res; + paddle::SmallVector res; res.reserve(vars.size()); res.insert(res.begin(), vars.begin(), vars.end()); return res; } - std::vector GetOutputVarPtrs( - const std::string& name) const override { + paddle::SmallVector + GetOutputVarPtrs(const std::string& name) const override { const std::vector& vars = OutputVars(name); - std::vector res; + paddle::SmallVector res; res.reserve(vars.size()); res.insert(res.begin(), vars.begin(), vars.end()); return res; @@ -1326,8 +1326,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, << ", using_kernel_key:" << *kernel_type_.get(); auto try_pt_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get()); - if (!phi::KernelFactory::Instance().IsSelectKernelValid( - pt_kernel_name, try_pt_kernel_key)) { + if (!phi::KernelFactory::Instance().HasKernel(pt_kernel_name, + try_pt_kernel_key)) { kernel_type_->library_type_ = expected_kernel_key_library_type; VLOG(3) << "modify XPU KP kernel in static graph: " << type_ << " is failed " << *kernel_type_.get(); @@ -2115,10 +2115,12 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar( KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( const ExecutionContext& ctx) const { - InitDefaultKernelSignatureMap(); ExecutionArgumentMappingContext arg_mapping_ctx(ctx); - return phi::OpUtilsMap::Instance().GetArgumentMappingFn(Type())( - arg_mapping_ctx); + if (arg_map_fn_ == nullptr) { + arg_map_fn_.reset(new phi::ArgumentMappingFn( + phi::OpUtilsMap::Instance().GetArgumentMappingFn(Type()))); + } + return (*arg_map_fn_)(arg_mapping_ctx); } Scope* OperatorWithKernel::PreparePhiData( diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index f7fc83f1d6d..f0887eb919c 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -701,6 +701,7 @@ class OperatorWithKernel : public OperatorBase { mutable bool run_kp_kernel = false; mutable std::unique_ptr pt_kernel_signature_; mutable std::unique_ptr pt_kernel_; + mutable std::unique_ptr arg_map_fn_; }; extern bool OpSupportGPU(const std::string& op_type); diff --git a/paddle/fluid/framework/phi_utils.cc b/paddle/fluid/framework/phi_utils.cc index 8e6f082da10..75bab059475 100644 --- a/paddle/fluid/framework/phi_utils.cc +++ b/paddle/fluid/framework/phi_utils.cc @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/kernel_factory.h" +#include "paddle/phi/core/type_defs.h" namespace paddle { namespace framework { @@ -40,9 +41,9 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker { ~KernelArgsNameMakerByOpProto() {} - const paddle::SmallVector& GetInputArgsNames() override; - const paddle::SmallVector& GetOutputArgsNames() override; - const paddle::SmallVector& GetAttrsArgsNames() override; + const paddle::SmallVector& GetInputArgsNames() override; + const paddle::SmallVector& GetOutputArgsNames() override; + const paddle::SmallVector& GetAttrsArgsNames() override; KernelSignature GetKernelSignature(); @@ -52,9 +53,9 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker { private: const framework::proto::OpProto* op_proto_; - paddle::SmallVector input_names_; - paddle::SmallVector output_names_; - paddle::SmallVector attr_names_; + paddle::SmallVector input_names_; + paddle::SmallVector output_names_; + paddle::SmallVector attr_names_; }; OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key) { @@ -102,7 +103,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, if (platform::is_xpu_place(expected_kernel_key.place_) || paddle::platform::is_in_xpu_black_list(op.Type())) { VLOG(3) << "phi missing XPU kernel: " << op.Type() - << ", phipected_kernel_key:" << expected_kernel_key + << ", expected_kernel_key:" << expected_kernel_key << ", fallbacking to CPU one!"; return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); @@ -111,7 +112,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, #ifdef PADDLE_WITH_ASCEND_CL if (platform::is_npu_place(expected_kernel_key.place_)) { VLOG(3) << "phi missing NPU kernel: " << op.Type() - << ", phipected_kernel_key:" << expected_kernel_key + << ", expected_kernel_key:" << expected_kernel_key << ", fallbacking to CPU one!"; return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); @@ -120,7 +121,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, #ifdef PADDLE_WITH_MLU if (platform::is_mlu_place(expected_kernel_key.place_)) { VLOG(3) << "phi missing MLU kernel: " << op.Type() - << ", phipected_kernel_key:" << expected_kernel_key + << ", expected_kernel_key:" << expected_kernel_key << ", fallbacking to CPU one!"; return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); @@ -129,7 +130,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, #ifdef PADDLE_WITH_IPU if (platform::is_ipu_place(expected_kernel_key.place_)) { VLOG(3) << "phi missing IPU kernel: " << op.Type() - << ", phipected_kernel_key:" << expected_kernel_key + << ", expected_kernel_key:" << expected_kernel_key << ", fallbacking to CPU one!"; return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); @@ -139,7 +140,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, if (platform::is_custom_place(expected_kernel_key.place_)) { VLOG(3) << "phi missing " << expected_kernel_key.place_.GetDeviceType() << " kernel: " << op.Type() - << ", phipected_kernel_key:" << expected_kernel_key + << ", expected_kernel_key:" << expected_kernel_key << ", fallbacking to CPU one!"; return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); @@ -148,45 +149,52 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, return phi::KernelKey(); } -const paddle::SmallVector& +const paddle::SmallVector& KernelArgsNameMakerByOpProto::GetInputArgsNames() { for (int i = 0; i < op_proto_->inputs_size(); ++i) { auto& in = op_proto_->inputs()[i]; auto& in_name = in.name(); if ((in.has_extra() && in.extra()) || (in.has_quant() && in.quant())) { - VLOG(6) << "Parse PhiKernel input: skip extra & quant input - " - << in_name; continue; } // If contains dispensable input, we should override the // OpArgumentMapping method self in phi/ops/compat dir if (in.has_dispensable() && in.dispensable()) { - VLOG(6) << "Parse PhiKernel input: skip dispensable input - " << in_name; continue; } - VLOG(6) << "Parse PhiKernel input: " << in_name; - input_names_.emplace_back(in_name); + input_names_.emplace_back(in_name.c_str()); + } + if (VLOG_IS_ON(10)) { + std::ostringstream sout; + sout << "PhiKernel inputs: "; + std::copy(input_names_.begin(), input_names_.end(), + std::ostream_iterator(sout, ", ")); + VLOG(10) << sout.str(); } return input_names_; } -const paddle::SmallVector& +const paddle::SmallVector& KernelArgsNameMakerByOpProto::GetOutputArgsNames() { for (int i = 0; i < op_proto_->outputs_size(); ++i) { auto& out = op_proto_->outputs()[i]; auto& out_name = out.name(); if ((out.has_extra() && out.extra()) || (out.has_quant() && out.quant())) { - VLOG(6) << "Parse PhiKernel output: skip extra & quant output - " - << out_name; continue; } - VLOG(6) << "Parse PhiKernel output: " << out_name; - output_names_.emplace_back(out_name); + output_names_.emplace_back(out_name.c_str()); + } + if (VLOG_IS_ON(10)) { + std::ostringstream sout; + sout << "PhiKernel outputs: "; + std::copy(output_names_.begin(), output_names_.end(), + std::ostream_iterator(sout, ", ")); + VLOG(10) << sout.str(); } return output_names_; } -const paddle::SmallVector& +const paddle::SmallVector& KernelArgsNameMakerByOpProto::GetAttrsArgsNames() { for (int i = 0; i < op_proto_->attrs_size(); ++i) { auto& attr = op_proto_->attrs()[i]; @@ -195,25 +203,26 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() { attr_name == "op_role" || attr_name == "op_role_var" || attr_name == "op_namescope" || attr_name == "op_callstack" || attr_name == "op_device") { - VLOG(6) << "Parse PhiKernel attribute: skip needless attr - " - << attr_name; continue; } if ((attr.has_extra() && attr.extra()) || (attr.has_quant() && attr.quant())) { - VLOG(6) << "Parse PhiKernel attribute: skip extra & quant attr - " - << attr_name; continue; } - VLOG(6) << "Parse PhiKernel attribute: " << attr_name; - attr_names_.emplace_back(attr_name); + attr_names_.emplace_back(attr_name.c_str()); + } + if (VLOG_IS_ON(10)) { + std::ostringstream sout; + sout << "PhiKernel attributes: "; + std::copy(attr_names_.begin(), attr_names_.end(), + std::ostream_iterator(sout, ", ")); + VLOG(10) << sout.str(); } - return attr_names_; } KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() { - return KernelSignature(phi::TransToPhiKernelName(op_proto_->type()), + return KernelSignature(phi::TransToPhiKernelName(op_proto_->type()).c_str(), GetInputArgsNames(), GetAttrsArgsNames(), GetOutputArgsNames()); } @@ -228,7 +237,7 @@ void InitDefaultKernelSignatureMap() { if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type) && op_proto) { paddle::framework::KernelArgsNameMakerByOpProto maker(op_proto); - VLOG(10) << "Register kernel signature for " << op_type; + VLOG(10) << "Register `" << op_type << "` kernel signature:"; phi::DefaultKernelSignatureMap::Instance().Insert( op_type, std::move(maker.GetKernelSignature())); } diff --git a/paddle/fluid/framework/phi_utils.h b/paddle/fluid/framework/phi_utils.h index a1757881692..392a3f9b06b 100644 --- a/paddle/fluid/framework/phi_utils.h +++ b/paddle/fluid/framework/phi_utils.h @@ -55,9 +55,9 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, class KernelArgsNameMaker { public: virtual ~KernelArgsNameMaker() {} - virtual const paddle::SmallVector& GetInputArgsNames() = 0; - virtual const paddle::SmallVector& GetOutputArgsNames() = 0; - virtual const paddle::SmallVector& GetAttrsArgsNames() = 0; + virtual const paddle::SmallVector& GetInputArgsNames() = 0; + virtual const paddle::SmallVector& GetOutputArgsNames() = 0; + virtual const paddle::SmallVector& GetAttrsArgsNames() = 0; }; void InitDefaultKernelSignatureMap(); diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 6ba60590cf8..bf9731bafce 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/variable.h" #include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/type_defs.h" +#include "paddle/utils/small_vector.h" namespace paddle { namespace framework { @@ -106,10 +108,10 @@ class InferShapeContext { virtual bool IsRunMKLDNNKernel() const = 0; - virtual std::vector GetInputVarPtrs( - const std::string &name) const = 0; - virtual std::vector GetOutputVarPtrs( - const std::string &name) const = 0; + virtual paddle::SmallVector + GetInputVarPtrs(const std::string &name) const = 0; + virtual paddle::SmallVector + GetOutputVarPtrs(const std::string &name) const = 0; protected: virtual std::vector GetRepeatedDims(const std::string &name) const = 0; diff --git a/paddle/fluid/imperative/infer_shape_context.h b/paddle/fluid/imperative/infer_shape_context.h index 1e5b112ece2..5b63334c9ea 100644 --- a/paddle/fluid/imperative/infer_shape_context.h +++ b/paddle/fluid/imperative/infer_shape_context.h @@ -235,9 +235,10 @@ class DygraphInferShapeContext : public framework::InferShapeContext { (op_kernel_type_->data_layout_ == framework::DataLayout::kMKLDNN)); } - std::vector GetInputVarPtrs( - const std::string& name) const override { - std::vector res; + paddle::SmallVector + GetInputVarPtrs(const std::string& name) const override { + paddle::SmallVector + res; auto it = var_map_in_->find(name); PADDLE_ENFORCE_NE( it, var_map_in_->end(), @@ -248,9 +249,11 @@ class DygraphInferShapeContext : public framework::InferShapeContext { return res; } - std::vector GetOutputVarPtrs( - const std::string& name) const override { - std::vector res; + paddle::SmallVector + GetOutputVarPtrs(const std::string& name) const override { + paddle::SmallVector + res; auto it = var_map_out_->find(name); PADDLE_ENFORCE_NE( it, var_map_out_->end(), diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 0ad5e808b1d..cef7417ea41 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -36,6 +36,8 @@ DECLARE_bool(run_kp_kernel); namespace paddle { namespace imperative { +static const phi::Kernel empty_kernel; + const std::shared_ptr& GetVariableWrapper( const std::shared_ptr& var) { return var->SharedVar(); @@ -108,12 +110,13 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, ctx_(ctx), kernel_type_(kernel_type), func_(func), - dev_ctx_(dev_ctx) {} + dev_ctx_(dev_ctx), + pt_kernel_(empty_kernel) {} PreparedOp::PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OpKernelType& kernel_type, - const framework::KernelSignature& kernel_signature, + framework::KernelSignature&& kernel_signature, const phi::Kernel& pt_kernel, platform::DeviceContext* dev_ctx) : op_(op), @@ -122,7 +125,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, func_(nullptr), dev_ctx_(dev_ctx), run_phi_kernel_(true), - pt_kernel_signature_(kernel_signature), + pt_kernel_signature_(std::move(kernel_signature)), pt_kernel_(pt_kernel) {} template @@ -170,7 +173,8 @@ PreparedOp PrepareImpl(const NameVarMap& ins, #endif if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) { - pt_kernel_signature = op.GetExpectedPhiKernelArgs(dygraph_exe_ctx); + pt_kernel_signature = + std::move(op.GetExpectedPhiKernelArgs(dygraph_exe_ctx)); VLOG(6) << pt_kernel_signature; pt_kernel_name = pt_kernel_signature.name; @@ -200,8 +204,8 @@ PreparedOp PrepareImpl(const NameVarMap& ins, << ", using_kernel_key:" << expected_kernel_key; phi::KernelKey try_pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key); - if (!phi::KernelFactory::Instance().IsSelectKernelValid( - pt_kernel_name, try_pt_kernel_key)) { + if (!phi::KernelFactory::Instance().HasKernel(pt_kernel_name, + try_pt_kernel_key)) { expected_kernel_key.library_type_ = expected_kernel_key_library_type; VLOG(3) << "modify XPU KP kernel: " << op.Type() << " is failed " << expected_kernel_key; @@ -211,8 +215,8 @@ PreparedOp PrepareImpl(const NameVarMap& ins, #endif pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key); - auto pt_kernel = phi::KernelFactory::Instance().SelectKernel(pt_kernel_name, - pt_kernel_key); + auto& pt_kernel = phi::KernelFactory::Instance().SelectKernel( + pt_kernel_name, pt_kernel_key); if (pt_kernel.IsValid() #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) @@ -227,9 +231,8 @@ PreparedOp PrepareImpl(const NameVarMap& ins, dev_ctx = pool.Get(expected_kernel_key.place_); } - // TODO(chenweihang): using CPUKernel when miss device kernel case - return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature, - pt_kernel, dev_ctx); + return PreparedOp(op, ctx, expected_kernel_key, + std::move(pt_kernel_signature), pt_kernel, dev_ctx); } else { VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name << "` not found."; @@ -270,15 +273,16 @@ PreparedOp PrepareImpl(const NameVarMap& ins, if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) { auto pt_cpu_kernel_key = FallBackToCpu(expected_kernel_key, pt_kernel_key, op); - auto pt_cpu_kernel = phi::KernelFactory::Instance().SelectKernel( + auto& pt_cpu_kernel = phi::KernelFactory::Instance().SelectKernel( pt_kernel_name, pt_cpu_kernel_key); if (pt_cpu_kernel.IsValid()) { VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name << " | kernel key: " << pt_cpu_kernel_key << " | kernel: " << pt_cpu_kernel; auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace()); - return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature, - pt_cpu_kernel, cpu_ctx); + return PreparedOp(op, ctx, expected_kernel_key, + std::move(pt_kernel_signature), pt_cpu_kernel, + cpu_ctx); } } } @@ -505,7 +509,6 @@ static void PreparedOpRunPtImpl( #endif } - // TODO(chenweihang): add debug flags later if (framework::IsComplexType(kernel_type.data_type_)) { HandleComplexGradToRealGrad(outs); } diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 04d0b4ca7a5..b3c5a6b5fa2 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -154,7 +154,7 @@ class PreparedOp { PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OpKernelType& kernel_type, - const framework::KernelSignature& kernel_signature, + framework::KernelSignature&& kernel_signature, const phi::Kernel& pt_kernel, platform::DeviceContext* dev_ctx); static PreparedOp Prepare(const NameVarMap& ins, @@ -206,7 +206,7 @@ class PreparedOp { bool run_phi_kernel_{false}; bool run_kp_kernel_{false}; framework::KernelSignature pt_kernel_signature_; - phi::Kernel pt_kernel_; + const phi::Kernel& pt_kernel_; }; const inline framework::Attribute& GetAttr( @@ -289,7 +289,7 @@ void BuildDygraphPhiKernelContext( } } - auto ins_vector = it->second; + auto& ins_vector = it->second; size_t end_idx = start_idx + ins_vector.size(); for (size_t offset = 0; offset < ins_vector.size(); ++offset) { @@ -587,7 +587,7 @@ void PreparePhiData(const phi::Kernel& pt_kernel, auto& ins_vector = ins.at(input_names[i]); for (size_t offset = 0; offset < ins_vector.size(); ++offset) { - auto var = ins_vector[offset]; + auto& var = ins_vector[offset]; const auto* tensor_in = GetTensorFromVar(var->Var()); if (tensor_in && tensor_in->IsInitialized()) { if (in_def.backend == phi::Backend::ALL_BACKEND) { diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 0093decea5a..14e4c3da624 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -226,6 +226,7 @@ bool AnalysisPredictor::PrepareScope( status_is_cloned_ = true; } else { paddle::framework::InitDevices(); + paddle::framework::InitDefaultKernelSignatureMap(); // TODO(wilber): we need to release memory occupied by weights. scope_.reset(new paddle::framework::Scope()); status_is_cloned_ = false; diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 73d14f215e2..1c4369af646 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -92,6 +92,7 @@ bool NativePaddlePredictor::Init( "The sub_scope should not be nullptr.")); } else { paddle::framework::InitDevices(); + paddle::framework::InitDefaultKernelSignatureMap(); scope_.reset(new paddle::framework::Scope()); } diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index 03a244a457c..eb44655c88f 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -517,10 +517,8 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { ctx->HasInputs(kOutputs); ctx->HasInputs(framework::GradVarName(kOutputs)); auto pg_ig_names = ctx->Outputs(kXGRAD); - std::vector in_var_ptrs = - ctx->GetInputVarPtrs(kX); - std::vector out_var_ptrs = - ctx->GetOutputVarPtrs(kXGRAD); + auto in_var_ptrs = ctx->GetInputVarPtrs(kX); + auto out_var_ptrs = ctx->GetOutputVarPtrs(kXGRAD); PADDLE_ENFORCE_EQ(in_var_ptrs.size(), out_var_ptrs.size(), platform::errors::InvalidArgument( "The size of Inputs(X) must be the same as " diff --git a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc index 44f602237da..92c9ab34aa4 100644 --- a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc +++ b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc @@ -63,10 +63,8 @@ class CollectFpnProposalsOp : public framework::OperatorWithKernel { context->ShareLoD("MultiLevelRois", "FpnRois"); } if (context->IsRuntime() && !context->HasInputs("MultiLevelRoIsNum")) { - std::vector roi_inputs = - context->GetInputVarPtrs("MultiLevelRois"); - std::vector score_inputs = - context->GetInputVarPtrs("MultiLevelScores"); + auto roi_inputs = context->GetInputVarPtrs("MultiLevelRois"); + auto score_inputs = context->GetInputVarPtrs("MultiLevelScores"); for (size_t i = 0; i < roi_inputs.size(); ++i) { framework::Variable *roi_var = BOOST_GET(framework::Variable *, roi_inputs[i]); diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 7df6d8f7f79..93f10b34b6c 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -60,6 +60,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/uva_utils.h" #include "paddle/phi/core/compat/arg_map_context.h" #include "paddle/phi/core/compat/type_defs.h" +#include "paddle/phi/core/type_defs.h" namespace paddle { namespace pybind { @@ -2027,26 +2028,35 @@ void BindImperative(py::module *m_ptr) { *(imperative::AmpOperators::Instance().GetMutableAllowOps()), *(imperative::AmpOperators::Instance().GetMutableBlockOps())); }) - .def("_get_kernel_signature", - [](imperative::Tracer &self, const std::string &type, - const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, - framework::AttributeMap attrs) { - // TODO(xiongkun): move this function outside of tracer. - auto ins_map = ConvertToNameTensorMap(ins); - auto outs_map = ConvertToNameTensorMap(outs); - { - auto to_vector = [](paddle::SmallVector &vec) { - return std::vector(vec.begin(), vec.end()); - }; - auto ret = self.GetExpectedKernelSignature(type, ins_map, - outs_map, attrs); - auto kernelsig_ins = to_vector(std::get<0>(ret.args)); - auto kernelsig_attrs = to_vector(std::get<1>(ret.args)); - auto kernelsig_outs = to_vector(std::get<2>(ret.args)); - return std::make_tuple(kernelsig_ins, kernelsig_attrs, - kernelsig_outs); - } - }) + .def( + "_get_kernel_signature", + [](imperative::Tracer &self, const std::string &type, + const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, + framework::AttributeMap attrs) { + // TODO(xiongkun): move this function outside of tracer. + auto ins_map = ConvertToNameTensorMap(ins); + auto outs_map = ConvertToNameTensorMap(outs); + { + auto input_to_vector = + [](paddle::SmallVector &vec) { + return std::vector(vec.begin(), vec.end()); + }; + auto output_to_vector = + [](paddle::SmallVector &vec) { + return std::vector(vec.begin(), vec.end()); + }; + auto attr_to_vector = [](paddle::SmallVector &vec) { + return std::vector(vec.begin(), vec.end()); + }; + auto ret = self.GetExpectedKernelSignature(type, ins_map, + outs_map, attrs); + auto kernelsig_ins = input_to_vector(std::get<0>(ret.args)); + auto kernelsig_attrs = attr_to_vector(std::get<1>(ret.args)); + auto kernelsig_outs = output_to_vector(std::get<2>(ret.args)); + return std::make_tuple(kernelsig_ins, kernelsig_attrs, + kernelsig_outs); + } + }) .def("trace", [](imperative::Tracer &self, const std::string &type, const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 7cc9d2220af..5f9db51ee74 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2907,6 +2907,8 @@ All parameter, weight, gradient are variables in Paddle. framework::LoadOpMetaInfoAndRegisterOp(dso_name)); }); m.def("init_devices", []() { framework::InitDevices(); }); + m.def("init_default_kernel_signatures", + []() { framework::InitDefaultKernelSignatureMap(); }); m.def("is_compiled_with_cuda", IsCompiledWithCUDA); m.def("is_compiled_with_ascend", IsCompiledWithAscend); m.def("is_compiled_with_rocm", IsCompiledWithROCM); diff --git a/paddle/infrt/dialect/phi/pass/kernel_op_desc.cc b/paddle/infrt/dialect/phi/pass/kernel_op_desc.cc index a26e8e2dca5..b1aa8126096 100644 --- a/paddle/infrt/dialect/phi/pass/kernel_op_desc.cc +++ b/paddle/infrt/dialect/phi/pass/kernel_op_desc.cc @@ -15,6 +15,7 @@ #include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h" #include #include "paddle/infrt/dialect/phi/data_type.h" +#include "paddle/phi/core/type_defs.h" #include "paddle/phi/kernels/declarations.h" namespace infrt { @@ -92,10 +93,10 @@ std::vector GetCandidateKernels( phi_kernel_desc.input_types.clear(); phi_kernel_desc.output_types.clear(); phi::KernelArgsDef args_def = kernel_key_map.at(kernel_key).args_def(); - const paddle::SmallVector& input_arg = - args_def.input_defs(); - const paddle::SmallVector& output_arg = - args_def.output_defs(); + const paddle::SmallVector& + input_arg = args_def.input_defs(); + const paddle::SmallVector& + output_arg = args_def.output_defs(); for (auto tensor_arg : input_arg) { phi_kernel_desc.input_types.emplace_back(ConvertPlaceFromPhi(tensor_arg)); } diff --git a/paddle/infrt/host_context/value.h b/paddle/infrt/host_context/value.h index ecd11881809..1834cb4c0db 100644 --- a/paddle/infrt/host_context/value.h +++ b/paddle/infrt/host_context/value.h @@ -91,6 +91,7 @@ using ValueVariantType = std::vector<::phi::DenseTensor*>, paddle::experimental::ScalarBase<::phi::DenseTensor>, paddle::experimental::IntArrayBase<::phi::DenseTensor>, + std::vector, std::vector<::phi::MetaTensor*>, ::phi::MetaConfig, paddle::experimental::Backend, diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 8e05f9d9090..70c3b27ede5 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -271,10 +271,10 @@ std::vector split_impl(const Tensor& x, // Calculate the number of out tensors size_t out_number; - if (num_or_sections.GetData().size() == 1) { + if (num_or_sections.size() == 1) { out_number = num_or_sections.GetData()[0]; } else { - out_number = num_or_sections.GetData().size(); + out_number = num_or_sections.size(); } std::vector out; @@ -449,54 +449,6 @@ std::tuple momentum_impl( return api_output; } -std::vector unbind_impl(const Tensor& input, int axis) { - auto kernel_key_set = ParseKernelKeyByInputArgs(input); - auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - - Backend kernel_backend = kernel_key.backend(); - DataLayout kernel_layout = kernel_key.layout(); - DataType kernel_data_type = kernel_key.dtype(); - - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "unbind", {kernel_backend, kernel_layout, kernel_data_type}); - VLOG(6) << "unbind API kernel key: [" << kernel_backend << ", " - << kernel_layout << ", " << kernel_data_type << "]"; - VLOG(6) << "unbind API kernel: " << kernel; - - auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); - - auto dense_input = PrepareData(input, kernel.InputAt(0), {}); - - // Calculate the number of out tensors - auto input_shape = input.dims(); - if (axis < 0) { - axis = input_shape.size() + axis; - } - auto out_num = input_shape[axis]; - - std::vector out; - auto dense_outs = SetKernelOutput(out_num, kernel_backend, &out); - std::vector meta_outs; - meta_outs.reserve(out_num); - std::vector meta_out_ptrs; - meta_out_ptrs.reserve(out_num); - for (int64_t i = 0; i < out_num; ++i) { - meta_outs.push_back(dense_outs[i]); - meta_out_ptrs.push_back(&meta_outs.back()); - } - - phi::UnbindInferMeta(MakeMetaTensor(*dense_input), axis, meta_out_ptrs); - - using kernel_signature = void (*)(const phi::DeviceContext&, - const phi::DenseTensor&, - int, - std::vector&); - auto* kernel_fn = kernel.GetVariadicKernelFn(); - (*kernel_fn)(*dev_ctx, *dense_input, axis, dense_outs); - - return out; -} - ////////////////// Backward(grad) api impls ////////////////////// // TODO(chenweihang): the original sum grad op can support higher-level @@ -674,71 +626,6 @@ std::tuple batch_norm_impl( return api_output; } -std::vector concat_grad_impl(const std::vector& x, - const Tensor& out_grad, - const Scalar& axis) { - auto kernel_key_set = ParseKernelKeyByInputArgs(out_grad); - auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - - Backend kernel_backend = kernel_key.backend(); - DataLayout kernel_layout = kernel_key.layout(); - DataType kernel_data_type = kernel_key.dtype(); - - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "concat_grad", {kernel_backend, kernel_layout, kernel_data_type}); - VLOG(6) << "concat_grad API kernel key: [" << kernel_backend << ", " - << kernel_layout << ", " << kernel_data_type << "]"; - VLOG(6) << "concat_grad API kernel: " << kernel; - - auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); - - // std::unique_ptr> - auto dense_x = PrepareData(x, kernel.InputAt(0), {}); - auto dense_out_grad = PrepareData(out_grad, kernel.InputAt(1), {}); - - // Calculate the number of out tensors - size_t out_number = x.size(); - std::vector x_grad; - auto dense_x_grad = SetKernelOutput(out_number, kernel_backend, &x_grad); - - std::vector meta_x; - meta_x.reserve(x.size()); - std::vector meta_x_ptrs; - meta_x_ptrs.reserve(x.size()); - for (const auto& t : *dense_x) { - meta_x.push_back(t); - meta_x_ptrs.push_back(&meta_x.back()); - } - - std::vector meta_x_grad; - meta_x_grad.reserve(x.size()); - std::vector meta_x_grad_ptrs; - meta_x_grad_ptrs.reserve(x.size()); - for (size_t i = 0; i < out_number; ++i) { - meta_x_grad.push_back(*dense_x_grad[i]); - meta_x_grad_ptrs.push_back(&meta_x_grad.back()); - } - - phi::UnchangedMultiInferMeta(meta_x_ptrs, meta_x_grad_ptrs); - - std::vector dense_x_ptr; - dense_x_ptr.reserve(x.size()); - for (const auto& t : *dense_x) { - dense_x_ptr.push_back(&t); - } - - using kernel_signature = void (*)(const platform::DeviceContext&, - const std::vector&, - const phi::DenseTensor&, - const phi::Scalar&, - std::vector); - auto* kernel_fn = kernel.GetVariadicKernelFn(); - (*kernel_fn)( - *dev_ctx, dense_x_ptr, *dense_out_grad, phi::Scalar(axis), dense_x_grad); - - return x_grad; -} - Tensor imag_grad_impl(const Tensor& out_grad) { phi::KernelKey kernel_key{ParseBackend(out_grad), out_grad.layout(), @@ -795,328 +682,5 @@ Tensor real_grad_impl(const Tensor& out_grad) { return out; } -std::vector stack_grad_impl(const std::vector& x, - const Tensor& out_grad, - int axis) { - auto kernel_key_set = ParseKernelKeyByInputArgs(out_grad); - auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - - Backend kernel_backend = kernel_key.backend(); - DataLayout kernel_layout = kernel_key.layout(); - DataType kernel_data_type = kernel_key.dtype(); - - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "stack_grad", {kernel_backend, kernel_layout, kernel_data_type}); - VLOG(6) << "stack_grad API kernel key: [" << kernel_backend << ", " - << kernel_layout << ", " << kernel_data_type << "]"; - VLOG(6) << "stack_grad API kernel: " << kernel; - - auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); - - auto dense_out_grad = PrepareData(out_grad, kernel.InputAt(0), {}); - - size_t out_number = x.size(); - std::vector x_grad; - auto dense_x_grad = SetKernelOutput(out_number, kernel_backend, &x_grad); - std::vector meta_x_grad; - meta_x_grad.reserve(out_number); - std::vector meta_x_grad_ptrs; - meta_x_grad_ptrs.reserve(out_number); - for (size_t i = 0; i < out_number; ++i) { - meta_x_grad.push_back(dense_x_grad[i]); - meta_x_grad_ptrs.push_back(&meta_x_grad.back()); - } - - phi::StackGradInferMeta( - MakeMetaTensor(*dense_out_grad), axis, meta_x_grad_ptrs); - - using kernel_signature = void (*)(const platform::DeviceContext&, - const phi::DenseTensor&, - int axis, - std::vector); - auto* kernel_fn = kernel.GetVariadicKernelFn(); - (*kernel_fn)(*dev_ctx, *dense_out_grad, axis, dense_x_grad); - - return x_grad; -} - -std::vector meshgrid_impl(const std::vector& inputs) { - Backend kernel_backend = Backend::UNDEFINED; - DataLayout kernel_layout = DataLayout::UNDEFINED; - DataType kernel_data_type = DataType::UNDEFINED; - - if (kernel_backend == Backend::UNDEFINED || - kernel_layout == DataLayout::UNDEFINED || - kernel_data_type == DataType::UNDEFINED) { - auto kernel_key_set = ParseKernelKeyByInputArgs(inputs); - auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - if (kernel_backend == Backend::UNDEFINED) { - kernel_backend = kernel_key.backend(); - } - if (kernel_layout == DataLayout::UNDEFINED) { - kernel_layout = kernel_key.layout(); - } - if (kernel_data_type == DataType::UNDEFINED) { - kernel_data_type = kernel_key.dtype(); - } - } - - const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "meshgrid", {kernel_backend, kernel_layout, kernel_data_type}); - VLOG(6) << "meshgrid API kernel key: [" << kernel_backend << ", " - << kernel_layout << ", " << kernel_data_type << "]"; - VLOG(6) << "meshgrid API kernel: " << kernel; - - auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); - - auto input_inputs_vec = PrepareData(inputs, kernel.InputAt(0), {}); - std::vector input_inputs(input_inputs_vec->size()); - for (size_t i = 0; i < input_inputs.size(); ++i) { - input_inputs[i] = &input_inputs_vec->at(i); - } - - auto x_meta_vec = MakeMetaTensor(input_inputs); - std::vector inputs_metas(x_meta_vec.size()); - for (size_t i = 0; i < x_meta_vec.size(); ++i) { - inputs_metas[i] = &x_meta_vec[i]; - } - - // Calculate the number of out tensors - size_t out_number = inputs.size(); - - std::vector out; - auto dense_outs = SetKernelOutput(out_number, kernel_backend, &out); - - std::vector meta_outs; - meta_outs.reserve(out_number); - std::vector meta_out_ptrs; - meta_out_ptrs.reserve(out_number); - for (size_t i = 0; i < out_number; ++i) { - meta_outs.push_back(dense_outs[i]); - meta_out_ptrs.push_back(&meta_outs.back()); - } - phi::MeshgridInferMeta(inputs_metas, meta_out_ptrs); - - using kernel_signature = void (*)(const platform::DeviceContext&, - const std::vector&, - std::vector&); - auto* kernel_fn = kernel.GetVariadicKernelFn(); - (*kernel_fn)(*dev_ctx, input_inputs, dense_outs); - - return out; -} - -std::vector meshgrid_grad_impl( - const std::vector& inputs, - const std::vector& outputs_grad) { - Backend kernel_backend = Backend::UNDEFINED; - DataLayout kernel_layout = DataLayout::UNDEFINED; - DataType kernel_data_type = DataType::UNDEFINED; - - if (kernel_backend == Backend::UNDEFINED || - kernel_layout == DataLayout::UNDEFINED || - kernel_data_type == DataType::UNDEFINED) { - auto kernel_key_set = ParseKernelKeyByInputArgs(inputs, outputs_grad); - auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - if (kernel_backend == Backend::UNDEFINED) { - kernel_backend = kernel_key.backend(); - } - if (kernel_layout == DataLayout::UNDEFINED) { - kernel_layout = kernel_key.layout(); - } - if (kernel_data_type == DataType::UNDEFINED) { - kernel_data_type = kernel_key.dtype(); - } - } - - const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "meshgrid_grad", {kernel_backend, kernel_layout, kernel_data_type}); - VLOG(6) << "meshgrid_grad API kernel key: [" << kernel_backend << ", " - << kernel_layout << ", " << kernel_data_type << "]"; - VLOG(6) << "meshgrid_grad API kernel: " << kernel; - - auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); - - auto input_inputs_vec = PrepareData(inputs, kernel.InputAt(0), {}); - std::vector input_inputs(input_inputs_vec->size()); - for (size_t i = 0; i < input_inputs.size(); ++i) { - input_inputs[i] = &input_inputs_vec->at(i); - } - auto input_outputs_grad_vec = - PrepareData(outputs_grad, kernel.InputAt(1), {}); - std::vector input_outputs_grad( - input_outputs_grad_vec->size()); - for (size_t i = 0; i < input_outputs_grad.size(); ++i) { - input_outputs_grad[i] = &input_outputs_grad_vec->at(i); - } - - size_t out_number = inputs.size(); - std::vector api_output; - auto kernel_out = SetKernelOutput(out_number, kernel_backend, &api_output); - - auto inputs_meta_vec = MakeMetaTensor(input_inputs); - std::vector inputs_metas(inputs_meta_vec.size()); - for (size_t i = 0; i < inputs_meta_vec.size(); ++i) { - inputs_metas[i] = &inputs_meta_vec[i]; - } - - auto outputs_grad_meta_vec = MakeMetaTensor(input_outputs_grad); - std::vector outputs_grad_metas( - outputs_grad_meta_vec.size()); - for (size_t i = 0; i < outputs_grad_meta_vec.size(); ++i) { - outputs_grad_metas[i] = &outputs_grad_meta_vec[i]; - } - - std::vector meta_outs; - meta_outs.reserve(out_number); - std::vector meta_out_ptrs; - meta_out_ptrs.reserve(out_number); - for (size_t i = 0; i < out_number; ++i) { - meta_outs.push_back(kernel_out[i]); - meta_out_ptrs.push_back(&meta_outs.back()); - } - - phi::MeshgridGradInferMeta(inputs_metas, outputs_grad_metas, meta_out_ptrs); - - using kernel_signature = void (*)(const platform::DeviceContext&, - const std::vector&, - const std::vector&, - std::vector&); - auto* kernel_fn = kernel.GetVariadicKernelFn(); - (*kernel_fn)(*dev_ctx, input_inputs, input_outputs_grad, kernel_out); - - return api_output; -} - -std::vector multi_dot_grad_impl(const std::vector& x, - const Tensor& out_grad) { - Backend kernel_backend = Backend::UNDEFINED; - DataLayout kernel_layout = DataLayout::UNDEFINED; - DataType kernel_data_type = DataType::UNDEFINED; - - if (kernel_backend == Backend::UNDEFINED || - kernel_layout == DataLayout::UNDEFINED || - kernel_data_type == DataType::UNDEFINED) { - auto kernel_key_set = ParseKernelKeyByInputArgs(x, out_grad); - auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - if (kernel_backend == Backend::UNDEFINED) { - kernel_backend = kernel_key.backend(); - } - if (kernel_layout == DataLayout::UNDEFINED) { - kernel_layout = kernel_key.layout(); - } - if (kernel_data_type == DataType::UNDEFINED) { - kernel_data_type = kernel_key.dtype(); - } - } - - VLOG(6) << "multi_dot_grad API kernel key: [" << kernel_backend << ", " - << kernel_layout << ", " << kernel_data_type << "]"; - const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "multi_dot_grad", {kernel_backend, kernel_layout, kernel_data_type}); - VLOG(6) << "multi_dot_grad API kernel: " << kernel; - - auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); - - auto input_x_vec = PrepareData(x, kernel.InputAt(0), {}); - std::vector input_x(input_x_vec->size()); - for (size_t i = 0; i < input_x.size(); ++i) { - input_x[i] = &input_x_vec->at(i); - } - auto input_out_grad = PrepareData(out_grad, kernel.InputAt(1), {}); - - size_t out_number = input_x.size(); - std::vector api_output; - auto kernel_out = SetKernelOutput(out_number, kernel_backend, &api_output); - - auto x_meta_vec = MakeMetaTensor(input_x); - std::vector x_metas(x_meta_vec.size()); - for (size_t i = 0; i < x_meta_vec.size(); ++i) { - x_metas[i] = &x_meta_vec[i]; - } - - std::vector meta_outs; - meta_outs.reserve(out_number); - std::vector meta_out_ptrs; - meta_out_ptrs.reserve(out_number); - for (size_t i = 0; i < out_number; ++i) { - meta_outs.push_back(kernel_out[i]); - meta_out_ptrs.push_back(&meta_outs.back()); - } - - phi::MultiDotGradInferMeta( - x_metas, MakeMetaTensor(*input_out_grad), meta_out_ptrs); - - using kernel_signature = void (*)(const platform::DeviceContext&, - const std::vector&, - const phi::DenseTensor&, - std::vector&); - auto* kernel_fn = kernel.GetVariadicKernelFn(); - (*kernel_fn)(*dev_ctx, input_x, *input_out_grad, kernel_out); - - return api_output; -} - -std::vector multiplex_grad_impl(const std::vector& inputs, - const Tensor& ids, - const Tensor& out_grad) { - Backend kernel_backend = Backend::UNDEFINED; - DataLayout kernel_layout = DataLayout::UNDEFINED; - DataType kernel_data_type = DataType::UNDEFINED; - - if (kernel_backend == Backend::UNDEFINED || - kernel_layout == DataLayout::UNDEFINED || - kernel_data_type == DataType::UNDEFINED) { - auto kernel_key_set = ParseKernelKeyByInputArgs(out_grad); - auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - if (kernel_backend == Backend::UNDEFINED) { - kernel_backend = kernel_key.backend(); - } - if (kernel_layout == DataLayout::UNDEFINED) { - kernel_layout = kernel_key.layout(); - } - if (kernel_data_type == DataType::UNDEFINED) { - kernel_data_type = kernel_key.dtype(); - } - } - - VLOG(6) << "multiplex_grad API kernel key: [" << kernel_backend << ", " - << kernel_layout << ", " << kernel_data_type << "]"; - const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "multiplex_grad", {kernel_backend, kernel_layout, kernel_data_type}); - VLOG(6) << "multiplex_grad API kernel: " << kernel; - - auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); - - auto input_ids = PrepareData(ids, kernel.InputAt(0), {}); - auto input_out_grad = PrepareData(out_grad, kernel.InputAt(1), {}); - - auto out_number = inputs.size(); - std::vector api_output; - auto kernel_out = SetKernelOutput(out_number, kernel_backend, &api_output); - - std::vector meta_outs; - meta_outs.reserve(out_number); - std::vector meta_out_ptrs; - meta_out_ptrs.reserve(out_number); - for (size_t i = 0; i < out_number; ++i) { - meta_outs.push_back(kernel_out[i]); - meta_out_ptrs.push_back(&meta_outs.back()); - } - - phi::MultiplexGradInferMeta(MakeMetaTensor(*input_ids), - MakeMetaTensor(*input_out_grad), - meta_out_ptrs); - - using kernel_signature = void (*)(const platform::DeviceContext&, - const phi::DenseTensor&, - const phi::DenseTensor&, - std::vector&); - auto* kernel_fn = kernel.GetVariadicKernelFn(); - (*kernel_fn)(*dev_ctx, *input_ids, *input_out_grad, kernel_out); - - return api_output; -} - } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/lib/api_custom_impl.h b/paddle/phi/api/lib/api_custom_impl.h index 0e360ce4a99..0d1ba3e98e5 100644 --- a/paddle/phi/api/lib/api_custom_impl.h +++ b/paddle/phi/api/lib/api_custom_impl.h @@ -30,6 +30,20 @@ namespace experimental { ////////////////// Forward api impls ////////////////////// +std::tuple batch_norm_impl( + const Tensor& x, + const Tensor& scale, + const Tensor& bias, + const Tensor& mean, + const Tensor& variance, + float momentum, + float epsilon, + const std::string& data_layout, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + bool fuse_with_relu); + Tensor conv2d_impl(const Tensor& input, const Tensor& filter, const std::vector& strides, @@ -62,8 +76,6 @@ std::vector split_impl(const Tensor& x, const IntArray& num_or_sections, const Scalar& axis); -std::vector meshgrid_impl(const std::vector& inputs); - std::tuple momentum_impl( const Tensor& param, const Tensor& grad, @@ -77,49 +89,14 @@ std::tuple momentum_impl( bool multi_precision, float rescale_grad); -std::vector unbind_impl(const Tensor& input, int axis); - ////////////////// Backward(grad) api impls ////////////////////// std::vector add_n_grad_impl(const std::vector& x, const Tensor& out_grad); -std::tuple batch_norm_impl( - const Tensor& x, - const Tensor& scale, - const Tensor& bias, - const Tensor& mean, - const Tensor& variance, - float momentum, - float epsilon, - const std::string& data_layout, - bool is_test, - bool use_global_stats, - bool trainable_statistics, - bool fuse_with_relu); - -/************************ backward api impl ***************************/ - -std::vector concat_grad_impl(const std::vector& x, - const Tensor& out_grad, - const Scalar& axis); - Tensor imag_grad_impl(const Tensor& x); Tensor real_grad_impl(const Tensor& x); -std::vector stack_grad_impl(const std::vector& x, - const Tensor& out_grad, - int axis); -std::vector meshgrid_grad_impl(const std::vector& inputs, - const std::vector& outputs_grad); - -std::vector multi_dot_grad_impl(const std::vector& x, - const Tensor& out_grad); - -std::vector multiplex_grad_impl(const std::vector& inputs, - const Tensor& ids, - const Tensor& out_grad); - } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index 732ecacde94..f9db1529569 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -76,6 +76,16 @@ std::vector MakeMetaTensor( return meta_tensors; } +std::vector MakeMetaTensor( + const std::vector& tensors) { + std::vector meta_tensors; + meta_tensors.reserve(tensors.size()); + for (auto* t : tensors) { + meta_tensors.emplace_back(*t); + } + return meta_tensors; +} + phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor) { return phi::MetaTensor(tensor); } diff --git a/paddle/phi/api/lib/api_gen_utils.h b/paddle/phi/api/lib/api_gen_utils.h index d7ecef61c5b..035dfc52047 100644 --- a/paddle/phi/api/lib/api_gen_utils.h +++ b/paddle/phi/api/lib/api_gen_utils.h @@ -53,6 +53,9 @@ phi::MetaTensor MakeMetaTensor(const phi::DenseTensor& tensor); std::vector MakeMetaTensor( const std::vector& tensors); +std::vector MakeMetaTensor( + const std::vector& tensors); + phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor); phi::MetaTensor MakeMetaTensor(const phi::StringTensor& tensor); diff --git a/paddle/phi/common/int_array.h b/paddle/phi/common/int_array.h index 490d7dabd40..f9d07249e0f 100644 --- a/paddle/phi/common/int_array.h +++ b/paddle/phi/common/int_array.h @@ -96,6 +96,8 @@ class IntArrayBase { template IntArrayBase(const IntArrayBase& other) : array_(other.GetData()) {} + size_t size() const { return array_.size(); } + const std::vector& GetData() const { return array_; } private: diff --git a/paddle/phi/core/compat/arg_map_context.h b/paddle/phi/core/compat/arg_map_context.h index 71cec011411..122ebed2194 100644 --- a/paddle/phi/core/compat/arg_map_context.h +++ b/paddle/phi/core/compat/arg_map_context.h @@ -19,45 +19,33 @@ limitations under the License. */ #include #include "paddle/phi/common/place.h" +#include "paddle/phi/core/type_defs.h" #include "paddle/utils/any.h" #include "paddle/utils/flat_hash_map.h" #include "paddle/utils/small_vector.h" namespace phi { -constexpr char kGradVarSuffix[] = "@GRAD"; - -constexpr size_t kGradVarSuffixSize = 5U; - -inline std::string GradVarName(const std::string& var_name) { - std::string result; - result.reserve(var_name.size() + kGradVarSuffixSize); - result += var_name; - result += kGradVarSuffix; - return result; -} - // tuple(input_names, attr_names, output_names) -using KernelArgsTuple = std::tuple, - paddle::SmallVector, - paddle::SmallVector>; +using KernelArgsTuple = std::tuple, + paddle::SmallVector, + paddle::SmallVector>; struct KernelSignature { - std::string name; + const char* name; KernelArgsTuple args; KernelSignature() = default; - KernelSignature(std::string&& kernel_name, - paddle::SmallVector&& inputs, - paddle::SmallVector&& attrs, - paddle::SmallVector&& outputs) - : name(std::move(kernel_name)), - args(std::make_tuple(inputs, attrs, outputs)) {} - KernelSignature(const std::string& kernel_name, - const paddle::SmallVector& inputs, - const paddle::SmallVector& attrs, - const paddle::SmallVector& outputs) + KernelSignature(const char* kernel_name, + paddle::SmallVector&& inputs, + paddle::SmallVector&& attrs, + paddle::SmallVector&& outputs) + : name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {} + KernelSignature(const char* kernel_name, + const paddle::SmallVector& inputs, + const paddle::SmallVector& attrs, + const paddle::SmallVector& outputs) : name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {} // TODO(chenweihang): add assign constructor to solve windows compile diff --git a/paddle/phi/core/compat/convert_utils.cc b/paddle/phi/core/compat/convert_utils.cc index 43febb2ac04..4fa11ac7860 100644 --- a/paddle/phi/core/compat/convert_utils.cc +++ b/paddle/phi/core/compat/convert_utils.cc @@ -102,7 +102,7 @@ phi::Place TransToPhiPlace(const Backend& backend, bool set_device_id) { } } -std::string TransToPhiKernelName(const std::string& fluid_op_name) { +const std::string& TransToPhiKernelName(const std::string& fluid_op_name) { return OpUtilsMap::Instance().GetBaseKernelName(fluid_op_name); } diff --git a/paddle/phi/core/compat/convert_utils.h b/paddle/phi/core/compat/convert_utils.h index 62145976487..5982ab0deff 100644 --- a/paddle/phi/core/compat/convert_utils.h +++ b/paddle/phi/core/compat/convert_utils.h @@ -22,7 +22,7 @@ limitations under the License. */ namespace phi { -std::string TransToPhiKernelName(const std::string& fluid_op_name); +const std::string& TransToPhiKernelName(const std::string& fluid_op_name); const std::string& TransToFluidOpName(const std::string& phi_kernel_name); Backend TransToPhiBackend(const phi::Place& place); diff --git a/paddle/phi/core/compat/op_utils.h b/paddle/phi/core/compat/op_utils.h index 6716f479180..9c926fa871b 100644 --- a/paddle/phi/core/compat/op_utils.h +++ b/paddle/phi/core/compat/op_utils.h @@ -26,6 +26,8 @@ limitations under the License. */ namespace phi { +const static std::string deprecated_kernel_name = "deprecated"; // NOLINT + const std::unordered_set standard_kernel_suffixs({ "sr", // SelectedRows kernel "raw" // fallback kernel of origfinal fluid op @@ -134,9 +136,9 @@ class OpUtilsMap { arg_mapping_fn_map_.insert({std::move(op_type), std::move(fn)}); } - std::string GetBaseKernelName(const std::string& op_type) const { + const std::string& GetBaseKernelName(const std::string& op_type) const { if (deprecated_op_names.find(op_type) != deprecated_op_names.end()) { - return "deprecated"; + return deprecated_kernel_name; } auto it = base_kernel_name_map_.find(op_type); if (it == base_kernel_name_map_.end()) { @@ -150,7 +152,7 @@ class OpUtilsMap { auto it = arg_mapping_fn_map_.find(op_type); if (it == arg_mapping_fn_map_.end()) { auto func = - [op_type](const ArgumentMappingContext& ctx) -> KernelSignature { + [&op_type](const ArgumentMappingContext& ctx) -> KernelSignature { return DefaultKernelSignatureMap::Instance().Get(op_type); }; return func; diff --git a/paddle/phi/core/infermeta_utils.cc b/paddle/phi/core/infermeta_utils.cc index 0496d727e8d..70f26102cba 100644 --- a/paddle/phi/core/infermeta_utils.cc +++ b/paddle/phi/core/infermeta_utils.cc @@ -20,14 +20,12 @@ void InferMetaContext::SetMetaConfig(MetaConfig config) { config_ = std::move(config); } -void InferMetaContext::EmplaceBackInput( - std::shared_ptr input) { +void InferMetaContext::EmplaceBackInput(MetaTensor input) { int index = inputs_.size(); inputs_.emplace_back(std::move(input)); input_range_.emplace_back(std::pair(index, index + 1)); } -void InferMetaContext::EmplaceBackOutput( - std::shared_ptr output) { +void InferMetaContext::EmplaceBackOutput(MetaTensor output) { int index = outputs_.size(); outputs_.emplace_back(std::move(output)); output_range_.emplace_back(std::pair(index, index + 1)); @@ -37,7 +35,7 @@ void InferMetaContext::EmplaceBackAttr(paddle::any attr) { } void InferMetaContext::EmplaceBackInputs( - paddle::SmallVector> inputs) { + paddle::SmallVector inputs) { int index = inputs_.size(); input_range_.emplace_back(std::pair(index, index + inputs.size())); inputs_.insert(inputs_.end(), @@ -45,7 +43,7 @@ void InferMetaContext::EmplaceBackInputs( std::make_move_iterator(inputs.end())); } void InferMetaContext::EmplaceBackOutputs( - paddle::SmallVector> outputs) { + paddle::SmallVector outputs) { int index = outputs_.size(); output_range_.emplace_back( std::pair(index, index + outputs.size())); @@ -64,24 +62,25 @@ const std::pair& InferMetaContext::OutputRangeAt(size_t idx) const { const MetaConfig& InferMetaContext::GetMetaConfig() const { return config_; } const MetaTensor& InferMetaContext::InputAt(size_t idx) const { - return *inputs_.at(idx); + return inputs_.at(idx); } -paddle::optional InferMetaContext::OptionalInputAt( +paddle::optional InferMetaContext::OptionalInputAt( size_t idx) const { const auto& input = inputs_.at(idx); - return input ? paddle::optional{static_cast< - const phi::MetaTensor&>(*input)} - : paddle::optional{paddle::none}; + return input.initialized() + ? paddle::optional{input} + : paddle::optional{paddle::none}; } -std::vector InferMetaContext::InputsBetween(size_t start, - size_t end) const { - std::vector result; +std::vector InferMetaContext::InputsBetween( + size_t start, size_t end) const { + std::vector result; result.reserve(end - start); for (size_t i = start; i < end; ++i) { - result.push_back(inputs_.at(i).get()); + auto& in = inputs_.at(i); + result.emplace_back(in.initialized() ? &in : nullptr); } return result; @@ -91,12 +90,13 @@ paddle::optional> InferMetaContext::OptionalInputsBetween(size_t start, size_t end) const { const auto& first = inputs_.at(start); - if (first) { + if (first.initialized()) { std::vector result; result.reserve(end - start); for (size_t i = start; i < end; ++i) { - result.push_back(inputs_.at(i).get()); + auto& in = inputs_.at(i); + result.emplace_back(in.initialized() ? &in : nullptr); } return paddle::optional>(result); @@ -105,7 +105,8 @@ InferMetaContext::OptionalInputsBetween(size_t start, size_t end) const { } MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) { - return outputs_.at(idx).get(); + auto& out = outputs_.at(idx); + return out.initialized() ? &out : nullptr; } std::vector InferMetaContext::MutableOutputBetween(size_t start, @@ -113,7 +114,8 @@ std::vector InferMetaContext::MutableOutputBetween(size_t start, std::vector result; result.reserve(end - start); for (size_t i = start; i < end; ++i) { - result.emplace_back(outputs_.at(i).get()); + auto& out = outputs_.at(i); + result.emplace_back(out.initialized() ? &out : nullptr); } return result; } diff --git a/paddle/phi/core/infermeta_utils.h b/paddle/phi/core/infermeta_utils.h index fad437f82c3..699c38ebd47 100644 --- a/paddle/phi/core/infermeta_utils.h +++ b/paddle/phi/core/infermeta_utils.h @@ -37,28 +37,28 @@ class InferMetaContext { explicit InferMetaContext(MetaConfig config) : config_(config) {} void SetMetaConfig(MetaConfig config); - void EmplaceBackInput(std::shared_ptr input); - void EmplaceBackOutput(std::shared_ptr output); + const MetaConfig& GetMetaConfig() const; + + void EmplaceBackInput(MetaTensor input); + void EmplaceBackOutput(MetaTensor output); void EmplaceBackAttr(paddle::any attr); void EmplaceBackInputs( - paddle::SmallVector> inputs); + paddle::SmallVector inputs); void EmplaceBackOutputs( - paddle::SmallVector> outputs); + paddle::SmallVector outputs); - const std::pair& InputRangeAt(size_t idx) const; - const std::pair& OutputRangeAt(size_t idx) const; + virtual const MetaTensor& InputAt(size_t idx) const; + virtual paddle::optional OptionalInputAt(size_t idx) const; - const MetaConfig& GetMetaConfig() const; - - const MetaTensor& InputAt(size_t idx) const; - paddle::optional OptionalInputAt(size_t idx) const; - std::vector InputsBetween(size_t start, size_t end) const; - paddle::optional> + virtual std::vector InputsBetween(size_t start, + size_t end) const; + virtual paddle::optional> OptionalInputsBetween(size_t start, size_t end) const; - MetaTensor* MutableOutputAt(size_t idx); - std::vector MutableOutputBetween(size_t start, size_t end); + virtual MetaTensor* MutableOutputAt(size_t idx); + virtual std::vector MutableOutputBetween(size_t start, + size_t end); template AttrType AttrAt(size_t idx) { @@ -73,19 +73,24 @@ class InferMetaContext { } } - private: + const std::pair& InputRangeAt(size_t idx) const; + const std::pair& OutputRangeAt(size_t idx) const; + + virtual ~InferMetaContext() = default; + + protected: MetaConfig config_; - // NOTE(chenweihang): Because the MetaTensor is a base class, and MetaTensor - // objects are all created in each round, so we have to use smart pointer - // here, maybe we can implemented a new InferMetaContext and a series utils - // specifically for fluid to avoid using shared_ptr - paddle::SmallVector> inputs_; - paddle::SmallVector> outputs_; - paddle::SmallVector attrs_; + paddle::SmallVector attrs_; - paddle::SmallVector> input_range_; - paddle::SmallVector> output_range_; + paddle::SmallVector, phi::kInputSmallVectorSize> + input_range_; + paddle::SmallVector, phi::kOutputSmallVectorSize> + output_range_; + + private: + paddle::SmallVector inputs_; + paddle::SmallVector outputs_; }; #define PD_INFER_META(...) \ @@ -159,7 +164,7 @@ struct InferMetaFnImpl { }; template - struct InferMetaFnCallHelper&, Tail...> { + struct InferMetaFnCallHelper&, Tail...> { template static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { static_assert(attr_idx == 0, @@ -167,7 +172,7 @@ struct InferMetaFnImpl { static_assert(out_idx == 0, "InferMeta's Input should appear before Outputs."); const std::pair range = ctx->InputRangeAt(in_idx); - std::vector arg = + std::vector arg = ctx->InputsBetween(range.first, range.second); InferMetaFnCallHelper< Tail...>::template Call(ctx, diff --git a/paddle/phi/core/kernel_context.cc b/paddle/phi/core/kernel_context.cc index 234e3528c36..cf862cbde18 100644 --- a/paddle/phi/core/kernel_context.cc +++ b/paddle/phi/core/kernel_context.cc @@ -79,7 +79,7 @@ void KernelContext::EmplaceBackAttr(paddle::any attr) { void KernelContext::AssignInputRange(std::pair&& range, size_t idx) { if (idx < input_range_.size()) { - input_range_[idx] = range; + input_range_[idx] = std::move(range); } else if (idx == input_range_.size()) { input_range_.emplace_back(range); } else { @@ -93,7 +93,7 @@ void KernelContext::AssignInputRange(std::pair&& range, size_t idx) { void KernelContext::AssignOutputRange(std::pair&& range, size_t idx) { if (idx < output_range_.size()) { - output_range_[idx] = range; + output_range_[idx] = std::move(range); } else if (idx == output_range_.size()) { output_range_.emplace_back(range); } else { diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index a1ce90c2c78..d3fd2e0204e 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -19,6 +19,8 @@ namespace phi { +const static Kernel empty_kernel; // NOLINT + uint32_t KernelKey::Hash::operator()(const KernelKey& key) const { uint32_t hash_value = 0; // |----31-20------|---19-12---|---11-8----|---7-0---| @@ -37,15 +39,15 @@ KernelFactory& KernelFactory::Instance() { return g_op_kernel_factory; } -Kernel KernelFactory::SelectKernel(const std::string& kernel_name, - const KernelKey& kernel_key) const { +const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name, + const KernelKey& kernel_key) const { auto iter = kernels_.find(kernel_name); if (iter == kernels_.end()) { - return Kernel(); + return empty_kernel; } auto kernel_iter = iter->second.find(kernel_key); if (kernel_iter == iter->second.end()) { - return Kernel(); + return empty_kernel; } return kernel_iter->second; } @@ -59,8 +61,8 @@ KernelKeyMap KernelFactory::SelectKernelMap( return iter->second; } -bool KernelFactory::IsSelectKernelValid(const std::string& kernel_name, - const KernelKey& kernel_key) const { +bool KernelFactory::HasKernel(const std::string& kernel_name, + const KernelKey& kernel_key) const { auto iter = kernels_.find(kernel_name); PADDLE_ENFORCE_NE( iter, @@ -128,6 +130,16 @@ const Kernel& KernelFactory::SelectKernelOrThrowError( KernelKey(backend, layout, dtype)); } +const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef( + const std::string& kernel_name) const { + auto iter = kernels_.find(kernel_name); + PADDLE_ENFORCE_NE( + iter, + kernels_.end(), + phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name)); + return iter->second.cbegin()->second.args_def(); +} + // print kernel info with json format: // { // "(CPU, Undefined(AnyLayout), complex64)": { diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index 8fd25b691bd..812b6222cb5 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -151,30 +151,38 @@ class KernelArgsDef { attribute_defs_.emplace_back(AttributeArgDef(type_index)); } - const paddle::SmallVector& input_defs() const { + const paddle::SmallVector& input_defs() + const { return input_defs_; } - const paddle::SmallVector& output_defs() const { + const paddle::SmallVector& output_defs() + const { return output_defs_; } - const paddle::SmallVector& attribute_defs() const { + const paddle::SmallVector& + attribute_defs() const { return attribute_defs_; } - paddle::SmallVector& input_defs() { return input_defs_; } + paddle::SmallVector& input_defs() { + return input_defs_; + } - paddle::SmallVector& output_defs() { return output_defs_; } + paddle::SmallVector& output_defs() { + return output_defs_; + } - paddle::SmallVector& attribute_defs() { + paddle::SmallVector& attribute_defs() { return attribute_defs_; } private: - paddle::SmallVector input_defs_{{}}; - paddle::SmallVector output_defs_{{}}; - paddle::SmallVector attribute_defs_{{}}; + paddle::SmallVector input_defs_{{}}; + paddle::SmallVector output_defs_{{}}; + paddle::SmallVector attribute_defs_{ + {}}; }; class Kernel { @@ -209,7 +217,7 @@ class Kernel { TensorArgDef& OutputAt(size_t idx) { return args_def_.output_defs().at(idx); } - bool IsValid() { return fn_ != nullptr; } + bool IsValid() const { return fn_ != nullptr; } private: KernelFn fn_{nullptr}; @@ -246,14 +254,17 @@ class KernelFactory { DataLayout layout, DataType dtype) const; - bool IsSelectKernelValid(const std::string& kernel_name, - const KernelKey& kernel_key) const; + bool HasKernel(const std::string& kernel_name, + const KernelKey& kernel_key) const; - Kernel SelectKernel(const std::string& kernel_name, - const KernelKey& kernel_key) const; + const Kernel& SelectKernel(const std::string& kernel_name, + const KernelKey& kernel_key) const; KernelKeyMap SelectKernelMap(const std::string& kernel_name) const; + const KernelArgsDef& GetFirstKernelArgsDef( + const std::string& kernel_name) const; + private: KernelFactory() = default; diff --git a/paddle/phi/core/meta_tensor.cc b/paddle/phi/core/meta_tensor.cc index 04dfbf96031..2178855aa0f 100644 --- a/paddle/phi/core/meta_tensor.cc +++ b/paddle/phi/core/meta_tensor.cc @@ -148,4 +148,6 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) { } } +bool MetaTensor::initialized() const { return tensor_ != nullptr; } + } // namespace phi diff --git a/paddle/phi/core/meta_tensor.h b/paddle/phi/core/meta_tensor.h index 10c3a7c1a3d..3cdbfda61d6 100644 --- a/paddle/phi/core/meta_tensor.h +++ b/paddle/phi/core/meta_tensor.h @@ -45,10 +45,10 @@ class MetaTensor { : tensor_(const_cast(&tensor)) {} MetaTensor(TensorBase& tensor) : tensor_(&tensor) {} // NOLINT - MetaTensor(const MetaTensor&) = default; MetaTensor(MetaTensor&&) = default; - MetaTensor& operator=(const MetaTensor&) = delete; - MetaTensor& operator=(MetaTensor&&) = delete; + MetaTensor& operator=(MetaTensor&&) = default; + MetaTensor(const MetaTensor&) = default; + MetaTensor& operator=(const MetaTensor&) = default; virtual ~MetaTensor() = default; @@ -64,6 +64,8 @@ class MetaTensor { virtual void share_meta(const MetaTensor& meta_tensor); virtual void share_dims(const MetaTensor& meta_tensor); + virtual bool initialized() const; + private: // Because the lod in compiletime and runtime is different, // so `LoD` cannot in public methods diff --git a/paddle/phi/core/type_defs.h b/paddle/phi/core/type_defs.h index 3c879267bb8..a1e78360883 100644 --- a/paddle/phi/core/type_defs.h +++ b/paddle/phi/core/type_defs.h @@ -22,7 +22,7 @@ class Kernel; class KernelKey; class KernelArgsDef; class KernelContext; -class KernelSignature; +struct KernelSignature; class ArgumentMappingContext; class InferMetaContext; @@ -35,4 +35,9 @@ using ArgumentMappingFn = std::function; using InferMetaFn = void (*)(InferMetaContext* ctx); +// Global SmallVector size setting +constexpr size_t kInputSmallVectorSize = 10U; +constexpr size_t kAttrSmallVectorSize = 10U; +constexpr size_t kOutputSmallVectorSize = 5U; + } // namespace phi diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 84db67978fc..567f39a915c 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -315,8 +315,8 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, dx->share_meta(x); } -void MeshgridGradInferMeta(const std::vector& inputs, - const std::vector& outputs_grad, +void MeshgridGradInferMeta(const std::vector& inputs, + const std::vector& outputs_grad, std::vector inputs_grad) { PADDLE_ENFORCE_GT(outputs_grad.size(), 1, @@ -329,7 +329,7 @@ void MeshgridGradInferMeta(const std::vector& inputs, } } -void MultiDotGradInferMeta(const std::vector& x, +void MultiDotGradInferMeta(const std::vector& x, const MetaTensor& out_grad, std::vector x_grad) { PADDLE_ENFORCE_EQ( diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index c51708bb543..6807438ebbb 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -151,11 +151,11 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, bool adaptive, MetaTensor* dx); -void MeshgridGradInferMeta(const std::vector& inputs, - const std::vector& outputs_grad, +void MeshgridGradInferMeta(const std::vector& inputs, + const std::vector& outputs_grad, std::vector inputs_grad); -void MultiDotGradInferMeta(const std::vector& x, +void MultiDotGradInferMeta(const std::vector& x, const MetaTensor& out_grad, std::vector x_grad); diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 3ce24139fe1..152e04b74b0 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -21,7 +21,8 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/concat_funcs.h" namespace phi { -std::vector GetMetaTensorsDim(const std::vector& tensors) { +std::vector GetMetaTensorsDim( + const std::vector& tensors) { std::vector dims; dims.reserve(tensors.size()); for (const MetaTensor* tensor : tensors) { @@ -148,7 +149,7 @@ void AdamaxInferMeta(const MetaTensor& param, inf_norm_out->set_dtype(inf_norm.dtype()); } -void AddNInferMeta(const std::vector& x, +void AddNInferMeta(const std::vector& x, MetaTensor* out, MetaConfig config) { auto N = x.size(); @@ -511,7 +512,7 @@ void BilinearTensorProductInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } -void BroadcastTensorsInferMeta(const std::vector& x, +void BroadcastTensorsInferMeta(const std::vector& x, std::vector out) { int target_rank = 0; const auto& input_dims = GetMetaTensorsDim(x); @@ -565,7 +566,7 @@ void BroadcastTensorsInferMeta(const std::vector& x, } } -void ConcatInferMeta(const std::vector& x, +void ConcatInferMeta(const std::vector& x, const Scalar& axis_scalar, MetaTensor* out, MetaConfig config) { @@ -1357,7 +1358,7 @@ void InterpolateInferMeta( } } -void MeshgridInferMeta(const std::vector& inputs, +void MeshgridInferMeta(const std::vector& inputs, std::vector outputs) { const size_t inputs_num = inputs.size(); @@ -1420,7 +1421,8 @@ void MomentumInferMeta(const MetaTensor& param, } } -void MultiDotInferMeta(const std::vector& x, MetaTensor* out) { +void MultiDotInferMeta(const std::vector& x, + MetaTensor* out) { auto inputs_dims = GetMetaTensorsDim(x); const size_t inputs_num = inputs_dims.size(); @@ -1493,7 +1495,7 @@ void MultiDotInferMeta(const std::vector& x, MetaTensor* out) { out->share_lod(*x.at(0)); } -void MultiplexInferMeta(const std::vector& ins, +void MultiplexInferMeta(const std::vector& ins, const MetaTensor& ids, MetaTensor* out) { PADDLE_ENFORCE_NE( @@ -1672,8 +1674,8 @@ void RmspropInferMeta(const MetaTensor& param, } void RnnInferMeta(const MetaTensor& x, - const std::vector& pre_state, - const std::vector& weight_list, + const std::vector& pre_state, + const std::vector& weight_list, paddle::optional sequence_length, float dropout_prob, bool is_bidirec, @@ -1779,7 +1781,7 @@ void SGDInferMeta(const MetaTensor& param, param_out->set_dtype(param.dtype()); } -void StackInferMeta(const std::vector& x, +void StackInferMeta(const std::vector& x, int axis, MetaTensor* out) { PADDLE_ENFORCE_GT(x.size(), @@ -1825,7 +1827,7 @@ void StackInferMeta(const std::vector& x, out->share_lod(*x.at(0)); } -void UnchangedMultiInferMeta(const std::vector& x, +void UnchangedMultiInferMeta(const std::vector& x, std::vector out) { for (size_t i = 0; i < x.size(); ++i) { out[i]->share_meta(*x[i]); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 7db4480ffb2..bf3e1d8af6e 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -35,7 +35,8 @@ namespace phi { // // NOTE: The InferMeta Functions in this file are arranged in alphabetic order. -std::vector GetMetaTensorsDim(const std::vector& tensors); +std::vector GetMetaTensorsDim( + const std::vector& tensors); void AdadeltaInferMeta(const MetaTensor& param, const MetaTensor& grad, @@ -68,7 +69,7 @@ void AdamaxInferMeta(const MetaTensor& param, MetaTensor* moment_out, MetaTensor* inf_norm_out); -void AddNInferMeta(const std::vector& x, +void AddNInferMeta(const std::vector& x, MetaTensor* out, MetaConfig config = MetaConfig()); @@ -124,10 +125,10 @@ void BilinearTensorProductInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); -void BroadcastTensorsInferMeta(const std::vector& x, +void BroadcastTensorsInferMeta(const std::vector& x, std::vector out); -void ConcatInferMeta(const std::vector& x, +void ConcatInferMeta(const std::vector& x, const Scalar& axis_scalar, MetaTensor* out, MetaConfig config = MetaConfig()); @@ -178,7 +179,7 @@ void InterpolateInferMeta( MetaTensor* output, MetaConfig config = MetaConfig()); -void MeshgridInferMeta(const std::vector& inputs, +void MeshgridInferMeta(const std::vector& inputs, std::vector outputs); void MomentumInferMeta(const MetaTensor& param, @@ -196,9 +197,10 @@ void MomentumInferMeta(const MetaTensor& param, MetaTensor* velocity_out, MetaTensor* master_param_out); -void MultiDotInferMeta(const std::vector& x, MetaTensor* out); +void MultiDotInferMeta(const std::vector& x, + MetaTensor* out); -void MultiplexInferMeta(const std::vector& ins, +void MultiplexInferMeta(const std::vector& ins, const MetaTensor& ids, MetaTensor* out); @@ -227,8 +229,8 @@ void RmspropInferMeta(const MetaTensor& param, MetaTensor* mean_grad_out); void RnnInferMeta(const MetaTensor& x, - const std::vector& pre_state, - const std::vector& weight_list, + const std::vector& pre_state, + const std::vector& weight_list, paddle::optional sequence_length, float dropout_prob, bool is_bidirec, @@ -251,11 +253,11 @@ void SGDInferMeta(const MetaTensor& param, MetaTensor* param_out, MetaTensor* master_param_out); -void StackInferMeta(const std::vector& x, +void StackInferMeta(const std::vector& x, int axis, MetaTensor* out); -void UnchangedMultiInferMeta(const std::vector& x, +void UnchangedMultiInferMeta(const std::vector& x, std::vector out); void WarpctcInferMeta(const MetaTensor& logits, diff --git a/paddle/phi/kernels/concat_kernel.h b/paddle/phi/kernels/concat_kernel.h index cf83ab9aaab..f5ac2d3cbb7 100644 --- a/paddle/phi/kernels/concat_kernel.h +++ b/paddle/phi/kernels/concat_kernel.h @@ -32,7 +32,7 @@ DenseTensor Concat(const Context& dev_ctx, const Scalar& axis) { std::vector meta_x; meta_x.reserve(x.size()); - std::vector meta_x_ptr; + std::vector meta_x_ptr; for (const auto* t : x) { meta_x.emplace_back(*t); meta_x_ptr.push_back(&meta_x.back()); diff --git a/paddle/phi/ops/compat/abs_sig.cc b/paddle/phi/ops/compat/abs_sig.cc index b4b94457e6b..92d29dd0189 100644 --- a/paddle/phi/ops/compat/abs_sig.cc +++ b/paddle/phi/ops/compat/abs_sig.cc @@ -21,8 +21,7 @@ KernelSignature AbsOpArgumentMapping(const ArgumentMappingContext& ctx) { } KernelSignature AbsGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "abs_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")}); + return KernelSignature("abs_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"}); } KernelSignature AbsDoubleGradOpArgumentMapping( diff --git a/paddle/phi/ops/compat/activation_sig.cc b/paddle/phi/ops/compat/activation_sig.cc index 8add832c366..5900b499466 100644 --- a/paddle/phi/ops/compat/activation_sig.cc +++ b/paddle/phi/ops/compat/activation_sig.cc @@ -19,26 +19,22 @@ namespace phi { #define DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(func_name, op_name, attrs) \ KernelSignature func_name##GradOpArgumentMapping( \ const ArgumentMappingContext& ctx) { \ - return KernelSignature(op_name "_grad", \ - {"X", GradVarName("Out")}, \ - {attrs}, \ - {GradVarName("X")}); \ + return KernelSignature( \ + op_name "_grad", {"X", "Out@GRAD"}, {attrs}, {"X@GRAD"}); \ } #define DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(func_name, op_name, attrs) \ KernelSignature func_name##GradOpArgumentMapping( \ const ArgumentMappingContext& ctx) { \ - return KernelSignature(op_name "_grad", \ - {"Out", GradVarName("Out")}, \ - {attrs}, \ - {GradVarName("X")}); \ + return KernelSignature( \ + op_name "_grad", {"Out", "Out@GRAD"}, {attrs}, {"X@GRAD"}); \ } -#define DEFINE_ACT_GRAD_NODEP_OP_ARGMAP(func_name, op_name, attrs) \ - KernelSignature func_name##GradOpArgumentMapping( \ - const ArgumentMappingContext& ctx) { \ - return KernelSignature( \ - op_name "_grad", {GradVarName("Out")}, {attrs}, {GradVarName("X")}); \ +#define DEFINE_ACT_GRAD_NODEP_OP_ARGMAP(func_name, op_name, attrs) \ + KernelSignature func_name##GradOpArgumentMapping( \ + const ArgumentMappingContext& ctx) { \ + return KernelSignature( \ + op_name "_grad", {"Out@GRAD"}, {attrs}, {"X@GRAD"}); \ } #define comma , @@ -165,15 +161,12 @@ KernelSignature EluOpArgumentMapping(const ArgumentMappingContext& ctx) { } KernelSignature LogitGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "logit_grad", {"X", GradVarName("Out")}, {"eps"}, {GradVarName("X")}); + return KernelSignature("logit_grad", {"X", "Out@GRAD"}, {"eps"}, {"X@GRAD"}); } KernelSignature EluGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("elu_grad", - {"X", "Out", GradVarName("Out")}, - {"alpha"}, - {GradVarName("X")}); + return KernelSignature( + "elu_grad", {"X", "Out", "Out@GRAD"}, {"alpha"}, {"X@GRAD"}); } KernelSignature EluDoubleGradOpArgumentMapping( @@ -198,13 +191,11 @@ KernelSignature PowOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature PowGradOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.HasInput("FactorTensor")) { - return KernelSignature("pow_grad", - {"X", GradVarName("Out")}, - {"FactorTensor"}, - {GradVarName("X")}); + return KernelSignature( + "pow_grad", {"X", "Out@GRAD"}, {"FactorTensor"}, {"X@GRAD"}); } else { return KernelSignature( - "pow_grad", {"X", GradVarName("Out")}, {"factor"}, {GradVarName("X")}); + "pow_grad", {"X", "Out@GRAD"}, {"factor"}, {"X@GRAD"}); } } diff --git a/paddle/phi/ops/compat/addmm_sig.cc b/paddle/phi/ops/compat/addmm_sig.cc index b3bc0bb23a7..3919c875f56 100644 --- a/paddle/phi/ops/compat/addmm_sig.cc +++ b/paddle/phi/ops/compat/addmm_sig.cc @@ -17,11 +17,10 @@ namespace phi { KernelSignature AddmmGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "addmm_grad", - {"Input", "X", "Y", GradVarName("Out")}, - {"Alpha", "Beta"}, - {GradVarName("Input"), GradVarName("X"), GradVarName("Y")}); + return KernelSignature("addmm_grad", + {"Input", "X", "Y", "Out@GRAD"}, + {"Alpha", "Beta"}, + {"Input@GRAD", "X@GRAD", "Y@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/argsort_sig.cc b/paddle/phi/ops/compat/argsort_sig.cc index 62133a441ff..70531f16916 100644 --- a/paddle/phi/ops/compat/argsort_sig.cc +++ b/paddle/phi/ops/compat/argsort_sig.cc @@ -19,9 +19,9 @@ namespace phi { KernelSignature ArgsortGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("argsort_grad", - {"Indices", "X", GradVarName("Out")}, + {"Indices", "X", "Out@GRAD"}, {"axis", "descending"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/atan2_sig.cc b/paddle/phi/ops/compat/atan2_sig.cc index 8a6049e67b6..9fef8560df9 100644 --- a/paddle/phi/ops/compat/atan2_sig.cc +++ b/paddle/phi/ops/compat/atan2_sig.cc @@ -17,10 +17,8 @@ namespace phi { KernelSignature Atan2GradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("atan2_grad", - {"X1", "X2", GradVarName("Out")}, - {}, - {GradVarName("X1"), GradVarName("X2")}); + return KernelSignature( + "atan2_grad", {"X1", "X2", "Out@GRAD"}, {}, {"X1@GRAD", "X2@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/batch_norm_sig.cc b/paddle/phi/ops/compat/batch_norm_sig.cc index cfd9f4def93..14affe60b9d 100644 --- a/paddle/phi/ops/compat/batch_norm_sig.cc +++ b/paddle/phi/ops/compat/batch_norm_sig.cc @@ -57,27 +57,26 @@ KernelSignature BatchNormOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature BatchNormGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "batch_norm_grad", - { - "X", - "Scale", - "Bias", - "Mean", - "Variance", - "SavedMean", - "SavedVariance", - "ReserveSpace", - GradVarName("Y"), - }, - {"momentum", - "epsilon", - "data_layout", - "is_test", - "use_global_stats", - "trainable_statistics", - "fuse_with_relu"}, - {GradVarName("X"), GradVarName("Scale"), GradVarName("Bias")}); + return KernelSignature("batch_norm_grad", + { + "X", + "Scale", + "Bias", + "Mean", + "Variance", + "SavedMean", + "SavedVariance", + "ReserveSpace", + "Y@GRAD", + }, + {"momentum", + "epsilon", + "data_layout", + "is_test", + "use_global_stats", + "trainable_statistics", + "fuse_with_relu"}, + {"X@GRAD", "Scale@GRAD", "Bias@GRAD"}); } KernelSignature BatchNormGradGradOpArgumentMapping( diff --git a/paddle/phi/ops/compat/bce_loss_sig.cc b/paddle/phi/ops/compat/bce_loss_sig.cc index 17f76067d22..5575fa277eb 100644 --- a/paddle/phi/ops/compat/bce_loss_sig.cc +++ b/paddle/phi/ops/compat/bce_loss_sig.cc @@ -18,10 +18,8 @@ namespace phi { KernelSignature BCELossGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("bce_loss_grad", - {"X", "Label", GradVarName("Out")}, - {}, - {GradVarName("X")}); + return KernelSignature( + "bce_loss_grad", {"X", "Label", "Out@GRAD"}, {}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/bilinear_tensor_product_sig.cc b/paddle/phi/ops/compat/bilinear_tensor_product_sig.cc index 570bf7ce943..95a867fd3f7 100644 --- a/paddle/phi/ops/compat/bilinear_tensor_product_sig.cc +++ b/paddle/phi/ops/compat/bilinear_tensor_product_sig.cc @@ -25,12 +25,9 @@ KernelSignature BilinearTensorProductOpArgumentMapping( KernelSignature BilinearTensorProductGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("bilinear_tensor_product_grad", - {"X", "Y", "Weight", GradVarName("Out")}, + {"X", "Y", "Weight", "Out@GRAD"}, {}, - {GradVarName("X"), - GradVarName("Y"), - GradVarName("Weight"), - GradVarName("Bias")}); + {"X@GRAD", "Y@GRAD", "Weight@GRAD", "Bias@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/broadcast_tensors_sig.cc b/paddle/phi/ops/compat/broadcast_tensors_sig.cc index 2c979c4aedc..d0fcbb33be2 100644 --- a/paddle/phi/ops/compat/broadcast_tensors_sig.cc +++ b/paddle/phi/ops/compat/broadcast_tensors_sig.cc @@ -19,7 +19,7 @@ namespace phi { KernelSignature BroadcastTensorsGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( - "broadcast_tensors_grad", {GradVarName("Out")}, {}, {GradVarName("X")}); + "broadcast_tensors_grad", {"Out@GRAD"}, {}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/cholesky_sig.cc b/paddle/phi/ops/compat/cholesky_sig.cc index 8c7ca757046..9a26ea5c0c5 100644 --- a/paddle/phi/ops/compat/cholesky_sig.cc +++ b/paddle/phi/ops/compat/cholesky_sig.cc @@ -18,10 +18,8 @@ namespace phi { KernelSignature CholeskyGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("cholesky_grad", - {"Out", GradVarName("Out")}, - {"upper"}, - {GradVarName("X")}); + return KernelSignature( + "cholesky_grad", {"Out", "Out@GRAD"}, {"upper"}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/cholesky_solve_sig.cc b/paddle/phi/ops/compat/cholesky_solve_sig.cc index 6a9759f8352..2696d80a49f 100644 --- a/paddle/phi/ops/compat/cholesky_solve_sig.cc +++ b/paddle/phi/ops/compat/cholesky_solve_sig.cc @@ -19,9 +19,9 @@ namespace phi { KernelSignature CholeskySolveGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("cholesky_solve_grad", - {"X", "Y", "Out", GradVarName("Out")}, + {"X", "Y", "Out", "Out@GRAD"}, {"upper"}, - {GradVarName("X"), GradVarName("Y")}); + {"X@GRAD", "Y@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/clip_sig.cc b/paddle/phi/ops/compat/clip_sig.cc index 78fa6c36a51..25a34f2b9c8 100644 --- a/paddle/phi/ops/compat/clip_sig.cc +++ b/paddle/phi/ops/compat/clip_sig.cc @@ -18,7 +18,7 @@ namespace phi { KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) { - paddle::SmallVector attr_names; + paddle::SmallVector attr_names; attr_names.emplace_back(ctx.HasInput("Min") ? "Min" : "min"); attr_names.emplace_back(ctx.HasInput("Max") ? "Max" : "max"); if (ctx.IsDenseTensorInput("X")) { @@ -57,27 +57,19 @@ KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ClipGradOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.HasInput("Min")) { if (ctx.HasInput("Max")) { - return KernelSignature("clip_grad", - {"X", GradVarName("Out")}, - {"Min", "Max"}, - {GradVarName("X")}); + return KernelSignature( + "clip_grad", {"X", "Out@GRAD"}, {"Min", "Max"}, {"X@GRAD"}); } else { - return KernelSignature("clip_grad", - {"X", GradVarName("Out")}, - {"Min", "max"}, - {GradVarName("X")}); + return KernelSignature( + "clip_grad", {"X", "Out@GRAD"}, {"Min", "max"}, {"X@GRAD"}); } } else { if (ctx.HasInput("Max")) { - return KernelSignature("clip_grad", - {"X", GradVarName("Out")}, - {"min", "Max"}, - {GradVarName("X")}); + return KernelSignature( + "clip_grad", {"X", "Out@GRAD"}, {"min", "Max"}, {"X@GRAD"}); } else { - return KernelSignature("clip_grad", - {"X", GradVarName("Out")}, - {"min", "max"}, - {GradVarName("X")}); + return KernelSignature( + "clip_grad", {"X", "Out@GRAD"}, {"min", "max"}, {"X@GRAD"}); } } } diff --git a/paddle/phi/ops/compat/complex_sig.cc b/paddle/phi/ops/compat/complex_sig.cc index b9f59c97fb5..88156677d34 100644 --- a/paddle/phi/ops/compat/complex_sig.cc +++ b/paddle/phi/ops/compat/complex_sig.cc @@ -17,13 +17,11 @@ namespace phi { KernelSignature RealGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "real_grad", {GradVarName("Out")}, {}, {GradVarName("X")}); + return KernelSignature("real_grad", {"Out@GRAD"}, {}, {"X@GRAD"}); } KernelSignature ImagGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "imag_grad", {GradVarName("Out")}, {}, {GradVarName("X")}); + return KernelSignature("imag_grad", {"Out@GRAD"}, {}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/concat_sig.cc b/paddle/phi/ops/compat/concat_sig.cc index d443f521c61..d53bb5793bc 100644 --- a/paddle/phi/ops/compat/concat_sig.cc +++ b/paddle/phi/ops/compat/concat_sig.cc @@ -25,15 +25,11 @@ KernelSignature ConcatOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ConcatGradOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.HasInput("AxisTensor")) { - return KernelSignature("concat_grad", - {"X", {GradVarName("Out")}}, - {"AxisTensor"}, - {{GradVarName("X")}}); + return KernelSignature( + "concat_grad", {"X", {"Out@GRAD"}}, {"AxisTensor"}, {{"X@GRAD"}}); } - return KernelSignature("concat_grad", - {"X", {GradVarName("Out")}}, - {"axis"}, - {{GradVarName("X")}}); + return KernelSignature( + "concat_grad", {"X", {"Out@GRAD"}}, {"axis"}, {{"X@GRAD"}}); } } // namespace phi diff --git a/paddle/phi/ops/compat/conv2d_sig.cc b/paddle/phi/ops/compat/conv2d_sig.cc index 7cc0d6ad175..617c6e289bf 100644 --- a/paddle/phi/ops/compat/conv2d_sig.cc +++ b/paddle/phi/ops/compat/conv2d_sig.cc @@ -46,7 +46,7 @@ KernelSignature Conv2dOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("conv2d_grad", - {"Input", "Filter", GradVarName("Output")}, + {"Input", "Filter", "Output@GRAD"}, {"strides", "paddings", "padding_algorithm", @@ -56,7 +56,7 @@ KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { "use_addto", "workspace_size_MB", "exhaustive_search"}, - {GradVarName("Input"), GradVarName("Filter")}); + {"Input@GRAD", "Filter@GRAD"}); } KernelSignature Conv2dDoubleGradOpArgumentMapping( diff --git a/paddle/phi/ops/compat/conv3d_sig.cc b/paddle/phi/ops/compat/conv3d_sig.cc index b24c08b60c9..c6aae1bf5bb 100644 --- a/paddle/phi/ops/compat/conv3d_sig.cc +++ b/paddle/phi/ops/compat/conv3d_sig.cc @@ -33,7 +33,7 @@ KernelSignature Conv3dOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature Conv3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("conv2d_grad", - {"Input", "Filter", GradVarName("Output")}, + {"Input", "Filter", "Output@GRAD"}, {"strides", "paddings", "padding_algorithm", @@ -43,7 +43,7 @@ KernelSignature Conv3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { "use_addto", "workspace_size_MB", "exhaustive_search"}, - {GradVarName("Input"), GradVarName("Filter")}); + {"Input@GRAD", "Filter@GRAD"}); } KernelSignature Conv3dDoubleGradOpArgumentMapping( diff --git a/paddle/phi/ops/compat/conv_transpose_sig.cc b/paddle/phi/ops/compat/conv_transpose_sig.cc index 8697168b827..a040bce6f78 100644 --- a/paddle/phi/ops/compat/conv_transpose_sig.cc +++ b/paddle/phi/ops/compat/conv_transpose_sig.cc @@ -34,7 +34,7 @@ KernelSignature Conv2dTransposeOpArgumentMapping( KernelSignature Conv2dTransposeGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("conv2d_transpose_grad", - {"Input", "Filter", GradVarName("Output")}, + {"Input", "Filter", "Output@GRAD"}, {"strides", "paddings", "output_padding", @@ -43,7 +43,7 @@ KernelSignature Conv2dTransposeGradOpArgumentMapping( "groups", "dilations", "data_format"}, - {GradVarName("Input"), GradVarName("Filter")}); + {"Input@GRAD", "Filter@GRAD"}); } KernelSignature Conv2dTransposeDoubleGradOpArgumentMapping( @@ -79,7 +79,7 @@ KernelSignature Conv3dTransposeOpArgumentMapping( KernelSignature Conv3dTransposeGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("conv3d_transpose_grad", - {"Input", "Filter", GradVarName("Output")}, + {"Input", "Filter", "Output@GRAD"}, {"strides", "paddings", "output_padding", @@ -88,7 +88,7 @@ KernelSignature Conv3dTransposeGradOpArgumentMapping( "groups", "dilations", "data_format"}, - {GradVarName("Input"), GradVarName("Filter")}); + {"Input@GRAD", "Filter@GRAD"}); } KernelSignature DepthwiseConv2dTransposeOpArgumentMapping( @@ -109,7 +109,7 @@ KernelSignature DepthwiseConv2dTransposeOpArgumentMapping( KernelSignature DepthwiseConv2dTransposeGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("depthwise_conv2d_transpose_grad", - {"Input", "Filter", GradVarName("Output")}, + {"Input", "Filter", "Output@GRAD"}, {"strides", "paddings", "output_padding", @@ -118,7 +118,7 @@ KernelSignature DepthwiseConv2dTransposeGradOpArgumentMapping( "groups", "dilations", "data_format"}, - {GradVarName("Input"), GradVarName("Filter")}); + {"Input@GRAD", "Filter@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/cross_sig.cc b/paddle/phi/ops/compat/cross_sig.cc index 307c2ac5164..2a8a46678cd 100644 --- a/paddle/phi/ops/compat/cross_sig.cc +++ b/paddle/phi/ops/compat/cross_sig.cc @@ -21,10 +21,8 @@ KernelSignature CrossOpArgumentMapping(const ArgumentMappingContext& ctx) { } KernelSignature CrossGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("cross_grad", - {"X", "Y", GradVarName("Out")}, - {"dim"}, - {GradVarName("X"), GradVarName("Y")}); + return KernelSignature( + "cross_grad", {"X", "Y", "Out@GRAD"}, {"dim"}, {"X@GRAD", "Y@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/cumprod_sig.cc b/paddle/phi/ops/compat/cumprod_sig.cc index 01084e764ed..ffe0ba75bb9 100644 --- a/paddle/phi/ops/compat/cumprod_sig.cc +++ b/paddle/phi/ops/compat/cumprod_sig.cc @@ -18,10 +18,8 @@ namespace phi { KernelSignature CumprodGradGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("cumprod_grad", - {"X", "Out", GradVarName("Out")}, - {"dim"}, - {GradVarName("X")}); + return KernelSignature( + "cumprod_grad", {"X", "Out", "Out@GRAD"}, {"dim"}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/deformable_conv_sig.cc b/paddle/phi/ops/compat/deformable_conv_sig.cc index a84a0840090..aa2537aa10e 100644 --- a/paddle/phi/ops/compat/deformable_conv_sig.cc +++ b/paddle/phi/ops/compat/deformable_conv_sig.cc @@ -33,17 +33,14 @@ KernelSignature DeformableConvGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( "deformable_conv_grad", - {"Input", "Offset", "Filter", "Mask", GradVarName("Output")}, + {"Input", "Offset", "Filter", "Mask", "Output@GRAD"}, {"strides", "paddings", "dilations", "deformable_groups", "groups", "im2col_step"}, - {GradVarName("Input"), - GradVarName("Offset"), - GradVarName("Filter"), - GradVarName("Mask")}); + {"Input@GRAD", "Offset@GRAD", "Filter@GRAD", "Mask@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/depthwise_conv2d_sig.cc b/paddle/phi/ops/compat/depthwise_conv2d_sig.cc index d2d7451ecaf..1014d45e70a 100644 --- a/paddle/phi/ops/compat/depthwise_conv2d_sig.cc +++ b/paddle/phi/ops/compat/depthwise_conv2d_sig.cc @@ -36,7 +36,7 @@ KernelSignature DepthwiseConv2dOpArgumentMapping( KernelSignature DepthwiseConv2dGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("depthwise_conv2d_grad", - {"Input", "Filter", GradVarName("Output")}, + {"Input", "Filter", "Output@GRAD"}, {"strides", "paddings", "padding_algorithm", @@ -47,7 +47,7 @@ KernelSignature DepthwiseConv2dGradOpArgumentMapping( "workspace_size_MB", "exhaustive_search", "fuse_relu_before_depthwise_conv"}, - {GradVarName("Input"), GradVarName("Filter")}); + {"Input@GRAD", "Filter@GRAD"}); } KernelSignature DepthwiseConv2dDoubleGradOpArgumentMapping( diff --git a/paddle/phi/ops/compat/determinant_sig.cc b/paddle/phi/ops/compat/determinant_sig.cc index 7bcd30ec5d7..ee1d53704c1 100644 --- a/paddle/phi/ops/compat/determinant_sig.cc +++ b/paddle/phi/ops/compat/determinant_sig.cc @@ -18,10 +18,8 @@ namespace phi { KernelSignature DeterminantGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("determinant_grad", - {"Input", "Out", GradVarName("Out")}, - {}, - {GradVarName("Input")}); + return KernelSignature( + "determinant_grad", {"Input", "Out", "Out@GRAD"}, {}, {"Input@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/diag_sig.cc b/paddle/phi/ops/compat/diag_sig.cc index f3245b922c0..b232c714c97 100644 --- a/paddle/phi/ops/compat/diag_sig.cc +++ b/paddle/phi/ops/compat/diag_sig.cc @@ -22,7 +22,7 @@ KernelSignature DiagOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature DiagGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature( - "diag_grad", {"X", GradVarName("Out")}, {"offset"}, {GradVarName("X")}); + "diag_grad", {"X", "Out@GRAD"}, {"offset"}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/diagonal_sig.cc b/paddle/phi/ops/compat/diagonal_sig.cc index b4a424ec06b..94cecc3042a 100644 --- a/paddle/phi/ops/compat/diagonal_sig.cc +++ b/paddle/phi/ops/compat/diagonal_sig.cc @@ -19,9 +19,9 @@ namespace phi { KernelSignature DiagonalGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("diagonal_grad", - {"Input", GradVarName("Out")}, + {"Input", "Out@GRAD"}, {"offset", "axis1", "axis2"}, - {GradVarName("Input")}); + {"Input@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/digamma_sig.cc b/paddle/phi/ops/compat/digamma_sig.cc index 12ef3056f1e..6c14dd9bf17 100644 --- a/paddle/phi/ops/compat/digamma_sig.cc +++ b/paddle/phi/ops/compat/digamma_sig.cc @@ -18,8 +18,7 @@ namespace phi { KernelSignature DigammaGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "digamma_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")}); + return KernelSignature("digamma_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/dist_sig.cc b/paddle/phi/ops/compat/dist_sig.cc index 18a30b9b840..cc702fefbc9 100644 --- a/paddle/phi/ops/compat/dist_sig.cc +++ b/paddle/phi/ops/compat/dist_sig.cc @@ -17,10 +17,8 @@ limitations under the License. */ namespace phi { KernelSignature DistGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("dist_grad", - {"X", "Y", "Out", GradVarName("Out")}, - {"p"}, - {GradVarName("X"), GradVarName("Y")}); + return KernelSignature( + "dist_grad", {"X", "Y", "Out", "Out@GRAD"}, {"p"}, {"X@GRAD", "Y@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/dot_sig.cc b/paddle/phi/ops/compat/dot_sig.cc index 2437ecc1ca7..2187a7eb4fc 100644 --- a/paddle/phi/ops/compat/dot_sig.cc +++ b/paddle/phi/ops/compat/dot_sig.cc @@ -17,10 +17,8 @@ limitations under the License. */ namespace phi { KernelSignature DotGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("dot_grad", - {"X", "Y", GradVarName("Out")}, - {}, - {GradVarName("X"), GradVarName("Y")}); + return KernelSignature( + "dot_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/dropout_sig.cc b/paddle/phi/ops/compat/dropout_sig.cc index 6bf229c98bd..712c5cbb0d6 100644 --- a/paddle/phi/ops/compat/dropout_sig.cc +++ b/paddle/phi/ops/compat/dropout_sig.cc @@ -27,9 +27,9 @@ KernelSignature DropoutOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature DropoutGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("dropout_grad", - {"Mask", GradVarName("Out")}, + {"Mask", "Out@GRAD"}, {"dropout_prob", "is_test", "dropout_implementation"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/eigh_sig.cc b/paddle/phi/ops/compat/eigh_sig.cc index e50a9a5a12a..58718b6e32c 100644 --- a/paddle/phi/ops/compat/eigh_sig.cc +++ b/paddle/phi/ops/compat/eigh_sig.cc @@ -17,13 +17,11 @@ namespace phi { KernelSignature EighGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("eigh_grad", - {"Eigenvalues", - "Eigenvectors", - GradVarName("Eigenvalues"), - GradVarName("Eigenvectors")}, - {}, - {GradVarName("X")}); + return KernelSignature( + "eigh_grad", + {"Eigenvalues", "Eigenvectors", "Eigenvalues@GRAD", "Eigenvectors@GRAD"}, + {}, + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/elementwise_sig.cc b/paddle/phi/ops/compat/elementwise_sig.cc index 0a58d86b05b..19110eb0e0a 100644 --- a/paddle/phi/ops/compat/elementwise_sig.cc +++ b/paddle/phi/ops/compat/elementwise_sig.cc @@ -106,10 +106,8 @@ KernelSignature ElementwisePowOpArgumentMapping( KernelSignature ElementwiseAddGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("add_grad", - {"X", "Y", GradVarName("Out")}, - {"axis"}, - {GradVarName("X"), GradVarName("Y")}); + return KernelSignature( + "add_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"}); } KernelSignature ElementwiseAddDoubleGradOpArgumentMapping( @@ -128,10 +126,8 @@ KernelSignature ElementwiseAddTripleGradOpArgumentMapping( KernelSignature ElementwiseSubGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("subtract_grad", - {"X", "Y", GradVarName("Out")}, - {"axis"}, - {GradVarName("X"), GradVarName("Y")}); + return KernelSignature( + "subtract_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"}); } KernelSignature ElementwiseSubDoubleGradOpArgumentMapping( @@ -143,17 +139,15 @@ KernelSignature ElementwiseSubDoubleGradOpArgumentMapping( KernelSignature ElementwiseDivGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("divide_grad", - {"X", "Y", "Out", GradVarName("Out")}, + {"X", "Y", "Out", "Out@GRAD"}, {"axis"}, - {GradVarName("X"), GradVarName("Y")}); + {"X@GRAD", "Y@GRAD"}); } KernelSignature ElementwiseFMinGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("fmin_grad", - {"X", "Y", GradVarName("Out")}, - {"axis"}, - {GradVarName("X"), GradVarName("Y")}); + return KernelSignature( + "fmin_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"}); } KernelSignature ElementwiseDivDoubleGradOpArgumentMapping( @@ -161,15 +155,13 @@ KernelSignature ElementwiseDivDoubleGradOpArgumentMapping( return KernelSignature("divide_double_grad", {"Y", "Out", "DX", "DDX", "DDY"}, {"axis"}, - {GradVarName("Y"), "DOut", "DDOut"}); + {"Y@GRAD", "DOut", "DDOut"}); } KernelSignature ElementwiseMulGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("multiply_grad", - {"X", "Y", GradVarName("Out")}, - {"axis"}, - {GradVarName("X"), GradVarName("Y")}); + return KernelSignature( + "multiply_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"}); } KernelSignature ElementwiseFMaxOpArgumentMapping( @@ -184,10 +176,8 @@ KernelSignature ElementwiseFMinOpArgumentMapping( KernelSignature ElementwiseFMaxGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("fmax_grad", - {"X", "Y", GradVarName("Out")}, - {"axis"}, - {GradVarName("X"), GradVarName("Y")}); + return KernelSignature( + "fmax_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"}); } KernelSignature ElementwiseMulDoubleGradOpArgumentMapping( @@ -195,7 +185,7 @@ KernelSignature ElementwiseMulDoubleGradOpArgumentMapping( return KernelSignature("multiply_double_grad", {"X", "Y", "DOut", "DDX", "DDY"}, {"axis"}, - {GradVarName("X"), GradVarName("Y"), "DDOut"}); + {"X@GRAD", "Y@GRAD", "DDOut"}); } KernelSignature ElementwiseMulTripleGradOpArgumentMapping( @@ -209,25 +199,21 @@ KernelSignature ElementwiseMulTripleGradOpArgumentMapping( KernelSignature ElementwiseMaxGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("maximum_grad", - {"X", "Y", GradVarName("Out")}, - {"axis"}, - {GradVarName("X"), GradVarName("Y")}); + return KernelSignature( + "maximum_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"}); } KernelSignature ElementwiseMinGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("minimum_grad", - {"X", "Y", GradVarName("Out")}, - {"axis"}, - {GradVarName("X"), GradVarName("Y")}); + return KernelSignature( + "minimum_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"}); } KernelSignature ElementwisePowGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("elementwise_pow_grad", - {"X", "Y", GradVarName("Out")}, + {"X", "Y", "Out@GRAD"}, {"axis"}, - {GradVarName("X"), GradVarName("Y")}); + {"X@GRAD", "Y@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/embedding_sig.cc b/paddle/phi/ops/compat/embedding_sig.cc index b79a381dcec..48debcafaf2 100644 --- a/paddle/phi/ops/compat/embedding_sig.cc +++ b/paddle/phi/ops/compat/embedding_sig.cc @@ -30,26 +30,26 @@ KernelSignature EmbeddingGradOpArgumentMapping( if (ctx.IsDenseTensorInput("W")) { if ((paddle::any_cast(ctx.Attr("is_sparse"))) == true) { return KernelSignature("embedding_sparse_grad", - {"Ids", "W", GradVarName("Out")}, + {"Ids", "W", "Out@GRAD"}, {"padding_idx"}, - {GradVarName("W")}); + {"W@GRAD"}); } else { return KernelSignature("embedding_grad", - {"Ids", "W", GradVarName("Out")}, + {"Ids", "W", "Out@GRAD"}, {"padding_idx"}, - {GradVarName("W")}); + {"W@GRAD"}); } } else { if ((paddle::any_cast(ctx.Attr("is_sparse"))) == true) { return KernelSignature("sparse_weight_embedding_sparse_grad", - {"Ids", "W", GradVarName("Out")}, + {"Ids", "W", "Out@GRAD"}, {"padding_idx"}, - {GradVarName("W")}); + {"W@GRAD"}); } else { return KernelSignature("sparse_weight_embedding_grad", - {"Ids", "W", GradVarName("Out")}, + {"Ids", "W", "Out@GRAD"}, {"padding_idx"}, - {GradVarName("W")}); + {"W@GRAD"}); } } } diff --git a/paddle/phi/ops/compat/erf_sig.cc b/paddle/phi/ops/compat/erf_sig.cc index 784727a9804..6cd94e46c3e 100644 --- a/paddle/phi/ops/compat/erf_sig.cc +++ b/paddle/phi/ops/compat/erf_sig.cc @@ -17,8 +17,7 @@ namespace phi { KernelSignature ErfGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "erf_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")}); + return KernelSignature("erf_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/erfinv_sig.cc b/paddle/phi/ops/compat/erfinv_sig.cc index 49057319153..37d30aaaeb6 100644 --- a/paddle/phi/ops/compat/erfinv_sig.cc +++ b/paddle/phi/ops/compat/erfinv_sig.cc @@ -17,8 +17,7 @@ namespace phi { KernelSignature ErfinvGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "erfinv_grad", {"Out", GradVarName("Out")}, {}, {GradVarName("X")}); + return KernelSignature("erfinv_grad", {"Out", "Out@GRAD"}, {}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/expand_as_sig.cc b/paddle/phi/ops/compat/expand_as_sig.cc index a616b63c10b..03b308f4a8b 100644 --- a/paddle/phi/ops/compat/expand_as_sig.cc +++ b/paddle/phi/ops/compat/expand_as_sig.cc @@ -22,10 +22,8 @@ KernelSignature ExpandAsOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ExpandAsGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("expand_as_grad", - {"X", GradVarName("Out")}, - {"target_shape"}, - {GradVarName("X")}); + return KernelSignature( + "expand_as_grad", {"X", "Out@GRAD"}, {"target_shape"}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/expand_sig.cc b/paddle/phi/ops/compat/expand_sig.cc index 9b0a1f5ab7d..b0f4ff79b4c 100644 --- a/paddle/phi/ops/compat/expand_sig.cc +++ b/paddle/phi/ops/compat/expand_sig.cc @@ -39,20 +39,14 @@ KernelSignature ExpandGradOpArgumentMapping(const ArgumentMappingContext& ctx) { "expand_grad", {"X", "Out@GRAD"}, {"shape"}, {"X@GRAD"}); } if (ctx.HasInput("Shape")) { - return KernelSignature("expand_grad", - {"X", GradVarName("Out")}, - {"Shape"}, - {GradVarName("X")}); + return KernelSignature( + "expand_grad", {"X", "Out@GRAD"}, {"Shape"}, {"X@GRAD"}); } else if (ctx.InputSize("expand_shapes_tensor") > 0) { - return KernelSignature("expand_grad", - {"X", GradVarName("Out")}, - {"expand_shapes_tensor"}, - {GradVarName("X")}); + return KernelSignature( + "expand_grad", {"X", "Out@GRAD"}, {"expand_shapes_tensor"}, {"X@GRAD"}); } else { - return KernelSignature("expand_grad", - {"X", GradVarName("Out")}, - {"shape"}, - {GradVarName("X")}); + return KernelSignature( + "expand_grad", {"X", "Out@GRAD"}, {"shape"}, {"X@GRAD"}); } } diff --git a/paddle/phi/ops/compat/flatten_sig.cc b/paddle/phi/ops/compat/flatten_sig.cc index 3e8119c38cf..122e0efa22b 100644 --- a/paddle/phi/ops/compat/flatten_sig.cc +++ b/paddle/phi/ops/compat/flatten_sig.cc @@ -31,7 +31,7 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature FlattenGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( - "flatten_grad", {"XShape", GradVarName("Out")}, {}, {GradVarName("X")}); + "flatten_grad", {"XShape", "Out@GRAD"}, {}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/frobenius_norm_sig.cc b/paddle/phi/ops/compat/frobenius_norm_sig.cc index 8fddee5edb1..1fb53c36caf 100644 --- a/paddle/phi/ops/compat/frobenius_norm_sig.cc +++ b/paddle/phi/ops/compat/frobenius_norm_sig.cc @@ -25,9 +25,9 @@ KernelSignature FrobeniusNormOpArgumentMapping( KernelSignature FrobeniusNormGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("frobenius_norm_grad", - {"X", "Out", GradVarName("Out")}, + {"X", "Out", "Out@GRAD"}, {"dim", "keep_dim", "reduce_all"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/gather_scatter_sig.cc b/paddle/phi/ops/compat/gather_scatter_sig.cc index f71e30f85b0..a942ebb4408 100644 --- a/paddle/phi/ops/compat/gather_scatter_sig.cc +++ b/paddle/phi/ops/compat/gather_scatter_sig.cc @@ -17,25 +17,23 @@ namespace phi { KernelSignature GatherNdGradArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("gather_nd_grad", - {"X", "Index", GradVarName("Out")}, - {}, - {GradVarName("X")}); + return KernelSignature( + "gather_nd_grad", {"X", "Index", "Out@GRAD"}, {}, {"X@GRAD"}); } KernelSignature ScatterGradArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("scatter_grad", - {"Ids", "Updates", GradVarName("Out")}, + {"Ids", "Updates", "Out@GRAD"}, {"overwrite"}, - {GradVarName("X"), GradVarName("Updates")}); + {"X@GRAD", "Updates@GRAD"}); } KernelSignature ScatterNdAddGradArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("scatter_nd_add_grad", - {"Index", "Updates", GradVarName("Out")}, + {"Index", "Updates", "Out@GRAD"}, {}, - {GradVarName("X"), GradVarName("Updates")}); + {"X@GRAD", "Updates@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/gather_sig.cc b/paddle/phi/ops/compat/gather_sig.cc index 6c47bbe48b8..af9e50638ce 100644 --- a/paddle/phi/ops/compat/gather_sig.cc +++ b/paddle/phi/ops/compat/gather_sig.cc @@ -27,14 +27,14 @@ KernelSignature GatherOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature GatherGradOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.HasInput("Axis")) { return KernelSignature("gather_grad", - {"X", "Index", GradVarName("Out")}, + {"X", "Index", "Out@GRAD"}, {"Axis", "overwrite"}, - {GradVarName("X")}); + {"X@GRAD"}); } else { return KernelSignature("gather_grad", - {"X", "Index", GradVarName("Out")}, + {"X", "Index", "Out@GRAD"}, {"axis", "overwrite"}, - {GradVarName("X")}); + {"X@GRAD"}); } } diff --git a/paddle/phi/ops/compat/gelu_sig.cc b/paddle/phi/ops/compat/gelu_sig.cc index bf4b47bcf5f..45a0ecea713 100644 --- a/paddle/phi/ops/compat/gelu_sig.cc +++ b/paddle/phi/ops/compat/gelu_sig.cc @@ -21,10 +21,8 @@ KernelSignature GeluOpArgumentMapping(const ArgumentMappingContext& ctx) { } KernelSignature GeluGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("gelu_grad", - {"X", GradVarName("Out")}, - {"approximate"}, - {GradVarName("X")}); + return KernelSignature( + "gelu_grad", {"X", "Out@GRAD"}, {"approximate"}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/graph_send_recv_sig.cc b/paddle/phi/ops/compat/graph_send_recv_sig.cc index cf36b9baa2d..9df2cf4d0fe 100644 --- a/paddle/phi/ops/compat/graph_send_recv_sig.cc +++ b/paddle/phi/ops/compat/graph_send_recv_sig.cc @@ -28,9 +28,9 @@ KernelSignature GraphSendRecvGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( "graph_send_recv_grad", - {"X", "Src_index", "Dst_index", "Out", "Dst_count", GradVarName("Out")}, + {"X", "Src_index", "Dst_index", "Out", "Dst_count", "Out@GRAD"}, {"pool_type"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/grid_sampler_sig.cc b/paddle/phi/ops/compat/grid_sampler_sig.cc index b76a9770d4d..486d5230ee7 100644 --- a/paddle/phi/ops/compat/grid_sampler_sig.cc +++ b/paddle/phi/ops/compat/grid_sampler_sig.cc @@ -27,9 +27,9 @@ KernelSignature GridSamplerOpArgumentMapping( KernelSignature GridSamplerGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("grid_sample_grad", - {"X", "Grid", GradVarName("Output")}, + {"X", "Grid", "Output@GRAD"}, {"mode", "padding_mode", "align_corners"}, - {GradVarName("X"), GradVarName("Grid")}); + {"X@GRAD", "Grid@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/gumbel_softmax_sig.cc b/paddle/phi/ops/compat/gumbel_softmax_sig.cc index c7585a4e5f3..65537f8c894 100644 --- a/paddle/phi/ops/compat/gumbel_softmax_sig.cc +++ b/paddle/phi/ops/compat/gumbel_softmax_sig.cc @@ -18,10 +18,8 @@ namespace phi { KernelSignature GumbelSoftmaxGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("gumbel_softmax_grad", - {"Out", GradVarName("Out")}, - {"axis"}, - {GradVarName("X")}); + return KernelSignature( + "gumbel_softmax_grad", {"Out", "Out@GRAD"}, {"axis"}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/hierarchical_sigmoid_sig.cc b/paddle/phi/ops/compat/hierarchical_sigmoid_sig.cc index 58c190fb657..5393439901b 100644 --- a/paddle/phi/ops/compat/hierarchical_sigmoid_sig.cc +++ b/paddle/phi/ops/compat/hierarchical_sigmoid_sig.cc @@ -32,44 +32,42 @@ KernelSignature HierarchicalSigmoidOpArgumentMapping( KernelSignature HierarchicalSigmoidGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - if (ctx.IsDenseTensorOutput(GradVarName("W"))) { - return KernelSignature( - "hierarchical_sigmoid_grad", - {"X", - "W", - "Label", - "PathTable", - "PathCode", - "Bias", - "PreOut", - GradVarName("Out")}, - {"num_classes", - "remote_prefetch", - "trainer_id", - "height_sections", - "epmap", - "table_names", - "is_sparse"}, - {GradVarName("X"), GradVarName("W"), GradVarName("Bias")}); - } else if (ctx.IsSelectedRowsOutput(GradVarName("W"))) { - return KernelSignature( - "hierarchical_sigmoid_grad_sr", - {"X", - "W", - "Label", - "PathTable", - "PathCode", - "Bias", - "PreOut", - GradVarName("Out")}, - {"num_classes", - "remote_prefetch", - "trainer_id", - "height_sections", - "epmap", - "table_names", - "is_sparse"}, - {GradVarName("X"), GradVarName("W"), GradVarName("Bias")}); + if (ctx.IsDenseTensorOutput("W@GRAD")) { + return KernelSignature("hierarchical_sigmoid_grad", + {"X", + "W", + "Label", + "PathTable", + "PathCode", + "Bias", + "PreOut", + "Out@GRAD"}, + {"num_classes", + "remote_prefetch", + "trainer_id", + "height_sections", + "epmap", + "table_names", + "is_sparse"}, + {"X@GRAD", "W@GRAD", "Bias@GRAD"}); + } else if (ctx.IsSelectedRowsOutput("W@GRAD")) { + return KernelSignature("hierarchical_sigmoid_grad_sr", + {"X", + "W", + "Label", + "PathTable", + "PathCode", + "Bias", + "PreOut", + "Out@GRAD"}, + {"num_classes", + "remote_prefetch", + "trainer_id", + "height_sections", + "epmap", + "table_names", + "is_sparse"}, + {"X@GRAD", "W@GRAD", "Bias@GRAD"}); } else { return KernelSignature("unregistered", {}, {}, {}); } diff --git a/paddle/phi/ops/compat/huber_loss_sig.cc b/paddle/phi/ops/compat/huber_loss_sig.cc index 6f669a4a8b6..b7bf143fd40 100644 --- a/paddle/phi/ops/compat/huber_loss_sig.cc +++ b/paddle/phi/ops/compat/huber_loss_sig.cc @@ -24,9 +24,9 @@ KernelSignature HuberLossOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature HuberLossGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("huber_loss_grad", - {"Residual", GradVarName("Out")}, + {"Residual", "Out@GRAD"}, {"delta"}, - {GradVarName("X"), GradVarName("Y")}); + {"X@GRAD", "Y@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/index_sample_sig.cc b/paddle/phi/ops/compat/index_sample_sig.cc index 3b7e3f063d6..9c1b7e27f04 100644 --- a/paddle/phi/ops/compat/index_sample_sig.cc +++ b/paddle/phi/ops/compat/index_sample_sig.cc @@ -18,10 +18,8 @@ namespace phi { KernelSignature IndexSampleGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("index_sample_grad", - {"X", "Index", GradVarName("Out")}, - {}, - {GradVarName("X")}); + return KernelSignature( + "index_sample_grad", {"X", "Index", "Out@GRAD"}, {}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/index_select_sig.cc b/paddle/phi/ops/compat/index_select_sig.cc index 53eff1bbcd7..096ad2332c9 100644 --- a/paddle/phi/ops/compat/index_select_sig.cc +++ b/paddle/phi/ops/compat/index_select_sig.cc @@ -18,10 +18,8 @@ namespace phi { KernelSignature IndexSelectGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("index_select_grad", - {"X", "Index", GradVarName("Out")}, - {"dim"}, - {GradVarName("X")}); + return KernelSignature( + "index_select_grad", {"X", "Index", "Out@GRAD"}, {"dim"}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/interpolate_sig.cc b/paddle/phi/ops/compat/interpolate_sig.cc index ba0e971e4ab..61b02240730 100644 --- a/paddle/phi/ops/compat/interpolate_sig.cc +++ b/paddle/phi/ops/compat/interpolate_sig.cc @@ -92,81 +92,76 @@ KernelSignature BicubicInterpOpArgumentMapping( KernelSignature BilinearInterpGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "bilinear_interp_v2_grad", - {"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")}, - {"data_layout", - "out_d", - "out_h", - "out_w", - "scale", - "interp_method", - "align_corners", - "align_mode"}, - {GradVarName("X")}); + return KernelSignature("bilinear_interp_v2_grad", + {"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"}, + {"data_layout", + "out_d", + "out_h", + "out_w", + "scale", + "interp_method", + "align_corners", + "align_mode"}, + {"X@GRAD"}); } KernelSignature NearestInterpGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "nearest_interp_v2_grad", - {"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")}, - {"data_layout", - "out_d", - "out_h", - "out_w", - "scale", - "interp_method", - "align_corners", - "align_mode"}, - {GradVarName("X")}); + return KernelSignature("nearest_interp_v2_grad", + {"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"}, + {"data_layout", + "out_d", + "out_h", + "out_w", + "scale", + "interp_method", + "align_corners", + "align_mode"}, + {"X@GRAD"}); } KernelSignature TrilinearInterpGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "trilinear_interp_v2_grad", - {"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")}, - {"data_layout", - "out_d", - "out_h", - "out_w", - "scale", - "interp_method", - "align_corners", - "align_mode"}, - {GradVarName("X")}); + return KernelSignature("trilinear_interp_v2_grad", + {"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"}, + {"data_layout", + "out_d", + "out_h", + "out_w", + "scale", + "interp_method", + "align_corners", + "align_mode"}, + {"X@GRAD"}); } KernelSignature LinearInterpGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "linear_interp_v2_grad", - {"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")}, - {"data_layout", - "out_d", - "out_h", - "out_w", - "scale", - "interp_method", - "align_corners", - "align_mode"}, - {GradVarName("X")}); + return KernelSignature("linear_interp_v2_grad", + {"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"}, + {"data_layout", + "out_d", + "out_h", + "out_w", + "scale", + "interp_method", + "align_corners", + "align_mode"}, + {"X@GRAD"}); } KernelSignature BicubicInterpGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "bicubic_interp_v2_grad", - {"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")}, - {"data_layout", - "out_d", - "out_h", - "out_w", - "scale", - "interp_method", - "align_corners", - "align_mode"}, - {GradVarName("X")}); + return KernelSignature("bicubic_interp_v2_grad", + {"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"}, + {"data_layout", + "out_d", + "out_h", + "out_w", + "scale", + "interp_method", + "align_corners", + "align_mode"}, + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/kldiv_loss_sig.cc b/paddle/phi/ops/compat/kldiv_loss_sig.cc index 22d2f074e9f..8af0edd3164 100644 --- a/paddle/phi/ops/compat/kldiv_loss_sig.cc +++ b/paddle/phi/ops/compat/kldiv_loss_sig.cc @@ -20,9 +20,9 @@ namespace phi { KernelSignature KLDivLossGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("kldiv_loss_grad", - {"X", "Target", GradVarName("Loss")}, + {"X", "Target", "Loss@GRAD"}, {"reduction"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/kron_sig.cc b/paddle/phi/ops/compat/kron_sig.cc index 06b6545f58e..e2ba41dcadd 100644 --- a/paddle/phi/ops/compat/kron_sig.cc +++ b/paddle/phi/ops/compat/kron_sig.cc @@ -17,10 +17,8 @@ namespace phi { KernelSignature KronGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("kron_grad", - {"X", "Y", GradVarName("Out")}, - {}, - {GradVarName("X"), GradVarName("Y")}); + return KernelSignature( + "kron_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/kthvalue_sig.cc b/paddle/phi/ops/compat/kthvalue_sig.cc index 3b1a6a45f9a..b04726ec3b3 100644 --- a/paddle/phi/ops/compat/kthvalue_sig.cc +++ b/paddle/phi/ops/compat/kthvalue_sig.cc @@ -20,9 +20,9 @@ namespace phi { KernelSignature KthvalueGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("kthvalue_grad", - {"X", "Indices", GradVarName("Out")}, + {"X", "Indices", "Out@GRAD"}, {"k", "axis", "keepdim"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/label_smooth_sig.cc b/paddle/phi/ops/compat/label_smooth_sig.cc index 4fb62a8ca26..7607af2b61b 100644 --- a/paddle/phi/ops/compat/label_smooth_sig.cc +++ b/paddle/phi/ops/compat/label_smooth_sig.cc @@ -24,10 +24,8 @@ KernelSignature LabelSmoothOpArgumentMapping( KernelSignature LabelSmoothGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("label_smooth_grad", - {GradVarName("Out")}, - {"epsilon"}, - {GradVarName("X")}); + return KernelSignature( + "label_smooth_grad", {"Out@GRAD"}, {"epsilon"}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/layer_norm_sig.cc b/paddle/phi/ops/compat/layer_norm_sig.cc index 17a81e9ec01..eb47c516ab3 100644 --- a/paddle/phi/ops/compat/layer_norm_sig.cc +++ b/paddle/phi/ops/compat/layer_norm_sig.cc @@ -25,11 +25,10 @@ KernelSignature LayerNormOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature LayerNormGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "layer_norm_grad", - {"X", "Mean", "Variance", "Scale", "Bias", GradVarName("Y")}, - {"epsilon", "begin_norm_axis", "is_test"}, - {GradVarName("X"), GradVarName("Scale"), GradVarName("Bias")}); + return KernelSignature("layer_norm_grad", + {"X", "Mean", "Variance", "Scale", "Bias", "Y@GRAD"}, + {"epsilon", "begin_norm_axis", "is_test"}, + {"X@GRAD", "Scale@GRAD", "Bias@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/lerp_sig.cc b/paddle/phi/ops/compat/lerp_sig.cc index 3a8b23ca4c4..154424468d6 100644 --- a/paddle/phi/ops/compat/lerp_sig.cc +++ b/paddle/phi/ops/compat/lerp_sig.cc @@ -22,9 +22,9 @@ KernelSignature LerpOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature LerpGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("lerp_grad", - {"X", "Y", "Weight", "Out", GradVarName("Out")}, + {"X", "Y", "Weight", "Out", "Out@GRAD"}, {}, - {GradVarName("X"), GradVarName("Y")}); + {"X@GRAD", "Y@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/lgamma_sig.cc b/paddle/phi/ops/compat/lgamma_sig.cc index 452ba5e2b45..192754cc846 100644 --- a/paddle/phi/ops/compat/lgamma_sig.cc +++ b/paddle/phi/ops/compat/lgamma_sig.cc @@ -17,8 +17,7 @@ namespace phi { KernelSignature LgammaGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "lgamma_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")}); + return KernelSignature("lgamma_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/log_loss_sig.cc b/paddle/phi/ops/compat/log_loss_sig.cc index c4ae746e975..adf40bac000 100644 --- a/paddle/phi/ops/compat/log_loss_sig.cc +++ b/paddle/phi/ops/compat/log_loss_sig.cc @@ -19,9 +19,9 @@ namespace phi { KernelSignature LogLossGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("log_loss_grad", - {"Predicted", "Labels", GradVarName("Loss")}, + {"Predicted", "Labels", "Loss@GRAD"}, {"epsilon"}, - {GradVarName("Predicted")}); + {"Predicted@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/log_softmax_sig.cc b/paddle/phi/ops/compat/log_softmax_sig.cc index b1ecc6d5676..20635c89875 100644 --- a/paddle/phi/ops/compat/log_softmax_sig.cc +++ b/paddle/phi/ops/compat/log_softmax_sig.cc @@ -18,10 +18,8 @@ namespace phi { KernelSignature LogSoftmaxGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("log_softmax_grad", - {"Out", GradVarName("Out")}, - {"axis"}, - {GradVarName("X")}); + return KernelSignature( + "log_softmax_grad", {"Out", "Out@GRAD"}, {"axis"}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/logsumexp_sig.cc b/paddle/phi/ops/compat/logsumexp_sig.cc index ca7345dbe70..6d988c71880 100644 --- a/paddle/phi/ops/compat/logsumexp_sig.cc +++ b/paddle/phi/ops/compat/logsumexp_sig.cc @@ -19,9 +19,9 @@ namespace phi { KernelSignature LogsumexpGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("logsumexp_grad", - {"X", "Out", GradVarName("Out")}, + {"X", "Out", "Out@GRAD"}, {"axis", "keepdim", "reduce_all"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/masked_select_sig.cc b/paddle/phi/ops/compat/masked_select_sig.cc index ec0eb90315b..47b4f2fac31 100644 --- a/paddle/phi/ops/compat/masked_select_sig.cc +++ b/paddle/phi/ops/compat/masked_select_sig.cc @@ -23,10 +23,8 @@ KernelSignature MaskedSelectOpArgumentMapping( KernelSignature MaskedSelectGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("masked_select_grad", - {"X", "Mask", GradVarName("Y")}, - {}, - {GradVarName("X")}); + return KernelSignature( + "masked_select_grad", {"X", "Mask", "Y@GRAD"}, {}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/matmul_sig.cc b/paddle/phi/ops/compat/matmul_sig.cc index 771a7c3acc3..4e125f0dbea 100644 --- a/paddle/phi/ops/compat/matmul_sig.cc +++ b/paddle/phi/ops/compat/matmul_sig.cc @@ -19,14 +19,14 @@ namespace phi { KernelSignature MatmulGradOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.HasAttr("use_addto")) { return KernelSignature("addto_matmul_grad", - {"X", "Y", GradVarName("Out")}, + {"X", "Y", "Out@GRAD"}, {"trans_x", "trans_y", "use_addto"}, - {GradVarName("X"), GradVarName("Y")}); + {"X@GRAD", "Y@GRAD"}); } else { return KernelSignature("matmul_grad", - {"X", "Y", GradVarName("Out")}, + {"X", "Y", "Out@GRAD"}, {"trans_x", "trans_y"}, - {GradVarName("X"), GradVarName("Y")}); + {"X@GRAD", "Y@GRAD"}); } } diff --git a/paddle/phi/ops/compat/matrix_power_sig.cc b/paddle/phi/ops/compat/matrix_power_sig.cc index 4c9ad4e74ab..00cb1f82b80 100644 --- a/paddle/phi/ops/compat/matrix_power_sig.cc +++ b/paddle/phi/ops/compat/matrix_power_sig.cc @@ -18,10 +18,8 @@ namespace phi { KernelSignature MatrixPowerGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("matrix_power_grad", - {"X", "Out", GradVarName("Out")}, - {"n"}, - {GradVarName("X")}); + return KernelSignature( + "matrix_power_grad", {"X", "Out", "Out@GRAD"}, {"n"}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/maxout_sig.cc b/paddle/phi/ops/compat/maxout_sig.cc index d16dd1c8617..9e028bc81fb 100644 --- a/paddle/phi/ops/compat/maxout_sig.cc +++ b/paddle/phi/ops/compat/maxout_sig.cc @@ -21,10 +21,8 @@ KernelSignature MaxoutArgumentMapping(const ArgumentMappingContext& ctx) { } KernelSignature MaxoutGradArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("maxout_grad", - {"X", "Out", GradVarName("Out")}, - {"groups", "axis"}, - {GradVarName("X")}); + return KernelSignature( + "maxout_grad", {"X", "Out", "Out@GRAD"}, {"groups", "axis"}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/mean_sig.cc b/paddle/phi/ops/compat/mean_sig.cc index 6decd0da0b0..461d6ab32ce 100644 --- a/paddle/phi/ops/compat/mean_sig.cc +++ b/paddle/phi/ops/compat/mean_sig.cc @@ -22,8 +22,7 @@ KernelSignature MeanOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature MeanGradOpGradArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "mean_all_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")}); + return KernelSignature("mean_all_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/meshgrid_sig.cc b/paddle/phi/ops/compat/meshgrid_sig.cc index 44671c84e7a..f0c8cc7ea62 100644 --- a/paddle/phi/ops/compat/meshgrid_sig.cc +++ b/paddle/phi/ops/compat/meshgrid_sig.cc @@ -22,8 +22,7 @@ KernelSignature MeshgridOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature MeshgridGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "meshgrid_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")}); + return KernelSignature("meshgrid_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/mode_sig.cc b/paddle/phi/ops/compat/mode_sig.cc index 20994c08aa7..e21cd69bf60 100644 --- a/paddle/phi/ops/compat/mode_sig.cc +++ b/paddle/phi/ops/compat/mode_sig.cc @@ -23,9 +23,9 @@ KernelSignature ModeOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ModeGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("mode_grad", - {"X", "Indices", GradVarName("Out")}, + {"X", "Indices", "Out@GRAD"}, {"axis", "keepdim"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/mul_sig.cc b/paddle/phi/ops/compat/mul_sig.cc index 8770db1039e..4afff4aa1d7 100644 --- a/paddle/phi/ops/compat/mul_sig.cc +++ b/paddle/phi/ops/compat/mul_sig.cc @@ -18,9 +18,9 @@ namespace phi { KernelSignature MulGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("matmul_with_flatten_grad", - {"X", "Y", GradVarName("Out")}, + {"X", "Y", "Out@GRAD"}, {"x_num_col_dims", "y_num_col_dims"}, - {GradVarName("X"), GradVarName("Y")}); + {"X@GRAD", "Y@GRAD"}); } KernelSignature MulDoubleGradOpArgumentMapping( diff --git a/paddle/phi/ops/compat/multi_dot_sig.cc b/paddle/phi/ops/compat/multi_dot_sig.cc index 2e05bd6d155..29af82c9d1d 100644 --- a/paddle/phi/ops/compat/multi_dot_sig.cc +++ b/paddle/phi/ops/compat/multi_dot_sig.cc @@ -18,8 +18,7 @@ namespace phi { KernelSignature MultiDotGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "multi_dot_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")}); + return KernelSignature("multi_dot_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/multiplex_sig.cc b/paddle/phi/ops/compat/multiplex_sig.cc index 9dab4655d17..538b1c13dda 100644 --- a/paddle/phi/ops/compat/multiplex_sig.cc +++ b/paddle/phi/ops/compat/multiplex_sig.cc @@ -22,8 +22,7 @@ KernelSignature MultiplexOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature MultiplexGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "multiplex_grad", {"Ids", GradVarName("Out")}, {}, {GradVarName("X")}); + return KernelSignature("multiplex_grad", {"Ids", "Out@GRAD"}, {}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/mv_sig.cc b/paddle/phi/ops/compat/mv_sig.cc index 0012f8e1ccb..e965ddbb726 100644 --- a/paddle/phi/ops/compat/mv_sig.cc +++ b/paddle/phi/ops/compat/mv_sig.cc @@ -17,10 +17,8 @@ namespace phi { KernelSignature MvGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("mv_grad", - {"X", "Vec", GradVarName("Out")}, - {}, - {GradVarName("X"), GradVarName("Vec")}); + return KernelSignature( + "mv_grad", {"X", "Vec", "Out@GRAD"}, {}, {"X@GRAD", "Vec@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/nll_loss_sig.cc b/paddle/phi/ops/compat/nll_loss_sig.cc index 87a060ce7a6..f3f9c531781 100644 --- a/paddle/phi/ops/compat/nll_loss_sig.cc +++ b/paddle/phi/ops/compat/nll_loss_sig.cc @@ -27,11 +27,10 @@ KernelSignature NllLossOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature NllLossGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "nll_loss_grad", - {"X", "Label", "Weight", "Total_weight", GradVarName("Out")}, - {"ignore_index", "reduction"}, - {GradVarName("X")}); + return KernelSignature("nll_loss_grad", + {"X", "Label", "Weight", "Total_weight", "Out@GRAD"}, + {"ignore_index", "reduction"}, + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/norm_sig.cc b/paddle/phi/ops/compat/norm_sig.cc index a74db9b5686..b9e56f3d166 100644 --- a/paddle/phi/ops/compat/norm_sig.cc +++ b/paddle/phi/ops/compat/norm_sig.cc @@ -23,9 +23,9 @@ KernelSignature NormOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature NormGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("norm_grad", - {"X", "Norm", GradVarName("Out")}, + {"X", "Norm", "Out@GRAD"}, {"axis", "epsilon", "is_test"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/p_norm_sig.cc b/paddle/phi/ops/compat/p_norm_sig.cc index d3bff55346c..82b88aa09ff 100644 --- a/paddle/phi/ops/compat/p_norm_sig.cc +++ b/paddle/phi/ops/compat/p_norm_sig.cc @@ -17,9 +17,9 @@ namespace phi { KernelSignature PNormGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("p_norm_grad", - {"X", "Out", GradVarName("Out")}, + {"X", "Out", "Out@GRAD"}, {"porder", "axis", "epsilon", "keepdim", "asvector"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/pad3d_sig.cc b/paddle/phi/ops/compat/pad3d_sig.cc index c43b98fa27e..dd8a37d24b7 100644 --- a/paddle/phi/ops/compat/pad3d_sig.cc +++ b/paddle/phi/ops/compat/pad3d_sig.cc @@ -29,14 +29,14 @@ KernelSignature Pad3dOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature Pad3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.HasInput("Paddings")) { return KernelSignature("pad3d_grad", - {"X", GradVarName("Out")}, + {"X", "Out@GRAD"}, {"Paddings", "mode", "value", "data_format"}, - {GradVarName("X")}); + {"X@GRAD"}); } return KernelSignature("pad3d_grad", - {"X", GradVarName("Out")}, + {"X", "Out@GRAD"}, {"paddings", "mode", "value", "data_format"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/pad_sig.cc b/paddle/phi/ops/compat/pad_sig.cc index 4eadbfa98be..bb870eb256c 100644 --- a/paddle/phi/ops/compat/pad_sig.cc +++ b/paddle/phi/ops/compat/pad_sig.cc @@ -18,10 +18,8 @@ namespace phi { KernelSignature PadGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("pad_grad", - {GradVarName("Out")}, - {"paddings", "pad_value"}, - {GradVarName("X")}); + return KernelSignature( + "pad_grad", {"Out@GRAD"}, {"paddings", "pad_value"}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/pixel_shuffle_sig.cc b/paddle/phi/ops/compat/pixel_shuffle_sig.cc index 641288cf12a..96cb01a38fc 100644 --- a/paddle/phi/ops/compat/pixel_shuffle_sig.cc +++ b/paddle/phi/ops/compat/pixel_shuffle_sig.cc @@ -25,9 +25,9 @@ KernelSignature PixelShuffleOpArgumentMapping( KernelSignature PixelShuffleGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("pixel_shuffle_grad", - {GradVarName("Out")}, + {"Out@GRAD"}, {"upscale_factor", "data_format"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/poisson_sig.cc b/paddle/phi/ops/compat/poisson_sig.cc index e45640c11b6..6022c3b608d 100644 --- a/paddle/phi/ops/compat/poisson_sig.cc +++ b/paddle/phi/ops/compat/poisson_sig.cc @@ -18,8 +18,7 @@ namespace phi { KernelSignature PoissonGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "poisson_grad", {GradVarName("Out")}, {}, {GradVarName("X")}); + return KernelSignature("poisson_grad", {"Out@GRAD"}, {}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/pool_sig.cc b/paddle/phi/ops/compat/pool_sig.cc index 390d3db5e78..b807b21a1c0 100644 --- a/paddle/phi/ops/compat/pool_sig.cc +++ b/paddle/phi/ops/compat/pool_sig.cc @@ -34,7 +34,7 @@ KernelSignature Pool2dOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature Pool2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("pool2d_grad", - {"X", "Out", GradVarName("Out")}, + {"X", "Out", "Out@GRAD"}, {"ksize", "strides", "paddings", @@ -45,7 +45,7 @@ KernelSignature Pool2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { "global_pooling", "adaptive", "padding_algorithm"}, - {GradVarName("X")}); + {"X@GRAD"}); } KernelSignature Pool2dDoubleGradOpArgumentMapping( @@ -78,9 +78,9 @@ KernelSignature MaxPool2dWithIndexGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( "max_pool2d_with_index_grad", - {"X", "Mask", GradVarName("Out")}, + {"X", "Mask", "Out@GRAD"}, {"ksize", "strides", "paddings", "global_pooling", "adaptive"}, - {GradVarName("X")}); + {"X@GRAD"}); } KernelSignature Pool3dOpArgumentMapping(const ArgumentMappingContext& ctx) { @@ -101,7 +101,7 @@ KernelSignature Pool3dOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature Pool3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("pool3d_grad", - {"X", "Out", GradVarName("Out")}, + {"X", "Out", "Out@GRAD"}, {"ksize", "strides", "paddings", @@ -112,7 +112,7 @@ KernelSignature Pool3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { "global_pooling", "adaptive", "padding_algorithm"}, - {GradVarName("X")}); + {"X@GRAD"}); } KernelSignature MaxPool3dWithIndexOpArgumentMapping( @@ -128,9 +128,9 @@ KernelSignature MaxPool3dWithIndexGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( "max_pool3d_with_index_grad", - {"X", "Mask", GradVarName("Out")}, + {"X", "Mask", "Out@GRAD"}, {"ksize", "strides", "paddings", "global_pooling", "adaptive"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/prelu_sig.cc b/paddle/phi/ops/compat/prelu_sig.cc index 43e5f20a926..6e25e1d9f75 100644 --- a/paddle/phi/ops/compat/prelu_sig.cc +++ b/paddle/phi/ops/compat/prelu_sig.cc @@ -23,9 +23,9 @@ KernelSignature PReluOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature PReluGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("prelu_grad", - {"X", "Alpha", GradVarName("Out")}, + {"X", "Alpha", "Out@GRAD"}, {"data_format", "mode"}, - {GradVarName("X"), GradVarName("Alpha")}); + {"X@GRAD", "Alpha@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/psroi_pool_sig.cc b/paddle/phi/ops/compat/psroi_pool_sig.cc index 4d694d9a775..df1dc1113cc 100644 --- a/paddle/phi/ops/compat/psroi_pool_sig.cc +++ b/paddle/phi/ops/compat/psroi_pool_sig.cc @@ -28,9 +28,9 @@ KernelSignature PsroiPoolGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( "psroi_pool_grad", - {"X", "ROIs", "RoisNum", GradVarName("Out")}, + {"X", "ROIs", "RoisNum", "Out@GRAD"}, {"pooled_height", "pooled_width", "output_channels", "spatial_scale"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/put_along_axis_sig.cc b/paddle/phi/ops/compat/put_along_axis_sig.cc index 5f8dc1cf4cd..83f0e5f65a0 100644 --- a/paddle/phi/ops/compat/put_along_axis_sig.cc +++ b/paddle/phi/ops/compat/put_along_axis_sig.cc @@ -26,9 +26,9 @@ KernelSignature PutAlongAxisArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature PutAlongAxisGradArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("put_along_axis_grad", - {"Input", "Index", GradVarName("Result")}, + {"Input", "Index", "Result@GRAD"}, {"Axis", "Reduce"}, - {GradVarName("Input"), GradVarName("Value")}); + {"Input@GRAD", "Value@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/reduce_sig.cc b/paddle/phi/ops/compat/reduce_sig.cc index cf2edf9f20f..a0ba07f5e8e 100644 --- a/paddle/phi/ops/compat/reduce_sig.cc +++ b/paddle/phi/ops/compat/reduce_sig.cc @@ -130,41 +130,41 @@ KernelSignature ReduceAllOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ReduceSumGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("sum_grad", - {"X", GradVarName("Out")}, + {"X", "Out@GRAD"}, {"dim", "keep_dim", "reduce_all"}, - {GradVarName("X")}); + {"X@GRAD"}); } KernelSignature ReduceMeanGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("mean_grad", - {"X", GradVarName("Out")}, + {"X", "Out@GRAD"}, {"dim", "keep_dim", "reduce_all"}, - {GradVarName("X")}); + {"X@GRAD"}); } KernelSignature ReduceMaxGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("max_grad", - {"X", "Out", GradVarName("Out")}, + {"X", "Out", "Out@GRAD"}, {"dim", "keep_dim", "reduce_all"}, - {GradVarName("X")}); + {"X@GRAD"}); } KernelSignature ReduceMinGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("min_grad", - {"X", "Out", GradVarName("Out")}, + {"X", "Out", "Out@GRAD"}, {"dim", "keep_dim", "reduce_all"}, - {GradVarName("X")}); + {"X@GRAD"}); } KernelSignature ReduceProdGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("prod_grad", - {"X", "Out", GradVarName("Out")}, + {"X", "Out", "Out@GRAD"}, {"dim", "keep_dim", "reduce_all"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/reshape_sig.cc b/paddle/phi/ops/compat/reshape_sig.cc index 04f64e40352..a01f2a98c9b 100644 --- a/paddle/phi/ops/compat/reshape_sig.cc +++ b/paddle/phi/ops/compat/reshape_sig.cc @@ -41,8 +41,7 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ReshapeGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "reshape_grad", {GradVarName("Out")}, {}, {GradVarName("X")}); + return KernelSignature("reshape_grad", {"Out@GRAD"}, {}, {"X@GRAD"}); } KernelSignature ReshapeDoubleGradOpArgumentMapping( diff --git a/paddle/phi/ops/compat/rnn_sig.cc b/paddle/phi/ops/compat/rnn_sig.cc index 352510d5b2e..87c99ac13aa 100644 --- a/paddle/phi/ops/compat/rnn_sig.cc +++ b/paddle/phi/ops/compat/rnn_sig.cc @@ -39,8 +39,8 @@ KernelSignature RnnGradOpArgumentMapping(const ArgumentMappingContext& ctx) { "Out", "DropoutState", "Reserve", - GradVarName("Out"), - GradVarName("State")}, + "Out@GRAD", + "State@GRAD"}, {"dropout_prob", "is_bidirec", "input_size", @@ -49,9 +49,7 @@ KernelSignature RnnGradOpArgumentMapping(const ArgumentMappingContext& ctx) { "mode", "seed", "is_test"}, - {GradVarName("Input"), - GradVarName("PreState"), - GradVarName("WeightList")}); + {"Input@GRAD", "PreState@GRAD", "WeightList@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/roi_align_sig.cc b/paddle/phi/ops/compat/roi_align_sig.cc index 1717ec8f788..7279e82139b 100644 --- a/paddle/phi/ops/compat/roi_align_sig.cc +++ b/paddle/phi/ops/compat/roi_align_sig.cc @@ -30,13 +30,13 @@ KernelSignature RoiAlignOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature RoiAlignGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("roi_align_grad", - {"X", "ROIs", "RoisNum", GradVarName("Out")}, + {"X", "ROIs", "RoisNum", "Out@GRAD"}, {"pooled_height", "pooled_width", "spatial_scale", "sampling_ratio", "aligned"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/roi_pool_sig.cc b/paddle/phi/ops/compat/roi_pool_sig.cc index d04c645f183..971b4b9d5bf 100644 --- a/paddle/phi/ops/compat/roi_pool_sig.cc +++ b/paddle/phi/ops/compat/roi_pool_sig.cc @@ -26,9 +26,9 @@ KernelSignature RoiPoolOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature RoiPoolOpGradArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("roi_pool_grad", - {"X", "ROIs", "RoisNum", "Argmax", GradVarName("Out")}, + {"X", "ROIs", "RoisNum", "Argmax", "Out@GRAD"}, {"pooled_height", "pooled_width", "spatial_scale"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/roll_sig.cc b/paddle/phi/ops/compat/roll_sig.cc index a144f0e8e8a..e6817555bc4 100644 --- a/paddle/phi/ops/compat/roll_sig.cc +++ b/paddle/phi/ops/compat/roll_sig.cc @@ -24,10 +24,8 @@ KernelSignature RollOpArgumentMapping(const ArgumentMappingContext& ctx) { } KernelSignature RollGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("roll_grad", - {"X", GradVarName("Out")}, - {"shifts", "axis"}, - {GradVarName("X")}); + return KernelSignature( + "roll_grad", {"X", "Out@GRAD"}, {"shifts", "axis"}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/segment_pool_sig.cc b/paddle/phi/ops/compat/segment_pool_sig.cc index 97646a2ac31..db07343f9ad 100644 --- a/paddle/phi/ops/compat/segment_pool_sig.cc +++ b/paddle/phi/ops/compat/segment_pool_sig.cc @@ -18,13 +18,12 @@ namespace phi { KernelSignature SegmentPoolGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "segment_pool_grad", - { - "X", "SegmentIds", "Out", "SummedIds", GradVarName("Out"), - }, - {"pooltype"}, - {GradVarName("X")}); + return KernelSignature("segment_pool_grad", + { + "X", "SegmentIds", "Out", "SummedIds", "Out@GRAD", + }, + {"pooltype"}, + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/selu_sig.cc b/paddle/phi/ops/compat/selu_sig.cc index 23f5cc34515..08087584a10 100644 --- a/paddle/phi/ops/compat/selu_sig.cc +++ b/paddle/phi/ops/compat/selu_sig.cc @@ -19,10 +19,8 @@ namespace phi { KernelSignature SeluGradGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("selu_grad", - {"Out", GradVarName("Out")}, - {"scale", "alpha"}, - {GradVarName("X")}); + return KernelSignature( + "selu_grad", {"Out", "Out@GRAD"}, {"scale", "alpha"}, {"X@GRAD"}); } } // namespace phi PD_REGISTER_ARG_MAPPING_FN(selu_grad, phi::SeluGradGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/set_value_sig.cc b/paddle/phi/ops/compat/set_value_sig.cc index 5feff54b028..6ff94a6e263 100644 --- a/paddle/phi/ops/compat/set_value_sig.cc +++ b/paddle/phi/ops/compat/set_value_sig.cc @@ -737,96 +737,89 @@ KernelSignature SetValueGradOpArgumentMapping( if (ctx.InputSize("StartsTensorList") > 0) { if (ctx.InputSize("EndsTensorList") > 0) { if (ctx.InputSize("StepsTensorList") > 0) { - return KernelSignature( - "set_value_grad", - {GradVarName("Out")}, - {"StartsTensorList", - "EndsTensorList", - "StepsTensorList", - "axes", - "decrease_axes", - "none_axes"}, - {GradVarName("Input"), GradVarName("ValueTensor")}); + return KernelSignature("set_value_grad", + {"Out@GRAD"}, + {"StartsTensorList", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes"}, + {"Input@GRAD", "ValueTensor@GRAD"}); } else { - return KernelSignature( - "set_value_grad", - {GradVarName("Out")}, - {"StartsTensorList", - "EndsTensorList", - "steps", - "axes", - "decrease_axes", - "none_axes"}, - {GradVarName("Input"), GradVarName("ValueTensor")}); + return KernelSignature("set_value_grad", + {"Out@GRAD"}, + {"StartsTensorList", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes"}, + {"Input@GRAD", "ValueTensor@GRAD"}); } } else { if (ctx.InputSize("StepsTensorList") > 0) { - return KernelSignature( - "set_value_grad", - {GradVarName("Out")}, - {"StartsTensorList", - "ends", - "StepsTensorList", - "axes", - "decrease_axes", - "none_axes"}, - {GradVarName("Input"), GradVarName("ValueTensor")}); + return KernelSignature("set_value_grad", + {"Out@GRAD"}, + {"StartsTensorList", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes"}, + {"Input@GRAD", "ValueTensor@GRAD"}); } else { - return KernelSignature( - "set_value_grad", - {GradVarName("Out")}, - {"StartsTensorList", - "ends", - "steps", - "axes", - "decrease_axes", - "none_axes"}, - {GradVarName("Input"), GradVarName("ValueTensor")}); + return KernelSignature("set_value_grad", + {"Out@GRAD"}, + {"StartsTensorList", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes"}, + {"Input@GRAD", "ValueTensor@GRAD"}); } } } else { if (ctx.InputSize("EndsTensorList") > 0) { if (ctx.InputSize("StepsTensorList") > 0) { - return KernelSignature( - "set_value_grad", - {GradVarName("Out")}, - {"starts", - "EndsTensorList", - "StepsTensorList", - "axes", - "decrease_axes", - "none_axes"}, - {GradVarName("Input"), GradVarName("ValueTensor")}); + return KernelSignature("set_value_grad", + {"Out@GRAD"}, + {"starts", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes"}, + {"Input@GRAD", "ValueTensor@GRAD"}); } else { - return KernelSignature( - "set_value_grad", - {GradVarName("Out")}, - {"starts", - "EndsTensorList", - "steps", - "axes", - "decrease_axes", - "none_axes"}, - {GradVarName("Input"), GradVarName("ValueTensor")}); + return KernelSignature("set_value_grad", + {"Out@GRAD"}, + {"starts", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes"}, + {"Input@GRAD", "ValueTensor@GRAD"}); } } else { if (ctx.InputSize("StepsTensorList") > 0) { - return KernelSignature( - "set_value_grad", - {GradVarName("Out")}, - {"starts", - "ends", - "StepsTensorList", - "axes", - "decrease_axes", - "none_axes"}, - {GradVarName("Input"), GradVarName("ValueTensor")}); + return KernelSignature("set_value_grad", + {"Out@GRAD"}, + {"starts", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes"}, + {"Input@GRAD", "ValueTensor@GRAD"}); } else { return KernelSignature( "set_value_grad", - {GradVarName("Out")}, + {"Out@GRAD"}, {"starts", "ends", "steps", "axes", "decrease_axes", "none_axes"}, - {GradVarName("Input"), GradVarName("ValueTensor")}); + {"Input@GRAD", "ValueTensor@GRAD"}); } } } diff --git a/paddle/phi/ops/compat/sigmoid_cross_entropy_with_logits_sig.cc b/paddle/phi/ops/compat/sigmoid_cross_entropy_with_logits_sig.cc index 61ad9627a96..795e287d53d 100644 --- a/paddle/phi/ops/compat/sigmoid_cross_entropy_with_logits_sig.cc +++ b/paddle/phi/ops/compat/sigmoid_cross_entropy_with_logits_sig.cc @@ -19,9 +19,9 @@ namespace phi { KernelSignature SigmoidCrossEntropyWithLogitsKernelGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("sigmoid_cross_entropy_with_logits_grad", - {"X", "Label", GradVarName("Out")}, + {"X", "Label", "Out@GRAD"}, {"normalize", "ignore_index"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/slice_sig.cc b/paddle/phi/ops/compat/slice_sig.cc index ba3bafdaa51..607d0b31310 100644 --- a/paddle/phi/ops/compat/slice_sig.cc +++ b/paddle/phi/ops/compat/slice_sig.cc @@ -105,74 +105,74 @@ KernelSignature SliceGradOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.HasInput("StartsTensor")) { if (ctx.HasInput("EndsTensor")) { return KernelSignature("slice_grad", - {"Input", GradVarName("Out")}, + {"Input", "Out@GRAD"}, {"axes", "StartsTensor", "EndsTensor", "infer_flags", "decrease_axis"}, - {GradVarName("Input")}); + {"Input@GRAD"}); } else if (ctx.InputSize("EndsTensorList") > 0) { return KernelSignature("slice_grad", - {"Input", GradVarName("Out")}, + {"Input", "Out@GRAD"}, {"axes", "StartsTensor", "EndsTensorList", "infer_flags", "decrease_axis"}, - {GradVarName("Input")}); + {"Input@GRAD"}); } else { return KernelSignature( "slice_grad", - {"Input", GradVarName("Out")}, + {"Input", "Out@GRAD"}, {"axes", "StartsTensor", "ends", "infer_flags", "decrease_axis"}, - {GradVarName("Input")}); + {"Input@GRAD"}); } } else if (ctx.InputSize("StartsTensorList") > 0) { if (ctx.HasInput("EndsTensor")) { return KernelSignature("slice_grad", - {"Input", GradVarName("Out")}, + {"Input", "Out@GRAD"}, {"axes", "StartsTensorList", "EndsTensor", "infer_flags", "decrease_axis"}, - {GradVarName("Input")}); + {"Input@GRAD"}); } else if (ctx.InputSize("EndsTensorList") > 0) { return KernelSignature("slice_grad", - {"Input", GradVarName("Out")}, + {"Input", "Out@GRAD"}, {"axes", "StartsTensorList", "EndsTensorList", "infer_flags", "decrease_axis"}, - {GradVarName("Input")}); + {"Input@GRAD"}); } else { return KernelSignature( "slice_grad", - {"Input", GradVarName("Out")}, + {"Input", "Out@GRAD"}, {"axes", "StartsTensorList", "ends", "infer_flags", "decrease_axis"}, - {GradVarName("Input")}); + {"Input@GRAD"}); } } else { if (ctx.HasInput("EndsTensor")) { return KernelSignature( "slice_grad", - {"Input", GradVarName("Out")}, + {"Input", "Out@GRAD"}, {"axes", "starts", "EndsTensor", "infer_flags", "decrease_axis"}, - {GradVarName("Input")}); + {"Input@GRAD"}); } else if (ctx.InputSize("EndsTensorList") > 0) { return KernelSignature( "slice_grad", - {"Input", GradVarName("Out")}, + {"Input", "Out@GRAD"}, {"axes", "starts", "EndsTensorList", "infer_flags", "decrease_axis"}, - {GradVarName("Input")}); + {"Input@GRAD"}); } else { return KernelSignature( "slice_grad", - {"Input", GradVarName("Out")}, + {"Input", "Out@GRAD"}, {"axes", "starts", "ends", "infer_flags", "decrease_axis"}, - {GradVarName("Input")}); + {"Input@GRAD"}); } } } diff --git a/paddle/phi/ops/compat/softmax_sig.cc b/paddle/phi/ops/compat/softmax_sig.cc index 65a915b51d0..a30a2a2b06f 100644 --- a/paddle/phi/ops/compat/softmax_sig.cc +++ b/paddle/phi/ops/compat/softmax_sig.cc @@ -22,10 +22,8 @@ KernelSignature SoftmaxOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature SoftmaxGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("softmax_grad", - {"Out", GradVarName("Out")}, - {"axis"}, - {GradVarName("X")}); + return KernelSignature( + "softmax_grad", {"Out", "Out@GRAD"}, {"axis"}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/softmax_with_cross_entropy_sig.cc b/paddle/phi/ops/compat/softmax_with_cross_entropy_sig.cc index 9cfc5ded90a..c75d4f711dc 100644 --- a/paddle/phi/ops/compat/softmax_with_cross_entropy_sig.cc +++ b/paddle/phi/ops/compat/softmax_with_cross_entropy_sig.cc @@ -31,13 +31,13 @@ KernelSignature SoftmaxWithCrossEntropyOpArgumentMapping( KernelSignature SoftmaxWithCrossEntropyGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("cross_entropy_with_softmax_grad", - {"Label", "Softmax", GradVarName("Loss")}, + {"Label", "Softmax", "Loss@GRAD"}, {"soft_label", "use_softmax", "numeric_stable_mode", "ignore_index", "axis"}, - {GradVarName("Logits")}); + {"Logits@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/squeeze_sig.cc b/paddle/phi/ops/compat/squeeze_sig.cc index 276246533e8..c65d77df980 100644 --- a/paddle/phi/ops/compat/squeeze_sig.cc +++ b/paddle/phi/ops/compat/squeeze_sig.cc @@ -23,10 +23,8 @@ KernelSignature SqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature SqueezeGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("squeeze_grad", - {"XShape", GradVarName("Out")}, - {"axes"}, - {GradVarName("X")}); + return KernelSignature( + "squeeze_grad", {"XShape", "Out@GRAD"}, {"axes"}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/stack_sig.cc b/paddle/phi/ops/compat/stack_sig.cc index 97768eb8902..334fdb29e5f 100644 --- a/paddle/phi/ops/compat/stack_sig.cc +++ b/paddle/phi/ops/compat/stack_sig.cc @@ -14,8 +14,7 @@ limitations under the License. */ namespace phi { KernelSignature StackGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "stack_grad", {GradVarName("Y")}, {"axis"}, {GradVarName("X")}); + return KernelSignature("stack_grad", {"Y@GRAD"}, {"axis"}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/strided_slice_sig.cc b/paddle/phi/ops/compat/strided_slice_sig.cc index 9fb70af0dea..5421fcd616c 100644 --- a/paddle/phi/ops/compat/strided_slice_sig.cc +++ b/paddle/phi/ops/compat/strided_slice_sig.cc @@ -29,35 +29,35 @@ KernelSignature StridedSliceOpArgumentMapping( bool use_attr_ends = !ctx.IsRuntime() && !ends.empty(); bool use_attr_strides = !ctx.IsRuntime() && !strides.empty(); - std::string starts_key = + const char* starts_key = ctx.HasInput("StartsTensor") ? "StartsTensor" : (ctx.InputSize("StartsTensorList") > 0 ? (use_attr_starts ? "starts" : "StartsTensorList") : "starts"); - std::string ends_key = + const char* ends_key = ctx.HasInput("EndsTensor") ? "EndsTensor" : (ctx.InputSize("EndsTensorList") > 0 ? (use_attr_ends ? "ends" : "EndsTensorList") : "ends"); - std::string strides_key = + const char* strides_key = ctx.HasInput("StridesTensor") ? "StridesTensor" : (ctx.InputSize("StridesTensorList") > 0 ? (use_attr_strides ? "strides" : "StridesTensorList") : "strides"); - paddle::SmallVector inputs = {"Input"}; - paddle::SmallVector attrs = {"axes", + paddle::SmallVector inputs = {"Input"}; + paddle::SmallVector attrs = {"axes", starts_key, ends_key, strides_key, "infer_flags", "decrease_axis"}; - paddle::SmallVector outputs = {"Out"}; + paddle::SmallVector outputs = {"Out"}; - std::string kernel_name; + const char* kernel_name; if (ctx.IsDenseTensorVectorInput("Input")) { kernel_name = "strided_slice_array"; } else { @@ -78,35 +78,35 @@ KernelSignature StridedSliceGradOpArgumentMapping( bool use_attr_ends = !ctx.IsRuntime() && !ends.empty(); bool use_attr_strides = !ctx.IsRuntime() && !strides.empty(); - std::string starts_key = + const char* starts_key = ctx.HasInput("StartsTensor") ? "StartsTensor" : (ctx.InputSize("StartsTensorList") > 0 ? (use_attr_starts ? "starts" : "StartsTensorList") : "starts"); - std::string ends_key = + const char* ends_key = ctx.HasInput("EndsTensor") ? "EndsTensor" : (ctx.InputSize("EndsTensorList") > 0 ? (use_attr_ends ? "ends" : "EndsTensorList") : "ends"); - std::string strides_key = + const char* strides_key = ctx.HasInput("StridesTensor") ? "StridesTensor" : (ctx.InputSize("StridesTensorList") > 0 ? (use_attr_strides ? "strides" : "StridesTensorList") : "strides"); - paddle::SmallVector inputs = {"Input", GradVarName("Out")}; - paddle::SmallVector attrs = {"axes", + paddle::SmallVector inputs = {"Input", "Out@GRAD"}; + paddle::SmallVector attrs = {"axes", starts_key, ends_key, strides_key, "infer_flags", "decrease_axis"}; - paddle::SmallVector outputs = {GradVarName("Input")}; + paddle::SmallVector outputs = {"Input@GRAD"}; - std::string kernel_name; + const char* kernel_name; if (ctx.IsDenseTensorVectorInput("Input")) { kernel_name = "strided_slice_array_grad"; } else { diff --git a/paddle/phi/ops/compat/take_along_axis_sig.cc b/paddle/phi/ops/compat/take_along_axis_sig.cc index 27a996a270d..a35c1c2db44 100644 --- a/paddle/phi/ops/compat/take_along_axis_sig.cc +++ b/paddle/phi/ops/compat/take_along_axis_sig.cc @@ -25,9 +25,9 @@ KernelSignature TakeAlongAxisArgumentMapping( KernelSignature TakeAlongAxisGradArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("take_along_axis_grad", - {"Input", "Index", GradVarName("Result")}, + {"Input", "Index", "Result@GRAD"}, {"Axis"}, - {GradVarName("Input")}); + {"Input@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/temporal_shift_sig.cc b/paddle/phi/ops/compat/temporal_shift_sig.cc index a686c37ff7e..a6eed22716c 100644 --- a/paddle/phi/ops/compat/temporal_shift_sig.cc +++ b/paddle/phi/ops/compat/temporal_shift_sig.cc @@ -27,9 +27,9 @@ KernelSignature TemporalShiftOpArgumentMapping( KernelSignature TemporalShiftGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("temporal_shift_grad", - {GradVarName("Out")}, + {"Out@GRAD"}, {"seg_num", "shift_ratio", "data_format"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/tile_sig.cc b/paddle/phi/ops/compat/tile_sig.cc index ca3fa5fe1f8..be401e40c49 100644 --- a/paddle/phi/ops/compat/tile_sig.cc +++ b/paddle/phi/ops/compat/tile_sig.cc @@ -33,20 +33,14 @@ KernelSignature TileOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature TileGradOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.HasInput("RepeatTimes")) { - return KernelSignature("tile_grad", - {"X", GradVarName("Out")}, - {"RepeatTimes"}, - {GradVarName("X")}); + return KernelSignature( + "tile_grad", {"X", "Out@GRAD"}, {"RepeatTimes"}, {"X@GRAD"}); } else if (ctx.InputSize("repeat_times_tensor") > 0) { - return KernelSignature("tile_grad", - {"X", GradVarName("Out")}, - {"repeat_times_tensor"}, - {GradVarName("X")}); + return KernelSignature( + "tile_grad", {"X", "Out@GRAD"}, {"repeat_times_tensor"}, {"X@GRAD"}); } else { - return KernelSignature("tile_grad", - {"X", GradVarName("Out")}, - {"repeat_times"}, - {GradVarName("X")}); + return KernelSignature( + "tile_grad", {"X", "Out@GRAD"}, {"repeat_times"}, {"X@GRAD"}); } } diff --git a/paddle/phi/ops/compat/top_k_sig.cc b/paddle/phi/ops/compat/top_k_sig.cc index 8488a18e34c..c1073f9efdc 100644 --- a/paddle/phi/ops/compat/top_k_sig.cc +++ b/paddle/phi/ops/compat/top_k_sig.cc @@ -29,9 +29,9 @@ KernelSignature TopkOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature TopkGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("top_k_grad", - {"X", "Indices", GradVarName("Out")}, + {"X", "Indices", "Out@GRAD"}, {"k", "axis", "largest", "sorted"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/trace_sig.cc b/paddle/phi/ops/compat/trace_sig.cc index c3f5d6d2875..2cb7d9a80bc 100644 --- a/paddle/phi/ops/compat/trace_sig.cc +++ b/paddle/phi/ops/compat/trace_sig.cc @@ -23,9 +23,9 @@ KernelSignature TraceOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature TraceGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("trace_grad", - {"Input", GradVarName("Out")}, + {"Input", "Out@GRAD"}, {"offset", "axis1", "axis2"}, - {GradVarName("Input")}); + {"Input@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/transpose_sig.cc b/paddle/phi/ops/compat/transpose_sig.cc index 90961760cfc..0f2a3108ec9 100644 --- a/paddle/phi/ops/compat/transpose_sig.cc +++ b/paddle/phi/ops/compat/transpose_sig.cc @@ -22,8 +22,7 @@ KernelSignature TransposeOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature TransposeGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "transpose_grad", {GradVarName("Out")}, {"axis"}, {GradVarName("X")}); + return KernelSignature("transpose_grad", {"Out@GRAD"}, {"axis"}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/triangular_solve_sig.cc b/paddle/phi/ops/compat/triangular_solve_sig.cc index c56af3e21e5..851db32a032 100644 --- a/paddle/phi/ops/compat/triangular_solve_sig.cc +++ b/paddle/phi/ops/compat/triangular_solve_sig.cc @@ -19,9 +19,9 @@ namespace phi { KernelSignature TriangularSolveGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("triangular_solve_grad", - {"X", "Y", "Out", GradVarName("Out")}, + {"X", "Y", "Out", "Out@GRAD"}, {"upper", "transpose", "unitriangular"}, - {GradVarName("X"), GradVarName("Y")}); + {"X@GRAD", "Y@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/tril_triu_sig.cc b/paddle/phi/ops/compat/tril_triu_sig.cc index 4f79f8650de..3c5fa15b41c 100644 --- a/paddle/phi/ops/compat/tril_triu_sig.cc +++ b/paddle/phi/ops/compat/tril_triu_sig.cc @@ -22,10 +22,8 @@ KernelSignature TrilTriuOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature TrilTriuGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("tril_triu_grad", - {GradVarName("Out")}, - {"diagonal", "lower"}, - {GradVarName("X")}); + return KernelSignature( + "tril_triu_grad", {"Out@GRAD"}, {"diagonal", "lower"}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/trunc_sig.cc b/paddle/phi/ops/compat/trunc_sig.cc index 2d35439216d..7b6a7771fbe 100644 --- a/paddle/phi/ops/compat/trunc_sig.cc +++ b/paddle/phi/ops/compat/trunc_sig.cc @@ -21,8 +21,7 @@ KernelSignature TruncOpArgumentMapping(const ArgumentMappingContext& ctx) { } KernelSignature TruncGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "trunc_grad", {GradVarName("Out")}, {}, {GradVarName("X")}); + return KernelSignature("trunc_grad", {"Out@GRAD"}, {}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/unfold_sig.cc b/paddle/phi/ops/compat/unfold_sig.cc index ddc3b1813cb..45415616f29 100644 --- a/paddle/phi/ops/compat/unfold_sig.cc +++ b/paddle/phi/ops/compat/unfold_sig.cc @@ -18,9 +18,9 @@ namespace phi { KernelSignature UnfoldGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("unfold_grad", - {"X", GradVarName("Y")}, + {"X", "Y@GRAD"}, {"kernel_sizes", "strides", "paddings", "dilations"}, - {GradVarName("X")}); + {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/unsqueeze_sig.cc b/paddle/phi/ops/compat/unsqueeze_sig.cc index 20cd9701e83..c802c2684b2 100644 --- a/paddle/phi/ops/compat/unsqueeze_sig.cc +++ b/paddle/phi/ops/compat/unsqueeze_sig.cc @@ -35,7 +35,7 @@ KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature UnsqueezeGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( - "unsqueeze_grad", {"XShape", GradVarName("Out")}, {}, {GradVarName("X")}); + "unsqueeze_grad", {"XShape", "Out@GRAD"}, {}, {"X@GRAD"}); } } // namespace phi PD_REGISTER_BASE_KERNEL_NAME(unsqueeze2, unsqueeze); diff --git a/paddle/phi/ops/compat/unstack_sig.cc b/paddle/phi/ops/compat/unstack_sig.cc index 41d7fc120a9..d03499f94b6 100644 --- a/paddle/phi/ops/compat/unstack_sig.cc +++ b/paddle/phi/ops/compat/unstack_sig.cc @@ -15,8 +15,7 @@ namespace phi { KernelSignature UnStackGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "unstack_grad", {GradVarName("Y")}, {"axis"}, {GradVarName("X")}); + return KernelSignature("unstack_grad", {"Y@GRAD"}, {"axis"}, {"X@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/warpctc_sig.cc b/paddle/phi/ops/compat/warpctc_sig.cc index 75f440de7f2..ac3dc366ad8 100644 --- a/paddle/phi/ops/compat/warpctc_sig.cc +++ b/paddle/phi/ops/compat/warpctc_sig.cc @@ -25,11 +25,10 @@ KernelSignature WarpctcOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature WarpctcGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature( - "warpctc_grad", - {"WarpCTCGrad", "Logits", GradVarName("Loss"), "LogitsLength"}, - {"blank", "norm_by_times"}, - {GradVarName("Logits")}); + return KernelSignature("warpctc_grad", + {"WarpCTCGrad", "Logits", "Loss@GRAD", "LogitsLength"}, + {"blank", "norm_by_times"}, + {"Logits@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/where_grad_sig.cc b/paddle/phi/ops/compat/where_grad_sig.cc index 71984a26d35..e0c380672c8 100644 --- a/paddle/phi/ops/compat/where_grad_sig.cc +++ b/paddle/phi/ops/compat/where_grad_sig.cc @@ -18,9 +18,9 @@ namespace phi { KernelSignature WhereGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("where_grad", - {"Condition", "X", "Y", GradVarName("Out")}, + {"Condition", "X", "Y", "Out@GRAD"}, {}, - {GradVarName("X"), GradVarName("Y")}); + {"X@GRAD", "Y@GRAD"}); } } // namespace phi diff --git a/paddle/phi/ops/compat/yolov3_loss_sig.cc b/paddle/phi/ops/compat/yolov3_loss_sig.cc index bbdadfa93ba..8d5d82a9e72 100644 --- a/paddle/phi/ops/compat/yolov3_loss_sig.cc +++ b/paddle/phi/ops/compat/yolov3_loss_sig.cc @@ -31,25 +31,23 @@ KernelSignature Yolov3LossOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature Yolov3LossGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("yolov3_loss_grad", - {"X", - "GTBox", - "GTLabel", - "GTScore", - GradVarName("Loss"), - "ObjectnessMask", - "GTMatchMask"}, - {"anchors", - "anchor_mask", - "class_num", - "ignore_thresh", - "downsample_ratio", - "use_label_smooth", - "scale_x_y"}, - {GradVarName("X"), - GradVarName("GTBox"), - GradVarName("GTLabel"), - GradVarName("GTScore")}); + return KernelSignature( + "yolov3_loss_grad", + {"X", + "GTBox", + "GTLabel", + "GTScore", + "Loss@GRAD", + "ObjectnessMask", + "GTMatchMask"}, + {"anchors", + "anchor_mask", + "class_num", + "ignore_thresh", + "downsample_ratio", + "use_label_smooth", + "scale_x_y"}, + {"X@GRAD", "GTBox@GRAD", "GTLabel@GRAD", "GTScore@GRAD"}); } } // namespace phi diff --git a/paddle/phi/tests/core/test_meta_fn_utils.cc b/paddle/phi/tests/core/test_meta_fn_utils.cc index c90e2f3dbcd..028b9d23352 100644 --- a/paddle/phi/tests/core/test_meta_fn_utils.cc +++ b/paddle/phi/tests/core/test_meta_fn_utils.cc @@ -46,9 +46,9 @@ TEST(MetaFnFactory, InferMetaFnExists) { phi::MetaTensor meta_out(&dense_out1); phi::UnchangedInferMeta(meta_x, &meta_out); - auto shared_meat_x = std::make_shared(&dense_x); + auto shared_meat_x = phi::MetaTensor(&dense_x); phi::DenseTensor dense_out2; - auto shared_meta_out = std::make_shared(&dense_out2); + auto shared_meta_out = phi::MetaTensor(&dense_out2); phi::InferMetaContext ctx; ctx.EmplaceBackInput(shared_meat_x); ctx.EmplaceBackOutput(shared_meta_out); @@ -69,9 +69,9 @@ TEST(MetaFnFactory, CopyInferMetaFn) { phi::MetaTensor meta_out(&dense_out1); phi::UnchangedInferMeta(meta_x, &meta_out); - auto shared_meat_x = std::make_shared(&dense_x); + auto shared_meat_x = phi::MetaTensor(&dense_x); phi::DenseTensor dense_out2; - auto shared_meta_out = std::make_shared(&dense_out2); + auto shared_meta_out = phi::MetaTensor(&dense_out2); phi::InferMetaContext ctx; ctx.EmplaceBackInput(shared_meat_x); @@ -90,13 +90,13 @@ TEST(MetaFnFactory, SplitInferMetaFn) { phi::DenseTensor dense_x; dense_x.Resize({4, 10}); phi::MetaTensor meta_x(&dense_x); - auto shared_meat_x = std::make_shared(&dense_x); + auto shared_meat_x = phi::MetaTensor(&dense_x); phi::DenseTensor dense_out1; phi::DenseTensor dense_out2; - paddle::SmallVector> out; - out.push_back(std::make_shared(&dense_out1)); - out.push_back(std::make_shared(&dense_out2)); + paddle::SmallVector out; + out.emplace_back(phi::MetaTensor(&dense_out1)); + out.emplace_back(phi::MetaTensor(&dense_out2)); phi::InferMetaContext ctx; ctx.EmplaceBackInput(shared_meat_x); diff --git a/paddle/testing/CMakeLists.txt b/paddle/testing/CMakeLists.txt index 0cc68bf3161..2c977e923b5 100644 --- a/paddle/testing/CMakeLists.txt +++ b/paddle/testing/CMakeLists.txt @@ -1,5 +1,5 @@ # for paddle test case if(WITH_TESTING) - cc_library(paddle_gtest_main SRCS paddle_gtest_main.cc DEPS init device_context memory gtest gflags proto_desc) + cc_library(paddle_gtest_main SRCS paddle_gtest_main.cc DEPS init device_context memory gtest gflags proto_desc phi_utils) endif() diff --git a/paddle/testing/paddle_gtest_main.cc b/paddle/testing/paddle_gtest_main.cc index 0fb5412ff05..bb919f0e911 100644 --- a/paddle/testing/paddle_gtest_main.cc +++ b/paddle/testing/paddle_gtest_main.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "gflags/gflags.h" #include "gtest/gtest.h" +#include "paddle/fluid/framework/phi_utils.h" #include "paddle/fluid/memory/allocation/allocator_strategy.h" #include "paddle/fluid/platform/device/npu/npu_info.h" #include "paddle/fluid/platform/flags.h" @@ -85,6 +86,7 @@ int main(int argc, char** argv) { ::GFLAGS_NAMESPACE::ParseCommandLineFlags( &new_argc, &new_argv_address, false); paddle::framework::InitDevices(); + paddle::framework::InitDefaultKernelSignatureMap(); int ret = RUN_ALL_TESTS(); diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index fb9e8d8ece1..13b964274fd 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -212,6 +212,7 @@ def __bootstrap__(): core.init_glog(sys.argv[0]) # don't init_p2p when in unittest to save time. core.init_devices() + core.init_default_kernel_signatures() # TODO(panyx0718): Avoid doing complex initialization logic in __init__.py. diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 4779750b5b4..a4c6fac836c 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -1332,8 +1332,11 @@ - api : meshgrid args : (Tensor[] inputs) - output : Tensor[] - invoke : meshgrid_impl(inputs) + output : Tensor[]{inputs.size()} + infer_meta : + func : MeshgridInferMeta + kernel : + func : meshgrid backward : meshgrid_grad - api : min @@ -2103,8 +2106,11 @@ - api : unbind args : (Tensor input, int axis) - output : Tensor[] - invoke : unbind_impl(input, axis) + output : Tensor[] {axis<0 ? input.dims()[input.dims().size()+axis]:input.dims()[axis]} + infer_meta : + func : UnbindInferMeta + kernel : + func : unbind backward : unbind_grad # unfold diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index e8d067483d8..378ead7ff20 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -31,6 +31,7 @@ class BaseAPI(object): # outputs: # names : [], list of output names # types : [], list of output types + # out_size_expr : [], expression for getting size of vector # return_type : Tensor, vector, ..., the return type of api # args_str: # args_declare : "str" // str of function params with default value. Example: (..., bool flag=false) @@ -67,11 +68,12 @@ class BaseAPI(object): ] inputs, attrs, args_str = self.parse_input_and_attr( api_name, api_item_yaml['args'], optional_vars) - output_type_list, output_names, return_type = self.parse_output( + output_type_list, output_names, out_size_expr, return_type = self.parse_output( api_name, api_item_yaml['output']) return inputs, attrs, { 'names': output_names, 'types': output_type_list, + 'out_size_expr': out_size_expr, 'return_type': return_type }, args_str, optional_vars @@ -184,39 +186,36 @@ class BaseAPI(object): 'Tensor': 'Tensor', 'Tensor[]': 'std::vector' } - if re.search(r'\([a-zA-Z0-9_@]*\)', output_item): - result = re.search( - r"(?P[a-zA-Z0-9_[\]]+)\s*\((?P[a-zA-Z0-9_@]+)\)", - output_item) - out_type = result.group('out_type') - assert out_type in output_type_map, \ - f"{api_name} : Output type error: the output type only support Tensor and Tensor[], \ - but now is {out_type}." - - return output_type_map[out_type], result.group('name') - - else: - if output_item.strip() in output_type_map: - return output_type_map[output_item.strip()], 'out' - else: - raise ValueError( - "{} : Output type error: the output type only support Tensor and Tensor[], \ - but now is {}.".format(api_name, output_item.strip())) + result = re.search( + r"(?P[a-zA-Z0-9_[\]]+)\s*(?P\([a-zA-Z0-9_@]+\))?\s*(?P\{[^\}]+\})?", + output_item) + assert result is not None, f"{api_name} : the output config parse error." + out_type = result.group('out_type') + assert out_type in output_type_map, \ + f"{api_name} : Output type error: the output type only support Tensor and Tensor[], \ + but now is {out_type}." + + out_name = 'out' if result.group('name') is None else result.group( + 'name')[1:-1] + out_size_expr = None if result.group( + 'expr') is None else result.group('expr')[1:-1] + return output_type_map[out_type], out_name, out_size_expr temp_list = output_config.split(',') if len(temp_list) == 1: - out_type, out_name = parse_output_item(temp_list[0]) - return [out_type], [out_name], self.get_return_type([out_type]) + out_type, out_name, size_expr = parse_output_item(temp_list[0]) + return [out_type], [out_name], size_expr, self.get_return_type( + [out_type]) else: out_type_list = [] out_name_list = [] for output_item in temp_list: - out_type, out_name = parse_output_item(output_item) + out_type, out_name, size_expr = parse_output_item(output_item) out_type_list.append(out_type) out_name_list.append(out_name) - return out_type_list, out_name_list, self.get_return_type( + return out_type_list, out_name_list, size_expr, self.get_return_type( out_type_list) def parse_infer_meta(self, infer_meta_config): @@ -462,9 +461,8 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self attr_names = self.attrs['names'] infer_meta = self.infer_meta - infer_meta_params = infer_meta[ - 'param'] + kernel_output_names if infer_meta[ - 'param'] is not None else input_names + attr_names + kernel_output_names + infer_meta_params = infer_meta['param'] if infer_meta[ + 'param'] is not None else input_names + attr_names # generate meta tensors meta_tensor_code = "" param_code = "" @@ -476,7 +474,7 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self param] == "const std::vector&": meta_tensor_code = meta_tensor_code + f""" {code_indent} auto {param}_meta_vec = MakeMetaTensor({PREFIX_TENSOR_NAME}{param}); -{code_indent} std::vector {param}_metas({param}_meta_vec.size()); +{code_indent} std::vector {param}_metas({param}_meta_vec.size()); {code_indent} for (size_t i = 0; i < {param}_meta_vec.size(); ++i) {{ {code_indent} {param}_metas[i] = &{param}_meta_vec[i]; {code_indent} }} @@ -500,11 +498,6 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self raise ValueError( f"{self.api} : Param of infer_meta error : {self.inputs['input_info'][param]} type is not supported." ) - elif param in kernel_output_names: - meta_tensor_code = meta_tensor_code + code_indent + " phi::MetaTensor " + param.replace( - 'kernel_', PREFIX_META_TENSOR_NAME) + "(" + param + ");\n" - param_code = param_code + "&" + param.replace( - 'kernel_', PREFIX_META_TENSOR_NAME) + ", " elif param in attr_names: param_code = param_code + param + ", " elif isinstance(param, str): @@ -514,6 +507,23 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self else: param_code = param_code + str(param) + ", " + for i, out_name in enumerate(kernel_output_names): + if self.outputs['types'][i] == 'std::vector': + meta_tensor_code = meta_tensor_code + f""" +{code_indent} auto {out_name}_{PREFIX_META_TENSOR_NAME}vec = MakeMetaTensor({out_name}); +{code_indent} std::vector {out_name}_metas({out_name}_{PREFIX_META_TENSOR_NAME}vec.size()); +{code_indent} for (size_t i = 0; i < {out_name}_{PREFIX_META_TENSOR_NAME}vec.size(); ++i) {{ +{code_indent} {out_name}_metas[i] = &{out_name}_{PREFIX_META_TENSOR_NAME}vec[i]; +{code_indent} }}""" + + param_code = param_code + out_name + '_metas, ' + else: + meta_tensor_code = meta_tensor_code + code_indent + " phi::MetaTensor " + out_name.replace( + 'kernel_', + PREFIX_META_TENSOR_NAME) + "(" + out_name + ");\n" + param_code = param_code + "&" + out_name.replace( + 'kernel_', PREFIX_META_TENSOR_NAME) + ", " + param_code = param_code[:-2] return f"""{meta_tensor_code} {code_indent} phi::{infer_meta['func']}({param_code}); diff --git a/python/paddle/utils/code_gen/api_gen.py b/python/paddle/utils/code_gen/api_gen.py index 4087b55b513..538958c2361 100644 --- a/python/paddle/utils/code_gen/api_gen.py +++ b/python/paddle/utils/code_gen/api_gen.py @@ -91,7 +91,16 @@ class ForwardAPI(BaseAPI): 0]] if inplace_flag and self.inplace_map is not None and self.outputs[ 'names'][0] in self.inplace_map else "" output_create = f""" -{code_indent} {self.outputs['return_type']} api_output{inplace_assign}; +{code_indent} {self.outputs['return_type']} api_output{inplace_assign};""" + + if self.outputs['return_type'] == 'std::vector': + assert self.outputs['out_size_expr'] is not None, \ + f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." + output_create = output_create + f""" +{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr']}, kernel_backend, &api_output);""" + + else: + output_create = output_create + f""" {code_indent} auto kernel_out = {set_out_func}(kernel_backend, &api_output);""" if not inplace_flag and self.view_map is not None and self.outputs[ @@ -113,7 +122,14 @@ class ForwardAPI(BaseAPI): output_create = output_create + f""" {code_indent} std::get<{i}>(api_output) = {self.inplace_map[self.outputs['names'][i]]};""" - output_create = output_create + f""" + if output_type_list[i] == 'std::vector': + assert self.outputs['out_size_expr'][i] is not None, \ + f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." + output_create = output_create + f""" +{code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, kernel_backend, &std::get<{i}>(api_output));""" + + else: + output_create = output_create + f""" {code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, &std::get<{i}>(api_output));""" if not inplace_flag and self.view_map is not None and self.outputs[ diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 55efb61e73f..59ad29db61f 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -57,7 +57,7 @@ - backward_api : add_n_grad forward : add_n (Tensor[] x) -> Tensor(out) args : (Tensor[] x, Tensor out_grad) - output : Tensor[](x_grad) + output : Tensor[](x_grad){x.size()} invoke : add_n_grad_impl(x, out_grad) no_need_buffer : x @@ -238,8 +238,12 @@ - backward_api : concat_grad forward : concat (Tensor[] x, Scalar axis) -> Tensor(out) args : (Tensor[] x, Tensor out_grad, Scalar axis = 0) - output : Tensor[](x_grad) - invoke : concat_grad_impl(x, out_grad, axis) + output : Tensor[](x_grad){x.size()} + infer_meta : + func : UnchangedMultiInferMeta + param : [x] + kernel : + func : concat_grad no_need_buffer : x - backward_api : conj_grad @@ -1018,8 +1022,11 @@ - backward_api : meshgrid_grad forward : meshgrid (Tensor[] inputs) -> Tensor[](outputs) args : (Tensor[] inputs, Tensor[] outputs_grad) - output : Tensor[](inputs_grad) - invoke : meshgrid_grad_impl(inputs, outputs_grad) + output : Tensor[](inputs_grad){inputs.size()} + infer_meta : + func : MeshgridGradInferMeta + kernel : + func : meshgrid_grad - backward_api : min_grad forward: min (Tensor x, int64_t[] dims={}, bool keep_dim=false) -> Tensor(out) @@ -1075,14 +1082,22 @@ - backward_api : multi_dot_grad forward : multi_dot (Tensor[] x) -> Tensor(out) args : (Tensor[] x, Tensor out_grad) - output : Tensor[](x_grad) - invoke : multi_dot_grad_impl(x, out_grad) + output : Tensor[](x_grad) {x.size()} + infer_meta : + func : MultiDotGradInferMeta + kernel : + func : multi_dot_grad - backward_api : multiplex_grad forward : multiplex (Tensor[] ins, Tensor ids) -> Tensor(out) args : (Tensor[] ins, Tensor ids, Tensor out_grad) - output : Tensor[](ins_grad) - invoke : multiplex_grad_impl(ins, ids, out_grad) + output : Tensor[](ins_grad){ins.size()} + infer_meta : + func : MultiplexGradInferMeta + param : [ids, out_grad] + kernel : + func : multiplex_grad + param : [ids, out_grad] - backward_api : multiply_double_grad forward : multiply_grad (Tensor x, Tensor y, Tensor grad_out, int axis = -1) -> Tensor(grad_x), Tensor(grad_y) @@ -1581,8 +1596,13 @@ - backward_api : stack_grad forward : stack (Tensor[] x, int axis) -> Tensor(out) args : (Tensor[] x, Tensor out_grad, int axis) - output : Tensor[](x_grad) - invoke : stack_grad_impl(x, out_grad, axis) + output : Tensor[](x_grad){x.size()} + infer_meta : + func : StackGradInferMeta + param: [out_grad, axis] + kernel : + func : stack_grad + param : [out_grad, axis] no_need_buffer : x - backward_api : strided_slice_grad diff --git a/python/paddle/utils/code_gen/backward_api_gen.py b/python/paddle/utils/code_gen/backward_api_gen.py index 46aa3e7e23d..a88339c607c 100644 --- a/python/paddle/utils/code_gen/backward_api_gen.py +++ b/python/paddle/utils/code_gen/backward_api_gen.py @@ -35,7 +35,7 @@ class BackwardAPI(BaseAPI): r"(?P[a-z][a-z0-9_]+)\s*(?P\([^\)]+\))\s*->\s*(?P.+)", forward_config) api = result.group('api') - _, outputs, _ = self.parse_output(self.api, result.group('outputs')) + _, outputs, _, _ = self.parse_output(self.api, result.group('outputs')) outputs = [item.split('@')[0] for item in outputs] fw_inputs, fw_attrs, _, = self.parse_input_and_attr( api, result.group('args')) @@ -110,7 +110,16 @@ class BackwardAPI(BaseAPI): 0]] if inplace_flag and self.inplace_map is not None and self.outputs[ 'names'][0] in self.inplace_map else "" output_create = f""" -{code_indent} {self.outputs['return_type']} api_output{inplace_assign}; +{code_indent} {self.outputs['return_type']} api_output{inplace_assign};""" + + if output_type_list[0] == 'std::vector': + assert self.outputs['out_size_expr'] is not None, \ + f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." + output_create = output_create + f""" +{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr']}, kernel_backend, &api_output);""" + + else: + output_create = output_create + f""" {code_indent} auto kernel_out = {set_out_func}(kernel_backend, &api_output);""" elif len(output_type_list) > 1: @@ -121,7 +130,6 @@ class BackwardAPI(BaseAPI): kernel_output = kernel_output + f'kernel_out_{i}, ' output_names.append(f'kernel_out_{i}') if out_type_item == 'Tensor': - get_out_code = f'&api_output[{i}][0]' if inplace_flag and self.inplace_map is not None and self.outputs[ 'names'][i] in self.inplace_map: output_create = output_create + f""" @@ -131,6 +139,9 @@ class BackwardAPI(BaseAPI): output_create = output_create + f""" {code_indent} api_output[{i}].emplace_back();""" + output_create = output_create + f""" +{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, &api_output[{i}][0]);""" + else: get_out_code = f'&api_output[{i}]' if inplace_flag and self.inplace_map is not None and self.outputs[ @@ -138,8 +149,10 @@ class BackwardAPI(BaseAPI): output_create = output_create + f""" {code_indent} api_output[{i}] = {self.inplace_map[self.outputs['names'][i]]};""" - output_create = output_create + f""" -{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, {get_out_code});""" + assert self.outputs['out_size_expr'][i] is not None, \ + f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." + output_create = output_create + f""" +{code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, kernel_backend, &api_output[{i}]);""" kernel_output = kernel_output[:-2] else: -- GitLab