From 0e0983cc1d9a607ba8a339bbbe9e495e304cd11f Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 18 Dec 2018 21:27:04 +0800 Subject: [PATCH] convert more infer shape --- paddle/fluid/framework/operator.cc | 34 ++++++++++++++++++------------ 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 5bee6b41bde..a7bee3344df 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -614,16 +614,19 @@ class RuntimeInferShapeContext : public InferShapeContext { void ShareDim(const std::string& in, const std::string& out, size_t i = 0, size_t j = 0) override { - PADDLE_ENFORCE_LT(i, Inputs(in).size()); - PADDLE_ENFORCE_LT(j, Outputs(out).size()); - const std::string& input_n = Inputs(in)[i]; - const std::string& output_n = Outputs(out)[j]; + auto in_it = ctx_.inputs.find(in); + auto out_it = ctx_.outputs.find(out); + PADDLE_ENFORCE(in_it != ctx_.inputs.end() && in_it->second.size() > i, + "Inputs %s should have %llu argument", in, i); + PADDLE_ENFORCE(out_it != ctx_.outputs.end() && out_it->second.size() > j, + "Outputs %s should have %llu argument", out, j); + + Variable* in_var = in_it->second[i]; + Variable* out_var = out_it->second[j]; - Variable* in_var = scope_.FindVar(input_n); - Variable* out_var = scope_.FindVar(output_n); PADDLE_ENFORCE(in_var->Type() == out_var->Type(), - "The type of %s and %s is not the same.", output_n, - GetDim(input_n)); + "The type of %s and %s is not the same.", in_var->Type(), + out_var->Type()); if (in_var->IsType()) { auto& in_sele_rows = in_var->Get(); @@ -644,13 +647,16 @@ class RuntimeInferShapeContext : public InferShapeContext { void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, size_t j = 0) const override { - const std::vector& inputs = Inputs(in); - const std::vector& outputs = Outputs(out); - PADDLE_ENFORCE_LT(i, inputs.size()); - PADDLE_ENFORCE_LT(j, outputs.size()); - Variable* in_var = scope_.FindVar(inputs.at(i)); + auto in_it = ctx_.inputs.find(in); + auto out_it = ctx_.outputs.find(out); + PADDLE_ENFORCE(in_it != ctx_.inputs.end() && in_it->second.size() > i, + "Inputs %s should have %llu argument", in, i); + PADDLE_ENFORCE(out_it != ctx_.outputs.end() && out_it->second.size() > j, + "Outputs %s should have %llu argument", out, j); + + Variable* in_var = in_it->second.at(i); if (!in_var->IsType()) return; - Variable* out_var = scope_.FindVar(outputs.at(j)); + Variable* out_var = out_it->second.at(j); PADDLE_ENFORCE(out_var->IsType(), "The %d-th output of Output(%s) must be LoDTensor.", j, out); auto in_tensor = in_var->Get(); -- GitLab