未验证 提交 943e4deb 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #13750 from panyx0718/fix

clean unused code and small optimize
......@@ -544,11 +544,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
const std::vector<std::string>& inputs = Inputs(in);
const std::vector<std::string>& outputs = Outputs(out);
PADDLE_ENFORCE_LT(i, inputs.size());
PADDLE_ENFORCE_LT(j, outputs.size());
Variable* in_var = scope_.FindVar(inputs.at(i));
if (!in_var->IsType<LoDTensor>()) return;
Variable* out_var = scope_.FindVar(outputs.at(j));
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
auto in_tensor = in_var->Get<LoDTensor>();
......@@ -576,20 +578,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
out_tensor->set_layout(in_tensor.layout());
}
void ShareLayout(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
if (!in_var->IsType<LoDTensor>()) return;
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
auto in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_layout(in_tensor.layout());
}
bool IsRuntime() const override { return true; }
protected:
......
......@@ -46,16 +46,6 @@ std::vector<DDim> InferShapeContext::GetReaderDims(
return this->GetRepeatedDims(arg_names[0]);
}
void InferShapeContext::ShareLoDs(const std::string &in,
const std::string &out) const {
PADDLE_ENFORCE_EQ(Inputs(in).size(), Outputs(out).size(),
"The number of arguments in %s and %s is not equal.", in,
out);
for (size_t i = 0; i < in.size(); ++i) {
ShareLoD(in, out, i, i);
}
}
DDim InferShapeContext::GetInputsElementDim(const std::string &name,
int idx) const {
const std::vector<std::string> &names = Inputs(name);
......
......@@ -56,8 +56,6 @@ class InferShapeContext {
virtual const std::vector<std::string> &Outputs(
const std::string &name) const = 0;
void ShareLoDs(const std::string &in, const std::string &out) const;
virtual void ShareLoD(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册