未验证 提交 0d537003 编写于 作者: Z zyfncg 提交者: GitHub

[cherry-pick] Optimize performance of dygraph (#42093, #42103, #42137) (#42171)

* optimiaze performance of PreparePhiData (#42093)

* Dygraph performance optimization (v2) (#42103)

* optimiaze performance of PreparePhiData

* dygraph performance optimization

* optimize performance of dygraph (#42137)
上级 26167969
...@@ -402,21 +402,20 @@ std::vector<phi::MetaTensor*> CompatInferMetaContext::MutableOutputBetween( ...@@ -402,21 +402,20 @@ std::vector<phi::MetaTensor*> CompatInferMetaContext::MutableOutputBetween(
CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type) { const std::string& op_type) {
// 1. get kernel args // 1. get kernel args
auto arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type); auto* arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type);
PADDLE_ENFORCE_NOT_NULL(
arg_map_fn, platform::errors::NotFound(
"The ArgumentMappingFn of %s op is not found.", op_type));
InferShapeArgumentMappingContext arg_map_context(*ctx); InferShapeArgumentMappingContext arg_map_context(*ctx);
auto signature = arg_map_fn(arg_map_context); KernelSignature signature =
arg_map_fn ? (*arg_map_fn)(arg_map_context)
: phi::DefaultKernelSignatureMap::Instance().Get(op_type);
VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature; VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature;
// 2. build infermeta context // 2. build infermeta context
CompatInferMetaContext infer_meta_context( CompatInferMetaContext infer_meta_context(
{ctx->IsRuntime(), ctx->IsRunMKLDNNKernel()}); {ctx->IsRuntime(), ctx->IsRunMKLDNNKernel()});
auto& input_names = std::get<0>(signature.args); const auto& input_names = signature.input_names;
auto& attr_names = std::get<1>(signature.args); const auto& attr_names = signature.attr_names;
auto& output_names = std::get<2>(signature.args); const auto& output_names = signature.output_names;
const auto& args_def = const auto& args_def =
phi::KernelFactory::Instance().GetFirstKernelArgsDef(signature.name); phi::KernelFactory::Instance().GetFirstKernelArgsDef(signature.name);
......
...@@ -1200,8 +1200,10 @@ bool OperatorWithKernel::SupportsMKLDNN( ...@@ -1200,8 +1200,10 @@ bool OperatorWithKernel::SupportsMKLDNN(
bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const { proto::VarType::Type data_type) const {
bool use_mkldnn_ctx = ctx.HasAttr("use_mkldnn") && const auto& attrs_map = ctx.Attrs();
ctx.Attr<bool>("use_mkldnn") && auto iter = attrs_map.find("use_mkldnn");
bool use_mkldnn_ctx = iter != attrs_map.end() &&
BOOST_GET_CONST(bool, iter->second) &&
platform::is_cpu_place(ctx.GetPlace()); platform::is_cpu_place(ctx.GetPlace());
return use_mkldnn_ctx && this->SupportsMKLDNN(data_type); return use_mkldnn_ctx && this->SupportsMKLDNN(data_type);
} }
...@@ -2117,8 +2119,16 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( ...@@ -2117,8 +2119,16 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
ExecutionArgumentMappingContext arg_mapping_ctx(ctx); ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
if (arg_map_fn_ == nullptr) { if (arg_map_fn_ == nullptr) {
arg_map_fn_.reset(new phi::ArgumentMappingFn( auto* arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(type_);
phi::OpUtilsMap::Instance().GetArgumentMappingFn(Type()))); if (arg_map_fn) {
arg_map_fn_.reset(new phi::ArgumentMappingFn(*arg_map_fn));
} else {
auto func =
[this](const phi::ArgumentMappingContext& ctx) -> KernelSignature {
return phi::DefaultKernelSignatureMap::Instance().Get(type_);
};
arg_map_fn_.reset(new phi::ArgumentMappingFn(func));
}
} }
return (*arg_map_fn_)(arg_mapping_ctx); return (*arg_map_fn_)(arg_mapping_ctx);
} }
...@@ -2126,7 +2136,7 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( ...@@ -2126,7 +2136,7 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
Scope* OperatorWithKernel::PreparePhiData( Scope* OperatorWithKernel::PreparePhiData(
const Scope& scope, const phi::Kernel& pt_kernel, const Scope& scope, const phi::Kernel& pt_kernel,
const KernelSignature& pt_kernel_signature, RuntimeContext* ctx) const { const KernelSignature& pt_kernel_signature, RuntimeContext* ctx) const {
auto& input_names = std::get<0>(pt_kernel_signature.args); const auto& input_names = pt_kernel_signature.input_names;
auto input_defs = pt_kernel.args_def().input_defs(); auto input_defs = pt_kernel.args_def().input_defs();
PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(), PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -2178,11 +2188,15 @@ Scope* OperatorWithKernel::PreparePhiData( ...@@ -2178,11 +2188,15 @@ Scope* OperatorWithKernel::PreparePhiData(
if (in_def.backend == phi::Backend::ALL_BACKEND) { if (in_def.backend == phi::Backend::ALL_BACKEND) {
continue; continue;
} }
auto expected_place = phi::TransToPhiPlace(in_def.backend);
if (platform::is_same_place(tensor_in->place(), expected_place)) { auto tensor_backend = phi::TransToPhiBackend(tensor_in->place());
if (in_def.backend == tensor_backend ||
(in_def.backend == phi::Backend::GPUDNN &&
tensor_backend == phi::Backend::GPU)) {
continue; continue;
} }
auto expected_place = phi::TransToPhiPlace(in_def.backend);
VLOG(3) << "phi Transform Variable " << input_names[i] << " from " VLOG(3) << "phi Transform Variable " << input_names[i] << " from "
<< tensor_in->place() << " to " << expected_place; << tensor_in->place() << " to " << expected_place;
...@@ -2219,9 +2233,9 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2219,9 +2233,9 @@ void OperatorWithKernel::BuildPhiKernelContext(
phi::KernelContext* pt_kernel_context) const { phi::KernelContext* pt_kernel_context) const {
pt_kernel_context->SetDeviceContext(dev_ctx); pt_kernel_context->SetDeviceContext(dev_ctx);
auto& input_names = std::get<0>(pt_kernel_signature_->args); auto& input_names = pt_kernel_signature_->input_names;
auto& attr_names = std::get<1>(pt_kernel_signature_->args); auto& attr_names = pt_kernel_signature_->attr_names;
auto& output_names = std::get<2>(pt_kernel_signature_->args); auto& output_names = pt_kernel_signature_->output_names;
auto input_defs = pt_kernel_->args_def().input_defs(); auto input_defs = pt_kernel_->args_def().input_defs();
auto attr_defs = pt_kernel_->args_def().attribute_defs(); auto attr_defs = pt_kernel_->args_def().attribute_defs();
......
...@@ -37,6 +37,8 @@ namespace paddle { ...@@ -37,6 +37,8 @@ namespace paddle {
namespace imperative { namespace imperative {
static const phi::Kernel empty_kernel; static const phi::Kernel empty_kernel;
static const framework::RuntimeContext empty_ctx({}, {});
static const framework::Scope empty_scope;
const std::shared_ptr<VariableWrapper>& GetVariableWrapper( const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<paddle::imperative::VarBase>& var) { const std::shared_ptr<paddle::imperative::VarBase>& var) {
...@@ -138,8 +140,6 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -138,8 +140,6 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
framework::RuntimeContext ctx({}, {});
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and // MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and
// GetKernelType functions, so we need to copy the attributes there. // GetKernelType functions, so we need to copy the attributes there.
...@@ -158,7 +158,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -158,7 +158,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
// 1. get expected kernel key // 1. get expected kernel key
auto dygraph_exe_ctx = DygraphExecutionContext<VarType>( auto dygraph_exe_ctx = DygraphExecutionContext<VarType>(
op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs, default_attrs); op, empty_scope, *dev_ctx, empty_ctx, ins, outs, attrs, default_attrs);
auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx); auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx);
framework::KernelSignature pt_kernel_signature; framework::KernelSignature pt_kernel_signature;
...@@ -172,11 +172,26 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -172,11 +172,26 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
paddle::platform::is_in_xpu_black_list(op.Type()); paddle::platform::is_in_xpu_black_list(op.Type());
#endif #endif
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
pt_kernel_signature =
std::move(op.GetExpectedPhiKernelArgs(dygraph_exe_ctx));
VLOG(6) << pt_kernel_signature;
bool has_phi_kernel = false;
const auto* arg_map_fn =
phi::OpUtilsMap::Instance().GetArgumentMappingFn(op.Type());
if (arg_map_fn) {
has_phi_kernel = true;
pt_kernel_signature = (*arg_map_fn)(
framework::ExecutionArgumentMappingContext(dygraph_exe_ctx));
} else {
const auto* kernel_sig =
phi::DefaultKernelSignatureMap::Instance().GetNullable(op.Type());
if (kernel_sig) {
has_phi_kernel = true;
pt_kernel_signature = *kernel_sig;
}
}
if (has_phi_kernel) {
VLOG(6) << pt_kernel_signature;
pt_kernel_name = pt_kernel_signature.name; pt_kernel_name = pt_kernel_signature.name;
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP], // NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the // But the default library_type is Plain, so we need to modify the
...@@ -231,7 +246,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -231,7 +246,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
dev_ctx = pool.Get(expected_kernel_key.place_); dev_ctx = pool.Get(expected_kernel_key.place_);
} }
return PreparedOp(op, ctx, expected_kernel_key, return PreparedOp(op, empty_ctx, expected_kernel_key,
std::move(pt_kernel_signature), pt_kernel, dev_ctx); std::move(pt_kernel_signature), pt_kernel, dev_ctx);
} else { } else {
VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name
...@@ -280,7 +295,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -280,7 +295,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
<< " | kernel key: " << pt_cpu_kernel_key << " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << pt_cpu_kernel; << " | kernel: " << pt_cpu_kernel;
auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace()); auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace());
return PreparedOp(op, ctx, expected_kernel_key, return PreparedOp(op, empty_ctx, expected_kernel_key,
std::move(pt_kernel_signature), pt_cpu_kernel, std::move(pt_kernel_signature), pt_cpu_kernel,
cpu_ctx); cpu_ctx);
} }
...@@ -373,7 +388,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -373,7 +388,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
dev_ctx = pool.Get(expected_kernel_key.place_); dev_ctx = pool.Get(expected_kernel_key.place_);
} }
return PreparedOp(op, ctx, expected_kernel_key, kernel_iter->second, dev_ctx); return PreparedOp(op, empty_ctx, expected_kernel_key, kernel_iter->second,
dev_ctx);
} }
PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins, PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
......
...@@ -233,9 +233,9 @@ void BuildDygraphPhiKernelContext( ...@@ -233,9 +233,9 @@ void BuildDygraphPhiKernelContext(
platform::DeviceContext* dev_ctx, phi::KernelContext* kernel_ctx) { platform::DeviceContext* dev_ctx, phi::KernelContext* kernel_ctx) {
kernel_ctx->SetDeviceContext(dev_ctx); kernel_ctx->SetDeviceContext(dev_ctx);
auto& input_names = std::get<0>(pt_kernel_signature.args); const auto& input_names = pt_kernel_signature.input_names;
auto& attr_names = std::get<1>(pt_kernel_signature.args); const auto& attr_names = pt_kernel_signature.attr_names;
auto& output_names = std::get<2>(pt_kernel_signature.args); const auto& output_names = pt_kernel_signature.output_names;
auto& input_defs = pt_kernel.args_def().input_defs(); auto& input_defs = pt_kernel.args_def().input_defs();
auto& output_defs = pt_kernel.args_def().output_defs(); auto& output_defs = pt_kernel.args_def().output_defs();
...@@ -570,7 +570,7 @@ template <typename VarType> ...@@ -570,7 +570,7 @@ template <typename VarType>
void PreparePhiData(const phi::Kernel& pt_kernel, void PreparePhiData(const phi::Kernel& pt_kernel,
const framework::KernelSignature& pt_kernel_signature, const framework::KernelSignature& pt_kernel_signature,
const NameVarMap<VarType>& ins) { const NameVarMap<VarType>& ins) {
auto& input_names = std::get<0>(pt_kernel_signature.args); const auto& input_names = pt_kernel_signature.input_names;
auto& input_defs = pt_kernel.args_def().input_defs(); auto& input_defs = pt_kernel.args_def().input_defs();
PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(), PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
...@@ -581,10 +581,11 @@ void PreparePhiData(const phi::Kernel& pt_kernel, ...@@ -581,10 +581,11 @@ void PreparePhiData(const phi::Kernel& pt_kernel,
for (size_t i = 0; i < input_names.size(); ++i) { for (size_t i = 0; i < input_names.size(); ++i) {
auto& in_def = input_defs.at(i); auto& in_def = input_defs.at(i);
if (ins.find(input_names[i]) == ins.end()) { auto iter = ins.find(input_names[i]);
if (iter == ins.end()) {
continue; continue;
} }
auto& ins_vector = ins.at(input_names[i]); auto& ins_vector = iter->second;
for (size_t offset = 0; offset < ins_vector.size(); ++offset) { for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
auto& var = ins_vector[offset]; auto& var = ins_vector[offset];
...@@ -593,11 +594,15 @@ void PreparePhiData(const phi::Kernel& pt_kernel, ...@@ -593,11 +594,15 @@ void PreparePhiData(const phi::Kernel& pt_kernel,
if (in_def.backend == phi::Backend::ALL_BACKEND) { if (in_def.backend == phi::Backend::ALL_BACKEND) {
continue; continue;
} }
auto expected_place = phi::TransToPhiPlace(in_def.backend); auto tensor_backend = phi::TransToPhiBackend(tensor_in->place());
if (platform::is_same_place(tensor_in->place(), expected_place)) { if (in_def.backend == tensor_backend ||
(in_def.backend == phi::Backend::GPUDNN &&
tensor_backend == phi::Backend::GPU)) {
continue; continue;
} }
auto expected_place = phi::TransToPhiPlace(in_def.backend);
VLOG(3) << "Phi Transform Variable " << input_names[i] << " from " VLOG(3) << "Phi Transform Variable " << input_names[i] << " from "
<< tensor_in->place() << " to " << expected_place; << tensor_in->place() << " to " << expected_place;
......
...@@ -2050,9 +2050,9 @@ void BindImperative(py::module *m_ptr) { ...@@ -2050,9 +2050,9 @@ void BindImperative(py::module *m_ptr) {
}; };
auto ret = self.GetExpectedKernelSignature(type, ins_map, auto ret = self.GetExpectedKernelSignature(type, ins_map,
outs_map, attrs); outs_map, attrs);
auto kernelsig_ins = input_to_vector(std::get<0>(ret.args)); auto kernelsig_ins = input_to_vector(ret.input_names);
auto kernelsig_attrs = attr_to_vector(std::get<1>(ret.args)); auto kernelsig_attrs = attr_to_vector(ret.attr_names);
auto kernelsig_outs = output_to_vector(std::get<2>(ret.args)); auto kernelsig_outs = output_to_vector(ret.output_names);
return std::make_tuple(kernelsig_ins, kernelsig_attrs, return std::make_tuple(kernelsig_ins, kernelsig_attrs,
kernelsig_outs); kernelsig_outs);
} }
......
...@@ -58,10 +58,10 @@ int main(int argc, char **argv) { ...@@ -58,10 +58,10 @@ int main(int argc, char **argv) {
if (kernel_signature_map.Has(op_name)) { if (kernel_signature_map.Has(op_name)) {
kernel_signature_map_str = kernel_signature_map_str =
kernel_signature_map_str + "\"" + op_kernel_pair.first + "\":{"; kernel_signature_map_str + "\"" + op_kernel_pair.first + "\":{";
auto &args = kernel_signature_map.Get(op_name).args; const auto &args = kernel_signature_map.Get(op_name);
kernel_signature_map_str += "\"inputs\":["; kernel_signature_map_str += "\"inputs\":[";
auto inputs_ = std::get<0>(args); auto inputs_ = args.input_names;
for (size_t i = 0; i < inputs_.size(); i++) { for (size_t i = 0; i < inputs_.size(); i++) {
kernel_signature_map_str = kernel_signature_map_str =
kernel_signature_map_str + "\"" + inputs_[i] + "\","; kernel_signature_map_str + "\"" + inputs_[i] + "\",";
...@@ -69,14 +69,14 @@ int main(int argc, char **argv) { ...@@ -69,14 +69,14 @@ int main(int argc, char **argv) {
if (inputs_.size()) kernel_signature_map_str.pop_back(); if (inputs_.size()) kernel_signature_map_str.pop_back();
kernel_signature_map_str += "],\"attrs\":["; kernel_signature_map_str += "],\"attrs\":[";
auto attrs_ = std::get<1>(args); auto attrs_ = args.attr_names;
for (size_t i = 0; i < attrs_.size(); i++) { for (size_t i = 0; i < attrs_.size(); i++) {
kernel_signature_map_str = kernel_signature_map_str =
kernel_signature_map_str + "\"" + attrs_[i] + "\","; kernel_signature_map_str + "\"" + attrs_[i] + "\",";
} }
if (attrs_.size()) kernel_signature_map_str.pop_back(); if (attrs_.size()) kernel_signature_map_str.pop_back();
kernel_signature_map_str += "],\"outputs\":["; kernel_signature_map_str += "],\"outputs\":[";
auto outputs_ = std::get<2>(args); auto outputs_ = args.output_names;
for (size_t i = 0; i < outputs_.size(); i++) { for (size_t i = 0; i < outputs_.size(); i++) {
kernel_signature_map_str = kernel_signature_map_str =
kernel_signature_map_str + "\"" + outputs_[i] + "\","; kernel_signature_map_str + "\"" + outputs_[i] + "\",";
......
...@@ -108,14 +108,14 @@ void PhiOpConvertPass::convertStage() { ...@@ -108,14 +108,14 @@ void PhiOpConvertPass::convertStage() {
op->replaceAllUsesWith(kernel_op.getResults()); op->replaceAllUsesWith(kernel_op.getResults());
} else { } else {
::phi::KernelSignature kernel_sign = ::phi::KernelSignature kernel_sign =
::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)( (*::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name))(
infrt::ProtoArgumentMappingContext(op)); infrt::ProtoArgumentMappingContext(op));
VLOG(3) << "IncompatiblePhiKernel: op(" << op_name << "), kernel(" VLOG(3) << "IncompatiblePhiKernel: op(" << op_name << "), kernel("
<< kernel_sign.name << ")"; << kernel_sign.name << ")";
// resort input&output according to kernel_sign // resort input&output according to kernel_sign
::llvm::SmallVector<mlir::Value, 4> inputs, ori_output; ::llvm::SmallVector<mlir::Value, 4> inputs, ori_output;
::llvm::SmallVector<mlir::Type, 4> output_types; ::llvm::SmallVector<mlir::Type, 4> output_types;
for (const std::string &str : std::get<0>(kernel_sign.args)) { for (const std::string &str : kernel_sign.input_names) {
if (pd_dialect_inputs_info_map_.at(op_name).count(str) == 0) { if (pd_dialect_inputs_info_map_.at(op_name).count(str) == 0) {
LOG(ERROR) << "No input info for Op " << op_name << " and argument " LOG(ERROR) << "No input info for Op " << op_name << " and argument "
<< str; << str;
...@@ -125,7 +125,7 @@ void PhiOpConvertPass::convertStage() { ...@@ -125,7 +125,7 @@ void PhiOpConvertPass::convertStage() {
inputs.push_back(op->getOperands()[index]); inputs.push_back(op->getOperands()[index]);
} }
for (const std::string &str : std::get<2>(kernel_sign.args)) { for (const std::string &str : kernel_sign.output_names) {
if (pd_dialect_outputs_info_map_.at(op_name).count(str) == 0) { if (pd_dialect_outputs_info_map_.at(op_name).count(str) == 0) {
LOG(ERROR) << "No output info for Op " << op_name << " and argument " LOG(ERROR) << "No output info for Op " << op_name << " and argument "
<< str; << str;
......
...@@ -20,11 +20,11 @@ limitations under the License. */ ...@@ -20,11 +20,11 @@ limitations under the License. */
namespace phi { namespace phi {
std::ostream& operator<<(std::ostream& os, KernelSignature signature) { std::ostream& operator<<(std::ostream& os, KernelSignature signature) {
os << "Kernel Signature - name: " << signature.name << "; inputs: " os << "Kernel Signature - name: " << signature.name << "; inputs: "
<< paddle::string::join_strings(std::get<0>(signature.args), ", ") << paddle::string::join_strings(signature.input_names, ", ")
<< "; attributes: " << "; attributes: "
<< paddle::string::join_strings(std::get<1>(signature.args), ", ") << paddle::string::join_strings(signature.attr_names, ", ")
<< "; outputs: " << "; outputs: "
<< paddle::string::join_strings(std::get<2>(signature.args), ", "); << paddle::string::join_strings(signature.output_names, ", ");
return os; return os;
} }
......
...@@ -33,7 +33,9 @@ using KernelArgsTuple = std::tuple<paddle::SmallVector<const char*>, ...@@ -33,7 +33,9 @@ using KernelArgsTuple = std::tuple<paddle::SmallVector<const char*>,
struct KernelSignature { struct KernelSignature {
const char* name; const char* name;
KernelArgsTuple args; paddle::SmallVector<const char*> input_names;
paddle::SmallVector<const char*> attr_names;
paddle::SmallVector<const char*> output_names;
KernelSignature() = default; KernelSignature() = default;
...@@ -41,18 +43,26 @@ struct KernelSignature { ...@@ -41,18 +43,26 @@ struct KernelSignature {
paddle::SmallVector<const char*>&& inputs, paddle::SmallVector<const char*>&& inputs,
paddle::SmallVector<const char*>&& attrs, paddle::SmallVector<const char*>&& attrs,
paddle::SmallVector<const char*>&& outputs) paddle::SmallVector<const char*>&& outputs)
: name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {} : name(kernel_name),
input_names(std::move(inputs)),
attr_names(std::move(attrs)),
output_names(std::move(outputs)) {}
KernelSignature(const char* kernel_name, KernelSignature(const char* kernel_name,
const paddle::SmallVector<const char*>& inputs, const paddle::SmallVector<const char*>& inputs,
const paddle::SmallVector<const char*>& attrs, const paddle::SmallVector<const char*>& attrs,
const paddle::SmallVector<const char*>& outputs) const paddle::SmallVector<const char*>& outputs)
: name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {} : name(kernel_name),
input_names(inputs),
attr_names(attrs),
output_names(outputs) {}
// TODO(chenweihang): add assign constructor to solve windows compile // TODO(chenweihang): add assign constructor to solve windows compile
// problem, remove it later // problem, remove it later
KernelSignature& operator=(const KernelSignature& other) { KernelSignature& operator=(const KernelSignature& other) {
name = other.name; name = other.name;
args = other.args; input_names = other.input_names;
attr_names = other.attr_names;
output_names = other.output_names;
return *this; return *this;
} }
}; };
......
...@@ -86,6 +86,14 @@ class DefaultKernelSignatureMap { ...@@ -86,6 +86,14 @@ class DefaultKernelSignatureMap {
return it->second; return it->second;
} }
const KernelSignature* GetNullable(const std::string& op_type) const {
auto it = map_.find(op_type);
if (it != map_.end()) {
return &it->second;
}
return nullptr;
}
void Insert(std::string op_type, KernelSignature signature) { void Insert(std::string op_type, KernelSignature signature) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
Has(op_type), Has(op_type),
...@@ -148,16 +156,13 @@ class OpUtilsMap { ...@@ -148,16 +156,13 @@ class OpUtilsMap {
} }
} }
ArgumentMappingFn GetArgumentMappingFn(const std::string& op_type) const { const ArgumentMappingFn* GetArgumentMappingFn(
const std::string& op_type) const {
auto it = arg_mapping_fn_map_.find(op_type); auto it = arg_mapping_fn_map_.find(op_type);
if (it == arg_mapping_fn_map_.end()) { if (it == arg_mapping_fn_map_.end()) {
auto func = return nullptr;
[&op_type](const ArgumentMappingContext& ctx) -> KernelSignature {
return DefaultKernelSignatureMap::Instance().Get(op_type);
};
return func;
} else { } else {
return it->second; return &it->second;
} }
} }
......
...@@ -30,8 +30,8 @@ namespace tests { ...@@ -30,8 +30,8 @@ namespace tests {
TEST(ARG_MAP, fill_constant) { TEST(ARG_MAP, fill_constant) {
TestArgumentMappingContext arg_case1( TestArgumentMappingContext arg_case1(
{"ShapeTensor", "ValueTensor"}, {}, {}, {}, {"Out"}); {"ShapeTensor", "ValueTensor"}, {}, {}, {}, {"Out"});
auto signature1 = auto signature1 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case1); "fill_constant"))(arg_case1);
ASSERT_EQ(signature1.name, "full_sr"); ASSERT_EQ(signature1.name, "full_sr");
TestArgumentMappingContext arg_case2( TestArgumentMappingContext arg_case2(
...@@ -40,8 +40,8 @@ TEST(ARG_MAP, fill_constant) { ...@@ -40,8 +40,8 @@ TEST(ARG_MAP, fill_constant) {
{{"str_value", paddle::any{std::string{"10"}}}}, {{"str_value", paddle::any{std::string{"10"}}}},
{}, {},
{"Out"}); {"Out"});
auto signature2 = auto signature2 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case2); "fill_constant"))(arg_case2);
ASSERT_EQ(signature2.name, "full_sr"); ASSERT_EQ(signature2.name, "full_sr");
TestArgumentMappingContext arg_case3( TestArgumentMappingContext arg_case3(
...@@ -50,14 +50,14 @@ TEST(ARG_MAP, fill_constant) { ...@@ -50,14 +50,14 @@ TEST(ARG_MAP, fill_constant) {
{{"value", paddle::any{0}}, {"str_value", paddle::any{std::string{""}}}}, {{"value", paddle::any{0}}, {"str_value", paddle::any{std::string{""}}}},
{}, {},
{"Out"}); {"Out"});
auto signature3 = auto signature3 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case3); "fill_constant"))(arg_case3);
ASSERT_EQ(signature3.name, "full_sr"); ASSERT_EQ(signature3.name, "full_sr");
TestArgumentMappingContext arg_case4( TestArgumentMappingContext arg_case4(
{"ShapeTensorList", "ValueTensor"}, {}, {}, {}, {"Out"}); {"ShapeTensorList", "ValueTensor"}, {}, {}, {}, {"Out"});
auto signature4 = auto signature4 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case4); "fill_constant"))(arg_case4);
ASSERT_EQ(signature4.name, "full_sr"); ASSERT_EQ(signature4.name, "full_sr");
TestArgumentMappingContext arg_case5( TestArgumentMappingContext arg_case5(
...@@ -66,8 +66,8 @@ TEST(ARG_MAP, fill_constant) { ...@@ -66,8 +66,8 @@ TEST(ARG_MAP, fill_constant) {
{{"str_value", paddle::any{std::string{"10"}}}}, {{"str_value", paddle::any{std::string{"10"}}}},
{}, {},
{"Out"}); {"Out"});
auto signature5 = auto signature5 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case5); "fill_constant"))(arg_case5);
ASSERT_EQ(signature5.name, "full_sr"); ASSERT_EQ(signature5.name, "full_sr");
TestArgumentMappingContext arg_case6( TestArgumentMappingContext arg_case6(
...@@ -76,8 +76,8 @@ TEST(ARG_MAP, fill_constant) { ...@@ -76,8 +76,8 @@ TEST(ARG_MAP, fill_constant) {
{{"value", paddle::any{0}}, {"str_value", paddle::any{std::string{""}}}}, {{"value", paddle::any{0}}, {"str_value", paddle::any{std::string{""}}}},
{}, {},
{"Out"}); {"Out"});
auto signature6 = auto signature6 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case6); "fill_constant"))(arg_case6);
ASSERT_EQ(signature6.name, "full_sr"); ASSERT_EQ(signature6.name, "full_sr");
TestArgumentMappingContext arg_case7( TestArgumentMappingContext arg_case7(
...@@ -86,8 +86,8 @@ TEST(ARG_MAP, fill_constant) { ...@@ -86,8 +86,8 @@ TEST(ARG_MAP, fill_constant) {
{{"shape", paddle::any{std::vector<int64_t>{2, 3}}}}, {{"shape", paddle::any{std::vector<int64_t>{2, 3}}}},
{}, {},
{"Out"}); {"Out"});
auto signature7 = auto signature7 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case7); "fill_constant"))(arg_case7);
ASSERT_EQ(signature7.name, "full_sr"); ASSERT_EQ(signature7.name, "full_sr");
TestArgumentMappingContext arg_case8( TestArgumentMappingContext arg_case8(
...@@ -98,8 +98,8 @@ TEST(ARG_MAP, fill_constant) { ...@@ -98,8 +98,8 @@ TEST(ARG_MAP, fill_constant) {
{"str_value", paddle::any{std::string{""}}}}, {"str_value", paddle::any{std::string{""}}}},
{}, {},
{"Out"}); {"Out"});
auto signature8 = auto signature8 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case8); "fill_constant"))(arg_case8);
ASSERT_EQ(signature8.name, "full_sr"); ASSERT_EQ(signature8.name, "full_sr");
TestArgumentMappingContext arg_case9( TestArgumentMappingContext arg_case9(
...@@ -109,8 +109,8 @@ TEST(ARG_MAP, fill_constant) { ...@@ -109,8 +109,8 @@ TEST(ARG_MAP, fill_constant) {
{"str_value", paddle::any{std::string{"10"}}}}, {"str_value", paddle::any{std::string{"10"}}}},
{}, {},
{"Out"}); {"Out"});
auto signature9 = auto signature9 = (*OpUtilsMap::Instance().GetArgumentMappingFn(
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case9); "fill_constant"))(arg_case9);
ASSERT_EQ(signature9.name, "full_sr"); ASSERT_EQ(signature9.name, "full_sr");
} }
...@@ -122,7 +122,8 @@ TEST(ARG_MAP, set_value) { ...@@ -122,7 +122,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case1( TestArgumentMappingContext arg_case1(
...@@ -132,7 +133,8 @@ TEST(ARG_MAP, set_value) { ...@@ -132,7 +133,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case1).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case1)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case2( TestArgumentMappingContext arg_case2(
...@@ -142,7 +144,8 @@ TEST(ARG_MAP, set_value) { ...@@ -142,7 +144,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case2).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case2)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case3( TestArgumentMappingContext arg_case3(
...@@ -152,7 +155,8 @@ TEST(ARG_MAP, set_value) { ...@@ -152,7 +155,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case3).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case3)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case4( TestArgumentMappingContext arg_case4(
...@@ -162,7 +166,8 @@ TEST(ARG_MAP, set_value) { ...@@ -162,7 +166,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case4).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case4)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case5( TestArgumentMappingContext arg_case5(
...@@ -172,7 +177,8 @@ TEST(ARG_MAP, set_value) { ...@@ -172,7 +177,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case5).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case5)
.name,
"set_value_with_tensor"); "set_value_with_tensor");
TestArgumentMappingContext arg_case6( TestArgumentMappingContext arg_case6(
...@@ -182,7 +188,8 @@ TEST(ARG_MAP, set_value) { ...@@ -182,7 +188,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case6).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case6)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case7( TestArgumentMappingContext arg_case7(
...@@ -192,7 +199,8 @@ TEST(ARG_MAP, set_value) { ...@@ -192,7 +199,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case7).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case7)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case8( TestArgumentMappingContext arg_case8(
...@@ -202,7 +210,8 @@ TEST(ARG_MAP, set_value) { ...@@ -202,7 +210,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case8).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case8)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case9( TestArgumentMappingContext arg_case9(
...@@ -212,7 +221,8 @@ TEST(ARG_MAP, set_value) { ...@@ -212,7 +221,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case9).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case9)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case10( TestArgumentMappingContext arg_case10(
...@@ -222,7 +232,8 @@ TEST(ARG_MAP, set_value) { ...@@ -222,7 +232,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case10).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case10)
.name,
"set_value_with_tensor"); "set_value_with_tensor");
TestArgumentMappingContext arg_case11( TestArgumentMappingContext arg_case11(
...@@ -232,7 +243,8 @@ TEST(ARG_MAP, set_value) { ...@@ -232,7 +243,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case11).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case11)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case12( TestArgumentMappingContext arg_case12(
...@@ -242,7 +254,8 @@ TEST(ARG_MAP, set_value) { ...@@ -242,7 +254,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case12).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case12)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case13( TestArgumentMappingContext arg_case13(
...@@ -252,7 +265,8 @@ TEST(ARG_MAP, set_value) { ...@@ -252,7 +265,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case13).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case13)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case14( TestArgumentMappingContext arg_case14(
...@@ -262,13 +276,15 @@ TEST(ARG_MAP, set_value) { ...@@ -262,13 +276,15 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case14).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case14)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case15( TestArgumentMappingContext arg_case15(
{"Input", "StartsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {}); {"Input", "StartsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case15).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case15)
.name,
"set_value_with_tensor"); "set_value_with_tensor");
TestArgumentMappingContext arg_case16( TestArgumentMappingContext arg_case16(
...@@ -278,7 +294,8 @@ TEST(ARG_MAP, set_value) { ...@@ -278,7 +294,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case16).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case16)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case17( TestArgumentMappingContext arg_case17(
...@@ -288,7 +305,8 @@ TEST(ARG_MAP, set_value) { ...@@ -288,7 +305,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case17).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case17)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case18( TestArgumentMappingContext arg_case18(
...@@ -298,7 +316,8 @@ TEST(ARG_MAP, set_value) { ...@@ -298,7 +316,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case18).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case18)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case19( TestArgumentMappingContext arg_case19(
...@@ -308,7 +327,8 @@ TEST(ARG_MAP, set_value) { ...@@ -308,7 +327,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case19).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case19)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case20( TestArgumentMappingContext arg_case20(
...@@ -318,7 +338,8 @@ TEST(ARG_MAP, set_value) { ...@@ -318,7 +338,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case20).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case20)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case21( TestArgumentMappingContext arg_case21(
...@@ -328,7 +349,8 @@ TEST(ARG_MAP, set_value) { ...@@ -328,7 +349,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case21).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case21)
.name,
"set_value_with_tensor"); "set_value_with_tensor");
TestArgumentMappingContext arg_case22( TestArgumentMappingContext arg_case22(
...@@ -338,7 +360,8 @@ TEST(ARG_MAP, set_value) { ...@@ -338,7 +360,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case22).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case22)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case23( TestArgumentMappingContext arg_case23(
...@@ -348,7 +371,8 @@ TEST(ARG_MAP, set_value) { ...@@ -348,7 +371,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case23).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case23)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case24( TestArgumentMappingContext arg_case24(
...@@ -358,7 +382,8 @@ TEST(ARG_MAP, set_value) { ...@@ -358,7 +382,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case24).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case24)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case25( TestArgumentMappingContext arg_case25(
...@@ -368,13 +393,15 @@ TEST(ARG_MAP, set_value) { ...@@ -368,13 +393,15 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case25).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case25)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case26( TestArgumentMappingContext arg_case26(
{"Input", "EndsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {}); {"Input", "EndsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case26).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case26)
.name,
"set_value_with_tensor"); "set_value_with_tensor");
TestArgumentMappingContext arg_case27( TestArgumentMappingContext arg_case27(
...@@ -384,7 +411,8 @@ TEST(ARG_MAP, set_value) { ...@@ -384,7 +411,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case27).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case27)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case28( TestArgumentMappingContext arg_case28(
...@@ -394,7 +422,8 @@ TEST(ARG_MAP, set_value) { ...@@ -394,7 +422,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case28).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case28)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case29( TestArgumentMappingContext arg_case29(
...@@ -404,7 +433,8 @@ TEST(ARG_MAP, set_value) { ...@@ -404,7 +433,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case29).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case29)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case30( TestArgumentMappingContext arg_case30(
...@@ -414,7 +444,8 @@ TEST(ARG_MAP, set_value) { ...@@ -414,7 +444,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case30).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case30)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case31( TestArgumentMappingContext arg_case31(
...@@ -424,13 +455,15 @@ TEST(ARG_MAP, set_value) { ...@@ -424,13 +455,15 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case31).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case31)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case32( TestArgumentMappingContext arg_case32(
{"Input", "StepsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {}); {"Input", "StepsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case32).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case32)
.name,
"set_value_with_tensor"); "set_value_with_tensor");
TestArgumentMappingContext arg_case33( TestArgumentMappingContext arg_case33(
...@@ -440,7 +473,8 @@ TEST(ARG_MAP, set_value) { ...@@ -440,7 +473,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case33).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case33)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case34( TestArgumentMappingContext arg_case34(
...@@ -450,7 +484,8 @@ TEST(ARG_MAP, set_value) { ...@@ -450,7 +484,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case34).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case34)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case35( TestArgumentMappingContext arg_case35(
...@@ -460,7 +495,8 @@ TEST(ARG_MAP, set_value) { ...@@ -460,7 +495,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case35).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case35)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case36( TestArgumentMappingContext arg_case36(
...@@ -470,7 +506,8 @@ TEST(ARG_MAP, set_value) { ...@@ -470,7 +506,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case36).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case36)
.name,
"set_value"); "set_value");
TestArgumentMappingContext arg_case37( TestArgumentMappingContext arg_case37(
...@@ -480,7 +517,8 @@ TEST(ARG_MAP, set_value) { ...@@ -480,7 +517,8 @@ TEST(ARG_MAP, set_value) {
{"Out"}, {"Out"},
{}); {});
ASSERT_EQ( ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case37).name, (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value"))(arg_case37)
.name,
"set_value"); "set_value");
} }
...@@ -491,10 +529,10 @@ TEST(ARG_MAP, set_value_grad) { ...@@ -491,10 +529,10 @@ TEST(ARG_MAP, set_value_grad) {
{}, {},
{"Input@GRAD", "ValueTensor@GRAD"}, {"Input@GRAD", "ValueTensor@GRAD"},
{}); {});
ASSERT_EQ(OpUtilsMap::Instance() ASSERT_EQ(
.GetArgumentMappingFn("set_value_grad")(arg_case) (*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(arg_case)
.name, .name,
"set_value_grad"); "set_value_grad");
TestArgumentMappingContext arg_case1( TestArgumentMappingContext arg_case1(
{"Out@GRAD", "StartsTensorList", "StepsTensorList"}, {"Out@GRAD", "StartsTensorList", "StepsTensorList"},
...@@ -502,8 +540,8 @@ TEST(ARG_MAP, set_value_grad) { ...@@ -502,8 +540,8 @@ TEST(ARG_MAP, set_value_grad) {
{}, {},
{"Input@GRAD", "ValueTensor@GRAD"}, {"Input@GRAD", "ValueTensor@GRAD"},
{}); {});
ASSERT_EQ(OpUtilsMap::Instance() ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(
.GetArgumentMappingFn("set_value_grad")(arg_case1) arg_case1)
.name, .name,
"set_value_grad"); "set_value_grad");
...@@ -512,8 +550,8 @@ TEST(ARG_MAP, set_value_grad) { ...@@ -512,8 +550,8 @@ TEST(ARG_MAP, set_value_grad) {
{}, {},
{"Input@GRAD", "ValueTensor@GRAD"}, {"Input@GRAD", "ValueTensor@GRAD"},
{}); {});
ASSERT_EQ(OpUtilsMap::Instance() ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(
.GetArgumentMappingFn("set_value_grad")(arg_case2) arg_case2)
.name, .name,
"set_value_grad"); "set_value_grad");
...@@ -523,8 +561,8 @@ TEST(ARG_MAP, set_value_grad) { ...@@ -523,8 +561,8 @@ TEST(ARG_MAP, set_value_grad) {
{}, {},
{"Input@GRAD", "ValueTensor@GRAD"}, {"Input@GRAD", "ValueTensor@GRAD"},
{}); {});
ASSERT_EQ(OpUtilsMap::Instance() ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(
.GetArgumentMappingFn("set_value_grad")(arg_case3) arg_case3)
.name, .name,
"set_value_grad"); "set_value_grad");
...@@ -533,8 +571,8 @@ TEST(ARG_MAP, set_value_grad) { ...@@ -533,8 +571,8 @@ TEST(ARG_MAP, set_value_grad) {
{}, {},
{"Input@GRAD", "ValueTensor@GRAD"}, {"Input@GRAD", "ValueTensor@GRAD"},
{}); {});
ASSERT_EQ(OpUtilsMap::Instance() ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(
.GetArgumentMappingFn("set_value_grad")(arg_case4) arg_case4)
.name, .name,
"set_value_grad"); "set_value_grad");
...@@ -543,8 +581,8 @@ TEST(ARG_MAP, set_value_grad) { ...@@ -543,8 +581,8 @@ TEST(ARG_MAP, set_value_grad) {
{}, {},
{"Input@GRAD", "ValueTensor@GRAD"}, {"Input@GRAD", "ValueTensor@GRAD"},
{}); {});
ASSERT_EQ(OpUtilsMap::Instance() ASSERT_EQ((*OpUtilsMap::Instance().GetArgumentMappingFn("set_value_grad"))(
.GetArgumentMappingFn("set_value_grad")(arg_case5) arg_case5)
.name, .name,
"set_value_grad"); "set_value_grad");
} }
...@@ -558,10 +596,9 @@ TEST(ARG_MAP, allclose) { ...@@ -558,10 +596,9 @@ TEST(ARG_MAP, allclose) {
{"Out"}, {"Out"},
{}); {});
auto signature1 = auto signature1 =
OpUtilsMap::Instance().GetArgumentMappingFn("allclose")(arg_case1); (*OpUtilsMap::Instance().GetArgumentMappingFn("allclose"))(arg_case1);
ASSERT_EQ(signature1.name, "allclose"); ASSERT_EQ(signature1.name, "allclose");
auto attr_names1 = std::get<1>(signature1.args); ASSERT_EQ(signature1.attr_names[0], "Rtol");
ASSERT_EQ(attr_names1[0], "Rtol");
TestArgumentMappingContext arg_case2( TestArgumentMappingContext arg_case2(
{"Input", "Other", "Atol"}, {"Input", "Other", "Atol"},
...@@ -571,27 +608,26 @@ TEST(ARG_MAP, allclose) { ...@@ -571,27 +608,26 @@ TEST(ARG_MAP, allclose) {
{"Out"}, {"Out"},
{}); {});
auto signature2 = auto signature2 =
OpUtilsMap::Instance().GetArgumentMappingFn("allclose")(arg_case2); (*OpUtilsMap::Instance().GetArgumentMappingFn("allclose"))(arg_case2);
ASSERT_EQ(signature2.name, "allclose"); ASSERT_EQ(signature2.name, "allclose");
auto attr_names2 = std::get<1>(signature2.args); ASSERT_EQ(signature2.attr_names[1], "Atol");
ASSERT_EQ(attr_names2[1], "Atol");
} }
TEST(ARG_MAP, reshape) { TEST(ARG_MAP, reshape) {
TestArgumentMappingContext arg_case1({"X", "ShapeTensor"}, {}, {}, {"Out"}); TestArgumentMappingContext arg_case1({"X", "ShapeTensor"}, {}, {}, {"Out"});
auto signature1 = auto signature1 =
OpUtilsMap::Instance().GetArgumentMappingFn("reshape2")(arg_case1); (*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case1);
ASSERT_EQ(signature1.name, "reshape"); ASSERT_EQ(signature1.name, "reshape");
TestArgumentMappingContext arg_case2({"X", "Shape"}, {}, {}, {"Out"}); TestArgumentMappingContext arg_case2({"X", "Shape"}, {}, {}, {"Out"});
auto signature2 = auto signature2 =
OpUtilsMap::Instance().GetArgumentMappingFn("reshape2")(arg_case2); (*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case2);
ASSERT_EQ(signature2.name, "reshape"); ASSERT_EQ(signature2.name, "reshape");
TestArgumentMappingContext arg_case3( TestArgumentMappingContext arg_case3(
{"X"}, {}, {{"shape", paddle::any(std::vector<int>({1, 2}))}}, {"Out"}); {"X"}, {}, {{"shape", paddle::any(std::vector<int>({1, 2}))}}, {"Out"});
auto signature3 = auto signature3 =
OpUtilsMap::Instance().GetArgumentMappingFn("reshape2")(arg_case3); (*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case3);
ASSERT_EQ(signature3.name, "reshape"); ASSERT_EQ(signature3.name, "reshape");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册