未验证 提交 e94db381 编写于 作者: D dzhwinter 提交者: GitHub

Feature/add shared layout (#7233)

* "reuse ShareLoD with no regret"

* "removed base class shareLayout"

* "fix CI"
上级 5e90f5e1
......@@ -146,9 +146,6 @@ void TransDataLayout(const std::vector<int>& axis,
auto* dst = out->GetMutable<Tensor>();
PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!");
auto place = kernel_pair.second.place_;
CopyFrom(src, place, *ctx, dst);
auto src_dim = src.dims();
std::vector<int64_t> dst_dim;
......@@ -158,6 +155,8 @@ void TransDataLayout(const std::vector<int>& axis,
}
dst->Resize(make_ddim(dst_dim));
auto place = kernel_pair.second.place_;
dst->mutable_data(place, src.type());
auto src_type = kernel_pair.first.data_type_;
framework::VisitDataType(src_type, CastDataLayout(ctx, axis, src, dst));
......
......@@ -66,6 +66,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
out);
out_var->SetLoDLevel(in_var->GetLoDLevel());
}
bool IsRuntime() const override;
protected:
......
......@@ -417,6 +417,25 @@ class RuntimeInferShapeContext : public InferShapeContext {
auto in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_lod(in_tensor.lod());
// TODO(dzhwinter) : reuse ShareLoD in most operators.
// Need to call ShareLayout explicitly in sequence related ops.
// Shall we have a better method to shared info between in/out Tensor?
out_tensor->set_layout(in_tensor.layout());
}
void ShareLayout(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
if (!in_var->IsType<LoDTensor>()) return;
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
auto in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_layout(in_tensor.layout());
}
bool IsRuntime() const override { return true; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册