From 45c4dcaabb4cbf140384dcffe3392d2e10b2a6d7 Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 5 Oct 2017 15:54:44 -0700 Subject: [PATCH] add fetch operator --- paddle/framework/executor.cc | 18 ++++---- paddle/framework/executor_test.cc | 67 ++++++++++++++++++++++++++++++ paddle/framework/scope.cc | 5 ++- paddle/operators/activation_op.cu | 18 ++++---- paddle/operators/feed_op.cc | 6 +-- paddle/operators/fetch_op.cc | 68 +++++++++++++++++++++++++++++++ paddle/operators/fetch_op.cu | 18 ++++++++ paddle/operators/fetch_op.h | 40 ++++++++++++++++++ 8 files changed, 218 insertions(+), 22 deletions(-) create mode 100644 paddle/operators/fetch_op.cc create mode 100644 paddle/operators/fetch_op.cu create mode 100644 paddle/operators/fetch_op.h diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index aafef12554f..51ddb7e58e2 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -75,15 +75,15 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) { device_context->Wait(); } // // print tensor value - for (auto& var : block.vars()) { - std::cout << var.name() << std::endl; - auto v = scope->FindVar(var.name()); - const LoDTensor& t = v->Get(); - for (int i = 0; i < t.numel(); ++i) { - std::cout << t.data()[i] << " "; - } - std::cout << std::endl; - } + // for (auto& var : block.vars()) { + // std::cout << var.name() << std::endl; + // auto v = scope->FindVar(var.name()); + // const LoDTensor& t = v->Get(); + // for (int i = 0; i < t.numel(); ++i) { + // std::cout << t.data()[i] << " "; + // } + // std::cout << std::endl; + // } } } // namespace framework diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc index 0856d1f32e9..980f5f579c8 100644 --- a/paddle/framework/executor_test.cc +++ b/paddle/framework/executor_test.cc @@ -25,6 +25,7 @@ limitations under the License. */ USE_OP(elementwise_add); USE_OP(gaussian_random); USE_OP(feed); +USE_OP(fetch); using std::string; using namespace paddle::platform; @@ -94,6 +95,41 @@ void add_feed_op(string var_name, int index, proto_block* block) { Out->add_arguments(var_name); } +void add_fetch_op(string var_name, int index, proto_block* block) { + std::vector dim{3}; + + // insert variable + auto a = block->add_vars(); + a->set_name(var_name); + auto a_lt = a->mutable_lod_tensor(); + a_lt->set_data_type(paddle::framework::DataType::FP32); + for (int i : dim) { + a_lt->add_dims(i); + } + + // insert operation + auto op = block->add_ops(); + op->set_type("fetch"); + + // set dims attr + auto dims = op->add_attrs(); + dims->set_name("dims"); + dims->set_type(paddle::framework::AttrType::INTS); + for (int i : dim) { + dims->add_ints(i); + } + + // set col attr + auto col = op->add_attrs(); + col->set_name("col"); + col->set_type(paddle::framework::AttrType::INT); + col->set_i(index); + + auto Out = op->add_inputs(); + Out->set_parameter("Input"); + Out->add_arguments(var_name); +} + std::once_flag set_variable_flag; template @@ -119,6 +155,27 @@ void set_feed_variable(const std::vector>& inputs) { } } +template +std::vector> get_fetch_variable() { + typedef std::vector FetchOutputs; + Variable* g_fetch_value = GetScope()->FindVar("fetch_value"); + FetchOutputs& fetch_outputs = *(g_fetch_value->GetMutable()); + auto size = fetch_outputs.size(); + + std::vector> result; + result.reserve(size); + + for (size_t i = 0; i < size; i++) { + std::vector tmp; + tmp.reserve(fetch_outputs[i].numel()); + memcpy(tmp.data(), fetch_outputs[i].data(), + fetch_outputs[i].numel() * sizeof(T)); + result.push_back(tmp); + } + + return result; +} + class ExecutorTesterRandom : public ::testing::Test { public: virtual void SetUp() override { @@ -181,6 +238,8 @@ class ExecutorTesterFeed : public ::testing::Test { Out->set_parameter("Out"); Out->add_arguments("c"); + add_fetch_op("c", 0, root_block); + std::vector vec1 = {1.0, 2.0, 3.0}; std::vector vec2 = {4.0, 5.0, 6.0}; inputs_.push_back(vec1); @@ -213,8 +272,16 @@ TEST_F(ExecutorTesterFeed, CPU) { // 3 mini-batch for (int i = 0; i < 3; i++) { // need to set feed variable before Executor::Run + std::cout << "start mini-batch " << i << std::endl; set_feed_variable(inputs_); executor->Run(pdesc_, GetScope()); + std::vector> result = get_fetch_variable(); + for (auto& vec : result) { + for (auto& num : vec) { + std::cout << num << " "; + } + std::cout << std::endl; + } } delete executor; diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc index b04120abf27..2c416570cf6 100644 --- a/paddle/framework/scope.cc +++ b/paddle/framework/scope.cc @@ -74,7 +74,10 @@ std::unique_ptr make_unique(Args&&... args) { framework::Scope* GetScope() { static std::unique_ptr g_scope = make_unique(); - std::call_once(feed_variable_flag, [&]() { g_scope->NewVar("feed_value"); }); + std::call_once(feed_variable_flag, [&]() { + g_scope->NewVar("feed_value"); + g_scope->NewVar("fetch_value"); + }); return g_scope.get(); } diff --git a/paddle/operators/activation_op.cu b/paddle/operators/activation_op.cu index 44a6aaf9cb6..93e9f1c694b 100644 --- a/paddle/operators/activation_op.cu +++ b/paddle/operators/activation_op.cu @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/operators/activation_op.h" diff --git a/paddle/operators/feed_op.cc b/paddle/operators/feed_op.cc index 5ae882bc8ac..a61855cb994 100644 --- a/paddle/operators/feed_op.cc +++ b/paddle/operators/feed_op.cc @@ -49,9 +49,9 @@ class FeedOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("data_type", "output data type") .SetDefault(framework::DataType::FP32); AddAttr("col", "The col in global feed variable").SetDefault(0); - AddAttr>("dims", "The dimension of random tensor."); - AddOutput("Out", "The output of dropout op."); - AddComment(R"DOC(Feed data to global feed variable)DOC"); + AddAttr>("dims", "The dimension of feed tensor."); + AddOutput("Out", "The output of feed op."); + AddComment(R"DOC(Feed data from global feed variable)DOC"); } }; diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc new file mode 100644 index 00000000000..68e8d26dbe1 --- /dev/null +++ b/paddle/operators/fetch_op.cc @@ -0,0 +1,68 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/fetch_op.h" + +namespace paddle { +namespace operators { + +class FetchOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase* ctx) const override { + typedef std::vector FetchOutputs; + PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null."); + int col = ctx->Attrs().Get("col"); + framework::Variable* g_fetch_variable = + framework::GetScope()->FindVar("fetch_value"); + + FetchOutputs* tensors = g_fetch_variable->GetMutable(); + if (tensors->size() < col) { + tensors->resize(col); + } + + auto input_dim = ctx->GetInputDim("Input"); + framework::Tensor tmp; + tmp.Resize(input_dim); + (*tensors)[col].Resize(input_dim); + // need to handle LodTensor later + } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return static_cast(Attr("data_type")); + } +}; + +class FetchOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FetchOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("data_type", "output data type") + .SetDefault(framework::DataType::FP32); + AddAttr("col", "The col in global fetch variable").SetDefault(0); + AddAttr>("dims", "The dimension of fetch tensor."); + AddInput("Input", "The output of fetch op."); + AddComment(R"DOC(Fetch data to global fetch variable)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(fetch, ops::FetchOp, ops::FetchOpMaker); +REGISTER_OP_CPU_KERNEL(fetch, ops::FetchKernel); diff --git a/paddle/operators/fetch_op.cu b/paddle/operators/fetch_op.cu new file mode 100644 index 00000000000..2e24d3a8adc --- /dev/null +++ b/paddle/operators/fetch_op.cu @@ -0,0 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/feed_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(fetch, ops::FetchKernel); diff --git a/paddle/operators/fetch_op.h b/paddle/operators/fetch_op.h new file mode 100644 index 00000000000..95e7986a22b --- /dev/null +++ b/paddle/operators/fetch_op.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class FetchKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + typedef std::vector FetchOutputs; + Tensor* input = ctx.Output("Input"); + int col = ctx.template Attr("col"); + framework::Variable* g_fetch_variable = + framework::GetScope()->FindVar("fetch_value"); + FetchOutputs tensors = g_fetch_variable->Get(); + tensors[col].mutable_data(platform::CPUPlace()); + tensors[col].CopyFrom(*input, platform::CPUPlace()); + } +}; + +} // namespace operators +} // namespace paddle -- GitLab