未验证 提交 c79d1186 编写于 作者: Z zyfncg 提交者: GitHub

Dygraph performance optimization (v2) (#42103)

* optimiaze performance of PreparePhiData

* dygraph performance optimization
上级 f0ec580e
...@@ -414,9 +414,9 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -414,9 +414,9 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
CompatInferMetaContext infer_meta_context( CompatInferMetaContext infer_meta_context(
{ctx->IsRuntime(), ctx->IsRunMKLDNNKernel()}); {ctx->IsRuntime(), ctx->IsRunMKLDNNKernel()});
auto& input_names = std::get<0>(signature.args); const auto& input_names = signature.input_names;
auto& attr_names = std::get<1>(signature.args); const auto& attr_names = signature.attr_names;
auto& output_names = std::get<2>(signature.args); const auto& output_names = signature.output_names;
const auto& args_def = const auto& args_def =
phi::KernelFactory::Instance().GetFirstKernelArgsDef(signature.name); phi::KernelFactory::Instance().GetFirstKernelArgsDef(signature.name);
......
...@@ -1198,8 +1198,10 @@ bool OperatorWithKernel::SupportsMKLDNN( ...@@ -1198,8 +1198,10 @@ bool OperatorWithKernel::SupportsMKLDNN(
bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const { proto::VarType::Type data_type) const {
bool use_mkldnn_ctx = ctx.HasAttr("use_mkldnn") && const auto& attrs_map = ctx.Attrs();
ctx.Attr<bool>("use_mkldnn") && auto iter = attrs_map.find("use_mkldnn");
bool use_mkldnn_ctx = iter != attrs_map.end() &&
BOOST_GET_CONST(bool, iter->second) &&
platform::is_cpu_place(ctx.GetPlace()); platform::is_cpu_place(ctx.GetPlace());
return use_mkldnn_ctx && this->SupportsMKLDNN(data_type); return use_mkldnn_ctx && this->SupportsMKLDNN(data_type);
} }
...@@ -2124,7 +2126,7 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( ...@@ -2124,7 +2126,7 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
Scope* OperatorWithKernel::PreparePhiData( Scope* OperatorWithKernel::PreparePhiData(
const Scope& scope, const phi::Kernel& pt_kernel, const Scope& scope, const phi::Kernel& pt_kernel,
const KernelSignature& pt_kernel_signature, RuntimeContext* ctx) const { const KernelSignature& pt_kernel_signature, RuntimeContext* ctx) const {
auto& input_names = std::get<0>(pt_kernel_signature.args); const auto& input_names = pt_kernel_signature.input_names;
auto input_defs = pt_kernel.args_def().input_defs(); auto input_defs = pt_kernel.args_def().input_defs();
PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(), PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -2176,11 +2178,15 @@ Scope* OperatorWithKernel::PreparePhiData( ...@@ -2176,11 +2178,15 @@ Scope* OperatorWithKernel::PreparePhiData(
if (in_def.backend == phi::Backend::ALL_BACKEND) { if (in_def.backend == phi::Backend::ALL_BACKEND) {
continue; continue;
} }
auto expected_place = phi::TransToPhiPlace(in_def.backend);
if (platform::is_same_place(tensor_in->place(), expected_place)) { auto tensor_backend = phi::TransToPhiBackend(tensor_in->place());
if (in_def.backend == tensor_backend ||
(in_def.backend == phi::Backend::GPUDNN &&
tensor_backend == phi::Backend::GPU)) {
continue; continue;
} }
auto expected_place = phi::TransToPhiPlace(in_def.backend);
VLOG(3) << "phi Transform Variable " << input_names[i] << " from " VLOG(3) << "phi Transform Variable " << input_names[i] << " from "
<< tensor_in->place() << " to " << expected_place; << tensor_in->place() << " to " << expected_place;
...@@ -2217,9 +2223,9 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2217,9 +2223,9 @@ void OperatorWithKernel::BuildPhiKernelContext(
phi::KernelContext* pt_kernel_context) const { phi::KernelContext* pt_kernel_context) const {
pt_kernel_context->SetDeviceContext(dev_ctx); pt_kernel_context->SetDeviceContext(dev_ctx);
auto& input_names = std::get<0>(pt_kernel_signature_->args); auto& input_names = pt_kernel_signature_->input_names;
auto& attr_names = std::get<1>(pt_kernel_signature_->args); auto& attr_names = pt_kernel_signature_->attr_names;
auto& output_names = std::get<2>(pt_kernel_signature_->args); auto& output_names = pt_kernel_signature_->output_names;
auto input_defs = pt_kernel_->args_def().input_defs(); auto input_defs = pt_kernel_->args_def().input_defs();
auto attr_defs = pt_kernel_->args_def().attribute_defs(); auto attr_defs = pt_kernel_->args_def().attribute_defs();
......
...@@ -233,9 +233,9 @@ void BuildDygraphPhiKernelContext( ...@@ -233,9 +233,9 @@ void BuildDygraphPhiKernelContext(
platform::DeviceContext* dev_ctx, phi::KernelContext* kernel_ctx) { platform::DeviceContext* dev_ctx, phi::KernelContext* kernel_ctx) {
kernel_ctx->SetDeviceContext(dev_ctx); kernel_ctx->SetDeviceContext(dev_ctx);
auto& input_names = std::get<0>(pt_kernel_signature.args); const auto& input_names = pt_kernel_signature.input_names;
auto& attr_names = std::get<1>(pt_kernel_signature.args); const auto& attr_names = pt_kernel_signature.attr_names;
auto& output_names = std::get<2>(pt_kernel_signature.args); const auto& output_names = pt_kernel_signature.output_names;
auto& input_defs = pt_kernel.args_def().input_defs(); auto& input_defs = pt_kernel.args_def().input_defs();
auto& output_defs = pt_kernel.args_def().output_defs(); auto& output_defs = pt_kernel.args_def().output_defs();
...@@ -570,7 +570,7 @@ template <typename VarType> ...@@ -570,7 +570,7 @@ template <typename VarType>
void PreparePhiData(const phi::Kernel& pt_kernel, void PreparePhiData(const phi::Kernel& pt_kernel,
const framework::KernelSignature& pt_kernel_signature, const framework::KernelSignature& pt_kernel_signature,
const NameVarMap<VarType>& ins) { const NameVarMap<VarType>& ins) {
auto& input_names = std::get<0>(pt_kernel_signature.args); const auto& input_names = pt_kernel_signature.input_names;
auto& input_defs = pt_kernel.args_def().input_defs(); auto& input_defs = pt_kernel.args_def().input_defs();
PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(), PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
......
...@@ -2050,9 +2050,9 @@ void BindImperative(py::module *m_ptr) { ...@@ -2050,9 +2050,9 @@ void BindImperative(py::module *m_ptr) {
}; };
auto ret = self.GetExpectedKernelSignature(type, ins_map, auto ret = self.GetExpectedKernelSignature(type, ins_map,
outs_map, attrs); outs_map, attrs);
auto kernelsig_ins = input_to_vector(std::get<0>(ret.args)); auto kernelsig_ins = input_to_vector(ret.input_names);
auto kernelsig_attrs = attr_to_vector(std::get<1>(ret.args)); auto kernelsig_attrs = attr_to_vector(ret.attr_names);
auto kernelsig_outs = output_to_vector(std::get<2>(ret.args)); auto kernelsig_outs = output_to_vector(ret.output_names);
return std::make_tuple(kernelsig_ins, kernelsig_attrs, return std::make_tuple(kernelsig_ins, kernelsig_attrs,
kernelsig_outs); kernelsig_outs);
} }
......
...@@ -58,10 +58,10 @@ int main(int argc, char **argv) { ...@@ -58,10 +58,10 @@ int main(int argc, char **argv) {
if (kernel_signature_map.Has(op_name)) { if (kernel_signature_map.Has(op_name)) {
kernel_signature_map_str = kernel_signature_map_str =
kernel_signature_map_str + "\"" + op_kernel_pair.first + "\":{"; kernel_signature_map_str + "\"" + op_kernel_pair.first + "\":{";
auto &args = kernel_signature_map.Get(op_name).args; const auto &args = kernel_signature_map.Get(op_name);
kernel_signature_map_str += "\"inputs\":["; kernel_signature_map_str += "\"inputs\":[";
auto inputs_ = std::get<0>(args); auto inputs_ = args.input_names;
for (size_t i = 0; i < inputs_.size(); i++) { for (size_t i = 0; i < inputs_.size(); i++) {
kernel_signature_map_str = kernel_signature_map_str =
kernel_signature_map_str + "\"" + inputs_[i] + "\","; kernel_signature_map_str + "\"" + inputs_[i] + "\",";
...@@ -69,14 +69,14 @@ int main(int argc, char **argv) { ...@@ -69,14 +69,14 @@ int main(int argc, char **argv) {
if (inputs_.size()) kernel_signature_map_str.pop_back(); if (inputs_.size()) kernel_signature_map_str.pop_back();
kernel_signature_map_str += "],\"attrs\":["; kernel_signature_map_str += "],\"attrs\":[";
auto attrs_ = std::get<1>(args); auto attrs_ = args.attr_names;
for (size_t i = 0; i < attrs_.size(); i++) { for (size_t i = 0; i < attrs_.size(); i++) {
kernel_signature_map_str = kernel_signature_map_str =
kernel_signature_map_str + "\"" + attrs_[i] + "\","; kernel_signature_map_str + "\"" + attrs_[i] + "\",";
} }
if (attrs_.size()) kernel_signature_map_str.pop_back(); if (attrs_.size()) kernel_signature_map_str.pop_back();
kernel_signature_map_str += "],\"outputs\":["; kernel_signature_map_str += "],\"outputs\":[";
auto outputs_ = std::get<2>(args); auto outputs_ = args.output_names;
for (size_t i = 0; i < outputs_.size(); i++) { for (size_t i = 0; i < outputs_.size(); i++) {
kernel_signature_map_str = kernel_signature_map_str =
kernel_signature_map_str + "\"" + outputs_[i] + "\","; kernel_signature_map_str + "\"" + outputs_[i] + "\",";
......
...@@ -200,7 +200,7 @@ void PhiOpConvertPass::convertStage() { ...@@ -200,7 +200,7 @@ void PhiOpConvertPass::convertStage() {
// resort input&output according to kernel_sign // resort input&output according to kernel_sign
::llvm::SmallVector<mlir::Value, 4> inputs, ori_output; ::llvm::SmallVector<mlir::Value, 4> inputs, ori_output;
::llvm::SmallVector<mlir::Type, 4> output_types; ::llvm::SmallVector<mlir::Type, 4> output_types;
for (const std::string &str : std::get<0>(kernel_sign.args)) { for (const std::string &str : kernel_sign.input_names) {
if (pd_dialect_inputs_info_map_.at(op_name).count(str) == 0) { if (pd_dialect_inputs_info_map_.at(op_name).count(str) == 0) {
LOG(ERROR) << "No input info for Op " << op_name << " and argument " LOG(ERROR) << "No input info for Op " << op_name << " and argument "
<< str; << str;
...@@ -210,7 +210,7 @@ void PhiOpConvertPass::convertStage() { ...@@ -210,7 +210,7 @@ void PhiOpConvertPass::convertStage() {
inputs.push_back(op->getOperands()[index]); inputs.push_back(op->getOperands()[index]);
} }
for (const std::string &str : std::get<2>(kernel_sign.args)) { for (const std::string &str : kernel_sign.output_names) {
if (pd_dialect_outputs_info_map_.at(op_name).count(str) == 0) { if (pd_dialect_outputs_info_map_.at(op_name).count(str) == 0) {
LOG(ERROR) << "No output info for Op " << op_name << " and argument " LOG(ERROR) << "No output info for Op " << op_name << " and argument "
<< str; << str;
......
...@@ -20,11 +20,11 @@ limitations under the License. */ ...@@ -20,11 +20,11 @@ limitations under the License. */
namespace phi { namespace phi {
std::ostream& operator<<(std::ostream& os, KernelSignature signature) { std::ostream& operator<<(std::ostream& os, KernelSignature signature) {
os << "Kernel Signature - name: " << signature.name << "; inputs: " os << "Kernel Signature - name: " << signature.name << "; inputs: "
<< paddle::string::join_strings(std::get<0>(signature.args), ", ") << paddle::string::join_strings(signature.input_names, ", ")
<< "; attributes: " << "; attributes: "
<< paddle::string::join_strings(std::get<1>(signature.args), ", ") << paddle::string::join_strings(signature.attr_names, ", ")
<< "; outputs: " << "; outputs: "
<< paddle::string::join_strings(std::get<2>(signature.args), ", "); << paddle::string::join_strings(signature.output_names, ", ");
return os; return os;
} }
......
...@@ -33,7 +33,9 @@ using KernelArgsTuple = std::tuple<paddle::SmallVector<const char*>, ...@@ -33,7 +33,9 @@ using KernelArgsTuple = std::tuple<paddle::SmallVector<const char*>,
struct KernelSignature { struct KernelSignature {
const char* name; const char* name;
KernelArgsTuple args; paddle::SmallVector<const char*> input_names;
paddle::SmallVector<const char*> attr_names;
paddle::SmallVector<const char*> output_names;
KernelSignature() = default; KernelSignature() = default;
...@@ -41,18 +43,26 @@ struct KernelSignature { ...@@ -41,18 +43,26 @@ struct KernelSignature {
paddle::SmallVector<const char*>&& inputs, paddle::SmallVector<const char*>&& inputs,
paddle::SmallVector<const char*>&& attrs, paddle::SmallVector<const char*>&& attrs,
paddle::SmallVector<const char*>&& outputs) paddle::SmallVector<const char*>&& outputs)
: name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {} : name(kernel_name),
input_names(std::move(inputs)),
attr_names(std::move(attrs)),
output_names(std::move(outputs)) {}
KernelSignature(const char* kernel_name, KernelSignature(const char* kernel_name,
const paddle::SmallVector<const char*>& inputs, const paddle::SmallVector<const char*>& inputs,
const paddle::SmallVector<const char*>& attrs, const paddle::SmallVector<const char*>& attrs,
const paddle::SmallVector<const char*>& outputs) const paddle::SmallVector<const char*>& outputs)
: name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {} : name(kernel_name),
input_names(inputs),
attr_names(attrs),
output_names(outputs) {}
// TODO(chenweihang): add assign constructor to solve windows compile // TODO(chenweihang): add assign constructor to solve windows compile
// problem, remove it later // problem, remove it later
KernelSignature& operator=(const KernelSignature& other) { KernelSignature& operator=(const KernelSignature& other) {
name = other.name; name = other.name;
args = other.args; input_names = other.input_names;
attr_names = other.attr_names;
output_names = other.output_names;
return *this; return *this;
} }
}; };
......
...@@ -560,8 +560,7 @@ TEST(ARG_MAP, allclose) { ...@@ -560,8 +560,7 @@ TEST(ARG_MAP, allclose) {
auto signature1 = auto signature1 =
OpUtilsMap::Instance().GetArgumentMappingFn("allclose")(arg_case1); OpUtilsMap::Instance().GetArgumentMappingFn("allclose")(arg_case1);
ASSERT_EQ(signature1.name, "allclose"); ASSERT_EQ(signature1.name, "allclose");
auto attr_names1 = std::get<1>(signature1.args); ASSERT_EQ(signature1.attr_names[0], "Rtol");
ASSERT_EQ(attr_names1[0], "Rtol");
TestArgumentMappingContext arg_case2( TestArgumentMappingContext arg_case2(
{"Input", "Other", "Atol"}, {"Input", "Other", "Atol"},
...@@ -573,8 +572,7 @@ TEST(ARG_MAP, allclose) { ...@@ -573,8 +572,7 @@ TEST(ARG_MAP, allclose) {
auto signature2 = auto signature2 =
OpUtilsMap::Instance().GetArgumentMappingFn("allclose")(arg_case2); OpUtilsMap::Instance().GetArgumentMappingFn("allclose")(arg_case2);
ASSERT_EQ(signature2.name, "allclose"); ASSERT_EQ(signature2.name, "allclose");
auto attr_names2 = std::get<1>(signature2.args); ASSERT_EQ(signature2.attr_names[1], "Atol");
ASSERT_EQ(attr_names2[1], "Atol");
} }
TEST(ARG_MAP, reshape) { TEST(ARG_MAP, reshape) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册