From 62eb43ba98931f303127441b0f53f142b12f439f Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 18 Dec 2018 20:22:56 +0800 Subject: [PATCH] convert more test=develop --- paddle/fluid/framework/operator.cc | 35 ++++++++++++++---------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 8c8374866..5bee6b41b 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -142,12 +142,14 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames, const Scope& scope) { for (auto& var_name_item : innames) { std::vector& input_vars = inputs[var_name_item.first]; + input_vars.reserve(var_name_item.second.size()); for (auto& var_name : var_name_item.second) { input_vars.push_back(scope.FindVar(var_name)); } } for (auto& var_name_item : outnames) { std::vector& output_vars = outputs[var_name_item.first]; + output_vars.reserve(var_name_item.second.size()); for (auto& var_name : var_name_item.second) { output_vars.push_back(scope.FindVar(var_name)); } @@ -556,30 +558,28 @@ class RuntimeInferShapeContext : public InferShapeContext { bool HasOutput(const std::string& name) const override { // has only one output - const auto& outs = op_.Outputs(); + const auto& outs = ctx_.outputs; auto it = outs.find(name); if (it == outs.end()) { return false; } const auto& out = it->second; - if (out.size() == 0 || out[0] == kEmptyVarName) { + if (out.size() == 0) { return false; } PADDLE_ENFORCE_EQ(out.size(), 1UL, "Output %s should not have more than one outputs", name); - return scope_.FindVar(out[0]) != nullptr; + return out[0] != nullptr; } bool HasInputs(const std::string& name) const override { - if (!op_.HasInputs(name)) { - return false; - } - auto inputs = op_.Inputs(name); - if (inputs.empty()) { + const auto& ins = ctx_.inputs; + auto it = ins.find(name); + if (it == ins.end()) { return false; } - for (auto& input : inputs) { - if (scope_.FindVar(input) == nullptr) { + for (auto& input : it->second) { + if (input == nullptr) { return false; } } @@ -587,15 +587,13 @@ class RuntimeInferShapeContext : public InferShapeContext { } bool HasOutputs(const std::string& name) const override { - if (!op_.HasOutputs(name)) { - return false; - } - auto outputs = op_.Outputs(name); - if (outputs.empty()) { + const auto& outs = ctx_.outputs; + auto it = outs.find(name); + if (it == outs.end()) { return false; } - for (auto& output : outputs) { - if (scope_.FindVar(output) == nullptr) { + for (auto& output : it->second) { + if (output == nullptr) { return false; } } @@ -864,8 +862,7 @@ Scope* OperatorWithKernel::PrepareData( for (size_t i = 0; i < var_name_item.second.size(); ++i) { auto& var_name = var_name_item.second[i]; - auto* var = scope.FindVar(var_name); - input_vars[i] = var; + auto* var = input_vars[i]; // Only tensor can be tranfer to another device. if (var == nullptr || !VarIsTensor(*var)) { -- GitLab