未验证 提交 c56fffb4 编写于 作者: Z zyfncg 提交者: GitHub

optimize performance of dygraph (#42137)

上级 79ac8870
......@@ -402,12 +402,11 @@ std::vector<phi::MetaTensor*> CompatInferMetaContext::MutableOutputBetween(
CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type) {
// 1. get kernel args
auto arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type);
PADDLE_ENFORCE_NOT_NULL(
arg_map_fn, platform::errors::NotFound(
"The ArgumentMappingFn of %s op is not found.", op_type));
auto* arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type);
InferShapeArgumentMappingContext arg_map_context(*ctx);
auto signature = arg_map_fn(arg_map_context);
KernelSignature signature =
arg_map_fn ? (*arg_map_fn)(arg_map_context)
: phi::DefaultKernelSignatureMap::Instance().Get(op_type);
VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature;
// 2. build infermeta context
......
......@@ -2117,8 +2117,16 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
const ExecutionContext& ctx) const {
ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
if (arg_map_fn_ == nullptr) {
arg_map_fn_.reset(new phi::ArgumentMappingFn(
phi::OpUtilsMap::Instance().GetArgumentMappingFn(Type())));
auto* arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(type_);
if (arg_map_fn) {
arg_map_fn_.reset(new phi::ArgumentMappingFn(*arg_map_fn));
} else {
auto func =
[this](const phi::ArgumentMappingContext& ctx) -> KernelSignature {
return phi::DefaultKernelSignatureMap::Instance().Get(type_);
};
arg_map_fn_.reset(new phi::ArgumentMappingFn(func));
}
}
return (*arg_map_fn_)(arg_mapping_ctx);
}
......
......@@ -37,6 +37,8 @@ namespace paddle {
namespace imperative {
static const phi::Kernel empty_kernel;
static const framework::RuntimeContext empty_ctx({}, {});
static const framework::Scope empty_scope;
const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<paddle::imperative::VarBase>& var) {
......@@ -138,8 +140,6 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
framework::RuntimeContext ctx({}, {});
#ifdef PADDLE_WITH_MKLDNN
// MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and
// GetKernelType functions, so we need to copy the attributes there.
......@@ -158,7 +158,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
// 1. get expected kernel key
auto dygraph_exe_ctx = DygraphExecutionContext<VarType>(
op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs, default_attrs);
op, empty_scope, *dev_ctx, empty_ctx, ins, outs, attrs, default_attrs);
auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx);
framework::KernelSignature pt_kernel_signature;
......@@ -172,11 +172,26 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
paddle::platform::is_in_xpu_black_list(op.Type());
#endif
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
pt_kernel_signature =
std::move(op.GetExpectedPhiKernelArgs(dygraph_exe_ctx));
VLOG(6) << pt_kernel_signature;
bool has_phi_kernel = false;
const auto* arg_map_fn =
phi::OpUtilsMap::Instance().GetArgumentMappingFn(op.Type());
if (arg_map_fn) {
has_phi_kernel = true;
pt_kernel_signature = (*arg_map_fn)(
framework::ExecutionArgumentMappingContext(dygraph_exe_ctx));
} else {
const auto* kernel_sig =
phi::DefaultKernelSignatureMap::Instance().GetNullable(op.Type());
if (kernel_sig) {
has_phi_kernel = true;
pt_kernel_signature = *kernel_sig;
}
}
if (has_phi_kernel) {
VLOG(6) << pt_kernel_signature;
pt_kernel_name = pt_kernel_signature.name;
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the
......@@ -231,7 +246,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
dev_ctx = pool.Get(expected_kernel_key.place_);
}
return PreparedOp(op, ctx, expected_kernel_key,
return PreparedOp(op, empty_ctx, expected_kernel_key,
std::move(pt_kernel_signature), pt_kernel, dev_ctx);
} else {
VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name
......@@ -280,7 +295,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
<< " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << pt_cpu_kernel;
auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace());
return PreparedOp(op, ctx, expected_kernel_key,
return PreparedOp(op, empty_ctx, expected_kernel_key,
std::move(pt_kernel_signature), pt_cpu_kernel,
cpu_ctx);
}
......@@ -373,7 +388,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
dev_ctx = pool.Get(expected_kernel_key.place_);
}
return PreparedOp(op, ctx, expected_kernel_key, kernel_iter->second, dev_ctx);
return PreparedOp(op, empty_ctx, expected_kernel_key, kernel_iter->second,
dev_ctx);
}
PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
......
......@@ -193,7 +193,7 @@ void PhiOpConvertPass::convertStage() {
op->replaceAllUsesWith(kernel_op.getResults());
} else {
::phi::KernelSignature kernel_sign =
::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)(
(*::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name))(
infrt::ProtoArgumentMappingContext(op));
VLOG(3) << "IncompatiblePhiKernel: op(" << op_name << "), kernel("
<< kernel_sign.name << ")";
......
......@@ -86,6 +86,14 @@ class DefaultKernelSignatureMap {
return it->second;
}
const KernelSignature* GetNullable(const std::string& op_type) const {
auto it = map_.find(op_type);
if (it != map_.end()) {
return &it->second;
}
return nullptr;
}
void Insert(std::string op_type, KernelSignature signature) {
PADDLE_ENFORCE_NE(
Has(op_type),
......@@ -148,16 +156,13 @@ class OpUtilsMap {
}
}
ArgumentMappingFn GetArgumentMappingFn(const std::string& op_type) const {
const ArgumentMappingFn* GetArgumentMappingFn(
const std::string& op_type) const {
auto it = arg_mapping_fn_map_.find(op_type);
if (it == arg_mapping_fn_map_.end()) {
auto func =
[&op_type](const ArgumentMappingContext& ctx) -> KernelSignature {
return DefaultKernelSignatureMap::Instance().Get(op_type);
};
return func;
return nullptr;
} else {
return it->second;
return &it->second;
}
}
......
......@@ -30,8 +30,8 @@ namespace tests {
TEST(ARG_MAP, fill_constant) {
TestArgumentMappingContext arg_case1(
{"ShapeTensor", "ValueTensor"}, {}, {}, {}, {"Out"});
auto signature1 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case1);
auto signature1 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
"fill_constant"))(arg_case1);
ASSERT_EQ(signature1.name, "full_sr");
TestArgumentMappingContext arg_case2(
......@@ -40,8 +40,8 @@ TEST(ARG_MAP, fill_constant) {
{{"str_value", paddle::any{std::string{"10"}}}},
{},
{"Out"});
auto signature2 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case2);
auto signature2 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
"fill_constant"))(arg_case2);
ASSERT_EQ(signature2.name, "full_sr");
TestArgumentMappingContext arg_case3(
......@@ -50,14 +50,14 @@ TEST(ARG_MAP, fill_constant) {
{{"value", paddle::any{0}}, {"str_value", paddle::any{std::string{""}}}},
{},
{"Out"});
auto signature3 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case3);
auto signature3 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
"fill_constant"))(arg_case3);
ASSERT_EQ(signature3.name, "full_sr");
TestArgumentMappingContext arg_case4(
{"ShapeTensorList", "ValueTensor"}, {}, {}, {}, {"Out"});
auto signature4 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case4);
auto signature4 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
"fill_constant"))(arg_case4);
ASSERT_EQ(signature4.name, "full_sr");
TestArgumentMappingContext arg_case5(
......@@ -66,8 +66,8 @@ TEST(ARG_MAP, fill_constant) {
{{"str_value", paddle::any{std::string{"10"}}}},
{},
{"Out"});
auto signature5 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case5);
auto signature5 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
"fill_constant"))(arg_case5);
ASSERT_EQ(signature5.name, "full_sr");
TestArgumentMappingContext arg_case6(
......@@ -76,8 +76,8 @@ TEST(ARG_MAP, fill_constant) {
{{"value", paddle::any{0}}, {"str_value", paddle::any{std::string{""}}}},
{},
{"Out"});
auto signature6 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case6);
auto signature6 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
"fill_constant"))(arg_case6);
ASSERT_EQ(signature6.name, "full_sr");
TestArgumentMappingContext arg_case7(
......@@ -86,8 +86,8 @@ TEST(ARG_MAP, fill_constant) {
{{"shape", paddle::any{std::vector<int64_t>{2, 3}}}},
{},
{"Out"});
auto signature7 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case7);
auto signature7 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
"fill_constant"))(arg_case7);
ASSERT_EQ(signature7.name, "full_sr");
TestArgumentMappingContext arg_case8(
......@@ -98,8 +98,8 @@ TEST(ARG_MAP, fill_constant) {
{"str_value", paddle::any{std::string{""}}}},
{},
{"Out"});
auto signature8 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case8);
auto signature8 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
"fill_constant"))(arg_case8);
ASSERT_EQ(signature8.name, "full_sr");
TestArgumentMappingContext arg_case9(
......@@ -109,8 +109,8 @@ TEST(ARG_MAP, fill_constant) {
{"str_value", paddle::any{std::string{"10"}}}},
{},
{"Out"});
auto signature9 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case9);
auto signature9 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
"fill_constant"))(arg_case9);
ASSERT_EQ(signature9.name, "full_sr");
}
......@@ -122,7 +122,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case)
.name,
"set_value");
TestArgumentMappingContext arg_case1(
......@@ -132,7 +133,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case1).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case1)
.name,
"set_value");
TestArgumentMappingContext arg_case2(
......@@ -142,7 +144,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case2).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case2)
.name,
"set_value");
TestArgumentMappingContext arg_case3(
......@@ -152,7 +155,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case3).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case3)
.name,
"set_value");
TestArgumentMappingContext arg_case4(
......@@ -162,7 +166,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case4).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case4)
.name,
"set_value");
TestArgumentMappingContext arg_case5(
......@@ -172,7 +177,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case5).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case5)
.name,
"set_value_with_tensor");
TestArgumentMappingContext arg_case6(
......@@ -182,7 +188,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case6).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case6)
.name,
"set_value");
TestArgumentMappingContext arg_case7(
......@@ -192,7 +199,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case7).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case7)
.name,
"set_value");
TestArgumentMappingContext arg_case8(
......@@ -202,7 +210,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case8).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case8)
.name,
"set_value");
TestArgumentMappingContext arg_case9(
......@@ -212,7 +221,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case9).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case9)
.name,
"set_value");
TestArgumentMappingContext arg_case10(
......@@ -222,7 +232,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case10).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case10)
.name,
"set_value_with_tensor");
TestArgumentMappingContext arg_case11(
......@@ -232,7 +243,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case11).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case11)
.name,
"set_value");
TestArgumentMappingContext arg_case12(
......@@ -242,7 +254,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case12).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case12)
.name,
"set_value");
TestArgumentMappingContext arg_case13(
......@@ -252,7 +265,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case13).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case13)
.name,
"set_value");
TestArgumentMappingContext arg_case14(
......@@ -262,13 +276,15 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case14).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case14)
.name,
"set_value");
TestArgumentMappingContext arg_case15(
{"Input", "StartsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case15).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case15)
.name,
"set_value_with_tensor");
TestArgumentMappingContext arg_case16(
......@@ -278,7 +294,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case16).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case16)
.name,
"set_value");
TestArgumentMappingContext arg_case17(
......@@ -288,7 +305,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case17).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case17)
.name,
"set_value");
TestArgumentMappingContext arg_case18(
......@@ -298,7 +316,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case18).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case18)
.name,
"set_value");
TestArgumentMappingContext arg_case19(
......@@ -308,7 +327,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case19).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case19)
.name,
"set_value");
TestArgumentMappingContext arg_case20(
......@@ -318,7 +338,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case20).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case20)
.name,
"set_value");
TestArgumentMappingContext arg_case21(
......@@ -328,7 +349,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case21).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case21)
.name,
"set_value_with_tensor");
TestArgumentMappingContext arg_case22(
......@@ -338,7 +360,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case22).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case22)
.name,
"set_value");
TestArgumentMappingContext arg_case23(
......@@ -348,7 +371,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case23).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case23)
.name,
"set_value");
TestArgumentMappingContext arg_case24(
......@@ -358,7 +382,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case24).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case24)
.name,
"set_value");
TestArgumentMappingContext arg_case25(
......@@ -368,13 +393,15 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case25).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case25)
.name,
"set_value");
TestArgumentMappingContext arg_case26(
{"Input", "EndsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case26).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case26)
.name,
"set_value_with_tensor");
TestArgumentMappingContext arg_case27(
......@@ -384,7 +411,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case27).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case27)
.name,
"set_value");
TestArgumentMappingContext arg_case28(
......@@ -394,7 +422,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case28).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case28)
.name,
"set_value");
TestArgumentMappingContext arg_case29(
......@@ -404,7 +433,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case29).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case29)
.name,
"set_value");
TestArgumentMappingContext arg_case30(
......@@ -414,7 +444,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case30).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case30)
.name,
"set_value");
TestArgumentMappingContext arg_case31(
......@@ -424,13 +455,15 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case31).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case31)
.name,
"set_value");
TestArgumentMappingContext arg_case32(
{"Input", "StepsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case32).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case32)
.name,
"set_value_with_tensor");
TestArgumentMappingContext arg_case33(
......@@ -440,7 +473,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case33).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case33)
.name,
"set_value");
TestArgumentMappingContext arg_case34(
......@@ -450,7 +484,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case34).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case34)
.name,
"set_value");
TestArgumentMappingContext arg_case35(
......@@ -460,7 +495,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case35).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case35)
.name,
"set_value");
TestArgumentMappingContext arg_case36(
......@@ -470,7 +506,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case36).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case36)
.name,
"set_value");
TestArgumentMappingContext arg_case37(
......@@ -480,7 +517,8 @@ TEST(ARG_MAP, set_value) {
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case37).name,
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case37)
.name,
"set_value");
}
......@@ -491,8 +529,8 @@ TEST(ARG_MAP, set_value_grad) {
{},
{"Input@GRAD", "ValueTensor@GRAD"},
{});
ASSERT_EQ(OpUtilsMap::Instance()
.GetArgumentMappingFn("set_value_grad")(arg_case)
ASSERT_EQ(
(*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(arg_case)
.name,
"set_value_grad");
......@@ -502,8 +540,8 @@ TEST(ARG_MAP, set_value_grad) {
{},
{"Input@GRAD", "ValueTensor@GRAD"},
{});
ASSERT_EQ(OpUtilsMap::Instance()
.GetArgumentMappingFn("set_value_grad")(arg_case1)
ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(
arg_case1)
.name,
"set_value_grad");
......@@ -512,8 +550,8 @@ TEST(ARG_MAP, set_value_grad) {
{},
{"Input@GRAD", "ValueTensor@GRAD"},
{});
ASSERT_EQ(OpUtilsMap::Instance()
.GetArgumentMappingFn("set_value_grad")(arg_case2)
ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(
arg_case2)
.name,
"set_value_grad");
......@@ -523,8 +561,8 @@ TEST(ARG_MAP, set_value_grad) {
{},
{"Input@GRAD", "ValueTensor@GRAD"},
{});
ASSERT_EQ(OpUtilsMap::Instance()
.GetArgumentMappingFn("set_value_grad")(arg_case3)
ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(
arg_case3)
.name,
"set_value_grad");
......@@ -533,8 +571,8 @@ TEST(ARG_MAP, set_value_grad) {
{},
{"Input@GRAD", "ValueTensor@GRAD"},
{});
ASSERT_EQ(OpUtilsMap::Instance()
.GetArgumentMappingFn("set_value_grad")(arg_case4)
ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(
arg_case4)
.name,
"set_value_grad");
......@@ -543,8 +581,8 @@ TEST(ARG_MAP, set_value_grad) {
{},
{"Input@GRAD", "ValueTensor@GRAD"},
{});
ASSERT_EQ(OpUtilsMap::Instance()
.GetArgumentMappingFn("set_value_grad")(arg_case5)
ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(
arg_case5)
.name,
"set_value_grad");
}
......@@ -558,7 +596,7 @@ TEST(ARG_MAP, allclose) {
{"Out"},
{});
auto signature1 =
OpUtilsMap::Instance().GetArgumentMappingFn("allclose")(arg_case1);
(*OpUtilsMap::Instance().GetArgumentMappingFn("allclose"))(arg_case1);
ASSERT_EQ(signature1.name, "allclose");
ASSERT_EQ(signature1.attr_names[0], "Rtol");
......@@ -570,7 +608,7 @@ TEST(ARG_MAP, allclose) {
{"Out"},
{});
auto signature2 =
OpUtilsMap::Instance().GetArgumentMappingFn("allclose")(arg_case2);
(*OpUtilsMap::Instance().GetArgumentMappingFn("allclose"))(arg_case2);
ASSERT_EQ(signature2.name, "allclose");
ASSERT_EQ(signature2.attr_names[1], "Atol");
}
......@@ -578,18 +616,18 @@ TEST(ARG_MAP, allclose) {
TEST(ARG_MAP, reshape) {
TestArgumentMappingContext arg_case1({"X", "ShapeTensor"}, {}, {}, {"Out"});
auto signature1 =
OpUtilsMap::Instance().GetArgumentMappingFn("reshape2")(arg_case1);
(*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case1);
ASSERT_EQ(signature1.name, "reshape");
TestArgumentMappingContext arg_case2({"X", "Shape"}, {}, {}, {"Out"});
auto signature2 =
OpUtilsMap::Instance().GetArgumentMappingFn("reshape2")(arg_case2);
(*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case2);
ASSERT_EQ(signature2.name, "reshape");
TestArgumentMappingContext arg_case3(
{"X"}, {}, {{"shape", paddle::any(std::vector<int>({1, 2}))}}, {"Out"});
auto signature3 =
OpUtilsMap::Instance().GetArgumentMappingFn("reshape2")(arg_case3);
(*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case3);
ASSERT_EQ(signature3.name, "reshape");
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册