From 0d537003eb0681f010d08fd227f34ab952d10463 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 25 Apr 2022 11:05:37 +0800 Subject: [PATCH] [cherry-pick] Optimize performance of dygraph (#42093, #42103, #42137) (#42171) * optimiaze performance of PreparePhiData (#42093) * Dygraph performance optimization (v2) (#42103) * optimiaze performance of PreparePhiData * dygraph performance optimization * optimize performance of dygraph (#42137) --- paddle/fluid/framework/infershape_utils.cc | 15 +- paddle/fluid/framework/operator.cc | 34 ++- paddle/fluid/imperative/prepared_operator.cc | 36 +++- paddle/fluid/imperative/prepared_operator.h | 21 +- paddle/fluid/pybind/imperative.cc | 6 +- .../pybind/kernel_signature_generator.cc | 8 +- .../dialect/phi/pass/phi_op_convert_pass.cc | 6 +- paddle/phi/core/compat/arg_map_context.cc | 6 +- paddle/phi/core/compat/arg_map_context.h | 18 +- paddle/phi/core/compat/op_utils.h | 19 +- paddle/phi/tests/ops/test_op_signature.cc | 194 +++++++++++------- 11 files changed, 224 insertions(+), 139 deletions(-) diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index bd71ade7e9..6deebe93dc 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -402,21 +402,20 @@ std::vector CompatInferMetaContext::MutableOutputBetween( CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, const std::string& op_type) { // 1. get kernel args - auto arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type); - PADDLE_ENFORCE_NOT_NULL( - arg_map_fn, platform::errors::NotFound( - "The ArgumentMappingFn of %s op is not found.", op_type)); + auto* arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type); InferShapeArgumentMappingContext arg_map_context(*ctx); - auto signature = arg_map_fn(arg_map_context); + KernelSignature signature = + arg_map_fn ? (*arg_map_fn)(arg_map_context) + : phi::DefaultKernelSignatureMap::Instance().Get(op_type); VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature; // 2. build infermeta context CompatInferMetaContext infer_meta_context( {ctx->IsRuntime(), ctx->IsRunMKLDNNKernel()}); - auto& input_names = std::get<0>(signature.args); - auto& attr_names = std::get<1>(signature.args); - auto& output_names = std::get<2>(signature.args); + const auto& input_names = signature.input_names; + const auto& attr_names = signature.attr_names; + const auto& output_names = signature.output_names; const auto& args_def = phi::KernelFactory::Instance().GetFirstKernelArgsDef(signature.name); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 0291309aa0..1dd47873c0 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1200,8 +1200,10 @@ bool OperatorWithKernel::SupportsMKLDNN( bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, proto::VarType::Type data_type) const { - bool use_mkldnn_ctx = ctx.HasAttr("use_mkldnn") && - ctx.Attr("use_mkldnn") && + const auto& attrs_map = ctx.Attrs(); + auto iter = attrs_map.find("use_mkldnn"); + bool use_mkldnn_ctx = iter != attrs_map.end() && + BOOST_GET_CONST(bool, iter->second) && platform::is_cpu_place(ctx.GetPlace()); return use_mkldnn_ctx && this->SupportsMKLDNN(data_type); } @@ -2117,8 +2119,16 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( const ExecutionContext& ctx) const { ExecutionArgumentMappingContext arg_mapping_ctx(ctx); if (arg_map_fn_ == nullptr) { - arg_map_fn_.reset(new phi::ArgumentMappingFn( - phi::OpUtilsMap::Instance().GetArgumentMappingFn(Type()))); + auto* arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(type_); + if (arg_map_fn) { + arg_map_fn_.reset(new phi::ArgumentMappingFn(*arg_map_fn)); + } else { + auto func = + [this](const phi::ArgumentMappingContext& ctx) -> KernelSignature { + return phi::DefaultKernelSignatureMap::Instance().Get(type_); + }; + arg_map_fn_.reset(new phi::ArgumentMappingFn(func)); + } } return (*arg_map_fn_)(arg_mapping_ctx); } @@ -2126,7 +2136,7 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( Scope* OperatorWithKernel::PreparePhiData( const Scope& scope, const phi::Kernel& pt_kernel, const KernelSignature& pt_kernel_signature, RuntimeContext* ctx) const { - auto& input_names = std::get<0>(pt_kernel_signature.args); + const auto& input_names = pt_kernel_signature.input_names; auto input_defs = pt_kernel.args_def().input_defs(); PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(), platform::errors::InvalidArgument( @@ -2178,11 +2188,15 @@ Scope* OperatorWithKernel::PreparePhiData( if (in_def.backend == phi::Backend::ALL_BACKEND) { continue; } - auto expected_place = phi::TransToPhiPlace(in_def.backend); - if (platform::is_same_place(tensor_in->place(), expected_place)) { + + auto tensor_backend = phi::TransToPhiBackend(tensor_in->place()); + if (in_def.backend == tensor_backend || + (in_def.backend == phi::Backend::GPUDNN && + tensor_backend == phi::Backend::GPU)) { continue; } + auto expected_place = phi::TransToPhiPlace(in_def.backend); VLOG(3) << "phi Transform Variable " << input_names[i] << " from " << tensor_in->place() << " to " << expected_place; @@ -2219,9 +2233,9 @@ void OperatorWithKernel::BuildPhiKernelContext( phi::KernelContext* pt_kernel_context) const { pt_kernel_context->SetDeviceContext(dev_ctx); - auto& input_names = std::get<0>(pt_kernel_signature_->args); - auto& attr_names = std::get<1>(pt_kernel_signature_->args); - auto& output_names = std::get<2>(pt_kernel_signature_->args); + auto& input_names = pt_kernel_signature_->input_names; + auto& attr_names = pt_kernel_signature_->attr_names; + auto& output_names = pt_kernel_signature_->output_names; auto input_defs = pt_kernel_->args_def().input_defs(); auto attr_defs = pt_kernel_->args_def().attribute_defs(); diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index cef7417ea4..fdeda8aa97 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -37,6 +37,8 @@ namespace paddle { namespace imperative { static const phi::Kernel empty_kernel; +static const framework::RuntimeContext empty_ctx({}, {}); +static const framework::Scope empty_scope; const std::shared_ptr& GetVariableWrapper( const std::shared_ptr& var) { @@ -138,8 +140,6 @@ PreparedOp PrepareImpl(const NameVarMap& ins, platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); - framework::RuntimeContext ctx({}, {}); - #ifdef PADDLE_WITH_MKLDNN // MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and // GetKernelType functions, so we need to copy the attributes there. @@ -158,7 +158,7 @@ PreparedOp PrepareImpl(const NameVarMap& ins, // 1. get expected kernel key auto dygraph_exe_ctx = DygraphExecutionContext( - op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs, default_attrs); + op, empty_scope, *dev_ctx, empty_ctx, ins, outs, attrs, default_attrs); auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx); framework::KernelSignature pt_kernel_signature; @@ -172,11 +172,26 @@ PreparedOp PrepareImpl(const NameVarMap& ins, paddle::platform::is_in_xpu_black_list(op.Type()); #endif - if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) { - pt_kernel_signature = - std::move(op.GetExpectedPhiKernelArgs(dygraph_exe_ctx)); - VLOG(6) << pt_kernel_signature; + bool has_phi_kernel = false; + + const auto* arg_map_fn = + phi::OpUtilsMap::Instance().GetArgumentMappingFn(op.Type()); + if (arg_map_fn) { + has_phi_kernel = true; + pt_kernel_signature = (*arg_map_fn)( + framework::ExecutionArgumentMappingContext(dygraph_exe_ctx)); + } else { + const auto* kernel_sig = + phi::DefaultKernelSignatureMap::Instance().GetNullable(op.Type()); + if (kernel_sig) { + has_phi_kernel = true; + pt_kernel_signature = *kernel_sig; + } + } + + if (has_phi_kernel) { + VLOG(6) << pt_kernel_signature; pt_kernel_name = pt_kernel_signature.name; // NOTE(Liu-xiandong): The register kernel used KP have library_type[KP], // But the default library_type is Plain, so we need to modify the @@ -231,7 +246,7 @@ PreparedOp PrepareImpl(const NameVarMap& ins, dev_ctx = pool.Get(expected_kernel_key.place_); } - return PreparedOp(op, ctx, expected_kernel_key, + return PreparedOp(op, empty_ctx, expected_kernel_key, std::move(pt_kernel_signature), pt_kernel, dev_ctx); } else { VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name @@ -280,7 +295,7 @@ PreparedOp PrepareImpl(const NameVarMap& ins, << " | kernel key: " << pt_cpu_kernel_key << " | kernel: " << pt_cpu_kernel; auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace()); - return PreparedOp(op, ctx, expected_kernel_key, + return PreparedOp(op, empty_ctx, expected_kernel_key, std::move(pt_kernel_signature), pt_cpu_kernel, cpu_ctx); } @@ -373,7 +388,8 @@ PreparedOp PrepareImpl(const NameVarMap& ins, dev_ctx = pool.Get(expected_kernel_key.place_); } - return PreparedOp(op, ctx, expected_kernel_key, kernel_iter->second, dev_ctx); + return PreparedOp(op, empty_ctx, expected_kernel_key, kernel_iter->second, + dev_ctx); } PreparedOp PreparedOp::Prepare(const NameVarMap& ins, diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index b3c5a6b5fa..754b553bd1 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -233,9 +233,9 @@ void BuildDygraphPhiKernelContext( platform::DeviceContext* dev_ctx, phi::KernelContext* kernel_ctx) { kernel_ctx->SetDeviceContext(dev_ctx); - auto& input_names = std::get<0>(pt_kernel_signature.args); - auto& attr_names = std::get<1>(pt_kernel_signature.args); - auto& output_names = std::get<2>(pt_kernel_signature.args); + const auto& input_names = pt_kernel_signature.input_names; + const auto& attr_names = pt_kernel_signature.attr_names; + const auto& output_names = pt_kernel_signature.output_names; auto& input_defs = pt_kernel.args_def().input_defs(); auto& output_defs = pt_kernel.args_def().output_defs(); @@ -570,7 +570,7 @@ template void PreparePhiData(const phi::Kernel& pt_kernel, const framework::KernelSignature& pt_kernel_signature, const NameVarMap& ins) { - auto& input_names = std::get<0>(pt_kernel_signature.args); + const auto& input_names = pt_kernel_signature.input_names; auto& input_defs = pt_kernel.args_def().input_defs(); PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(), @@ -581,10 +581,11 @@ void PreparePhiData(const phi::Kernel& pt_kernel, for (size_t i = 0; i < input_names.size(); ++i) { auto& in_def = input_defs.at(i); - if (ins.find(input_names[i]) == ins.end()) { + auto iter = ins.find(input_names[i]); + if (iter == ins.end()) { continue; } - auto& ins_vector = ins.at(input_names[i]); + auto& ins_vector = iter->second; for (size_t offset = 0; offset < ins_vector.size(); ++offset) { auto& var = ins_vector[offset]; @@ -593,11 +594,15 @@ void PreparePhiData(const phi::Kernel& pt_kernel, if (in_def.backend == phi::Backend::ALL_BACKEND) { continue; } - auto expected_place = phi::TransToPhiPlace(in_def.backend); - if (platform::is_same_place(tensor_in->place(), expected_place)) { + auto tensor_backend = phi::TransToPhiBackend(tensor_in->place()); + if (in_def.backend == tensor_backend || + (in_def.backend == phi::Backend::GPUDNN && + tensor_backend == phi::Backend::GPU)) { continue; } + auto expected_place = phi::TransToPhiPlace(in_def.backend); + VLOG(3) << "Phi Transform Variable " << input_names[i] << " from " << tensor_in->place() << " to " << expected_place; diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 93f10b34b6..6e5c095d69 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -2050,9 +2050,9 @@ void BindImperative(py::module *m_ptr) { }; auto ret = self.GetExpectedKernelSignature(type, ins_map, outs_map, attrs); - auto kernelsig_ins = input_to_vector(std::get<0>(ret.args)); - auto kernelsig_attrs = attr_to_vector(std::get<1>(ret.args)); - auto kernelsig_outs = output_to_vector(std::get<2>(ret.args)); + auto kernelsig_ins = input_to_vector(ret.input_names); + auto kernelsig_attrs = attr_to_vector(ret.attr_names); + auto kernelsig_outs = output_to_vector(ret.output_names); return std::make_tuple(kernelsig_ins, kernelsig_attrs, kernelsig_outs); } diff --git a/paddle/fluid/pybind/kernel_signature_generator.cc b/paddle/fluid/pybind/kernel_signature_generator.cc index 1520174fba..0b0a8628b1 100644 --- a/paddle/fluid/pybind/kernel_signature_generator.cc +++ b/paddle/fluid/pybind/kernel_signature_generator.cc @@ -58,10 +58,10 @@ int main(int argc, char **argv) { if (kernel_signature_map.Has(op_name)) { kernel_signature_map_str = kernel_signature_map_str + "\"" + op_kernel_pair.first + "\":{"; - auto &args = kernel_signature_map.Get(op_name).args; + const auto &args = kernel_signature_map.Get(op_name); kernel_signature_map_str += "\"inputs\":["; - auto inputs_ = std::get<0>(args); + auto inputs_ = args.input_names; for (size_t i = 0; i < inputs_.size(); i++) { kernel_signature_map_str = kernel_signature_map_str + "\"" + inputs_[i] + "\","; @@ -69,14 +69,14 @@ int main(int argc, char **argv) { if (inputs_.size()) kernel_signature_map_str.pop_back(); kernel_signature_map_str += "],\"attrs\":["; - auto attrs_ = std::get<1>(args); + auto attrs_ = args.attr_names; for (size_t i = 0; i < attrs_.size(); i++) { kernel_signature_map_str = kernel_signature_map_str + "\"" + attrs_[i] + "\","; } if (attrs_.size()) kernel_signature_map_str.pop_back(); kernel_signature_map_str += "],\"outputs\":["; - auto outputs_ = std::get<2>(args); + auto outputs_ = args.output_names; for (size_t i = 0; i < outputs_.size(); i++) { kernel_signature_map_str = kernel_signature_map_str + "\"" + outputs_[i] + "\","; diff --git a/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc b/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc index e3fdd5ae5b..d1f5f37593 100644 --- a/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc +++ b/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc @@ -108,14 +108,14 @@ void PhiOpConvertPass::convertStage() { op->replaceAllUsesWith(kernel_op.getResults()); } else { ::phi::KernelSignature kernel_sign = - ::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)( + (*::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name))( infrt::ProtoArgumentMappingContext(op)); VLOG(3) << "IncompatiblePhiKernel: op(" << op_name << "), kernel(" << kernel_sign.name << ")"; // resort input&output according to kernel_sign ::llvm::SmallVector inputs, ori_output; ::llvm::SmallVector output_types; - for (const std::string &str : std::get<0>(kernel_sign.args)) { + for (const std::string &str : kernel_sign.input_names) { if (pd_dialect_inputs_info_map_.at(op_name).count(str) == 0) { LOG(ERROR) << "No input info for Op " << op_name << " and argument " << str; @@ -125,7 +125,7 @@ void PhiOpConvertPass::convertStage() { inputs.push_back(op->getOperands()[index]); } - for (const std::string &str : std::get<2>(kernel_sign.args)) { + for (const std::string &str : kernel_sign.output_names) { if (pd_dialect_outputs_info_map_.at(op_name).count(str) == 0) { LOG(ERROR) << "No output info for Op " << op_name << " and argument " << str; diff --git a/paddle/phi/core/compat/arg_map_context.cc b/paddle/phi/core/compat/arg_map_context.cc index 6f678966ba..800245406a 100644 --- a/paddle/phi/core/compat/arg_map_context.cc +++ b/paddle/phi/core/compat/arg_map_context.cc @@ -20,11 +20,11 @@ limitations under the License. */ namespace phi { std::ostream& operator<<(std::ostream& os, KernelSignature signature) { os << "Kernel Signature - name: " << signature.name << "; inputs: " - << paddle::string::join_strings(std::get<0>(signature.args), ", ") + << paddle::string::join_strings(signature.input_names, ", ") << "; attributes: " - << paddle::string::join_strings(std::get<1>(signature.args), ", ") + << paddle::string::join_strings(signature.attr_names, ", ") << "; outputs: " - << paddle::string::join_strings(std::get<2>(signature.args), ", "); + << paddle::string::join_strings(signature.output_names, ", "); return os; } diff --git a/paddle/phi/core/compat/arg_map_context.h b/paddle/phi/core/compat/arg_map_context.h index 122ebed219..102dca48b9 100644 --- a/paddle/phi/core/compat/arg_map_context.h +++ b/paddle/phi/core/compat/arg_map_context.h @@ -33,7 +33,9 @@ using KernelArgsTuple = std::tuple, struct KernelSignature { const char* name; - KernelArgsTuple args; + paddle::SmallVector input_names; + paddle::SmallVector attr_names; + paddle::SmallVector output_names; KernelSignature() = default; @@ -41,18 +43,26 @@ struct KernelSignature { paddle::SmallVector&& inputs, paddle::SmallVector&& attrs, paddle::SmallVector&& outputs) - : name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {} + : name(kernel_name), + input_names(std::move(inputs)), + attr_names(std::move(attrs)), + output_names(std::move(outputs)) {} KernelSignature(const char* kernel_name, const paddle::SmallVector& inputs, const paddle::SmallVector& attrs, const paddle::SmallVector& outputs) - : name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {} + : name(kernel_name), + input_names(inputs), + attr_names(attrs), + output_names(outputs) {} // TODO(chenweihang): add assign constructor to solve windows compile // problem, remove it later KernelSignature& operator=(const KernelSignature& other) { name = other.name; - args = other.args; + input_names = other.input_names; + attr_names = other.attr_names; + output_names = other.output_names; return *this; } }; diff --git a/paddle/phi/core/compat/op_utils.h b/paddle/phi/core/compat/op_utils.h index 9c926fa871..bd19d403c9 100644 --- a/paddle/phi/core/compat/op_utils.h +++ b/paddle/phi/core/compat/op_utils.h @@ -86,6 +86,14 @@ class DefaultKernelSignatureMap { return it->second; } + const KernelSignature* GetNullable(const std::string& op_type) const { + auto it = map_.find(op_type); + if (it != map_.end()) { + return &it->second; + } + return nullptr; + } + void Insert(std::string op_type, KernelSignature signature) { PADDLE_ENFORCE_NE( Has(op_type), @@ -148,16 +156,13 @@ class OpUtilsMap { } } - ArgumentMappingFn GetArgumentMappingFn(const std::string& op_type) const { + const ArgumentMappingFn* GetArgumentMappingFn( + const std::string& op_type) const { auto it = arg_mapping_fn_map_.find(op_type); if (it == arg_mapping_fn_map_.end()) { - auto func = - [&op_type](const ArgumentMappingContext& ctx) -> KernelSignature { - return DefaultKernelSignatureMap::Instance().Get(op_type); - }; - return func; + return nullptr; } else { - return it->second; + return &it->second; } } diff --git a/paddle/phi/tests/ops/test_op_signature.cc b/paddle/phi/tests/ops/test_op_signature.cc index 6acf3916a1..4379dfd7cc 100644 --- a/paddle/phi/tests/ops/test_op_signature.cc +++ b/paddle/phi/tests/ops/test_op_signature.cc @@ -30,8 +30,8 @@ namespace tests { TEST(ARG_MAP, fill_constant) { TestArgumentMappingContext arg_case1( {"ShapeTensor", "ValueTensor"}, {}, {}, {}, {"Out"}); - auto signature1 = - OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case1); + auto signature1 = (*OpUtilsMap::Instance().GetArgumentMappingFn( + "fill_constant"))(arg_case1); ASSERT_EQ(signature1.name, "full_sr"); TestArgumentMappingContext arg_case2( @@ -40,8 +40,8 @@ TEST(ARG_MAP, fill_constant) { {{"str_value", paddle::any{std::string{"10"}}}}, {}, {"Out"}); - auto signature2 = - OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case2); + auto signature2 = (*OpUtilsMap::Instance().GetArgumentMappingFn( + "fill_constant"))(arg_case2); ASSERT_EQ(signature2.name, "full_sr"); TestArgumentMappingContext arg_case3( @@ -50,14 +50,14 @@ TEST(ARG_MAP, fill_constant) { {{"value", paddle::any{0}}, {"str_value", paddle::any{std::string{""}}}}, {}, {"Out"}); - auto signature3 = - OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case3); + auto signature3 = (*OpUtilsMap::Instance().GetArgumentMappingFn( + "fill_constant"))(arg_case3); ASSERT_EQ(signature3.name, "full_sr"); TestArgumentMappingContext arg_case4( {"ShapeTensorList", "ValueTensor"}, {}, {}, {}, {"Out"}); - auto signature4 = - OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case4); + auto signature4 = (*OpUtilsMap::Instance().GetArgumentMappingFn( + "fill_constant"))(arg_case4); ASSERT_EQ(signature4.name, "full_sr"); TestArgumentMappingContext arg_case5( @@ -66,8 +66,8 @@ TEST(ARG_MAP, fill_constant) { {{"str_value", paddle::any{std::string{"10"}}}}, {}, {"Out"}); - auto signature5 = - OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case5); + auto signature5 = (*OpUtilsMap::Instance().GetArgumentMappingFn( + "fill_constant"))(arg_case5); ASSERT_EQ(signature5.name, "full_sr"); TestArgumentMappingContext arg_case6( @@ -76,8 +76,8 @@ TEST(ARG_MAP, fill_constant) { {{"value", paddle::any{0}}, {"str_value", paddle::any{std::string{""}}}}, {}, {"Out"}); - auto signature6 = - OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case6); + auto signature6 = (*OpUtilsMap::Instance().GetArgumentMappingFn( + "fill_constant"))(arg_case6); ASSERT_EQ(signature6.name, "full_sr"); TestArgumentMappingContext arg_case7( @@ -86,8 +86,8 @@ TEST(ARG_MAP, fill_constant) { {{"shape", paddle::any{std::vector{2, 3}}}}, {}, {"Out"}); - auto signature7 = - OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case7); + auto signature7 = (*OpUtilsMap::Instance().GetArgumentMappingFn( + "fill_constant"))(arg_case7); ASSERT_EQ(signature7.name, "full_sr"); TestArgumentMappingContext arg_case8( @@ -98,8 +98,8 @@ TEST(ARG_MAP, fill_constant) { {"str_value", paddle::any{std::string{""}}}}, {}, {"Out"}); - auto signature8 = - OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case8); + auto signature8 = (*OpUtilsMap::Instance().GetArgumentMappingFn( + "fill_constant"))(arg_case8); ASSERT_EQ(signature8.name, "full_sr"); TestArgumentMappingContext arg_case9( @@ -109,8 +109,8 @@ TEST(ARG_MAP, fill_constant) { {"str_value", paddle::any{std::string{"10"}}}}, {}, {"Out"}); - auto signature9 = - OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case9); + auto signature9 = (*OpUtilsMap::Instance().GetArgumentMappingFn( + "fill_constant"))(arg_case9); ASSERT_EQ(signature9.name, "full_sr"); } @@ -122,7 +122,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case) + .name, "set_value"); TestArgumentMappingContext arg_case1( @@ -132,7 +133,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case1).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case1) + .name, "set_value"); TestArgumentMappingContext arg_case2( @@ -142,7 +144,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case2).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case2) + .name, "set_value"); TestArgumentMappingContext arg_case3( @@ -152,7 +155,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case3).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case3) + .name, "set_value"); TestArgumentMappingContext arg_case4( @@ -162,7 +166,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case4).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case4) + .name, "set_value"); TestArgumentMappingContext arg_case5( @@ -172,7 +177,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case5).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case5) + .name, "set_value_with_tensor"); TestArgumentMappingContext arg_case6( @@ -182,7 +188,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case6).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case6) + .name, "set_value"); TestArgumentMappingContext arg_case7( @@ -192,7 +199,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case7).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case7) + .name, "set_value"); TestArgumentMappingContext arg_case8( @@ -202,7 +210,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case8).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case8) + .name, "set_value"); TestArgumentMappingContext arg_case9( @@ -212,7 +221,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case9).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case9) + .name, "set_value"); TestArgumentMappingContext arg_case10( @@ -222,7 +232,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case10).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case10) + .name, "set_value_with_tensor"); TestArgumentMappingContext arg_case11( @@ -232,7 +243,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case11).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case11) + .name, "set_value"); TestArgumentMappingContext arg_case12( @@ -242,7 +254,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case12).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case12) + .name, "set_value"); TestArgumentMappingContext arg_case13( @@ -252,7 +265,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case13).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case13) + .name, "set_value"); TestArgumentMappingContext arg_case14( @@ -262,13 +276,15 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case14).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case14) + .name, "set_value"); TestArgumentMappingContext arg_case15( {"Input", "StartsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case15).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case15) + .name, "set_value_with_tensor"); TestArgumentMappingContext arg_case16( @@ -278,7 +294,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case16).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case16) + .name, "set_value"); TestArgumentMappingContext arg_case17( @@ -288,7 +305,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case17).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case17) + .name, "set_value"); TestArgumentMappingContext arg_case18( @@ -298,7 +316,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case18).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case18) + .name, "set_value"); TestArgumentMappingContext arg_case19( @@ -308,7 +327,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case19).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case19) + .name, "set_value"); TestArgumentMappingContext arg_case20( @@ -318,7 +338,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case20).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case20) + .name, "set_value"); TestArgumentMappingContext arg_case21( @@ -328,7 +349,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case21).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case21) + .name, "set_value_with_tensor"); TestArgumentMappingContext arg_case22( @@ -338,7 +360,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case22).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case22) + .name, "set_value"); TestArgumentMappingContext arg_case23( @@ -348,7 +371,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case23).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case23) + .name, "set_value"); TestArgumentMappingContext arg_case24( @@ -358,7 +382,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case24).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case24) + .name, "set_value"); TestArgumentMappingContext arg_case25( @@ -368,13 +393,15 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case25).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case25) + .name, "set_value"); TestArgumentMappingContext arg_case26( {"Input", "EndsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case26).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case26) + .name, "set_value_with_tensor"); TestArgumentMappingContext arg_case27( @@ -384,7 +411,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case27).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case27) + .name, "set_value"); TestArgumentMappingContext arg_case28( @@ -394,7 +422,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case28).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case28) + .name, "set_value"); TestArgumentMappingContext arg_case29( @@ -404,7 +433,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case29).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case29) + .name, "set_value"); TestArgumentMappingContext arg_case30( @@ -414,7 +444,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case30).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case30) + .name, "set_value"); TestArgumentMappingContext arg_case31( @@ -424,13 +455,15 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case31).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case31) + .name, "set_value"); TestArgumentMappingContext arg_case32( {"Input", "StepsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case32).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case32) + .name, "set_value_with_tensor"); TestArgumentMappingContext arg_case33( @@ -440,7 +473,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case33).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case33) + .name, "set_value"); TestArgumentMappingContext arg_case34( @@ -450,7 +484,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case34).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case34) + .name, "set_value"); TestArgumentMappingContext arg_case35( @@ -460,7 +495,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case35).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case35) + .name, "set_value"); TestArgumentMappingContext arg_case36( @@ -470,7 +506,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case36).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case36) + .name, "set_value"); TestArgumentMappingContext arg_case37( @@ -480,7 +517,8 @@ TEST(ARG_MAP, set_value) { {"Out"}, {}); ASSERT_EQ( - OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case37).name, + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case37) + .name, "set_value"); } @@ -491,10 +529,10 @@ TEST(ARG_MAP, set_value_grad) { {}, {"Input@GRAD", "ValueTensor@GRAD"}, {}); - ASSERT_EQ(OpUtilsMap::Instance() - .GetArgumentMappingFn("set_value_grad")(arg_case) - .name, - "set_value_grad"); + ASSERT_EQ( + (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(arg_case) + .name, + "set_value_grad"); TestArgumentMappingContext arg_case1( {"Out@GRAD", "StartsTensorList", "StepsTensorList"}, @@ -502,8 +540,8 @@ TEST(ARG_MAP, set_value_grad) { {}, {"Input@GRAD", "ValueTensor@GRAD"}, {}); - ASSERT_EQ(OpUtilsMap::Instance() - .GetArgumentMappingFn("set_value_grad")(arg_case1) + ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))( + arg_case1) .name, "set_value_grad"); @@ -512,8 +550,8 @@ TEST(ARG_MAP, set_value_grad) { {}, {"Input@GRAD", "ValueTensor@GRAD"}, {}); - ASSERT_EQ(OpUtilsMap::Instance() - .GetArgumentMappingFn("set_value_grad")(arg_case2) + ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))( + arg_case2) .name, "set_value_grad"); @@ -523,8 +561,8 @@ TEST(ARG_MAP, set_value_grad) { {}, {"Input@GRAD", "ValueTensor@GRAD"}, {}); - ASSERT_EQ(OpUtilsMap::Instance() - .GetArgumentMappingFn("set_value_grad")(arg_case3) + ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))( + arg_case3) .name, "set_value_grad"); @@ -533,8 +571,8 @@ TEST(ARG_MAP, set_value_grad) { {}, {"Input@GRAD", "ValueTensor@GRAD"}, {}); - ASSERT_EQ(OpUtilsMap::Instance() - .GetArgumentMappingFn("set_value_grad")(arg_case4) + ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))( + arg_case4) .name, "set_value_grad"); @@ -543,8 +581,8 @@ TEST(ARG_MAP, set_value_grad) { {}, {"Input@GRAD", "ValueTensor@GRAD"}, {}); - ASSERT_EQ(OpUtilsMap::Instance() - .GetArgumentMappingFn("set_value_grad")(arg_case5) + ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))( + arg_case5) .name, "set_value_grad"); } @@ -558,10 +596,9 @@ TEST(ARG_MAP, allclose) { {"Out"}, {}); auto signature1 = - OpUtilsMap::Instance().GetArgumentMappingFn("allclose")(arg_case1); + (*OpUtilsMap::Instance().GetArgumentMappingFn("allclose"))(arg_case1); ASSERT_EQ(signature1.name, "allclose"); - auto attr_names1 = std::get<1>(signature1.args); - ASSERT_EQ(attr_names1[0], "Rtol"); + ASSERT_EQ(signature1.attr_names[0], "Rtol"); TestArgumentMappingContext arg_case2( {"Input", "Other", "Atol"}, @@ -571,27 +608,26 @@ TEST(ARG_MAP, allclose) { {"Out"}, {}); auto signature2 = - OpUtilsMap::Instance().GetArgumentMappingFn("allclose")(arg_case2); + (*OpUtilsMap::Instance().GetArgumentMappingFn("allclose"))(arg_case2); ASSERT_EQ(signature2.name, "allclose"); - auto attr_names2 = std::get<1>(signature2.args); - ASSERT_EQ(attr_names2[1], "Atol"); + ASSERT_EQ(signature2.attr_names[1], "Atol"); } TEST(ARG_MAP, reshape) { TestArgumentMappingContext arg_case1({"X", "ShapeTensor"}, {}, {}, {"Out"}); auto signature1 = - OpUtilsMap::Instance().GetArgumentMappingFn("reshape2")(arg_case1); + (*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case1); ASSERT_EQ(signature1.name, "reshape"); TestArgumentMappingContext arg_case2({"X", "Shape"}, {}, {}, {"Out"}); auto signature2 = - OpUtilsMap::Instance().GetArgumentMappingFn("reshape2")(arg_case2); + (*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case2); ASSERT_EQ(signature2.name, "reshape"); TestArgumentMappingContext arg_case3( {"X"}, {}, {{"shape", paddle::any(std::vector({1, 2}))}}, {"Out"}); auto signature3 = - OpUtilsMap::Instance().GetArgumentMappingFn("reshape2")(arg_case3); + (*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case3); ASSERT_EQ(signature3.name, "reshape"); } -- GitLab