提交 62eb43ba 编写于 作者: X Xin Pan

convert more

test=develop
上级 dfcf746e
...@@ -142,12 +142,14 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames, ...@@ -142,12 +142,14 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames,
const Scope& scope) { const Scope& scope) {
for (auto& var_name_item : innames) { for (auto& var_name_item : innames) {
std::vector<Variable*>& input_vars = inputs[var_name_item.first]; std::vector<Variable*>& input_vars = inputs[var_name_item.first];
input_vars.reserve(var_name_item.second.size());
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
input_vars.push_back(scope.FindVar(var_name)); input_vars.push_back(scope.FindVar(var_name));
} }
} }
for (auto& var_name_item : outnames) { for (auto& var_name_item : outnames) {
std::vector<Variable*>& output_vars = outputs[var_name_item.first]; std::vector<Variable*>& output_vars = outputs[var_name_item.first];
output_vars.reserve(var_name_item.second.size());
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
output_vars.push_back(scope.FindVar(var_name)); output_vars.push_back(scope.FindVar(var_name));
} }
...@@ -556,30 +558,28 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -556,30 +558,28 @@ class RuntimeInferShapeContext : public InferShapeContext {
bool HasOutput(const std::string& name) const override { bool HasOutput(const std::string& name) const override {
// has only one output // has only one output
const auto& outs = op_.Outputs(); const auto& outs = ctx_.outputs;
auto it = outs.find(name); auto it = outs.find(name);
if (it == outs.end()) { if (it == outs.end()) {
return false; return false;
} }
const auto& out = it->second; const auto& out = it->second;
if (out.size() == 0 || out[0] == kEmptyVarName) { if (out.size() == 0) {
return false; return false;
} }
PADDLE_ENFORCE_EQ(out.size(), 1UL, PADDLE_ENFORCE_EQ(out.size(), 1UL,
"Output %s should not have more than one outputs", name); "Output %s should not have more than one outputs", name);
return scope_.FindVar(out[0]) != nullptr; return out[0] != nullptr;
} }
bool HasInputs(const std::string& name) const override { bool HasInputs(const std::string& name) const override {
if (!op_.HasInputs(name)) { const auto& ins = ctx_.inputs;
return false; auto it = ins.find(name);
} if (it == ins.end()) {
auto inputs = op_.Inputs(name);
if (inputs.empty()) {
return false; return false;
} }
for (auto& input : inputs) { for (auto& input : it->second) {
if (scope_.FindVar(input) == nullptr) { if (input == nullptr) {
return false; return false;
} }
} }
...@@ -587,15 +587,13 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -587,15 +587,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
bool HasOutputs(const std::string& name) const override { bool HasOutputs(const std::string& name) const override {
if (!op_.HasOutputs(name)) { const auto& outs = ctx_.outputs;
return false; auto it = outs.find(name);
} if (it == outs.end()) {
auto outputs = op_.Outputs(name);
if (outputs.empty()) {
return false; return false;
} }
for (auto& output : outputs) { for (auto& output : it->second) {
if (scope_.FindVar(output) == nullptr) { if (output == nullptr) {
return false; return false;
} }
} }
...@@ -864,8 +862,7 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -864,8 +862,7 @@ Scope* OperatorWithKernel::PrepareData(
for (size_t i = 0; i < var_name_item.second.size(); ++i) { for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto& var_name = var_name_item.second[i]; auto& var_name = var_name_item.second[i];
auto* var = scope.FindVar(var_name); auto* var = input_vars[i];
input_vars[i] = var;
// Only tensor can be tranfer to another device. // Only tensor can be tranfer to another device.
if (var == nullptr || !VarIsTensor(*var)) { if (var == nullptr || !VarIsTensor(*var)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册