提交 62eb43ba 编写于 作者: X Xin Pan

convert more

test=develop
上级 dfcf746e
......@@ -142,12 +142,14 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames,
const Scope& scope) {
for (auto& var_name_item : innames) {
std::vector<Variable*>& input_vars = inputs[var_name_item.first];
input_vars.reserve(var_name_item.second.size());
for (auto& var_name : var_name_item.second) {
input_vars.push_back(scope.FindVar(var_name));
}
}
for (auto& var_name_item : outnames) {
std::vector<Variable*>& output_vars = outputs[var_name_item.first];
output_vars.reserve(var_name_item.second.size());
for (auto& var_name : var_name_item.second) {
output_vars.push_back(scope.FindVar(var_name));
}
......@@ -556,30 +558,28 @@ class RuntimeInferShapeContext : public InferShapeContext {
bool HasOutput(const std::string& name) const override {
// has only one output
const auto& outs = op_.Outputs();
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end()) {
return false;
}
const auto& out = it->second;
if (out.size() == 0 || out[0] == kEmptyVarName) {
if (out.size() == 0) {
return false;
}
PADDLE_ENFORCE_EQ(out.size(), 1UL,
"Output %s should not have more than one outputs", name);
return scope_.FindVar(out[0]) != nullptr;
return out[0] != nullptr;
}
bool HasInputs(const std::string& name) const override {
if (!op_.HasInputs(name)) {
return false;
}
auto inputs = op_.Inputs(name);
if (inputs.empty()) {
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end()) {
return false;
}
for (auto& input : inputs) {
if (scope_.FindVar(input) == nullptr) {
for (auto& input : it->second) {
if (input == nullptr) {
return false;
}
}
......@@ -587,15 +587,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
bool HasOutputs(const std::string& name) const override {
if (!op_.HasOutputs(name)) {
return false;
}
auto outputs = op_.Outputs(name);
if (outputs.empty()) {
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end()) {
return false;
}
for (auto& output : outputs) {
if (scope_.FindVar(output) == nullptr) {
for (auto& output : it->second) {
if (output == nullptr) {
return false;
}
}
......@@ -864,8 +862,7 @@ Scope* OperatorWithKernel::PrepareData(
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto& var_name = var_name_item.second[i];
auto* var = scope.FindVar(var_name);
input_vars[i] = var;
auto* var = input_vars[i];
// Only tensor can be tranfer to another device.
if (var == nullptr || !VarIsTensor(*var)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册