未验证 提交 943e4deb 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #13750 from panyx0718/fix

clean unused code and small optimize
...@@ -544,11 +544,13 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -544,11 +544,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override { size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size()); const std::vector<std::string>& inputs = Inputs(in);
PADDLE_ENFORCE_LT(j, Outputs(out).size()); const std::vector<std::string>& outputs = Outputs(out);
Variable* in_var = scope_.FindVar(Inputs(in)[i]); PADDLE_ENFORCE_LT(i, inputs.size());
Variable* out_var = scope_.FindVar(Outputs(out)[j]); PADDLE_ENFORCE_LT(j, outputs.size());
Variable* in_var = scope_.FindVar(inputs.at(i));
if (!in_var->IsType<LoDTensor>()) return; if (!in_var->IsType<LoDTensor>()) return;
Variable* out_var = scope_.FindVar(outputs.at(j));
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(), PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
"The %d-th output of Output(%s) must be LoDTensor.", j, out); "The %d-th output of Output(%s) must be LoDTensor.", j, out);
auto in_tensor = in_var->Get<LoDTensor>(); auto in_tensor = in_var->Get<LoDTensor>();
...@@ -576,20 +578,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -576,20 +578,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
out_tensor->set_layout(in_tensor.layout()); out_tensor->set_layout(in_tensor.layout());
} }
void ShareLayout(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
if (!in_var->IsType<LoDTensor>()) return;
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
auto in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_layout(in_tensor.layout());
}
bool IsRuntime() const override { return true; } bool IsRuntime() const override { return true; }
protected: protected:
......
...@@ -46,16 +46,6 @@ std::vector<DDim> InferShapeContext::GetReaderDims( ...@@ -46,16 +46,6 @@ std::vector<DDim> InferShapeContext::GetReaderDims(
return this->GetRepeatedDims(arg_names[0]); return this->GetRepeatedDims(arg_names[0]);
} }
void InferShapeContext::ShareLoDs(const std::string &in,
const std::string &out) const {
PADDLE_ENFORCE_EQ(Inputs(in).size(), Outputs(out).size(),
"The number of arguments in %s and %s is not equal.", in,
out);
for (size_t i = 0; i < in.size(); ++i) {
ShareLoD(in, out, i, i);
}
}
DDim InferShapeContext::GetInputsElementDim(const std::string &name, DDim InferShapeContext::GetInputsElementDim(const std::string &name,
int idx) const { int idx) const {
const std::vector<std::string> &names = Inputs(name); const std::vector<std::string> &names = Inputs(name);
......
...@@ -56,8 +56,6 @@ class InferShapeContext { ...@@ -56,8 +56,6 @@ class InferShapeContext {
virtual const std::vector<std::string> &Outputs( virtual const std::vector<std::string> &Outputs(
const std::string &name) const = 0; const std::string &name) const = 0;
void ShareLoDs(const std::string &in, const std::string &out) const;
virtual void ShareLoD(const std::string &in, const std::string &out, virtual void ShareLoD(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0; size_t i = 0, size_t j = 0) const = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册