From 4df6cf4d16ad271101bd37de7f84fb054f1d788a Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 16 Oct 2017 09:59:58 -0700 Subject: [PATCH] Rewrite feed/fetch op (#4815) * Feed/Fetch op just plain operator, not a OpWithKernel * Do not register OpInfoMaker since Feed/Fetch will never be configured by users * Feed/Fetch op has empty gradient * Feed/Fetch op do not hard code `feed_variable`, `fetch_variable` as its input and output, make it as a plain Operator input/output --- paddle/framework/executor_test.cc | 55 ++++++++++------- paddle/framework/feed_fetch_type.h | 24 ++++++++ paddle/framework/grad_op_desc_maker.h | 8 +++ paddle/framework/op_desc.cc | 9 ++- paddle/framework/tensor.h | 16 +++++ paddle/operators/feed_op.cc | 86 +++++++++++++-------------- paddle/operators/feed_op.cu | 18 ------ paddle/operators/feed_op.h | 42 ------------- paddle/operators/fetch_op.cc | 78 ++++++++++++++---------- paddle/operators/fetch_op.cu | 18 ------ paddle/operators/fetch_op.h | 45 -------------- 11 files changed, 173 insertions(+), 226 deletions(-) create mode 100644 paddle/framework/feed_fetch_type.h delete mode 100644 paddle/operators/feed_op.cu delete mode 100644 paddle/operators/feed_op.h delete mode 100644 paddle/operators/fetch_op.cu delete mode 100644 paddle/operators/fetch_op.h diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc index fcd2e47cff..e08d31e361 100644 --- a/paddle/framework/executor_test.cc +++ b/paddle/framework/executor_test.cc @@ -28,8 +28,8 @@ limitations under the License. */ USE_OP(elementwise_add); USE_OP(gaussian_random); -USE_OP(feed); -USE_OP(fetch); +USE_NO_KERNEL_OP(feed); +USE_NO_KERNEL_OP(fetch); USE_OP(mul); USE_OP(sum); USE_OP(squared_l2_distance); @@ -37,6 +37,9 @@ USE_OP(fill_constant); USE_OP(mean); USE_OP(sgd); +constexpr auto kFeedValueName = "feed_value"; +constexpr auto kFetchValueName = "fetch_value"; + using namespace paddle::platform; using namespace paddle::framework; @@ -77,9 +80,9 @@ void AddOp(const std::string& type, const VariableNameMap& inputs, template void SetFeedVariable(const std::vector>& inputs, const std::vector>& dims) { - Variable* g_feed_value = GetGlobalScope().FindVar("feed_value"); + Variable* g_feed_value = GetGlobalScope().FindVar(kFeedValueName); auto& feed_inputs = - *(g_feed_value->GetMutable>()); + *(g_feed_value->GetMutable>()); size_t size = inputs.size(); feed_inputs.resize(size); for (size_t i = 0; i < size; i++) { @@ -92,9 +95,9 @@ void SetFeedVariable(const std::vector>& inputs, // So we can memcpy the data from fetch_value to vector template std::vector> GetFetchVariable() { - Variable* g_fetch_value = GetGlobalScope().FindVar("fetch_value"); + Variable* g_fetch_value = GetGlobalScope().FindVar(kFetchValueName); auto& fetch_outputs = - *(g_fetch_value->GetMutable>()); + *(g_fetch_value->GetMutable>()); size_t size = fetch_outputs.size(); std::vector> result; @@ -126,8 +129,10 @@ class ExecutorTesterRandom : public ::testing::Test { {{"dims", std::vector{input_dim, embed_dim}}}, init_root_block); AddOp("gaussian_random", {}, {{"Out", {"w2"}}}, {{"dims", std::vector{embed_dim, input_dim}}}, init_root_block); - AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, init_root_block); - AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, init_root_block); + AddOp("fetch", {{"Input", {"w1"}}}, {{"Out", {kFetchValueName}}}, + {{"col", 0}}, init_root_block); + AddOp("fetch", {{"Input", {"w2"}}}, {{"Out", {kFetchValueName}}}, + {{"col", 1}}, init_root_block); // flush init_program.Proto(); @@ -143,7 +148,7 @@ class ExecutorTesterRandom : public ::testing::Test { // feed data inputs_.push_back({1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); dims_.push_back({batch_size, input_dim}); - AddOp("feed", {}, {{"Out", {"a"}}}, + AddOp("feed", {{"Input", {kFeedValueName}}}, {{"Out", {"a"}}}, {{"dims", std::vector{batch_size, input_dim}}, {"col", 0}}, root_block); @@ -175,9 +180,12 @@ class ExecutorTesterRandom : public ::testing::Test { {"Grad", {"w2@GRAD"}}}, {{"ParamOut", {"w2"}}}, {}, root_block); - AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, root_block); - AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, root_block); - AddOp("fetch", {{"Input", {"l2_distance"}}}, {}, {{"col", 0}}, root_block); + AddOp("fetch", {{"Input", {"w1"}}}, {{"Out", {kFetchValueName}}}, + {{"col", 0}}, root_block); + AddOp("fetch", {{"Input", {"w2"}}}, {{"Out", {kFetchValueName}}}, + {{"col", 1}}, root_block); + AddOp("fetch", {{"Input", {"l2_distance"}}}, {{"Out", {kFetchValueName}}}, + {{"col", 0}}, root_block); // flush program.Proto(); @@ -204,12 +212,14 @@ class ExecutorTesterFeedAndFetch : public ::testing::Test { std::vector dim{6}; - AddOp("feed", {}, {{"Out", {"a"}}}, {{"dims", dim}, {"col", 0}}, - root_block); - AddOp("feed", {}, {{"Out", {"b"}}}, {{"dims", dim}, {"col", 1}}, - root_block); - AddOp("fetch", {{"Input", {"a"}}}, {}, {{"col", 0}}, root_block); - AddOp("fetch", {{"Input", {"b"}}}, {}, {{"col", 1}}, root_block); + AddOp("feed", {{"Input", {kFeedValueName}}}, {{"Out", {"a"}}}, + {{"dims", dim}, {"col", 0}}, root_block); + AddOp("feed", {{"Input", {kFeedValueName}}}, {{"Out", {"b"}}}, + {{"dims", dim}, {"col", 1}}, root_block); + AddOp("fetch", {{"Input", {"a"}}}, {{"Out", {kFetchValueName}}}, + {{"col", 0}}, root_block); + AddOp("fetch", {{"Input", {"b"}}}, {{"Out", {kFetchValueName}}}, + {{"col", 1}}, root_block); // flush program.Proto(); @@ -241,7 +251,6 @@ TEST_F(ExecutorTesterRandom, CPU) { paddle::memory::Used(cpu_place); std::unique_ptr executor(new Executor(places)); - executor->Run(init_pdesc_, &GetGlobalScope(), 0); SetFeedVariable(inputs_, dims_); executor->Run(pdesc_, &GetGlobalScope(), 0); @@ -251,7 +260,7 @@ TEST_F(ExecutorTesterRandom, CPU) { TEST_F(ExecutorTesterFeedAndFetch, CPU) { std::vector places; CPUPlace cpu_place; - places.push_back(cpu_place); + places.emplace_back(cpu_place); // We have a global Scope and BuddyAllocator, and we must ensure // global BuddyAllocator is initialized before global Scope. Thus, @@ -265,11 +274,11 @@ TEST_F(ExecutorTesterFeedAndFetch, CPU) { SetFeedVariable(inputs_, dims_); executor->Run(pdesc_, &GetGlobalScope(), 0); std::vector> result = GetFetchVariable(); - PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); + ASSERT_EQ(result.size(), inputs_.size()); for (size_t i = 0; i < result.size(); ++i) { - PADDLE_ENFORCE_EQ(result[i].size(), inputs_[i].size()); + ASSERT_EQ(result[i].size(), inputs_[i].size()); for (size_t j = 0; j < result[i].size(); ++j) { - PADDLE_ENFORCE_EQ(result[i][j], inputs_[i][j]); + ASSERT_EQ(result[i][j], inputs_[i][j]); } } } diff --git a/paddle/framework/feed_fetch_type.h b/paddle/framework/feed_fetch_type.h new file mode 100644 index 0000000000..bc4ae440fc --- /dev/null +++ b/paddle/framework/feed_fetch_type.h @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include "paddle/framework/lod_tensor.h" + +namespace paddle { +namespace framework { +using FeedFetchType = LoDTensor; +using FeedFetchList = std::vector; +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/grad_op_desc_maker.h b/paddle/framework/grad_op_desc_maker.h index 1219e04875..94944c79b6 100644 --- a/paddle/framework/grad_op_desc_maker.h +++ b/paddle/framework/grad_op_desc_maker.h @@ -149,5 +149,13 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker { } }; +class EmptyGradOpMaker : public GradOpDescMakerBase { + public: + using GradOpDescMakerBase::GradOpDescMakerBase; + std::vector> operator()() const override { + return {}; + } +}; + } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index ef207dc54e..7f7cebb026 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -220,9 +220,12 @@ static InferShapeFuncMap &InferShapeFuncs() { void OpDescBind::CheckAttrs() { PADDLE_ENFORCE(!Type().empty(), "CheckAttr() can not be called before type is setted."); - const auto *checker = OpInfoMap::Instance().Get(Type()).Checker(); - PADDLE_ENFORCE_NOT_NULL(checker, "Operator \"%s\" has no registered checker.", - Type()); + auto *checker = OpInfoMap::Instance().Get(Type()).Checker(); + if (checker == nullptr) { + // checker is not configured. That operator could be generated by Paddle, + // not by users. + return; + } checker->Check(attrs_); } diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 3304d857ae..bc430852de 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -100,6 +100,22 @@ class Tensor { inline void CopyFrom(const Tensor& src, const platform::Place& dst_place, const platform::DeviceContext& ctx); + // FIXME(yuyang18): CopyFrom should without template T, use the replace + // `CopyFrom` with `CopyFromTensor` + inline void CopyFromTensor(const Tensor& src, + const platform::Place& dst_place, + const platform::DeviceContext& ctx) { + // NOLINTNEXTLINES_8 cpplint.py will recognize below lines as functions. + // That is a bug of cpplint.py. Just ignore lint these lines. + if (src.type() == std::type_index(typeid(double))) { + CopyFrom(src, dst_place, ctx); + } else if (src.type() == std::type_index(typeid(float))) { + CopyFrom(src, dst_place, ctx); + } else if (src.type() == std::type_index(typeid(int))) { + CopyFrom(src, dst_place, ctx); + } + } + /** * @brief Copy the content of an external vector to a tensor. * diff --git a/paddle/operators/feed_op.cc b/paddle/operators/feed_op.cc index fa325bb282..d742bbe51b 100644 --- a/paddle/operators/feed_op.cc +++ b/paddle/operators/feed_op.cc @@ -1,59 +1,57 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ -#include "paddle/operators/feed_op.h" +#include "paddle/framework/feed_fetch_type.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" namespace paddle { namespace operators { - -class FeedOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null."); - auto& shape = ctx->Attrs().Get>("dims"); - std::vector shape_int64(shape.size(), 0); - std::transform(shape.begin(), shape.end(), shape_int64.begin(), - [](int a) { return static_cast(a); }); - ctx->SetOutputDim("Out", framework::make_ddim(shape_int64)); - // TODO(qijun): need to handle LodTensor later - } - - framework::DataType IndicateDataType( - const framework::ExecutionContext& ctx) const override { - return static_cast(Attr("dataType")); - } -}; - -class FeedOpMaker : public framework::OpProtoAndCheckerMaker { +class FeedOp : public framework::OperatorBase { public: - FeedOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddAttr("dataType", "output data type") - .SetDefault(framework::DataType::FP32); - AddAttr("col", "The col in global feed variable").SetDefault(0); - AddAttr>("dims", "The dimension of feed tensor."); - AddOutput("Out", "The output of feed op."); - AddComment(R"DOC(Feed data from global feed variable)DOC"); + FeedOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto feed_var_name = Input("Input"); + auto *feed_var = scope.FindVar(feed_var_name); + PADDLE_ENFORCE(feed_var != nullptr, + "Cannot find feed_var in scope, feed_var_name is %s", + feed_var_name); + + auto out_name = this->Output("Out"); + auto *out_var = scope.FindVar(out_name); + PADDLE_ENFORCE(out_var != nullptr, + "Cannot find out_var in scope, out_var_name is %s", + out_name); + + auto col = Attr("col"); + + auto &feed_list = feed_var->Get(); + auto &feed_item = feed_list.at(static_cast(col)); + auto *out_item = out_var->GetMutable(); + out_item->CopyFromTensor(feed_item, dev_ctx.GetPlace(), dev_ctx); + out_item->set_lod(feed_item.lod()); } }; } // namespace operators } // namespace paddle -namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(feed, ops::FeedOp, ops::FeedOpMaker); -REGISTER_OP_CPU_KERNEL(feed, ops::FeedKernel); +// We do not need to register OpInfoMaker, +// since feed operator will not be used by end users directly +REGISTER_OPERATOR(feed, paddle::operators::FeedOp, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/operators/feed_op.cu b/paddle/operators/feed_op.cu deleted file mode 100644 index 7b6a2ac91e..0000000000 --- a/paddle/operators/feed_op.cu +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/operators/feed_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(feed, ops::FeedKernel); diff --git a/paddle/operators/feed_op.h b/paddle/operators/feed_op.h deleted file mode 100644 index e756cd1842..0000000000 --- a/paddle/operators/feed_op.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/framework/eigen.h" -#include "paddle/framework/op_registry.h" - -namespace paddle { -namespace operators { - -template -class FeedKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - framework::Tensor* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - framework::Variable* g_feed_variable = - framework::GetGlobalScope().FindVar("feed_value"); - const auto& tensors = - g_feed_variable->Get>(); - int col = ctx.template Attr("col"); - PADDLE_ENFORCE_GT(tensors.size(), static_cast(col)); - // TODO(qijun): - // check tensors[col].dims() with attribute, - // except the first dimenson. - out->CopyFrom(tensors[col], ctx.GetPlace(), ctx.device_context()); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc index 90737c8c55..55d6ac0939 100644 --- a/paddle/operators/fetch_op.cc +++ b/paddle/operators/fetch_op.cc @@ -1,52 +1,64 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ -#include "paddle/operators/fetch_op.h" +#include "paddle/framework/feed_fetch_type.h" +#include "paddle/framework/op_registry.h" namespace paddle { namespace operators { -class FetchOp : public framework::OperatorWithKernel { +class FetchOp : public framework::OperatorBase { public: - using framework::OperatorWithKernel::OperatorWithKernel; + FetchOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null."); - } + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto fetch_var_name = Input("Input"); + auto *fetch_var = scope.FindVar(fetch_var_name); + PADDLE_ENFORCE(fetch_var != nullptr, + "Cannot find fetch variable in scope, fetch_var_name is %s", + fetch_var_name); - framework::DataType IndicateDataType( - const framework::ExecutionContext& ctx) const override { - return static_cast(Attr("dataType")); - } -}; + auto out_name = this->Output("Out"); + auto *out_var = scope.FindVar(out_name); + PADDLE_ENFORCE(out_var != nullptr, + "Cannot find out_var in scope, out_var_name is %s", + out_name); -class FetchOpMaker : public framework::OpProtoAndCheckerMaker { - public: - FetchOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddAttr("dataType", "output data type") - .SetDefault(framework::DataType::FP32); - AddAttr("col", "The col in global fetch variable").SetDefault(0); - AddInput("Input", "The output of fetch op."); - AddComment(R"DOC(Fetch data to global fetch variable)DOC"); + auto col = static_cast(Attr("col")); + + auto *fetch_list = out_var->GetMutable(); + auto &src_item = fetch_var->Get(); + + if (col >= fetch_list->size()) { + fetch_list->resize(col + 1); + } + auto &dst_item = fetch_list->at(col); + + // FIXME(yuyang18): Should we assume the fetch operator always generate + // CPU outputs? + dst_item.CopyFromTensor(src_item, platform::CPUPlace(), dev_ctx); } }; } // namespace operators } // namespace paddle -namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(fetch, ops::FetchOp, ops::FetchOpMaker); -REGISTER_OP_CPU_KERNEL(fetch, ops::FetchKernel); +// We do not need to register OpInfoMaker, +// since fetch operator will not be used by end users directly +REGISTER_OPERATOR(fetch, paddle::operators::FetchOp, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/operators/fetch_op.cu b/paddle/operators/fetch_op.cu deleted file mode 100644 index ca39d24c79..0000000000 --- a/paddle/operators/fetch_op.cu +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/operators/fetch_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(fetch, ops::FetchKernel); diff --git a/paddle/operators/fetch_op.h b/paddle/operators/fetch_op.h deleted file mode 100644 index b2a6e95875..0000000000 --- a/paddle/operators/fetch_op.h +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/framework/eigen.h" -#include "paddle/framework/op_registry.h" - -namespace paddle { -namespace operators { - -template -class FetchKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const framework::Tensor* input = ctx.Input("Input"); - framework::Variable* g_fetch_variable = - framework::GetGlobalScope().FindVar("fetch_value"); - auto* tensors = - g_fetch_variable->GetMutable>(); - int col = ctx.template Attr("col"); - if (tensors->size() < static_cast(col + 1)) { - tensors->resize(col + 1); - } - PADDLE_ENFORCE_GT(tensors->size(), static_cast(col)); - (*tensors)[col].Resize(input->dims()); - (*tensors)[col].mutable_data(platform::CPUPlace()); - (*tensors)[col].CopyFrom(*input, platform::CPUPlace(), - ctx.device_context()); - // TODO(qijun): need to handle LodTensor later - } -}; - -} // namespace operators -} // namespace paddle -- GitLab