diff --git a/paddle/framework/data_transform.cc b/paddle/framework/data_transform.cc index 55825c5b7d2155df5a80fc15933579441aea75a8..fed958db1584c4fda5394d59a2ef8936045a9ce9 100644 --- a/paddle/framework/data_transform.cc +++ b/paddle/framework/data_transform.cc @@ -146,9 +146,6 @@ void TransDataLayout(const std::vector& axis, auto* dst = out->GetMutable(); PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!"); - auto place = kernel_pair.second.place_; - CopyFrom(src, place, *ctx, dst); - auto src_dim = src.dims(); std::vector dst_dim; @@ -158,6 +155,8 @@ void TransDataLayout(const std::vector& axis, } dst->Resize(make_ddim(dst_dim)); + auto place = kernel_pair.second.place_; + dst->mutable_data(place, src.type()); auto src_type = kernel_pair.first.data_type_; framework::VisitDataType(src_type, CastDataLayout(ctx, axis, src, dst)); diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index e02e572af2c4ee122d033bc7b5231be94026c180..47c91290e4bf90897d35f1b3bce2e1f10ad0782c 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -66,6 +66,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { out); out_var->SetLoDLevel(in_var->GetLoDLevel()); } + bool IsRuntime() const override; protected: diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index adc85b1049f982449e7bf9c8aea1f096c974693d..a1f1be5f34264c11e8125f78650d63d9996aea84 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -417,6 +417,25 @@ class RuntimeInferShapeContext : public InferShapeContext { auto in_tensor = in_var->Get(); auto* out_tensor = out_var->GetMutable(); out_tensor->set_lod(in_tensor.lod()); + + // TODO(dzhwinter) : reuse ShareLoD in most operators. + // Need to call ShareLayout explicitly in sequence related ops. + // Shall we have a better method to shared info between in/out Tensor? + out_tensor->set_layout(in_tensor.layout()); + } + + void ShareLayout(const std::string& in, const std::string& out, size_t i = 0, + size_t j = 0) const { + PADDLE_ENFORCE_LT(i, Inputs(in).size()); + PADDLE_ENFORCE_LT(j, Outputs(out).size()); + Variable* in_var = scope_.FindVar(Inputs(in)[i]); + Variable* out_var = scope_.FindVar(Outputs(out)[j]); + if (!in_var->IsType()) return; + PADDLE_ENFORCE(out_var->IsType(), + "The %d-th output of Output(%s) must be LoDTensor.", j, out); + auto in_tensor = in_var->Get(); + auto* out_tensor = out_var->GetMutable(); + out_tensor->set_layout(in_tensor.layout()); } bool IsRuntime() const override { return true; }