diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc index 259205f7c1ddaeca0b1537956966a9ef5b83c44b..0710eb577979ba436c60757d0c9a46617a1e3b58 100644 --- a/paddle/framework/executor_test.cc +++ b/paddle/framework/executor_test.cc @@ -116,12 +116,8 @@ class ExecutorTesterRandom : public ::testing::Test { {{"dims", std::vector{input_dim, embed_dim}}}, init_root_block); AddOp("gaussian_random", {}, {{"Out", {"w2"}}}, {{"dims", std::vector{embed_dim, input_dim}}}, init_root_block); - AddOp("fetch", {{"Input", {"w1"}}}, {}, - {{"dims", std::vector{input_dim, embed_dim}}, {"col", 0}}, - init_root_block); - AddOp("fetch", {{"Input", {"w2"}}}, {}, - {{"dims", std::vector{embed_dim, input_dim}}, {"col", 1}}, - init_root_block); + AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, init_root_block); + AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, init_root_block); // flush init_program.Proto(); @@ -163,12 +159,8 @@ class ExecutorTesterRandom : public ::testing::Test { {"Grad", {"w2@GRAD"}}}, {{"ParamOut", {"w2"}}}, {}, root_block); - AddOp("fetch", {{"Input", {"w1"}}}, {}, - {{"dims", std::vector{input_dim, embed_dim}}, {"col", 0}}, - root_block); - AddOp("fetch", {{"Input", {"w2"}}}, {}, - {{"dims", std::vector{embed_dim, input_dim}}, {"col", 1}}, - root_block); + AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, root_block); + AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, root_block); // flush program.Proto(); @@ -197,10 +189,8 @@ class ExecutorTesterFeedAndFetch : public ::testing::Test { root_block); AddOp("feed", {}, {{"Out", {"b"}}}, {{"dims", dim}, {"col", 1}}, root_block); - AddOp("fetch", {{"Input", {"a"}}}, {}, {{"dims", dim}, {"col", 0}}, - root_block); - AddOp("fetch", {{"Input", {"b"}}}, {}, {{"dims", dim}, {"col", 1}}, - root_block); + AddOp("fetch", {{"Input", {"a"}}}, {}, {{"col", 0}}, root_block); + AddOp("fetch", {{"Input", {"b"}}}, {}, {{"col", 1}}, root_block); // flush program.Proto(); diff --git a/paddle/operators/feed_op.cc b/paddle/operators/feed_op.cc index 1d65c2bb4693c8c4a131c073637bc7f5b860ab64..fa325bb28299afe24a67772473529fb76b9c73e1 100644 --- a/paddle/operators/feed_op.cc +++ b/paddle/operators/feed_op.cc @@ -24,15 +24,6 @@ class FeedOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null."); - int col = ctx->Attrs().Get("col"); - framework::Variable* g_feed_variable = - framework::GetGlobalScope()->FindVar("feed_value"); - - const auto& tensors = - g_feed_variable->Get>(); - - PADDLE_ENFORCE_GT(tensors.size(), static_cast(col)); - auto& shape = ctx->Attrs().Get>("dims"); std::vector shape_int64(shape.size(), 0); std::transform(shape.begin(), shape.end(), shape_int64.begin(), @@ -43,7 +34,7 @@ class FeedOp : public framework::OperatorWithKernel { framework::DataType IndicateDataType( const framework::ExecutionContext& ctx) const override { - return static_cast(Attr("data_type")); + return static_cast(Attr("dataType")); } }; @@ -51,7 +42,7 @@ class FeedOpMaker : public framework::OpProtoAndCheckerMaker { public: FeedOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddAttr("data_type", "output data type") + AddAttr("dataType", "output data type") .SetDefault(framework::DataType::FP32); AddAttr("col", "The col in global feed variable").SetDefault(0); AddAttr>("dims", "The dimension of feed tensor."); diff --git a/paddle/operators/feed_op.h b/paddle/operators/feed_op.h index 96e3bf52bd12faaa5a3fef125499e17afbb89741..47344e309ce381e25a51b7162d8b5e28ccec09cf 100644 --- a/paddle/operators/feed_op.h +++ b/paddle/operators/feed_op.h @@ -27,9 +27,10 @@ class FeedKernel : public framework::OpKernel { out->mutable_data(ctx.GetPlace()); framework::Variable* g_feed_variable = framework::GetGlobalScope()->FindVar("feed_value"); - int col = ctx.template Attr("col"); const auto& tensors = g_feed_variable->Get>(); + int col = ctx.template Attr("col"); + PADDLE_ENFORCE_GT(tensors.size(), static_cast(col)); out->CopyFrom(tensors[col], ctx.GetPlace()); } }; diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc index 77e3450a73fb43997aa36967977b66b942f82bcd..90737c8c550ca18f03c6a9ad0d9323d0b4d0b96d 100644 --- a/paddle/operators/fetch_op.cc +++ b/paddle/operators/fetch_op.cc @@ -24,26 +24,11 @@ class FetchOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null."); - int col = ctx->Attrs().Get("col"); - framework::Variable* g_fetch_variable = - framework::GetGlobalScope()->FindVar("fetch_value"); - - auto* tensors = - g_fetch_variable->GetMutable>(); - if (tensors->size() < static_cast(col + 1)) { - tensors->resize(col + 1); - } - - auto input_dim = ctx->GetInputDim("Input"); - PADDLE_ENFORCE_GT(tensors->size(), col); - (*tensors)[col].Resize(input_dim); - - // TODO(qijun): need to handle LodTensor later } framework::DataType IndicateDataType( const framework::ExecutionContext& ctx) const override { - return static_cast(Attr("data_type")); + return static_cast(Attr("dataType")); } }; @@ -51,10 +36,9 @@ class FetchOpMaker : public framework::OpProtoAndCheckerMaker { public: FetchOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddAttr("data_type", "output data type") + AddAttr("dataType", "output data type") .SetDefault(framework::DataType::FP32); AddAttr("col", "The col in global fetch variable").SetDefault(0); - AddAttr>("dims", "The dimension of fetch tensor."); AddInput("Input", "The output of fetch op."); AddComment(R"DOC(Fetch data to global fetch variable)DOC"); } diff --git a/paddle/operators/fetch_op.h b/paddle/operators/fetch_op.h index fd98552055f95bc21f47a4808735a6fb4ecced3d..6fee8b05892687d06eb1d3f7c92f0df92a8a63e6 100644 --- a/paddle/operators/fetch_op.h +++ b/paddle/operators/fetch_op.h @@ -24,13 +24,19 @@ class FetchKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor* input = ctx.Input("Input"); - int col = ctx.template Attr("col"); framework::Variable* g_fetch_variable = framework::GetGlobalScope()->FindVar("fetch_value"); auto* tensors = g_fetch_variable->GetMutable>(); + int col = ctx.template Attr("col"); + if (tensors->size() < static_cast(col + 1)) { + tensors->resize(col + 1); + } + PADDLE_ENFORCE_GT(tensors->size(), static_cast(col)); + (*tensors)[col].Resize(input->dims()); (*tensors)[col].mutable_data(platform::CPUPlace()); (*tensors)[col].CopyFrom(*input, platform::CPUPlace()); + // TODO(qijun): need to handle LodTensor later } };