diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 1bf80b3e58df591376b79253c3eaf69355b3397f..148610aa2c7821542f9aa19690c3dc857ec9ab2e 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -42,5 +42,12 @@ add_custom_command(TARGET framework_py_proto POST_BUILD cc_library(backward SRCS backward.cc DEPS net_op) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context) +cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward ${GLOB_OP_LIB}) +if(WITH_GPU) + nv_test(executor_test SRCS executor_test.cc DEPS executor) +else() + cc_test(executor_test SRCS executor_test.cc DEPS executor) +endif() + cc_library(tensor_array SRCS tensor_array.cc DEPS lod_tensor) cc_test(tensor_array_test SRCS tensor_array_test.cc DEPS tensor_array place) diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..886e9ab33e56c5952fd8c9d2042ba46f6422e821 --- /dev/null +++ b/paddle/framework/executor.cc @@ -0,0 +1,165 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/executor.h" + +#include +#include +#include +#include +#include + +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/scope.h" + +#include + +namespace paddle { +namespace framework { + +const std::string kFeedOpType = "feed"; +const std::string kFetchOpType = "fetch"; + +Executor::Executor(const std::vector& places) { + PADDLE_ENFORCE_GT(places.size(), 0); + device_contexts_.resize(places.size()); + for (size_t i = 0; i < places.size(); i++) { + if (platform::is_cpu_place(places[i])) { + device_contexts_[i] = new platform::CPUDeviceContext( + boost::get(places[i])); + } else if (platform::is_gpu_place(places[i])) { +#ifdef PADDLE_WITH_CUDA + device_contexts_[i] = new platform::CUDADeviceContext( + boost::get(places[i])); +#else + PADDLE_THROW( + "'GPUPlace' is not supported, Please re-compile with WITH_GPU " + "option"); +#endif + } + } +} + +Executor::~Executor() { + for (auto& device_context : device_contexts_) { + delete device_context; + } +} + +void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) { + // TODO(tonyyang-svail): + // - only runs on the first device (i.e. no interdevice communication) + // - will change to use multiple blocks for RNN op and Cond Op + PADDLE_ENFORCE_GT(pdesc.blocks_size(), block_id); + auto& block = pdesc.blocks(block_id); + auto& device = device_contexts_[0]; + + // Instantiate all the vars in the global scope + for (auto& var : block.vars()) { + scope->NewVar(var.name()); + } + + Scope& local_scope = scope->NewScope(); + + std::vector should_run = Prune(pdesc, block_id); + PADDLE_ENFORCE_EQ(should_run.size(), static_cast(block.ops_size())); + for (size_t i = 0; i < should_run.size(); ++i) { + if (should_run[i]) { + for (auto& var : block.ops(i).outputs()) { + for (auto& argu : var.arguments()) { + if (local_scope.FindVar(argu) == nullptr) { + local_scope.NewVar(argu); + } + } + } + auto op = paddle::framework::OpRegistry::CreateOp(block.ops(i)); + op->Run(local_scope, *device); + } + } + + // TODO(tonyyang-svail): + // - Destroy local_scope +} + +std::vector Prune(const ProgramDesc& pdesc, int block_id) { + // TODO(tonyyang-svail): + // - will change to use multiple blocks for RNN op and Cond Op + + auto& block = pdesc.blocks(block_id); + auto& ops = block.ops(); + + bool expect_feed = true; + for (auto& op_desc : ops) { + PADDLE_ENFORCE(op_desc.type() != kFeedOpType || expect_feed, + "All FeedOps are at the beginning of the ProgramDesc"); + expect_feed = (op_desc.type() == kFeedOpType); + } + + bool expect_fetch = true; + for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { + auto& op_desc = *op_iter; + PADDLE_ENFORCE(op_desc.type() != kFetchOpType || expect_fetch, + "All FetchOps must at the end of the ProgramDesc"); + expect_fetch = (op_desc.type() == kFetchOpType); + } + + std::set dependent_vars; + std::vector should_run; + for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { + auto& op_desc = *op_iter; + + bool found_dependent_vars = false; + for (auto& var : op_desc.outputs()) { + for (auto& argu : var.arguments()) { + if (dependent_vars.count(argu) != 0) { + found_dependent_vars = true; + } + } + } + + if (op_desc.type() == kFetchOpType || found_dependent_vars) { + // erase its output to the dependency graph + for (auto& var : op_desc.outputs()) { + for (auto& argu : var.arguments()) { + dependent_vars.erase(argu); + } + } + + // insert its input to the dependency graph + for (auto& var : op_desc.inputs()) { + for (auto& argu : var.arguments()) { + dependent_vars.insert(argu); + } + } + + should_run.push_back(true); + } else { + should_run.push_back(false); + } + } + + // TODO(tonyyang-svail): + // - check this after integration of Init + // PADDLE_ENFORCE(dependent_vars.empty()); + + // since we are traversing the ProgramDesc in reverse order + // we reverse the should_run vector + std::reverse(should_run.begin(), should_run.end()); + + return should_run; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/executor.h b/paddle/framework/executor.h new file mode 100644 index 0000000000000000000000000000000000000000..4e3bc2c0a59dfee5b9993037671f14a109dc8a74 --- /dev/null +++ b/paddle/framework/executor.h @@ -0,0 +1,55 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/framework.pb.h" +#include "paddle/framework/op_info.h" +#include "paddle/framework/scope.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace framework { + +class Executor { + public: + explicit Executor(const std::vector& places); + ~Executor(); + + /* @Brief + * Runtime evaluation of the given ProgramDesc under certain Scope + * + * @param + * ProgramDesc + * Scope + */ + void Run(const ProgramDesc&, Scope*, int); + + private: + std::vector device_contexts_; +}; + +/* @Brief + * Pruning the graph + * + * @param + * ProgramDesc + * + * @return + * vector Same size as ops. Indicates whether an op should be run. + */ +std::vector Prune(const ProgramDesc& pdesc, int block_id); + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..137e53d849542e48080228e0002931867c4d7fb2 --- /dev/null +++ b/paddle/framework/executor_test.cc @@ -0,0 +1,318 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/executor.h" + +#include +#include + +#include "gtest/gtest.h" +#include "paddle/framework/attribute.h" +#include "paddle/framework/backward.h" +#include "paddle/framework/block_desc.h" +#include "paddle/framework/op_desc.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +USE_OP(elementwise_add); +USE_OP(gaussian_random); +USE_OP(feed); +USE_OP(fetch); +USE_OP(mul); +USE_OP(sum); +USE_OP(squared_l2_distance); +USE_OP(fill_constant); +USE_OP(sgd); + +using namespace paddle::platform; +using namespace paddle::framework; + +void AddOp(const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, AttributeMap attrs, + paddle::framework::BlockDescBind* block) { + // insert output + for (auto kv : outputs) { + for (auto v : kv.second) { + auto var = block->NewVar(v); + var->SetDataType(paddle::framework::DataType::FP32); + } + } + + // insert op + auto op = block->AppendOp(); + op->SetType(type); + for (auto& kv : inputs) { + op->SetInput(kv.first, kv.second); + } + for (auto& kv : outputs) { + op->SetOutput(kv.first, kv.second); + } + op->SetAttrMap(attrs); +} + +// Tensors in feed value variable will only be in CPUPlace +// So we can memcpy the data from vector to feed_value +template +void SetFeedVariable(const std::vector>& inputs, + const std::vector>& dims) { + Variable* g_feed_value = GetGlobalScope().FindVar("feed_value"); + auto& feed_inputs = + *(g_feed_value->GetMutable>()); + size_t size = inputs.size(); + feed_inputs.resize(size); + for (size_t i = 0; i < size; i++) { + T* dst = feed_inputs[i].mutable_data(make_ddim(dims[i]), CPUPlace()); + memcpy(dst, inputs[i].data(), inputs[i].size() * sizeof(T)); + } +} + +// Tensors in fetch value variable will only be in CPUPlace +// So we can memcpy the data from fetch_value to vector +template +std::vector> GetFetchVariable() { + Variable* g_fetch_value = GetGlobalScope().FindVar("fetch_value"); + auto& fetch_outputs = + *(g_fetch_value->GetMutable>()); + + size_t size = fetch_outputs.size(); + std::vector> result; + result.reserve(size); + for (size_t i = 0; i < size; i++) { + std::vector tmp; + tmp.resize(fetch_outputs[i].numel()); + memcpy(tmp.data(), fetch_outputs[i].data(), + fetch_outputs[i].numel() * sizeof(T)); + result.push_back(tmp); + } + + return result; +} + +class ExecutorTesterRandom : public ::testing::Test { + public: + virtual void SetUp() override { + int input_dim = 3, batch_size = 2, embed_dim = 5; + + auto temp_init_root_block = init_pdesc_.add_blocks(); + temp_init_root_block->set_idx(0); + temp_init_root_block->set_parent_idx(-1); + paddle::framework::ProgramDescBind& init_program = + paddle::framework::ProgramDescBind::Instance(&init_pdesc_); + paddle::framework::BlockDescBind* init_root_block = init_program.Block(0); + + AddOp("gaussian_random", {}, {{"Out", {"w1"}}}, + {{"dims", std::vector{input_dim, embed_dim}}}, init_root_block); + AddOp("gaussian_random", {}, {{"Out", {"w2"}}}, + {{"dims", std::vector{embed_dim, input_dim}}}, init_root_block); + AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, init_root_block); + AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, init_root_block); + + // flush + init_program.Proto(); + + // run block + auto temp_root_block = pdesc_.add_blocks(); + temp_root_block->set_idx(0); + temp_root_block->set_parent_idx(-1); + paddle::framework::ProgramDescBind& program = + paddle::framework::ProgramDescBind::Instance(&pdesc_); + paddle::framework::BlockDescBind* root_block = program.Block(0); + + // feed data + inputs_.push_back({1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); + dims_.push_back({batch_size, input_dim}); + AddOp("feed", {}, {{"Out", {"a"}}}, + {{"dims", std::vector{batch_size, input_dim}}, {"col", 0}}, + root_block); + + // forward + AddOp("mul", {{"X", {"a"}}, {"Y", {"w1"}}}, {{"Out", {"b"}}}, {}, + root_block); + AddOp("mul", {{"X", {"b"}}, {"Y", {"w2"}}}, {{"Out", {"a_out"}}}, {}, + root_block); + AddOp("squared_l2_distance", {{"X", {"a"}}, {"Y", {"a_out"}}}, + {{"Out", {"l2_distance"}}, {"sub_result", {"l2_distance_sub"}}}, {}, + root_block); + + // backward + AddOp("fill_constant", {}, {{"Out", {"l2_distance@GRAD"}}}, + {{"shape", std::vector{batch_size, 1}}, {"value", float(1.0)}}, + root_block); + AppendBackward(program, {}); + + // update + AddOp("fill_constant", {}, {{"Out", {"learning_rate"}}}, + {{"shape", std::vector{1}}, {"value", float(0.001)}}, + root_block); + AddOp("sgd", {{"Param", {"w1"}}, + {"LearningRate", {"learning_rate"}}, + {"Grad", {"w1@GRAD"}}}, + {{"ParamOut", {"w1"}}}, {}, root_block); + AddOp("sgd", {{"Param", {"w2"}}, + {"LearningRate", {"learning_rate"}}, + {"Grad", {"w2@GRAD"}}}, + {{"ParamOut", {"w2"}}}, {}, root_block); + + AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, root_block); + AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, root_block); + AddOp("fetch", {{"Input", {"l2_distance"}}}, {}, {{"col", 0}}, root_block); + + // flush + program.Proto(); + } + + protected: + ProgramDesc init_pdesc_; + ProgramDesc pdesc_; + std::vector> inputs_; + std::vector> dims_; +}; + +class ExecutorTesterFeedAndFetch : public ::testing::Test { + public: + virtual void SetUp() override { + auto temp_root_block = pdesc_.add_blocks(); + temp_root_block->set_idx(0); + temp_root_block->set_parent_idx(-1); + + // wrap to BlockDescBind + paddle::framework::ProgramDescBind& program = + paddle::framework::ProgramDescBind::Instance(&pdesc_); + paddle::framework::BlockDescBind* root_block = program.Block(0); + + std::vector dim{6}; + + AddOp("feed", {}, {{"Out", {"a"}}}, {{"dims", dim}, {"col", 0}}, + root_block); + AddOp("feed", {}, {{"Out", {"b"}}}, {{"dims", dim}, {"col", 1}}, + root_block); + AddOp("fetch", {{"Input", {"a"}}}, {}, {{"col", 0}}, root_block); + AddOp("fetch", {{"Input", {"b"}}}, {}, {{"col", 1}}, root_block); + + // flush + program.Proto(); + + std::vector vec1 = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + std::vector vec2 = {4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + inputs_.push_back(vec1); + inputs_.push_back(vec2); + dims_.push_back({static_cast(vec1.size())}); + dims_.push_back({static_cast(vec2.size())}); + } + + protected: + ProgramDesc pdesc_; + std::vector> inputs_; + std::vector> dims_; +}; + +#ifndef PADDLE_WITH_CUDA +TEST_F(ExecutorTesterRandom, CPU) { + std::vector places; + CPUPlace cpu_place; + places.push_back(cpu_place); + + // We have a global Scope and BuddyAllocator, and we must ensure + // global BuddyAllocator is initialized before global Scope. Thus, + // global Scope will deconstruct before BuddyAllocator. Otherwise, + // "pointer being freed was not allocated" error will appear. + paddle::memory::Used(cpu_place); + + std::unique_ptr executor(new Executor(places)); + + executor->Run(init_pdesc_, &GetGlobalScope(), 0); + SetFeedVariable(inputs_, dims_); + executor->Run(pdesc_, &GetGlobalScope(), 0); + std::vector> result = GetFetchVariable(); +} + +TEST_F(ExecutorTesterFeedAndFetch, CPU) { + std::vector places; + CPUPlace cpu_place; + places.push_back(cpu_place); + + // We have a global Scope and BuddyAllocator, and we must ensure + // global BuddyAllocator is initialized before global Scope. Thus, + // global Scope will deconstruct before BuddyAllocator. Otherwise, + // "pointer being freed was not allocated" error will appear. + paddle::memory::Used(cpu_place); + + std::unique_ptr executor(new Executor(places)); + + for (int batch_id = 0; batch_id < 3; batch_id++) { + SetFeedVariable(inputs_, dims_); + executor->Run(pdesc_, &GetGlobalScope(), 0); + std::vector> result = GetFetchVariable(); + PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); + for (size_t i = 0; i < result.size(); ++i) { + PADDLE_ENFORCE_EQ(result[i].size(), inputs_[i].size()); + for (size_t j = 0; j < result[i].size(); ++j) { + PADDLE_ENFORCE_EQ(result[i][j], inputs_[i][j]); + } + } + } +} +#else +TEST_F(ExecutorTesterRandom, GPU) { + std::vector places; + GPUPlace gpu_place(0); + places.push_back(gpu_place); + + // We have a global Scope and BuddyAllocator, and we must ensure + // global BuddyAllocator is initialized before global Scope. Thus, + // global Scope will deconstruct before BuddyAllocator. Otherwise, + // "pointer being freed was not allocated" error will appear. + // If paddle is compiled with GPU, both CPU and GPU BuddyAllocator + // need to be used at first. + paddle::memory::Used(CPUPlace()); + paddle::memory::Used(gpu_place); + + std::unique_ptr executor(new Executor(places)); + + executor->Run(init_pdesc_, &GetGlobalScope(), 0); + for (int batch_id = 0; batch_id < 3; batch_id++) { + SetFeedVariable(inputs_, dims_); + executor->Run(pdesc_, &GetGlobalScope(), 0); + } +} + +TEST_F(ExecutorTesterFeedAndFetch, GPU) { + std::vector places; + GPUPlace gpu_place(0); + places.push_back(gpu_place); + // We have a global Scope and BuddyAllocator, and we must ensure + // global BuddyAllocator is initialized before global Scope. Thus, + // global Scope will deconstruct before BuddyAllocator. Otherwise, + // "pointer being freed was not allocated" error will appear. + // If paddle is compiled with GPU, both CPU and GPU BuddyAllocator + // need to be used at first. + paddle::memory::Used(CPUPlace()); + paddle::memory::Used(gpu_place); + + std::unique_ptr executor(new Executor(places)); + + for (int batch_id = 0; batch_id < 3; batch_id++) { + SetFeedVariable(inputs_, dims_); + executor->Run(pdesc_, &GetGlobalScope(), 0); + std::vector> result = GetFetchVariable(); + PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); + for (size_t i = 0; i < result.size(); ++i) { + PADDLE_ENFORCE_EQ(result[i].size(), inputs_[i].size()); + for (size_t j = 0; j < result[i].size(); ++j) { + PADDLE_ENFORCE_EQ(result[i][j], inputs_[i][j]); + } + } + } +} +#endif diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc index 080b4ac621c1b8c0d4b4e7b26f394cf2be263894..5821bac928ed898971d61a3e2a86f59155d76991 100644 --- a/paddle/framework/scope.cc +++ b/paddle/framework/scope.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/scope.h" + +#include // for unique_ptr +#include // for call_once #include "paddle/string/printf.h" namespace paddle { @@ -62,5 +65,17 @@ void Scope::DropKids() { kids_.clear(); } +std::once_flag feed_variable_flag; + +framework::Scope& GetGlobalScope() { + static std::unique_ptr g_scope{nullptr}; + std::call_once(feed_variable_flag, [&]() { + g_scope.reset(new framework::Scope()); + g_scope->NewVar("feed_value"); + g_scope->NewVar("fetch_value"); + }); + return *(g_scope.get()); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index 7047f0d55e9844aec19892631fe4b5b387bf89ca..a8cfb107c25ccd62039db7349cc1c1dbff772f39 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -73,5 +73,7 @@ class Scope { DISABLE_COPY_AND_ASSIGN(Scope); }; +framework::Scope& GetGlobalScope(); + } // namespace framework } // namespace paddle diff --git a/paddle/operators/feed_op.cc b/paddle/operators/feed_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..fa325bb28299afe24a67772473529fb76b9c73e1 --- /dev/null +++ b/paddle/operators/feed_op.cc @@ -0,0 +1,59 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/feed_op.h" + +namespace paddle { +namespace operators { + +class FeedOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null."); + auto& shape = ctx->Attrs().Get>("dims"); + std::vector shape_int64(shape.size(), 0); + std::transform(shape.begin(), shape.end(), shape_int64.begin(), + [](int a) { return static_cast(a); }); + ctx->SetOutputDim("Out", framework::make_ddim(shape_int64)); + // TODO(qijun): need to handle LodTensor later + } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return static_cast(Attr("dataType")); + } +}; + +class FeedOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FeedOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("dataType", "output data type") + .SetDefault(framework::DataType::FP32); + AddAttr("col", "The col in global feed variable").SetDefault(0); + AddAttr>("dims", "The dimension of feed tensor."); + AddOutput("Out", "The output of feed op."); + AddComment(R"DOC(Feed data from global feed variable)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(feed, ops::FeedOp, ops::FeedOpMaker); +REGISTER_OP_CPU_KERNEL(feed, ops::FeedKernel); diff --git a/paddle/operators/feed_op.cu b/paddle/operators/feed_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..7b6a2ac91e7b8d306804ca3d27b1eaf8177302f9 --- /dev/null +++ b/paddle/operators/feed_op.cu @@ -0,0 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/feed_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(feed, ops::FeedKernel); diff --git a/paddle/operators/feed_op.h b/paddle/operators/feed_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9d8158299fea97a464a7bb64321b1adf8b7b2fab --- /dev/null +++ b/paddle/operators/feed_op.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class FeedKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + framework::Tensor* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + framework::Variable* g_feed_variable = + framework::GetGlobalScope().FindVar("feed_value"); + const auto& tensors = + g_feed_variable->Get>(); + int col = ctx.template Attr("col"); + PADDLE_ENFORCE_GT(tensors.size(), static_cast(col)); + // TODO(qijun): + // check tensors[col].dims() with attribute, + // except the first dimenson. + out->CopyFrom(tensors[col], ctx.GetPlace()); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..90737c8c550ca18f03c6a9ad0d9323d0b4d0b96d --- /dev/null +++ b/paddle/operators/fetch_op.cc @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/fetch_op.h" + +namespace paddle { +namespace operators { + +class FetchOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null."); + } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return static_cast(Attr("dataType")); + } +}; + +class FetchOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FetchOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("dataType", "output data type") + .SetDefault(framework::DataType::FP32); + AddAttr("col", "The col in global fetch variable").SetDefault(0); + AddInput("Input", "The output of fetch op."); + AddComment(R"DOC(Fetch data to global fetch variable)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(fetch, ops::FetchOp, ops::FetchOpMaker); +REGISTER_OP_CPU_KERNEL(fetch, ops::FetchKernel); diff --git a/paddle/operators/fetch_op.cu b/paddle/operators/fetch_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..ca39d24c791ded71149777acc53e3b5cc240329f --- /dev/null +++ b/paddle/operators/fetch_op.cu @@ -0,0 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/fetch_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(fetch, ops::FetchKernel); diff --git a/paddle/operators/fetch_op.h b/paddle/operators/fetch_op.h new file mode 100644 index 0000000000000000000000000000000000000000..eb9c3a7b593b84da7c8dc12d71c4f748269c64e6 --- /dev/null +++ b/paddle/operators/fetch_op.h @@ -0,0 +1,44 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class FetchKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const framework::Tensor* input = ctx.Input("Input"); + framework::Variable* g_fetch_variable = + framework::GetGlobalScope().FindVar("fetch_value"); + auto* tensors = + g_fetch_variable->GetMutable>(); + int col = ctx.template Attr("col"); + if (tensors->size() < static_cast(col + 1)) { + tensors->resize(col + 1); + } + PADDLE_ENFORCE_GT(tensors->size(), static_cast(col)); + (*tensors)[col].Resize(input->dims()); + (*tensors)[col].mutable_data(platform::CPUPlace()); + (*tensors)[col].CopyFrom(*input, platform::CPUPlace()); + // TODO(qijun): need to handle LodTensor later + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/platform/gpu_info.cc b/paddle/platform/gpu_info.cc index 70ad611d5dd61937e6bf7d980e34b5c9023977b2..0cab5ffc5609bbd6fd08c74329d8370fb95f8102 100644 --- a/paddle/platform/gpu_info.cc +++ b/paddle/platform/gpu_info.cc @@ -43,6 +43,8 @@ int GetCurrentDeviceId() { } void SetDeviceId(int id) { + // TODO(qijun): find a better way to cache the cuda device count + PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); PADDLE_ENFORCE(cudaSetDevice(id), "cudaSetDevice failed in paddle::platform::SetDeviceId"); }