From a308ff29af714be50e321c65fdcd88729a505ebe Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 10 Oct 2017 10:25:01 -0700 Subject: [PATCH] make infershape of feedop and fetchop compatible with compile-time design --- paddle/framework/executor_test.cc | 22 ++++++---------------- paddle/operators/feed_op.cc | 13 ++----------- paddle/operators/feed_op.h | 3 ++- paddle/operators/fetch_op.cc | 20 ++------------------ paddle/operators/fetch_op.h | 8 +++++++- 5 files changed, 19 insertions(+), 47 deletions(-) diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc index 259205f7c1d..0710eb57797 100644 --- a/paddle/framework/executor_test.cc +++ b/paddle/framework/executor_test.cc @@ -116,12 +116,8 @@ class ExecutorTesterRandom : public ::testing::Test { {{"dims", std::vector{input_dim, embed_dim}}}, init_root_block); AddOp("gaussian_random", {}, {{"Out", {"w2"}}}, {{"dims", std::vector{embed_dim, input_dim}}}, init_root_block); - AddOp("fetch", {{"Input", {"w1"}}}, {}, - {{"dims", std::vector{input_dim, embed_dim}}, {"col", 0}}, - init_root_block); - AddOp("fetch", {{"Input", {"w2"}}}, {}, - {{"dims", std::vector{embed_dim, input_dim}}, {"col", 1}}, - init_root_block); + AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, init_root_block); + AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, init_root_block); // flush init_program.Proto(); @@ -163,12 +159,8 @@ class ExecutorTesterRandom : public ::testing::Test { {"Grad", {"w2@GRAD"}}}, {{"ParamOut", {"w2"}}}, {}, root_block); - AddOp("fetch", {{"Input", {"w1"}}}, {}, - {{"dims", std::vector{input_dim, embed_dim}}, {"col", 0}}, - root_block); - AddOp("fetch", {{"Input", {"w2"}}}, {}, - {{"dims", std::vector{embed_dim, input_dim}}, {"col", 1}}, - root_block); + AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, root_block); + AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, root_block); // flush program.Proto(); @@ -197,10 +189,8 @@ class ExecutorTesterFeedAndFetch : public ::testing::Test { root_block); AddOp("feed", {}, {{"Out", {"b"}}}, {{"dims", dim}, {"col", 1}}, root_block); - AddOp("fetch", {{"Input", {"a"}}}, {}, {{"dims", dim}, {"col", 0}}, - root_block); - AddOp("fetch", {{"Input", {"b"}}}, {}, {{"dims", dim}, {"col", 1}}, - root_block); + AddOp("fetch", {{"Input", {"a"}}}, {}, {{"col", 0}}, root_block); + AddOp("fetch", {{"Input", {"b"}}}, {}, {{"col", 1}}, root_block); // flush program.Proto(); diff --git a/paddle/operators/feed_op.cc b/paddle/operators/feed_op.cc index 1d65c2bb469..fa325bb2829 100644 --- a/paddle/operators/feed_op.cc +++ b/paddle/operators/feed_op.cc @@ -24,15 +24,6 @@ class FeedOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null."); - int col = ctx->Attrs().Get("col"); - framework::Variable* g_feed_variable = - framework::GetGlobalScope()->FindVar("feed_value"); - - const auto& tensors = - g_feed_variable->Get>(); - - PADDLE_ENFORCE_GT(tensors.size(), static_cast(col)); - auto& shape = ctx->Attrs().Get>("dims"); std::vector shape_int64(shape.size(), 0); std::transform(shape.begin(), shape.end(), shape_int64.begin(), @@ -43,7 +34,7 @@ class FeedOp : public framework::OperatorWithKernel { framework::DataType IndicateDataType( const framework::ExecutionContext& ctx) const override { - return static_cast(Attr("data_type")); + return static_cast(Attr("dataType")); } }; @@ -51,7 +42,7 @@ class FeedOpMaker : public framework::OpProtoAndCheckerMaker { public: FeedOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddAttr("data_type", "output data type") + AddAttr("dataType", "output data type") .SetDefault(framework::DataType::FP32); AddAttr("col", "The col in global feed variable").SetDefault(0); AddAttr>("dims", "The dimension of feed tensor."); diff --git a/paddle/operators/feed_op.h b/paddle/operators/feed_op.h index 96e3bf52bd1..47344e309ce 100644 --- a/paddle/operators/feed_op.h +++ b/paddle/operators/feed_op.h @@ -27,9 +27,10 @@ class FeedKernel : public framework::OpKernel { out->mutable_data(ctx.GetPlace()); framework::Variable* g_feed_variable = framework::GetGlobalScope()->FindVar("feed_value"); - int col = ctx.template Attr("col"); const auto& tensors = g_feed_variable->Get>(); + int col = ctx.template Attr("col"); + PADDLE_ENFORCE_GT(tensors.size(), static_cast(col)); out->CopyFrom(tensors[col], ctx.GetPlace()); } }; diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc index 77e3450a73f..90737c8c550 100644 --- a/paddle/operators/fetch_op.cc +++ b/paddle/operators/fetch_op.cc @@ -24,26 +24,11 @@ class FetchOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null."); - int col = ctx->Attrs().Get("col"); - framework::Variable* g_fetch_variable = - framework::GetGlobalScope()->FindVar("fetch_value"); - - auto* tensors = - g_fetch_variable->GetMutable>(); - if (tensors->size() < static_cast(col + 1)) { - tensors->resize(col + 1); - } - - auto input_dim = ctx->GetInputDim("Input"); - PADDLE_ENFORCE_GT(tensors->size(), col); - (*tensors)[col].Resize(input_dim); - - // TODO(qijun): need to handle LodTensor later } framework::DataType IndicateDataType( const framework::ExecutionContext& ctx) const override { - return static_cast(Attr("data_type")); + return static_cast(Attr("dataType")); } }; @@ -51,10 +36,9 @@ class FetchOpMaker : public framework::OpProtoAndCheckerMaker { public: FetchOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddAttr("data_type", "output data type") + AddAttr("dataType", "output data type") .SetDefault(framework::DataType::FP32); AddAttr("col", "The col in global fetch variable").SetDefault(0); - AddAttr>("dims", "The dimension of fetch tensor."); AddInput("Input", "The output of fetch op."); AddComment(R"DOC(Fetch data to global fetch variable)DOC"); } diff --git a/paddle/operators/fetch_op.h b/paddle/operators/fetch_op.h index fd98552055f..6fee8b05892 100644 --- a/paddle/operators/fetch_op.h +++ b/paddle/operators/fetch_op.h @@ -24,13 +24,19 @@ class FetchKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor* input = ctx.Input("Input"); - int col = ctx.template Attr("col"); framework::Variable* g_fetch_variable = framework::GetGlobalScope()->FindVar("fetch_value"); auto* tensors = g_fetch_variable->GetMutable>(); + int col = ctx.template Attr("col"); + if (tensors->size() < static_cast(col + 1)) { + tensors->resize(col + 1); + } + PADDLE_ENFORCE_GT(tensors->size(), static_cast(col)); + (*tensors)[col].Resize(input->dims()); (*tensors)[col].mutable_data(platform::CPUPlace()); (*tensors)[col].CopyFrom(*input, platform::CPUPlace()); + // TODO(qijun): need to handle LodTensor later } }; -- GitLab