diff --git a/doc/design/images/feed_forward.png b/doc/design/images/feed_forward.png new file mode 100644 index 0000000000000000000000000000000000000000..d312371a04c26aa6cd196e0bd1f51becb425180b Binary files /dev/null and b/doc/design/images/feed_forward.png differ diff --git a/doc/design/images/feed_forward_regularized.png b/doc/design/images/feed_forward_regularized.png new file mode 100644 index 0000000000000000000000000000000000000000..677e99bfd9f8e72ed9fe4b27127af2ced202f447 Binary files /dev/null and b/doc/design/images/feed_forward_regularized.png differ diff --git a/doc/design/images/l1_regularization.png b/doc/design/images/l1_regularization.png new file mode 100644 index 0000000000000000000000000000000000000000..e1b9c7a44f94dc027598a98da93ddb8133190972 Binary files /dev/null and b/doc/design/images/l1_regularization.png differ diff --git a/doc/design/images/l2_regularization.png b/doc/design/images/l2_regularization.png new file mode 100644 index 0000000000000000000000000000000000000000..d5c2fcbc2ccae75ad083162e5a2dceb0210be298 Binary files /dev/null and b/doc/design/images/l2_regularization.png differ diff --git a/doc/design/images/loss_equation.png b/doc/design/images/loss_equation.png new file mode 100644 index 0000000000000000000000000000000000000000..14212ec8d36c803de96bde8a9a4b5591bd20434e Binary files /dev/null and b/doc/design/images/loss_equation.png differ diff --git a/doc/design/prune.md b/doc/design/prune.md new file mode 100644 index 0000000000000000000000000000000000000000..4a5cf10c79a554779137f0cce5494fdd96ef6b7a --- /dev/null +++ b/doc/design/prune.md @@ -0,0 +1,63 @@ +# Prune + +## Motivation + +We want to support running inference, training and checkpointing in one `ProgramDesc`. We implement +`void Prune(const ProgramDesc* input, ProgramDesc* output)` function, which takes a `ProgramDesc` +and generate a pruned `ProgramDesc`. + +## Challenge + +Pruning need to support both variables and operators being evaluation targets. Consider the following +different situations. + +```python +# Case 1: run foward pass. +cost_np = session.run(target=cost) +# Case 2: run backward passing. +opts_np, _ = session.run(target=[cost, opt]) +# Case 3: run checkpointing +_ = session.run(target=checkpoint) +``` + +## Solution + +To support evaluation of operators, we add `is_target` field in the `OpDesc`. + +```c++ +message OpDesc { + required string type = 3; + repeated Var inputs = 1; + repeated Var outputs = 2; + repeated Attr attrs = 4; + optional bool is_target = 5 [ default = false ]; +}; +``` + +To support evaluation of variables, we add [fetch_op](https://github.com/PaddlePaddle/Paddle/pull/4599). +For each variable in the `target`, we insert a `fetch_op` into the `ProgramDesc` with `variable` being +`fetch_op`'s input. Then we also set `fetch_op` is a target. + +### Algorithm + +If an operator needs to be run, it must fall into one of the following cases: + +1. It is the target. +2. It is depended by some other ops, meaning its output is some other op's input. + +The first case can be checked by `op_desc.is_traget()` . The second case can be implement as + +```c++ +bool HasDependentVar(const OpDesc& op_desc, const std::set& dependent_vars) { + for (auto& var : op_desc.outputs()) { + for (auto& argu : var.arguments()) { + if (dependent_vars.count(argu) != 0) { + return true; + } + } + } + return false; +} +``` + +Then the whole algorithm can be implemented as the following [code](https://github.com/tonyyang-svail/Paddle/blob/prune_impl/paddle/framework/prune.cc). diff --git a/doc/design/refactorization.md b/doc/design/refactorization.md index ec51aa1a0ec667175ff7215dcd359023e296769f..f93d6155e1764386b01d2f0df3f141ab75cd55d4 100644 --- a/doc/design/refactorization.md +++ b/doc/design/refactorization.md @@ -177,9 +177,6 @@ REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, grad_op_class) REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) ``` -### USE Macros -Make sure the registration process is executed and linked. - --- # Registration Process 1. Write an Op class and its gradient Op class, if required. @@ -188,8 +185,6 @@ Make sure the registration process is executed and linked. 1. Call maker class to complete `proto` and `checker` 2. Using the completed `proto` and `checker`, it will add a new key-value pair to the `OpInfoMap` -4. Invoke the `USE` macro in which the Op is used to make sure that it is linked. - --- # Backward Module (1/2) ### Create Backward Operator diff --git a/doc/design/regularization.md b/doc/design/regularization.md new file mode 100644 index 0000000000000000000000000000000000000000..703a9fbdd4392aa7f44733cce2da19caa1b51e4a --- /dev/null +++ b/doc/design/regularization.md @@ -0,0 +1,103 @@ +# Regularization in PaddlePaddle + +## Introduction to Regularization +A central problem in machine learning is how to design an algorithm that will perform well not just on the training data, but also on new data. Many strategies are used by machine learning practitioners to reduce the test error, possibly at the expense of increased training error. These strategies are collectively known as **regularization**. + +### Parameter Norm Penalties +Most common regularization approaches in deep learning are based on limiting the capacity of the models by adding a parameter norm penalty to the objective function `J`. This is given as follows: + +
+ +The parameter `alpha` is a hyperparameter that weights the relative contribution of the norm penalty term, `omega`, relative to the standard objective function `J`. + +The most commonly used norm penalties are the L2 norm penalty and the L1 norm penalty. These are given as follows: + +##### L2 Regularization: +
+ +##### L1 Regularization +
+ +A much more detailed mathematical background of reguilarization can be found [here](http://www.deeplearningbook.org/contents/regularization.html). + + +## How to do Regularization in PaddlePaddle + +On surveying existing frameworks like Tensorflow, PyTorch, Caffe, etc, it can be seen that there are 2 common approaches of doing regularization: + +1. Making regularization a part of the optimizer using an attribute like `weight_decay` that is used to control the scale of the L2 Penalty. This approach is used in PyTorch as follows: + ```python + opt = torch.optim.SGD(params, lr=0.2, weight_decay=0.2) + ``` + At every optimization step, this code will add the gradient of the L2 Norm of the params to the gradient of the params with respect to the loss function. This can seen in the following code snippet: + ```python + if weight_decay != 0: + d_p.add_(weight_decay, p.data) + ``` + This is a very restyrictive way of doing regularization and does not give the users enough flexibility. + + **Advantages**: + - It is easy to implement for us. + - Faster execution of backward. However, it can be done manually by advanced users too. + + **Disadvantages**: + - Not flexible for other regularizations such as L1/L0 regularization. + - Does not allow for different regularization coefficient for different parameters. For example, in most models, ony the weight matrices are regularized and the bias vectors are unregularized. + - Tightly coupled optimizer and regularization implementation. + + +2. Adding regularization ops to the graph through Python API. This approach is used by Tensorflow and Caffe. Using this approach, we manually add regularization ops to the graph and then add the regularization loss to the final loss function before sending them to the optimizer. + + **Advantages**: + - Allows for greater flexibility to the users of Paddle. Using this approach, the users can put different regularization to different parameters and also choose parameters that are not a part of regularization. + - Makes it easy for the users to customize and extend the framework. + + **Disadvantages**: + - Implementation requires comprehensive design and time. + +## Proposal for Regularization in PaddlePaddle + +### Low-Level implementation + +In the new design, we propose to create new operations for regularization. For now, we can add 2 ops thgat correspond to the most frequently used regularizations: +- L2_regularization_op +- L1_regularization_op + +These ops can be like any other ops with their own CPU/GPU implementations either using Eigen or separate Cpu and GPU kernels. As the initial implementation, we can implement their kernels using Eigen following the abstraction pattern implemented for [Activation Ops](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/accuracy_op.h). This abstraction pattern can make it very easy to implement new regularization schemes. other than L1 and L2 norm penalties. + +The idea of building ops for regularization is in sync with the refactored Paddle philosophy of using operators to represent any computation unit. The way these ops will be added to the computation graph, will be decided by the [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) in Python API. + +### Computation Graph + +Below is an example of a really simple feed forward neural network. + +
+ +The Python API will modify this computation graph to add regularization operators. The modified computation graph will look as follows: + +
+    +### Python API implementation for Regularization + +Using the low level ops, `L2_regularization_op` and `L1_regularization_op`, any user can add regularization to their computation graphs. However, this will require a lot of lines of code and we should design Python APIs that support regularization. An example of such an API can be seen in [Keras](https://keras.io/regularizers/). As per the PaddlePaddle [Python API design](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md), the layer functions are responsible for creating operators, operator parameters and variables. Since regularization is a property of parameters, it makes sense to create these in the layer functions. + +#### Creation of Regularization ops +There are two possibilities for creating the regularization ops: +1. We create these ops immediately while building the computation graph. +2. We add these ops in a lazy manner, just before the backward, similar to the way the optimization ops are added. + +The proposal is to add these ops in a lazy manner just before the backward pass. + +#### Storage of Regularization attributes + +Since we want to create the regularization ops in a lazy manner, the regularization attributes (type of regularization and weight of regularization penalty) can be stored as attributes of the [`Parameter`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/framework/framework.py#L421) class. This is because regularization is a property of the parameters and storing regularization properties with Parameters also allows for shared parameters. + +#### High-level API + +In PaddlePaddle Python API, users will primarily rely on [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) to create neural network layers. Hence, we lso need to provide regularization functionality in layer functions. The design of these APIs can be postponed for later right now. A good reference for these APIs can be found in [Keras](https://keras.io/regularizers/) and also by looking at Tensorflow in [`tf.contrib.layers`](https://www.tensorflow.org/api_guides/python/contrib.layers). + + + + + + diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 4bc3fdeeea461ea2a1f82caa00d6c0c11a2775d0..6e32a1c99ba6b8bfac12c227e0cf66e0a9f16557 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -20,13 +20,14 @@ proto_library(framework_proto SRCS framework.proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info) +cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) -cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope proto_desc) +cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope proto_desc glog) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) -cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator) +cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) py_proto_compile(framework_py_proto SRCS framework.proto) @@ -44,6 +45,9 @@ cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_co cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward) +cc_library(prune SRCS prune.cc DEPS framework_proto) +cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) + cc_library(tensor_array SRCS tensor_array.cc DEPS lod_tensor) cc_test(tensor_array_test SRCS tensor_array_test.cc DEPS tensor_array place) diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc index d6a2975aaa419406aef7b228e78381dbce78890d..29fe352ca450740e55ee87b63392e3aabac8aa40 100644 --- a/paddle/framework/attribute.cc +++ b/paddle/framework/attribute.cc @@ -19,19 +19,7 @@ limitations under the License. */ namespace paddle { namespace framework { -static ProgramDesc* g_program_desc = nullptr; - -ProgramDesc& GetProgramDesc() { - if (g_program_desc == nullptr) { - g_program_desc = new ProgramDesc(); - auto root_block = g_program_desc->mutable_blocks()->Add(); - root_block->set_idx(0); - root_block->set_parent_idx(-1); - } - return *g_program_desc; -} - -Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { +Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* program) { switch (attr_desc.type()) { case framework::AttrType::BOOLEAN: { return attr_desc.b(); @@ -74,7 +62,9 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { return val; } case framework::AttrType::BLOCK: { - return GetProgramDesc().mutable_blocks(attr_desc.block_idx()); + PADDLE_ENFORCE(program != nullptr, + "Need to specify ProgramDesc when get a block attr"); + return program->mutable_blocks(attr_desc.block_idx()); } } PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !"); diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 8a7a949346e73ca9d2a813ca2888755a23bb7d7b..9744662b8f7229b0b17e910ae5cd997fa7d31e06 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -26,16 +26,13 @@ limitations under the License. */ namespace paddle { namespace framework { - -ProgramDesc& GetProgramDesc(); - template inline AttrType AttrTypeID() { Attribute tmp = T(); return static_cast(tmp.which() - 1); } -Attribute GetAttrValue(const OpDesc::Attr& attr_desc); +Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* desc); class AttrReader { public: diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index ac80879c541306dba987040b71b0dadd81af2a3c..fb552fe3448b3f17e97e1262b5c9a0842f68f8b9 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -309,8 +309,7 @@ static void CreateGradVarInBlock( } std::vector> MakeOpGrad( - const std::unique_ptr& op_desc, - std::unordered_set* no_grad_vars, + const OpDescBind* op_desc, std::unordered_set* no_grad_vars, std::unordered_map* grad_to_var) { std::vector> grad_op_descs; // All input gradients of forwarding operator do not need to calculate. @@ -357,7 +356,7 @@ std::vector> MakeBlockBackward( std::unordered_set* no_grad_vars, std::unordered_map* grad_to_var) { BlockDescBind* cur_block = program_desc.Block(block_idx); - std::deque>& op_descs = cur_block->ops_; + std::vector op_descs = cur_block->AllOps(); std::unordered_map> dup_out_ops; size_t grad_desc_idx = 0; std::vector> backward_descs; @@ -375,7 +374,7 @@ std::vector> MakeBlockBackward( program_desc, step_block_idx, no_grad_vars, grad_to_var); BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block); for (auto& ptr : backward_block_op_descs) { - backward_block->ops_.push_back(std::move(ptr)); + backward_block->AppendAllocatedOp(std::move(ptr)); } op_grads[0]->SetBlockAttr("step_block", *backward_block); } @@ -432,7 +431,6 @@ ParamGradInfoMap AppendBackward( const int root_block_idx = 0; auto root_block = program_desc.Block(root_block_idx); - auto& all_ops = root_block->ops_; // insert fill one op for target // TODO(qiao) add some check to the target. @@ -447,8 +445,8 @@ ParamGradInfoMap AppendBackward( {{"shape", target_shape}, {"value", static_cast(1.0)}, {"data_type", framework::DataType::FP32}})); - all_ops.push_back(std::move(fill_one_op)); - size_t forward_op_num = all_ops.size(); + root_block->AppendAllocatedOp(std::move(fill_one_op)); + size_t forward_op_num = root_block->OpSize(); size_t forward_block_num = program_desc.Size(); // Insert backward operators @@ -457,7 +455,7 @@ ParamGradInfoMap AppendBackward( &no_grad_var_names, &grad_to_var); for (auto& ptr : backward_op_descs) { - all_ops.push_back(std::move(ptr)); + root_block->AppendAllocatedOp(std::move(ptr)); } // Create Variable diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 0c35a157bcfeb93193536fa026244e39e948cc2f..10301f7e39423c8ff0eba33277edecab14c119bf 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -495,19 +495,8 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { EXPECT_EQ(bwd_net->ops_[2]->Outputs(all).size(), 0UL); } -// =================================== // - -f::ProgramDesc *GetNewProgramDesc() { - auto *program_desc = new f::ProgramDesc(); - auto *root_block = program_desc->add_blocks(); - root_block->set_idx(0); - root_block->set_parent_idx(-1); - return program_desc; -} - TEST(Backward, simple_single_op) { - f::ProgramDesc *program_desc = GetNewProgramDesc(); - f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::ProgramDescBind program; f::BlockDescBind *block = program.Block(0); f::OpDescBind *op = block->AppendOp(); @@ -543,8 +532,7 @@ TEST(Backward, simple_single_op) { } TEST(Backward, default_attribute) { - f::ProgramDesc *program_desc = GetNewProgramDesc(); - f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::ProgramDescBind program; f::BlockDescBind *block = program.Block(0); f::OpDescBind *op = block->AppendOp(); op->SetType("mul"); @@ -570,8 +558,7 @@ TEST(Backward, default_attribute) { } TEST(Backward, simple_mult_op) { - f::ProgramDesc *program_desc = GetNewProgramDesc(); - f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::ProgramDescBind program; f::BlockDescBind *block = program.Block(0); f::OpDescBind *op1 = block->AppendOp(); op1->SetType("rowwise_add"); @@ -654,8 +641,7 @@ TEST(Backward, simple_mult_op) { } TEST(Backward, intermedia_var_no_grad) { - f::ProgramDesc *program_desc = GetNewProgramDesc(); - f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::ProgramDescBind program; f::BlockDescBind *block = program.Block(0); f::OpDescBind *op1 = block->AppendOp(); op1->SetType("rowwise_add"); @@ -725,8 +711,7 @@ TEST(Backward, intermedia_var_no_grad) { } TEST(Backward, var_no_grad) { - f::ProgramDesc *program_desc = GetNewProgramDesc(); - f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::ProgramDescBind program; f::BlockDescBind *block = program.Block(0); f::OpDescBind *op1 = block->AppendOp(); op1->SetType("mult_in_out"); @@ -802,8 +787,7 @@ TEST(Backward, var_no_grad) { } TEST(Backward, shared_var) { - f::ProgramDesc *program_desc = GetNewProgramDesc(); - f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::ProgramDescBind program; f::BlockDescBind *block = program.Block(0); f::OpDescBind *op1 = block->AppendOp(); op1->SetType("rowwise_add"); @@ -893,8 +877,7 @@ TEST(Backward, shared_var) { } TEST(Backward, half_backward) { - f::ProgramDesc *program_desc = GetNewProgramDesc(); - f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::ProgramDescBind program; f::BlockDescBind *block = program.Block(0); auto *op1 = block->AppendOp(); op1->SetType("minus"); diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc index ba970254e590308851e31542e2037e9c5590c7fa..21d4fdaf0680036a484ee4258e47c6c8854967c3 100644 --- a/paddle/framework/block_desc.cc +++ b/paddle/framework/block_desc.cc @@ -19,11 +19,11 @@ namespace paddle { namespace framework { VarDescBind *BlockDescBind::Var(const std::string &name) { - need_update_ = true; auto it = vars_.find(name); if (it != vars_.end()) { return it->second.get(); } + need_update_ = true; auto *var = new VarDescBind(name); vars_[name].reset(var); return var; @@ -55,6 +55,11 @@ OpDescBind *BlockDescBind::AppendOp() { return ops_.back().get(); } +void BlockDescBind::AppendAllocatedOp(std::unique_ptr &&op_desc) { + need_update_ = true; + ops_.emplace_back(std::move(op_desc)); +} + OpDescBind *BlockDescBind::PrependOp() { need_update_ = true; ops_.emplace_front(new OpDescBind()); @@ -70,6 +75,10 @@ std::vector BlockDescBind::AllOps() const { } void BlockDescBind::Flush() { + for (auto &op_desc : ops_) { + op_desc->Flush(); + } + if (need_update_) { auto &op_field = *this->desc_->mutable_ops(); this->ClearPBOps(); @@ -98,6 +107,19 @@ BlockDesc *BlockDescBind::Proto() { Flush(); return desc_; } +BlockDescBind::BlockDescBind(const BlockDescBind &other, BlockDesc *desc, + ProgramDescBind *prog) + : prog_(prog), desc_(desc) { + need_update_ = true; + for (auto &op : other.ops_) { + ops_.emplace_back(new OpDescBind(*op)); + } + + for (auto &it : other.vars_) { + auto *var = new VarDescBind(*it.second); + vars_[it.first].reset(var); + } +} void BlockDescBind::ClearPBOps() { auto ops = this->desc_->mutable_ops(); diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index dd7b1228befd6b4f19c9aa3dd493b0d8005359fc..7d1d33f6860aa90518abb379a5e9964d6029c6fa 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -16,8 +16,10 @@ limitations under the License. */ #include #include +#include #include #include + #include "paddle/framework/op_desc.h" #include "paddle/framework/var_desc.h" #include "paddle/platform/macros.h" @@ -36,6 +38,9 @@ class BlockDescBind { BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) : prog_(prog), desc_(desc), need_update_(false) {} + BlockDescBind(const BlockDescBind &other, BlockDesc *desc, + ProgramDescBind *prog); + ~BlockDescBind() { this->ClearPBVars(); this->ClearPBOps(); @@ -51,16 +56,30 @@ class BlockDescBind { bool HasVar(const std::string &var_name) const; + std::set LocalVarNames() const { + std::set var_names; + for (auto &var : vars_) { + var_names.insert(var.first); + } + return var_names; + } + std::vector AllVars() const; BlockDescBind *ParentBlock() const; OpDescBind *AppendOp(); + void AppendAllocatedOp(std::unique_ptr &&op_desc); + OpDescBind *PrependOp(); std::vector AllOps() const; + size_t OpSize() const { return ops_.size(); } + + OpDescBind *Op(int idx) { return ops_.at(idx).get(); } + void Flush(); BlockDesc *Proto(); @@ -69,9 +88,7 @@ class BlockDescBind { void ClearPBOps(); void ClearPBVars(); - // FIXME(yuyang18): backward will access private data of BlockDesc. - // Mark it public temporary. We can fix it later. - public: + private: ProgramDescBind *prog_; // not_own BlockDesc *desc_; // not_own bool need_update_; diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index b3b85b586579c1e8129a40e44141ef4011e356e1..00caa6e1d53a4bcfae56c4459413bc1622321960 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -75,7 +75,8 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) { } for (auto& op_desc : block.ops()) { - auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp( + op_desc, const_cast(&pdesc)); op->Run(local_scope, *device); } diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index 65760b07ada7a63a568cb8296eef35a8aa18d9ff..008fb45fb7bcb2f9b3d02376b15d2f88515f86d9 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -55,6 +55,7 @@ message OpDesc { repeated Var inputs = 1; repeated Var outputs = 2; repeated Attr attrs = 4; + optional bool is_target = 5 [ default = false ]; }; // OpProto describes a C++ framework::OperatorBase derived class. diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index c64ee94405b791393345d1187ca39e17ad72e2c5..af5e9f8abc1d0c7ef96a18d570511e9dcce69fae 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -74,12 +74,12 @@ class LoDTensor : public Tensor { LoD lod() const { return lod_; } /* - * Get a element from LoD. + * Get the start offset and end offset of an element from LoD. */ - size_t lod_element(size_t level, size_t elem) const { + std::pair lod_element(size_t level, size_t elem) const { PADDLE_ENFORCE_LT(level, NumLevels()); PADDLE_ENFORCE_LT(elem, NumElements(level)); - return (lod_)[level][elem]; + return std::make_pair((lod_)[level][elem], (lod_)[level][elem + 1]); } /* diff --git a/paddle/framework/lod_tensor_test.cu b/paddle/framework/lod_tensor_test.cu index 647d07536dd070bc37137fc01f683ec07ba7d6f4..25041024cb51d4d2f360edb06571a0a99dcf29b1 100644 --- a/paddle/framework/lod_tensor_test.cu +++ b/paddle/framework/lod_tensor_test.cu @@ -36,8 +36,8 @@ TEST(LoDTensor, LoDInGPU) { lod_tensor.mutable_data(place); lod_tensor.set_lod(src_lod); - CHECK_EQ(lod_tensor.lod_element(0, 2), 4UL); - CHECK_EQ(lod_tensor.lod_element(0, 4), 8UL); + CHECK_EQ(lod_tensor.lod_element(0, 2).first, 4UL); + CHECK_EQ(lod_tensor.lod_element(0, 4).first, 8UL); auto lod = lod_tensor.lod(); diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index 504afbd5dbacf7185f92e0000d19666230e2fb42..c2f2438edf6daadf26cbc6db37f6668739ab1726 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -43,12 +43,13 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap( return ret_val; } -std::unique_ptr OpRegistry::CreateOp(const OpDesc& op_desc) { +std::unique_ptr OpRegistry::CreateOp(const OpDesc& op_desc, + ProgramDesc* program) { VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); AttributeMap attrs; for (auto& attr : op_desc.attrs()) { - attrs[attr.name()] = GetAttrValue(attr); + attrs[attr.name()] = GetAttrValue(attr, program); } return CreateOp(op_desc.type(), inputs, outputs, attrs); diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index dfca46b789f3ee129d7a6e19398c808448e8afa0..ed85c386ec2632604bf5faf0ff9b1a087eb9c276 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -20,6 +20,8 @@ limitations under the License. */ #include #include #include + +#include "glog/logging.h" // For VLOG() #include "paddle/framework/attribute.h" #include "paddle/framework/details/op_registry.h" #include "paddle/framework/framework.pb.h" @@ -74,7 +76,8 @@ class OpRegistry { const VariableNameMap& outputs, AttributeMap attrs); - static std::unique_ptr CreateOp(const OpDesc& op_desc); + static std::unique_ptr CreateOp(const OpDesc& op_desc, + ProgramDesc* program); static std::unique_ptr CreateOp(const OpDescBind& op_desc); }; diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index b860fe6cac773d1e85adecc43f5dfec42b6c7661..6289125d7c782e542e5c55e1d4403836351b7e05 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -74,7 +74,7 @@ TEST(OpRegistry, CreateOp) { attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_f(scale); - auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); paddle::framework::Scope scope; paddle::platform::CPUDeviceContext dev_ctx; op->Run(scope, dev_ctx); @@ -95,7 +95,7 @@ TEST(OpRegistry, IllegalAttr) { bool caught = false; try { - paddle::framework::OpRegistry::CreateOp(op_desc); + paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); } catch (paddle::platform::EnforceNotMet err) { caught = true; std::string msg = "larger_than check fail"; @@ -115,7 +115,7 @@ TEST(OpRegistry, DefaultValue) { ASSERT_TRUE(op_desc.IsInitialized()); - auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); paddle::framework::Scope scope; paddle::platform::CPUDeviceContext dev_ctx; op->Run(scope, dev_ctx); @@ -131,7 +131,7 @@ TEST(OpRegistry, CustomChecker) { // attr 'test_attr' is not set bool caught = false; try { - paddle::framework::OpRegistry::CreateOp(op_desc); + paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); } catch (paddle::platform::EnforceNotMet err) { caught = true; std::string msg = "Attribute 'test_attr' is required!"; @@ -149,7 +149,7 @@ TEST(OpRegistry, CustomChecker) { attr->set_i(3); caught = false; try { - paddle::framework::OpRegistry::CreateOp(op_desc); + paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); } catch (paddle::platform::EnforceNotMet err) { caught = true; std::string msg = "'test_attr' must be even!"; @@ -166,7 +166,7 @@ TEST(OpRegistry, CustomChecker) { attr->set_name("test_attr"); attr->set_type(paddle::framework::AttrType::INT); attr->set_i(4); - auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); paddle::platform::CPUDeviceContext dev_ctx; paddle::framework::Scope scope; op->Run(scope, dev_ctx); diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index cf15f9933ab3bc881add3d45b7ca17194a70e0f1..12cd307297d010201a47e048089ed7be0db52647 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -20,12 +20,13 @@ limitations under the License. */ #include #include -#include "op_info.h" +#include "glog/logging.h" // For VLOG #include "paddle/framework/attribute.h" #include "paddle/framework/block_desc.h" #include "paddle/framework/data_type.h" #include "paddle/framework/framework.pb.h" #include "paddle/framework/lod_tensor.h" +#include "paddle/framework/op_info.h" #include "paddle/framework/scope.h" #include "paddle/framework/shape_inference.h" #include "paddle/framework/tensor.h" @@ -573,6 +574,7 @@ class OperatorWithKernel : public OperatorBase { void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const final { + VLOG(3) << "Running operator " << this->Type(); RuntimeInferShapeContext infer_shape_ctx(*this, scope); this->InferShape(&infer_shape_ctx); diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index d7890ac8d0af2171271a0cfccd356563c7604e72..c358f1a2b6ee3174b8c336ba1d212be7c5aa15c6 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -83,7 +83,7 @@ TEST(OperatorBase, all) { paddle::platform::CPUDeviceContext device_context; paddle::framework::Scope scope; - auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); scope.Var("OUT1"); ASSERT_EQ(paddle::framework::op_run_num, 0); op->Run(scope, device_context); @@ -208,7 +208,7 @@ TEST(OpKernel, all) { paddle::platform::CPUDeviceContext cpu_device_context; paddle::framework::Scope scope; - auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0); op->Run(scope, cpu_device_context); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); @@ -244,7 +244,7 @@ TEST(OpKernel, multi_inputs) { scope.Var("y0")->GetMutable(); scope.Var("y1")->GetMutable(); - auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); op->Run(scope, cpu_device_context); } diff --git a/paddle/framework/program_desc.cc b/paddle/framework/program_desc.cc index fcb7292884275d972377983cb3ba1bcd86fb8348..e2349cefe09a6c1e0b11f77775426fe5c7594c9d 100644 --- a/paddle/framework/program_desc.cc +++ b/paddle/framework/program_desc.cc @@ -18,27 +18,10 @@ limitations under the License. */ namespace paddle { namespace framework { -using ProgDescMap = - std::unordered_map>; -static ProgDescMap *g_bind_map = nullptr; - -ProgramDescBind &ProgramDescBind::Instance(ProgramDesc *prog) { - if (g_bind_map == nullptr) { - g_bind_map = new ProgDescMap(); - } - auto &map = *g_bind_map; - auto &ptr = map[prog]; - - if (ptr == nullptr) { - ptr.reset(new ProgramDescBind(prog)); - } - return *ptr; -} - BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) { - auto *b = prog_->add_blocks(); + auto *b = prog_.add_blocks(); b->set_parent_idx(parent.ID()); - b->set_idx(prog_->blocks_size() - 1); + b->set_idx(prog_.blocks_size() - 1); blocks_.emplace_back(new BlockDescBind(this, b)); return blocks_.back().get(); } @@ -47,13 +30,22 @@ ProgramDesc *ProgramDescBind::Proto() { for (auto &block : blocks_) { block->Flush(); } - return prog_; + return &prog_; +} + +ProgramDescBind::ProgramDescBind() { + auto *block = prog_.mutable_blocks()->Add(); + block->set_idx(0); + block->set_parent_idx(-1); + blocks_.emplace_back(new BlockDescBind(this, block)); } -ProgramDescBind::ProgramDescBind(ProgramDesc *prog) { - prog_ = prog; - for (auto &block : *prog->mutable_blocks()) { - blocks_.emplace_back(new BlockDescBind(this, &block)); +ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) { + prog_ = o.prog_; + + for (int i = 0; i < prog_.blocks_size(); ++i) { + auto *block = prog_.mutable_blocks(i); + blocks_.emplace_back(new BlockDescBind(*o.blocks_[i], block, this)); } } } // namespace framework diff --git a/paddle/framework/program_desc.h b/paddle/framework/program_desc.h index f29b1c54e7160ac477229f64e5471939131a2d8f..20cc1a2325ffd6f8ef52879a749f106c268376d4 100644 --- a/paddle/framework/program_desc.h +++ b/paddle/framework/program_desc.h @@ -26,7 +26,9 @@ class BlockDescBind; class ProgramDescBind { public: - static ProgramDescBind &Instance(ProgramDesc *prog); + ProgramDescBind(); + + ProgramDescBind(const ProgramDescBind &o); BlockDescBind *AppendBlock(const BlockDescBind &parent); @@ -37,14 +39,9 @@ class ProgramDescBind { ProgramDesc *Proto(); private: - explicit ProgramDescBind(ProgramDesc *prog); - - // Not owned - ProgramDesc *prog_; + ProgramDesc prog_; std::vector> blocks_; - - DISABLE_COPY_AND_ASSIGN(ProgramDescBind); }; } // namespace framework } // namespace paddle diff --git a/paddle/framework/program_desc_test.cc b/paddle/framework/program_desc_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..32ee275429687ae079ae8e15b3133428c6ff01b9 --- /dev/null +++ b/paddle/framework/program_desc_test.cc @@ -0,0 +1,83 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/program_desc.h" +#include "gtest/gtest.h" +#include "paddle/framework/block_desc.h" + +namespace paddle { +namespace framework { +TEST(ProgramDesc, copy_ctor) { + ProgramDescBind program; + auto* global_block = program.Block(0); + auto* x = global_block->Var("X"); + x->SetType(VarDesc_VarType_LOD_TENSOR); + x->SetLoDLevel(0); + x->SetDataType(FP32); + x->SetShape({1000, 784}); + + auto* y = global_block->Var("Y"); + y->SetType(VarDesc_VarType_LOD_TENSOR); + y->SetLoDLevel(0); + y->SetDataType(FP32); + y->SetShape({784, 100}); + + auto* op = global_block->AppendOp(); + op->SetType("mul"); + op->SetInput("X", {x->Name()}); + op->SetInput("Y", {y->Name()}); + + auto* out = global_block->Var("Out"); + out->SetType(VarDesc_VarType_LOD_TENSOR); + op->SetOutput("Y", {out->Name()}); + + ProgramDescBind program_copy(program); + + auto* global_block_copy = program_copy.Block(0); + ASSERT_NE(global_block, global_block_copy); + + auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) { + ASSERT_TRUE(global_block_copy->HasVar(name)); + auto* copy = global_block_copy->Var(name); + ASSERT_NE(copy, var_before); + ASSERT_EQ(copy->Name(), var_before->Name()); + ASSERT_EQ(copy->GetType(), var_before->GetType()); + ASSERT_EQ(copy->Shape(), var_before->Shape()); + ASSERT_EQ(copy->Proto()->SerializeAsString(), + var_before->Proto()->SerializeAsString()); + }; + + ASSERT_EQ(global_block->LocalVarNames(), global_block_copy->LocalVarNames()); + ASSERT_EQ(3, global_block_copy->LocalVarNames().size()); + assert_same_var("X", x); + assert_same_var("Y", y); + assert_same_var("Out", out); + + for (size_t i = 0; i < global_block->OpSize(); ++i) { + auto op_origin = global_block->Op(i); + auto op_copy = global_block->Op(i); + + ASSERT_EQ(op_origin->Type(), op_copy->Type()); + ASSERT_EQ(op_origin->Inputs(), op_copy->Inputs()); + ASSERT_EQ(op_origin->Outputs(), op_copy->Outputs()); + + ASSERT_EQ(op_copy->Proto()->SerializeAsString(), + op_origin->Proto()->SerializeAsString()); + } + + // Not check block's protostr are same it because the order of vars could be + // different and it is correct. +} +} // namespace framework +} // namespace paddle \ No newline at end of file diff --git a/paddle/framework/prune.cc b/paddle/framework/prune.cc new file mode 100644 index 0000000000000000000000000000000000000000..95833692925af4477fe575d6bd908a2ce7653c1b --- /dev/null +++ b/paddle/framework/prune.cc @@ -0,0 +1,109 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/prune.h" + +#include +#include +#include +#include + +#include + +namespace paddle { +namespace framework { + +const std::string kFeedOpType = "feed"; +const std::string kFetchOpType = "fetch"; + +bool HasDependentVar(const OpDesc& op_desc, + const std::set& dependent_vars) { + for (auto& var : op_desc.outputs()) { + for (auto& argu : var.arguments()) { + if (dependent_vars.count(argu) != 0) { + return true; + } + } + } + return false; +} + +bool IsTarget(const OpDesc& op_desc) { + if (op_desc.has_is_target()) { + return op_desc.is_target(); + } + return false; +} + +void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) { + // TODO(tonyyang-svail): + // - will change to use multiple blocks for RNN op and Cond Op + + auto& block = input.blocks(block_id); + auto& ops = block.ops(); + + bool expect_feed = true; + for (auto& op_desc : ops) { + PADDLE_ENFORCE(op_desc.type() != kFeedOpType || expect_feed, + "All FeedOps are at the beginning of the ProgramDesc"); + expect_feed = (op_desc.type() == kFeedOpType); + } + + bool expect_fetch = true; + for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { + auto& op_desc = *op_iter; + PADDLE_ENFORCE(op_desc.type() != kFetchOpType || expect_fetch, + "All FetchOps must at the end of the ProgramDesc"); + expect_fetch = (op_desc.type() == kFetchOpType); + } + + std::set dependent_vars; + std::vector should_run; + for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { + auto& op_desc = *op_iter; + + if (IsTarget(op_desc) || HasDependentVar(op_desc, dependent_vars)) { + // insert its input to the dependency graph + for (auto& var : op_desc.inputs()) { + for (auto& argu : var.arguments()) { + dependent_vars.insert(argu); + } + } + + should_run.push_back(true); + } else { + should_run.push_back(false); + } + } + + // since we are traversing the ProgramDesc in reverse order + // we reverse the should_run vector + std::reverse(should_run.begin(), should_run.end()); + + output = input; + auto* op_field = output.mutable_blocks(block_id)->mutable_ops(); + op_field->Clear(); + for (size_t i = 0; i < should_run.size(); ++i) { + if (should_run[i]) { + *op_field->Add() = input.blocks(block_id).ops(i); + } + } +} + +void Prune(const ProgramDesc& input, ProgramDesc& output) { + prune_impl(input, output, 0); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/prune.h b/paddle/framework/prune.h new file mode 100644 index 0000000000000000000000000000000000000000..9414ac64f9491c07aabb216a4c81dfe6e78e8043 --- /dev/null +++ b/paddle/framework/prune.h @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/framework.pb.h" +#include "paddle/platform/enforce.h" + +namespace paddle { +namespace framework { + +void Prune(const ProgramDesc& input, ProgramDesc& output); + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/prune_test.cc b/paddle/framework/prune_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3ab4b43d9256af5880083b00df446c451e3f598b --- /dev/null +++ b/paddle/framework/prune_test.cc @@ -0,0 +1,138 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/prune.h" + +#include "paddle/framework/attribute.h" +#include "paddle/framework/operator.h" +#include "paddle/operators/net_op.h" + +#include "paddle/framework/block_desc.h" +#include "paddle/framework/op_desc.h" +#include "paddle/framework/program_desc.h" + +#include + +namespace f = paddle::framework; +namespace ops = paddle::operators; + +void AddOp(const std::string &type, const f::VariableNameMap &inputs, + const f::VariableNameMap &outputs, f::AttributeMap attrs, + paddle::framework::BlockDescBind *block) { + // insert output + for (auto kv : outputs) { + for (auto v : kv.second) { + auto var = block->Var(v); + var->SetDataType(paddle::framework::DataType::FP32); + } + } + + // insert op + auto op = block->AppendOp(); + op->SetType(type); + for (auto &kv : inputs) { + op->SetInput(kv.first, kv.second); + } + for (auto &kv : outputs) { + op->SetOutput(kv.first, kv.second); + } + op->SetAttrMap(attrs); +} + +TEST(Prune, one_operator) { + f::ProgramDescBind program; + f::BlockDescBind *block = program.Block(0); + + AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, {}, block); + + f::ProgramDesc *pdesc = program.Proto(); + f::ProgramDesc pruned; + + Prune(*pdesc, pruned); + PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0); + + pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true); + Prune(*pdesc, pruned); + PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1); +} + +TEST(Prune, forward) { + f::ProgramDescBind program; + f::BlockDescBind *block = program.Block(0); + + AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, {}, block); + AddOp("one_one", {{"input", {"b"}}}, {{"output", {"c"}}}, {}, block); + AddOp("one_one", {{"input", {"c"}}}, {{"output", {"d"}}}, {}, block); + AddOp("one_one", {{"input", {"d"}}}, {{"output", {"e"}}}, {}, block); + + f::ProgramDesc *pdesc = program.Proto(); + + for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) { + f::ProgramDesc pruned; + pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true); + Prune(*pdesc, pruned); + PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1); + } +} + +TEST(Prune, multi_input_op) { + f::ProgramDescBind program; + f::BlockDescBind *block = program.Block(0); + + AddOp("one_one", {{"input", {"a0"}}}, {{"output", {"b0"}}}, {}, block); + AddOp("one_one", {{"input", {"a1"}}}, {{"output", {"b1"}}}, {}, block); + AddOp("one_one", {{"input", {"a2"}}}, {{"output", {"b2"}}}, {}, block); + AddOp("three_one", {{"input", {"b0", "b1", "b2"}}}, {{"output", {"c"}}}, {}, + block); + + f::ProgramDesc *pdesc = program.Proto(); + pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true); + + f::ProgramDesc pruned; + Prune(*pdesc, pruned); + PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4); +} + +TEST(Prune, multi_output_op) { + f::ProgramDescBind program; + f::BlockDescBind *block = program.Block(0); + + AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, {}, block); + AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, {}, block); + AddOp("one_one", {{"input", {"c"}}}, {{"output", {"c1"}}}, {}, block); + + f::ProgramDesc *pdesc = program.Proto(); + pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); + + f::ProgramDesc pruned; + Prune(*pdesc, pruned); + PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2); +} + +TEST(Prune, multi_target) { + f::ProgramDescBind program; + f::BlockDescBind *block = program.Block(0); + + AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, {}, block); + AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, {}, block); + AddOp("one_one", {{"input", {"c"}}}, {{"output", {"c1"}}}, {}, block); + + f::ProgramDesc *pdesc = program.Proto(); + pdesc->mutable_blocks(0)->mutable_ops(1)->set_is_target(true); + pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); + + f::ProgramDesc pruned; + Prune(*pdesc, pruned); + PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3); +} diff --git a/paddle/framework/var_desc.h b/paddle/framework/var_desc.h index 688a46f83982fc464c7602ec1041ad3f42122211..af4c26ca0a77b444852cc01545a8b585a5c3afcc 100644 --- a/paddle/framework/var_desc.h +++ b/paddle/framework/var_desc.h @@ -79,6 +79,10 @@ class VarDescBind { void SetType(VarDesc::VarType type) { desc_.set_type(type); } + bool Persistable() const { return desc_.persistable(); } + + void SetPersistable(bool persistable) { desc_.set_persistable(persistable); } + private: const TensorDesc &tensor_desc() const; TensorDesc *mutable_tensor_desc(); diff --git a/paddle/framework/var_type_inference_test.cc b/paddle/framework/var_type_inference_test.cc index 87399208e924d85ed8463df9a8f2eb49b1277fe9..918de1fd055e32888f71ffea1f33993ba1210e86 100644 --- a/paddle/framework/var_type_inference_test.cc +++ b/paddle/framework/var_type_inference_test.cc @@ -62,7 +62,7 @@ namespace paddle { namespace framework { TEST(InferVarType, sum_op) { - auto &prog = ProgramDescBind::Instance(&GetProgramDesc()); + ProgramDescBind prog; auto *op = prog.Block(0)->AppendOp(); op->SetType("sum"); op->SetInput("X", {"test_a", "test_b", "test_c"}); @@ -83,7 +83,7 @@ TEST(InferVarType, sum_op) { } TEST(InferVarType, sum_op_without_infer_var_type) { - auto &prog = ProgramDescBind::Instance(&GetProgramDesc()); + ProgramDescBind prog; auto *op = prog.Block(0)->AppendOp(); op->SetType("sum_without_infer_var_type"); op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); diff --git a/paddle/operators/batch_norm_op.md b/paddle/operators/batch_norm_op.md new file mode 100644 index 0000000000000000000000000000000000000000..80948adf2b9047a9685dbdd90b2296b5a955f9c1 --- /dev/null +++ b/paddle/operators/batch_norm_op.md @@ -0,0 +1,134 @@ +# Batch Normalization + +## What is batch normalization + +Batch normalization is a frequently-used method in deep network training. It adjusts the mean and variance of a layer's output, and make the data distribution easier for next layer's training. + +The principle of batch normalization can be summarized into a simple function: + +``` +y = (x - E[x]) / STD[x]) * scale + bias +``` + +`x` is a batch of output data of a certain layer. `E[x]` and `STD[x]` is the mean and standard deviation of `x`, respectively。 `scale` and `bias` are two trainable parameters. The training of batch normalization layer equals to the learning of best values of `scale` and `bias`. + +In our design, we use a single operator(`batch_norm_op`) to implement the whole batch normalization in C++, and wrap it as a layer in Python. + +## Differences with normal operators + +`batch_norm_op` is a single operator. However, there are a few differences between `BatchNormOp` and normal operators, which we shall take into consideration in our design. + +1. `batch_norm_op` shall behave differently in training and inferencing. For example, during inferencing, there is no batch data and it's impossible to compute `E[x]` and `STD[x]`, so we have to use an `estimated_mean` and an `estimated_variance` instead of them. These require our framework to be able to inform operators current running type (training/inferencing), then operators can switch their behaviors. + +2. `batch_norm_op` shall have the ability to maintain `estimated_mean` and `estimated_variance` across mini-batch. In each mini-batch, `estimated_mean` is iterated by the following equations: + +``` +if batch_id == 0 + estimated_mean = E[x] +else + estimated_mean = estimated_mean * momentum + (1.0 - momentum_) * E[x] +``` + +The iterating of `estimated_variance` is similar. `momentum` is an attribute, which controls estimated_mean updating speed. + +## Implementation + +Batch normalization is designed as a single operator is C++, and then wrapped as a layer in Python. + +### C++ + +As most C++ operators do, `batch_norm_op` is defined by inputs, outputs, attributes and compute kernels. + +#### Inputs + +- `x`: The inputs data, which is generated by the previous layer. +- `estimated_mean`: The estimated mean of all previous data batches. It is updated in each forward propagation and will be used in inferencing to take the role of `E[x]`. +- `estimated_var`: The estimated standard deviation of all previous data batches. It is updated in each forward propagation and will be used in inferencing to take the role of `STD[x]`. +- `scale`: trainable parameter 'scale' +- `bias`: trainable parameter 'bias' + +#### Outputs + +- `y`: The output data. +- `batch_mean`: The mean value of batch data. +- `batch_var`: The standard deviation value of batch data. +- `saved_mean`: Updated `estimated_mean` with current batch data. It's supposed to share the memory with input `estimated_mean`. +- `saved_var`: Updated `estimated_var` with current batch data. It's supposed to share the memory with input `estimated_var`. + +#### Attributes + +- `is_infer`: *bool*. If true, run `batch_norm_op` in inferencing mode. +- `use_global_est`: *bool*. If true, use `saved_mean` and `saved_var` instead of `E[x]` and `STD[x]` in trainning. +- `epsilon`: *float*. The epsilon value to avoid division by zero. +- `momentum`: *float*. Factor used in `estimated_mean` and `estimated_var` updating. The usage is shown above. + +#### Kernels + +The following graph showes the training computational process of `batch_norm_op`: + + + +cudnn provides APIs to finish the whole series of computation, we can use them in our GPU kernel. + +### Python + +`batch_norm_op` is warpped as a layer in Python: + +```python +def batch_norm_layer(net, + input, + output, + scale, + bias, + use_global_est = False, + epsilon = 1e-6, + momentum = 0.99): + mean_cache = scope.new_var(name = 'estimated_mean', trainable = False) + var_cache = scop.new_var(name = 'estimated_var', trainable = False) + batch_mean = scope.new_var(name = 'batch_mean') + batch_var = scope.new_var(name = 'batch_var') + batch_norm_op = Operator('batch_norm_op', + x = input, + estimated_mean = mean_cache, + estimated_mean = var_cache, + scale = scale, + bias = bias, + y = output, + batch_mean = batch_mean, + batch_var = batch_var, + saved_mean = mean_cache, + saved_var = var_cache, + is_infer = False, + use_global_est = use_global_est, + epsilon = epsilon, + momentum = momentum) + net.append_op(batch_norm_op) + return output +``` + +Because Python API has not been finally decided, the code above can be regarded as pseudo code. There are a few key points we shall note: + +1. `estimated_mean` and `estimated_var` are assigned the same variables with `saved_mean` and `saved_var` respectively. So they share same the memories. The output mean and variance values(`saved_mean` and `saved_var`) of a certain batch will be the inputs(`estimated_mean` and `estimated_var`) of the next batch. + +2. `is_infer` decided whether `batch_norm_op` will run in training mode or inferencing mode. However, a network may contains both training and inferencing parts. And user may switch `batch_norm_op`'s running mode in Python `for` loop like this: + +```python +for pass_id in range(PASS_NUM): + # ... + net.train() # run training model + if pass_id % 100 == 0: + net.infer(test_image) # run inferencing model + # ... +``` + +`is_infer` is an attribute. Once an operator is created, its attributes can not be changed. It suggests us that we shall maintain two `batch_norm_op` in the model, one's `is_infer` is `True`(we call it `infer_batch_norm_op`) and the other one's is `False`(we call it `train_batch_norm_op`). They share all parameters and variables, but be placed in two different branches. That is to say, if a network contains a `batch_norm_op`, it will fork into two branches, one go through `train_batch_norm_op` and the other one go through `infer_batch_norm_op`: + +
+ +
+ +Just like what is shown in the above graph, the net forks before `batch_norm_op` and will never merge again. All the operators after `batch_norm_op` will duplicate. + +When the net runs in training mode, the end of the left branch will be set as the running target, so the dependency tracking process will ignore right branch automatically. When the net runs in inferencing mode, the process is reversed. + +How to set a target is related to Python API design, so I will leave it here waiting for more discussions. diff --git a/paddle/operators/dynamic_recurrent_op_test.cc b/paddle/operators/dynamic_recurrent_op_test.cc index 83a5ba36d9af2ef81ebcbb33e056de2e0b98cbc1..36f405568d7e4ed9a469c3af7a80192b83142b7a 100644 --- a/paddle/operators/dynamic_recurrent_op_test.cc +++ b/paddle/operators/dynamic_recurrent_op_test.cc @@ -51,7 +51,7 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test { CreateGlobalVariables(); auto op_desc = CreateOpDesc(); - op = paddle::framework::OpRegistry::CreateOp(op_desc); + op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); dop = dynamic_cast(op.get()); InitCacheManually(); InitStepNet(); diff --git a/paddle/operators/images/batch_norm_fork.dot b/paddle/operators/images/batch_norm_fork.dot new file mode 100644 index 0000000000000000000000000000000000000000..4bc47713cba2cb23f1b34fffe6426ef10ac3a9df --- /dev/null +++ b/paddle/operators/images/batch_norm_fork.dot @@ -0,0 +1,25 @@ +digraph ImageBatchNormForkGragh { + subgraph cluster_before { + Prev [label="...", shape=plaintext]; + Rnn [label="rnn_op", shape=box]; + BatchNorm [label="batch_norm_op", shape=box]; + Fc [label="fc_op", shape=box]; + After [label="...", shape=plaintext]; + Prev -> Rnn -> BatchNorm -> Fc -> After; + label="original"; + } + + subgraph cluster_after { + Prev2 [label="...", shape=plaintext]; + Rnn2 [label="rnn_op", shape=box]; + BatchNorm2_1 [label="train_batch_norm_op", shape=box]; + BatchNorm2_2 [label="infer_batch_norm_op", shape=box]; + Fc2_1 [label="fc_op", shape=box]; + Fc2_2 [label="fc_op", shape=box]; + After2_1 [label="...", shape=plaintext]; + After2_2 [label="...", shape=plaintext]; + Prev2 -> Rnn2 -> BatchNorm2_1 -> Fc2_1 -> After2_1; + Rnn2 -> BatchNorm2_2 ->Fc2_2 ->After2_2 + label="forked"; + } +} diff --git a/paddle/operators/images/batch_norm_fork.png b/paddle/operators/images/batch_norm_fork.png new file mode 100644 index 0000000000000000000000000000000000000000..aded62bce5bc268b7a3ef4dc96c89fe21d6ea955 Binary files /dev/null and b/paddle/operators/images/batch_norm_fork.png differ diff --git a/paddle/operators/images/batch_norm_op_kernel.png b/paddle/operators/images/batch_norm_op_kernel.png new file mode 100644 index 0000000000000000000000000000000000000000..a99ce81ff3bf42880ebbd6a1297de3bf038e09b2 Binary files /dev/null and b/paddle/operators/images/batch_norm_op_kernel.png differ diff --git a/paddle/operators/proximal_gd_op.cc b/paddle/operators/proximal_gd_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e4b014b9f5866ec0791cba9b3998b1734066eeeb --- /dev/null +++ b/paddle/operators/proximal_gd_op.cc @@ -0,0 +1,93 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/proximal_gd_op.h" + +namespace paddle { +namespace operators { + +class ProximalGDOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Param"), + "Input(Param) of ProximalGDOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(Grad) of ProximalGDOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LearningRate"), + "Input(LearningRate) of ProximalGDOp should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), + "Output(ParamOut) of ProximalGDOp should not be null."); + + auto param_dim = ctx->GetInputDim("Param"); + PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"), + "Two input of ProximalGD Op's dimension must be same."); + + auto lr_dim = ctx->GetInputDim("LearningRate"); + PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1, + "Learning Rate should be a scalar."); + + ctx->SetOutputDim("ParamOut", param_dim); + } +}; + +class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ProximalGDOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Param", + "(Tensor, default Tensor) " + "Input parameter value that has to be updated."); + AddInput("Grad", + "(Tensor, default Tensor) " + "Input gradient of the parameter."); + AddInput("LearningRate", + "(Tensor, default Tensor) " + "The learning rate should be a tensor of size 1."); + + AddOutput("ParamOut", "(Tensor) Output updated parameter value."); + + AddAttr("l1", + "(float, default 0.0) " + "L1 regularization strength.") + .SetDefault(0.0f); + AddAttr("l2", + "(float, default 0.0)" + "L2 regularization strength.") + .SetDefault(0.0f); + AddComment(R"DOC( + +Optimizer that implements the proximal gradient descent algorithm. + +prox_param = param - learning_rate * grad +param = sign(prox_param) / (1 + learning_rate * l2) * + max { |prox_param| - learning_rate * l1 , 0 } + +The paper that proposed Proximal Gradient Descent: +(http://papers.nips.cc/paper/3793-efficient-learning-using-forward-backward-splitting.pdf) +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(proximal_gd, ops::ProximalGDOp, + ops::ProximalGDOpMaker); +REGISTER_OP_CPU_KERNEL( + proximal_gd, ops::ProximalGDOpKernel); diff --git a/paddle/operators/proximal_gd_op.cu b/paddle/operators/proximal_gd_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..26f4ebaa0f43620fee7ece2d71755be94a0e01a5 --- /dev/null +++ b/paddle/operators/proximal_gd_op.cu @@ -0,0 +1,19 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +You may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software distributed +under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/proximal_gd_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + proximal_gd, ops::ProximalGDOpKernel); diff --git a/paddle/operators/proximal_gd_op.h b/paddle/operators/proximal_gd_op.h new file mode 100644 index 0000000000000000000000000000000000000000..bebda0204173ec5c3ec9a7a9da6fb623171f4cea --- /dev/null +++ b/paddle/operators/proximal_gd_op.h @@ -0,0 +1,64 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +template +class ProximalGDOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* param_out = ctx.Output("ParamOut"); + + param_out->mutable_data(ctx.GetPlace()); + + auto grad = ctx.Input("Grad"); + + auto l1 = static_cast(ctx.Attr("l1")); + auto l2 = static_cast(ctx.Attr("l2")); + + auto p = EigenVector::Flatten(*ctx.Input("Param")); + auto g = EigenVector::Flatten(*grad); + auto lr = EigenVector::Flatten(*ctx.Input("LearningRate")); + + auto p_out = EigenVector::Flatten(*param_out); + auto place = ctx.GetEigenDevice(); + + Eigen::DSizes grad_dsize(grad->numel()); + + auto prox_param = p - lr.broadcast(grad_dsize) * g; + if (l1 > 0) { + p_out.device(place) = + prox_param.sign() * + (((prox_param.abs() - (lr * l1).broadcast(grad_dsize)) + .cwiseMax(T(0.0))) / + (1.0 + (lr * l2).broadcast(grad_dsize))); + } else { + p_out.device(place) = + prox_param / (1.0 + (lr * l2).broadcast(grad_dsize)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index 0f78eeab9bc643a1a70c4b6ab02a160bbeda2b33..2acb96d1b4f5903ff6c57b10e7621c8adaf73171 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -21,7 +21,7 @@ class SGDOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), "Input(Param) of SGDOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Grad"), @@ -35,15 +35,15 @@ class SGDOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, "Learning rate should have 1 element"); auto param_dim = ctx->GetInputDim("Param"); - PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"), - "Two input of SGD Op's dimension must be same."); + // TODO(qijun): check dimensions of Param and Grad at complie + // and run time. ctx->SetOutputDim("ParamOut", param_dim); } }; class SGDOpMaker : public framework::OpProtoAndCheckerMaker { public: - SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + SGDOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Param", "Input parameter"); AddInput("LearningRate", "Learning rate of SGD"); @@ -58,6 +58,38 @@ param_out = param - learning_rate * grad; )DOC"); } }; + +template +struct SparseSGDFunctor { + void operator()(const platform::DeviceContext& context, + const framework::SelectedRows& input, + const framework::Tensor& learning_rate, + framework::Tensor* output) { + auto in_height = input.height(); + auto out_dims = output->dims(); + PADDLE_ENFORCE_EQ(in_height, out_dims[0]); + + auto& in_value = input.value(); + auto& in_rows = input.rows(); + + int64_t in_row_numel = in_value.numel() / in_rows.size(); + PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height); + + auto* in_data = in_value.data(); + auto* out_data = output->data(); + auto* lr = learning_rate.data(); + + for (size_t i = 0; i < in_rows.size(); i++) { + for (int64_t j = 0; j < in_row_numel; j++) { + out_data[in_rows[i] * in_row_numel + j] -= + lr[0] * in_data[i * in_row_numel + j]; + } + } + } +}; + +template struct SparseSGDFunctor; + } // namespace operators } // namespace paddle diff --git a/paddle/operators/sgd_op.cu b/paddle/operators/sgd_op.cu index f5ba6d3c29f8dfbfdea4fbf2c3d5fd7f5b358666..106f9b746ba6614d8fa68b677c47ec04ed26fb81 100644 --- a/paddle/operators/sgd_op.cu +++ b/paddle/operators/sgd_op.cu @@ -14,6 +14,66 @@ #define EIGEN_USE_GPU #include "paddle/operators/sgd_op.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { + +namespace { +template +__global__ void SparseSGDFunctorKernel(const T* selected_rows, + const int64_t* rows, + const T* learning_rate, T* tensor_out, + int64_t row_numel, int block_size) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + + selected_rows += ty * row_numel; + tensor_out += rows[ty] * row_numel; + + for (int index = tid; index < row_numel; index += block_size) { + // Since index in rows of SelectedRows can be duplicate, we have to use + // Atomic Operation to avoid concurrent write error. + paddle::platform::CudaAtomicAdd( + tensor_out + index, -1.0 * learning_rate[0] * selected_rows[index]); + } +} +} // namespace + +template +struct SparseSGDFunctor { + void operator()(const platform::DeviceContext& context, + const framework::SelectedRows& input, + const framework::Tensor& learning_rate, + framework::Tensor* output) { + auto in_height = input.height(); + auto out_dims = output->dims(); + PADDLE_ENFORCE_EQ(in_height, out_dims[0]); + + auto& in_value = input.value(); + auto& in_rows = input.rows(); + + int64_t in_row_numel = in_value.numel() / in_rows.size(); + PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height); + + auto* in_data = in_value.data(); + auto* out_data = output->data(); + + int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid(1, in_rows.size()); + SparseSGDFunctorKernel< + T><<(context) + .stream()>>>(in_data, in_rows.data(), learning_rate.data(), + out_data, in_row_numel, block_size); + } +}; + +template struct SparseSGDFunctor; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(sgd, diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index 26f4012f258771794c736dbfad4af174b017f410..78b595fc6c63d775b627f23cafa9458f1dadd4e5 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -15,31 +15,53 @@ limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/framework/selected_rows.h" namespace paddle { namespace operators { +template +struct SparseSGDFunctor { + void operator()(const platform::DeviceContext& context, + const framework::SelectedRows& input, + const framework::Tensor& learning_rate, + framework::Tensor* output); +}; + template class SGDOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto param = ctx.Input("Param"); - auto grad = ctx.Input("Grad"); - auto param_out = ctx.Output("ParamOut"); - auto learning_rate = ctx.Input("LearningRate"); + auto* param = ctx.Input("Param"); + auto* param_out = ctx.Output("ParamOut"); + auto* learning_rate = ctx.Input("LearningRate"); - param_out->mutable_data(ctx.GetPlace()); + auto* grad_var = ctx.InputVar("Grad"); + // Actually, all tensors are LoDTensor except SelectedRows. + if (grad_var->IsType()) { + param_out->mutable_data(ctx.GetPlace()); + auto* grad = ctx.Input("Grad"); - auto p = framework::EigenVector::Flatten(*param); - auto g = framework::EigenVector::Flatten(*grad); - auto o = framework::EigenVector::Flatten(*param_out); - auto lr = framework::EigenVector::Flatten(*learning_rate); - auto place = ctx.GetEigenDevice(); + auto p = framework::EigenVector::Flatten(*param); + auto g = framework::EigenVector::Flatten(*grad); + auto o = framework::EigenVector::Flatten(*param_out); + auto lr = framework::EigenVector::Flatten(*learning_rate); + auto place = ctx.GetEigenDevice(); - Eigen::DSizes grad_dsize(grad->numel()); - o.device(place) = p - lr.broadcast(grad_dsize) * g; + Eigen::DSizes grad_dsize(grad->numel()); + o.device(place) = p - lr.broadcast(grad_dsize) * g; + } else if (grad_var->IsType()) { + // TODO(qijun): In Sparse SGD operator, in-place update is enforced. + // This manual optimization brings difficulty to track data dependency. + // It's better to find a more elegant solution. + PADDLE_ENFORCE_EQ(param, param_out); + auto* grad = ctx.Input("Grad"); + SparseSGDFunctor functor; + functor(ctx.device_context(), *grad, *learning_rate, param_out); + } else { + PADDLE_THROW("Unsupported Variable Type of Grad"); + } } }; - } // namespace operators } // namespace paddle diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 82aae72ba9378d6f05a9d4c0b95c7afd40650026..58739d888ac537222fdfeff0c52bd90ab58362c1 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -100,21 +100,11 @@ using namespace paddle::framework; // NOLINT // Bind Methods void BindProgramDesc(py::module &m) { py::class_(m, "ProgramDesc", "") - .def_static("instance", - []() -> ProgramDescBind * { - return &ProgramDescBind::Instance(&GetProgramDesc()); - }, - py::return_value_policy::reference) - .def_static("__create_program_desc__", - []() -> ProgramDescBind * { - // Only used for unit-test - auto *prog_desc = new ProgramDesc; - auto *block = prog_desc->mutable_blocks()->Add(); - block->set_idx(0); - block->set_parent_idx(-1); - return &ProgramDescBind::Instance(prog_desc); - }, - py::return_value_policy::reference) + .def(py::init<>()) + .def("__init__", + [](ProgramDescBind &self, const ProgramDescBind &other) { + new (&self) ProgramDescBind(other); + }) .def("append_block", &ProgramDescBind::AppendBlock, py::return_value_policy::reference) .def("append_backward", @@ -176,8 +166,8 @@ void BindBlockDesc(py::module &m) { py::return_value_policy::reference) .def("all_vars", &BlockDescBind::AllVars, py::return_value_policy::reference) - .def("all_ops", &BlockDescBind::AllOps, - py::return_value_policy::reference) + .def("op_size", &BlockDescBind::OpSize) + .def("op", &BlockDescBind::Op, py::return_value_policy::reference) .def("serialize_to_string", [](BlockDescBind &block_desc) -> py::bytes { const BlockDesc *desc = block_desc.Proto(); PADDLE_ENFORCE(desc->IsInitialized(), @@ -216,16 +206,19 @@ void BindVarDsec(py::module &m) { .def("set_lod_level", &VarDescBind::SetLoDLevel) .def("type", &VarDescBind::GetType) .def("set_type", &VarDescBind::SetType) - .def("serialize_to_string", [](VarDescBind &var_desc) -> py::bytes { - const VarDesc *desc = var_desc.Proto(); - PADDLE_ENFORCE(desc->IsInitialized(), - "VarDesc has not been initialized."); - std::string res; - PADDLE_ENFORCE( - desc->SerializeToString(&res), - "Serialize VarDesc Error. This could be a bug of Paddle."); - return res; - }); + .def("serialize_to_string", + [](VarDescBind &var_desc) -> py::bytes { + const VarDesc *desc = var_desc.Proto(); + PADDLE_ENFORCE(desc->IsInitialized(), + "VarDesc has not been initialized."); + std::string res; + PADDLE_ENFORCE( + desc->SerializeToString(&res), + "Serialize VarDesc Error. This could be a bug of Paddle."); + return res; + }) + .def("persistable", &VarDescBind::Persistable) + .def("set_persistable", &VarDescBind::SetPersistable); py::enum_(var_desc, "VarType", "") .value("LOD_TENSOR", VarDesc::LOD_TENSOR) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index fcae92ad99f7b393104ac04fd725ad3d43db04ad..16661b93e56da30ecd3848d28a0f4667b710e80c 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/framework/backward.h" #include "paddle/framework/executor.h" #include "paddle/framework/feed_fetch_method.h" +#include "paddle/framework/framework.pb.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/selected_rows.h" #include "paddle/framework/tensor_array.h" @@ -153,7 +154,15 @@ PYBIND11_PLUGIN(core) { py::return_value_policy::reference) .def("set_height", &SelectedRows::set_height) .def("height", &SelectedRows::height) - .def("set_rows", &SelectedRows::set_rows) + .def("set_rows", + [](SelectedRows &self, std::vector rows) { +#ifndef PADDLE_WITH_CUDA + self.set_rows(rows); +#else + Vector new_rows(rows); + self.set_rows(new_rows); +#endif + }) .def("rows", [](SelectedRows &self) { #ifndef PADDLE_WITH_CUDA return self.rows(); @@ -186,6 +195,11 @@ All parameter, weight, gradient are variables in Paddle. return self.GetMutable(); }, py::return_value_policy::reference) + .def("get_selected_rows", + [](Variable &self) -> SelectedRows * { + return self.GetMutable(); + }, + py::return_value_policy::reference) .def("get_net", [](Variable &self) -> operators::NetOp * { return self.GetMutable(); @@ -259,7 +273,7 @@ All parameter, weight, gradient are variables in Paddle. PADDLE_ENFORCE(desc.IsInitialized(), "User OpDesc is not initialized, reason %s", desc.InitializationErrorString()); - return OpRegistry::CreateOp(desc); + return OpRegistry::CreateOp(desc, nullptr); }) .def("backward", [](const OperatorBase &forwardOp, @@ -363,7 +377,7 @@ All parameter, weight, gradient are variables in Paddle. PADDLE_ENFORCE(desc.IsInitialized(), "User OpDesc is not initialized, reason %s", desc.InitializationErrorString()); - auto rnn_op = OpRegistry::CreateOp(desc); + auto rnn_op = OpRegistry::CreateOp(desc, nullptr); return static_cast(rnn_op.release()); }) .def("set_stepnet", [](operators::RecurrentOp &self, @@ -381,7 +395,7 @@ All parameter, weight, gradient are variables in Paddle. PADDLE_ENFORCE(desc.IsInitialized(), "User OpDesc is not initialized, reason %s", desc.InitializationErrorString()); - auto rnn_op = OpRegistry::CreateOp(desc); + auto rnn_op = OpRegistry::CreateOp(desc, nullptr); return static_cast( rnn_op.release()); }) @@ -408,7 +422,7 @@ All parameter, weight, gradient are variables in Paddle. PADDLE_ENFORCE(desc.IsInitialized(), "User OpDesc is not initialized, reason %s", desc.InitializationErrorString()); - auto cond_op = OpRegistry::CreateOp(desc); + auto cond_op = OpRegistry::CreateOp(desc, nullptr); return static_cast(cond_op.release()); }) .def("set_truenet", diff --git a/python/paddle/v2/framework/framework.py b/python/paddle/v2/framework/framework.py index 3fb6efe42a2c69e37274dbb655fa47d78eabdf26..a24c78171e07249beeb8131d2d338b01c30e5c1f 100644 --- a/python/paddle/v2/framework/framework.py +++ b/python/paddle/v2/framework/framework.py @@ -15,6 +15,7 @@ class Variable(object): shape=None, dtype=None, lod_level=None, + persistable=False, **kwargs): self.block = block @@ -70,6 +71,17 @@ class Variable(object): "lod_level is {2}. They are not " "matched".format(self.name, self.lod_level, lod_level)) + if persistable is not None: + if is_new_var: + self.desc.set_persistable(persistable) + else: + if persistable != self.persistable: + raise ValueError( + "Variable {0} has been created before." + "The previous persistable is {1}; the new " + "persistable is {2}. They are not matched".format( + self.name, self.persistable, persistable)) + self.block.vars[name] = self self.op = None @@ -80,6 +92,10 @@ class Variable(object): __repr__ = __str__ + @property + def persistable(self): + return self.desc.persistable() + @property def name(self): return self.desc.name() @@ -232,7 +248,7 @@ class Operator(object): if attrs is not None: for attr in proto.attrs: attr_name = attr.name - if not attr_name in attrs: + if (not attr_name in attrs) or (attrs[attr_name] is None): continue if not isinstance(attrs[attr_name], Block): self.desc.set_attr(attr_name, attrs[attr_name]) @@ -344,19 +360,26 @@ class Block(object): self.create_var(name=var.name(), desc=var, type=var.type()) # sync operators from cpp - ops_in_cpp = self.desc.all_ops() - first_op_in_python = self.ops[0].desc - last_op_in_python = self.ops[len(self.ops) - 1].desc - start_index = None - end_index = None - for index in range(len(ops_in_cpp)): - if first_op_in_python == ops_in_cpp[index]: - start_index = index - if last_op_in_python == ops_in_cpp[index]: - end_index = index - assert start_index is not None - assert end_index is not None - assert start_index <= end_index + ops_in_cpp = [] + for op_idx in range(0, self.desc.op_size()): + ops_in_cpp.append(self.desc.op(op_idx)) + + if len(self.ops) != 0: + first_op_in_python = self.ops[0].desc + last_op_in_python = self.ops[len(self.ops) - 1].desc + start_index = None + end_index = None + for index in range(len(ops_in_cpp)): + if first_op_in_python == ops_in_cpp[index]: + start_index = index + if last_op_in_python == ops_in_cpp[index]: + end_index = index + assert start_index is not None + assert end_index is not None + assert start_index <= end_index + else: + start_index = 0 + end_index = -1 # sync ops append to the head of cpp_ops for index in range((start_index - 1 - 1), -1, -1): @@ -376,18 +399,8 @@ class Block(object): class Program(object): - @classmethod - def instance(cls): - # From https://stackoverflow.com/questions/8212053 - # Making Program as a Singleton class. - if not hasattr(cls, '_instance'): - cls._instance = cls() - return cls._instance - - def __init__(self, desc=None): - if desc is None: - desc = core.ProgramDesc.instance() - self.desc = desc + def __init__(self): + self.desc = core.ProgramDesc() self.blocks = [Block(self, 0)] self.current_block_idx = 0 @@ -396,7 +409,15 @@ class Program(object): proto = framework_pb2.ProgramDesc.FromString(str(protostr)) return proto.__str__() - __repr__ = __str__ + def clone(self): + p = Program() + p.desc = core.ProgramDesc(self.desc) + p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())] + p.sync_with_cpp() + return p + + def __repr__(self): + return str(self) def global_block(self): return self.blocks[0] @@ -444,7 +465,9 @@ class Parameter(Variable): if each < 0: raise ValueError("Parameter shape should not be related with " "batch-size") - Variable.__init__(self, block, shape=shape, dtype=dtype, **kwargs) + + Variable.__init__( + self, block, persistable=True, shape=shape, dtype=dtype, **kwargs) self.trainable = kwargs.get('trainable', True) self.init_attr = kwargs.get('initialize_attr', { 'type': 'uniform_random', @@ -469,4 +492,4 @@ class Parameter(Variable): # program is a global instance. -g_program = Program.instance() +g_program = Program() diff --git a/python/paddle/v2/framework/layer_helper.py b/python/paddle/v2/framework/layer_helper.py index 26d3e04310b6b0415af31d1630575b32dce186d5..6615bdcd3b1afa493c9ad05c789664818e64d2f2 100644 --- a/python/paddle/v2/framework/layer_helper.py +++ b/python/paddle/v2/framework/layer_helper.py @@ -66,15 +66,15 @@ class LayerHelper(object): actual = self.kwargs.get('param_attr', None) return actual if actual is not None else default - def bias_attr(self, size, dtype): - bias_attr = self.kwargs.get('bias_attr', False) - if bias_attr is None or bias_attr: + def bias_attr(self, shape, dtype): + bias_attr = self.kwargs.get('bias_attr', None) + if bias_attr is True: bias_attr = { 'name': None, 'init_attr': { 'type': 'fill_constant', 'value': 0.0, - 'shape': [size], + 'shape': shape, 'dataType': dtype } } @@ -127,15 +127,13 @@ class LayerHelper(object): return self.program.global_block().create_var(*args, **kwargs) def append_bias_op(self, input_var): - bias_attr = self.bias_attr( - self.kwargs['size'], dtype=input_var.data_type) + size = list(input_var.shape[1:]) + bias_attr = self.bias_attr(size, dtype=input_var.data_type) if not bias_attr: return input_var + b = self.create_parameter( - attr=bias_attr, - shape=[self.kwargs['size']], - dtype=input_var.data_type, - suffix='b') + attr=bias_attr, shape=size, dtype=input_var.data_type, suffix='b') tmp = self.create_tmp_variable(dtype=input_var.data_type) self.append_op( type='elementwise_add', diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/framework/layers.py index 44b587b116e2ebfc5329348027492a4ee27b04e5..c7397716c47dd7088d840edb00d96dda2fe88f1d 100644 --- a/python/paddle/v2/framework/layers.py +++ b/python/paddle/v2/framework/layers.py @@ -3,17 +3,17 @@ import paddle.v2.framework.core as core from paddle.v2.framework.framework import OpProtoHolder, Variable import re -__all__ = ['fc_layer', 'data_layer', 'cross_entropy'] +__all__ = ['fc', 'data', 'cross_entropy', 'conv2d'] -def fc_layer(input, - size, - param_attr=None, - bias_attr=True, - name=None, - act=None, - num_flatten_dims=1, - program=None): +def fc(input, + size, + param_attr=None, + bias_attr=True, + name=None, + act=None, + num_flatten_dims=1, + program=None): # create helper helper = LayerHelper('fc', **locals()) @@ -24,6 +24,7 @@ def fc_layer(input, for input_var, param_attr in helper.iter_inputs_and_params(): input_shape = input_var.shape param_shape = list(input_shape[num_flatten_dims:]) + [size] + w = helper.create_parameter( attr=param_attr, shape=param_shape, dtype=dtype) tmp = helper.create_tmp_variable(dtype) @@ -50,11 +51,11 @@ def fc_layer(input, return helper.append_activation(pre_activation) -def data_layer(name, - shape, - data_type='float32', - type=core.VarDesc.VarType.LOD_TENSOR, - program=None): +def data(name, + shape, + data_type='float32', + type=core.VarDesc.VarType.LOD_TENSOR, + program=None): helper = LayerHelper('data', **locals()) shape = [-1] + shape # append batch size as -1 return helper.create_global_variable( @@ -111,6 +112,7 @@ def _create_op_func_(op_type): _create_op_func_('mean') +_create_op_func_('pool2d') def cross_entropy(input, label, **kwargs): @@ -141,3 +143,47 @@ def square_error_cost(input, label, **kwargs): outputs={'Y': [square_out]}, attrs={'factor': 2.0}) return square_out + + +def conv2d(input, + num_filters, + name=None, + filter_size=[1, 1], + act=None, + groups=None, + stride=[1, 1], + padding=None, + bias_attr=None, + param_attr=None, + program=None): + helper = LayerHelper('conv2d', **locals()) + dtype = helper.input_dtype() + + num_channels = input.shape[1] + if groups is None: + num_filter_channels = num_channels + else: + if num_channels % groups is not 0: + raise ValueError("num_channels must be divisible by groups.") + num_filter_channels = num_channels / groups + + input_shape = input.shape + filter_shape = [num_filters, num_filter_channels] + filter_size + filter = helper.create_parameter( + attr=helper.param_attr, shape=filter_shape, dtype=dtype) + pre_bias = helper.create_tmp_variable(dtype) + + helper.append_op( + type='conv2d', + inputs={ + 'Input': input, + 'Filter': filter, + }, + outputs={"Output": pre_bias}, + attrs={'strides': stride, + 'paddings': padding, + 'groups': groups}) + + pre_act = helper.append_bias_op(pre_bias) + + return helper.append_activation(pre_act) diff --git a/python/paddle/v2/framework/tests/test_infer_shape.py b/python/paddle/v2/framework/tests/test_infer_shape.py index 19bb45acef9a7443a974bf5f11afab5d067321f7..5cfb9e6687f733353cfdbfbd1ad830c2bed8463b 100644 --- a/python/paddle/v2/framework/tests/test_infer_shape.py +++ b/python/paddle/v2/framework/tests/test_infer_shape.py @@ -5,7 +5,7 @@ import paddle.v2.framework.core as core class TestInferShape(unittest.TestCase): def test_sum_op(self): - prog = core.ProgramDesc.__create_program_desc__() + prog = core.ProgramDesc() self.assertIsNotNone(prog) block = prog.block(0) self.assertIsNotNone(block) @@ -33,7 +33,7 @@ class TestInferShape(unittest.TestCase): self.assertEqual(out.shape(), shape) def test_mul_op(self): - prog = core.ProgramDesc.__create_program_desc__() + prog = core.ProgramDesc() self.assertIsNotNone(prog) block = prog.block(0) self.assertIsNotNone(block) diff --git a/python/paddle/v2/framework/tests/test_layers.py b/python/paddle/v2/framework/tests/test_layers.py index 1ef2591cca066788deed8a1c0f6443850251fb80..dbbb6535389e2336d156734cc672e0cc7bba175c 100644 --- a/python/paddle/v2/framework/tests/test_layers.py +++ b/python/paddle/v2/framework/tests/test_layers.py @@ -1,4 +1,4 @@ -from paddle.v2.framework.layers import fc_layer, data_layer, cross_entropy, mean, square_error_cost +import paddle.v2.framework.layers as layers from paddle.v2.framework.framework import Program, g_program import paddle.v2.framework.core as core import unittest @@ -6,36 +6,57 @@ import unittest class TestBook(unittest.TestCase): def test_fit_a_line(self): - pd = core.ProgramDesc.__create_program_desc__() - program = Program(desc=pd) - x = data_layer( + program = Program() + x = layers.data( name='x', shape=[13], data_type='float32', program=program) - y_predict = fc_layer(input=x, size=1, act=None, program=program) + y_predict = layers.fc(input=x, size=1, act=None, program=program) - y = data_layer( + y = layers.data( name='y', shape=[1], data_type='float32', program=program) - cost = square_error_cost(input=y_predict, label=y, program=program) + cost = layers.square_error_cost( + input=y_predict, label=y, program=program) - avg_cost = mean(x=cost, program=program) + avg_cost = layers.mean(x=cost, program=program) self.assertIsNotNone(avg_cost) + program.append_backward(avg_cost, set()) print str(program) def test_recognize_digits_mlp(self): - pd = core.ProgramDesc.__create_program_desc__() - program = Program(desc=pd) + program = Program() # Change g_program, so the rest layers use `g_program` - images = data_layer( + images = layers.data( name='pixel', shape=[784], data_type='float32', program=program) - label = data_layer( + label = layers.data( name='label', shape=[1], data_type='int32', program=program) - hidden1 = fc_layer(input=images, size=128, act='relu', program=program) - hidden2 = fc_layer(input=hidden1, size=64, act='relu', program=program) - predict = fc_layer( - input=hidden2, size=10, act='softmax', program=program) - cost = cross_entropy(input=predict, label=label, program=program) - avg_cost = mean(x=cost, program=program) + hidden1 = layers.fc(input=images, size=128, act='relu', program=program) + hidden2 = layers.fc(input=hidden1, size=64, act='relu', program=program) + predict = layers.fc(input=hidden2, + size=10, + act='softmax', + program=program) + cost = layers.cross_entropy(input=predict, label=label, program=program) + avg_cost = layers.mean(x=cost, program=program) self.assertIsNotNone(avg_cost) + # print str(program) + + def test_simple_conv2d(self): + pd = core.ProgramDesc.__create_program_desc__() + program = Program(desc=pd) + images = data_layer( + name='pixel', shape=[3, 48, 48], data_type='int32', program=program) + conv2d_layer( + input=images, num_filters=3, filter_size=[4, 4], program=program) + + # print str(program) + + def test_simple_conv2d(self): + program = Program() + images = layers.data( + name='pixel', shape=[3, 48, 48], data_type='int32', program=program) + layers.conv2d( + input=images, num_filters=3, filter_size=[4, 4], program=program) + print str(program) diff --git a/python/paddle/v2/framework/tests/test_program.py b/python/paddle/v2/framework/tests/test_program.py index c98dc3492b9506f5a1639d26e3f0f1c2c9ad75d7..c55dd8de7282d4c941777054ad9d6437c87f0bc6 100644 --- a/python/paddle/v2/framework/tests/test_program.py +++ b/python/paddle/v2/framework/tests/test_program.py @@ -34,8 +34,26 @@ class TestProgram(unittest.TestCase): self.assertEqual(1, b.idx) self.assertEqual(0, b.parent_idx) + def test_program_clone(self): + prog = Program() + + x = prog.global_block().create_var( + name='X', shape=[1000, 784], dtype='float32') + + y = prog.global_block().create_var( + name='Y', shape=[784, 100], dtype='float32') + out = prog.global_block().create_var(name='Out', dtype='float32') + prog.global_block().append_op( + type="mul", inputs={'X': [x], + 'Y': [y]}, outputs={'Out': [out]}) + + # FIXME(yuyang18): We manual compare the output string, since the order + # of variable could be changed. + print prog + print prog.clone() + def test_append_backward(self): - prog = Program.instance() + prog = Program() block = prog.global_block() mul_x = block.create_var( diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index c775b1a398dabb096845b4a8730152c682b2f0dd..2fd3d5d165ada5026510e0dc3e2c55b6e0596ff3 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -4,7 +4,7 @@ import paddle.v2.framework.core as core class TestOpDesc(unittest.TestCase): def test_op_desc(self): - prog = core.ProgramDesc.__create_program_desc__() + prog = core.ProgramDesc() self.assertIsNotNone(prog) block = prog.block(0) self.assertIsNotNone(block) @@ -64,16 +64,16 @@ class TestOpDesc(unittest.TestCase): class TestProgramDesc(unittest.TestCase): def test_instance(self): - program_desc = core.ProgramDesc.__create_program_desc__() + program_desc = core.ProgramDesc() self.assertIsNotNone(program_desc) del program_desc - program_desc = core.ProgramDesc.instance() + program_desc = core.ProgramDesc() self.assertIsNotNone(program_desc) self.assertIsNotNone(program_desc.block(0)) del program_desc def test_append_block(self): - prog_desc = core.ProgramDesc.__create_program_desc__() + prog_desc = core.ProgramDesc() self.assertIsNotNone(prog_desc) block_root = prog_desc.block(0) self.assertIsNotNone(block_root) @@ -91,7 +91,7 @@ class TestProgramDesc(unittest.TestCase): class TestVarDesc(unittest.TestCase): def test_shape(self): - program_desc = core.ProgramDesc.__create_program_desc__() + program_desc = core.ProgramDesc() block = program_desc.block(0) var = block.var('my_var') var.set_type(core.VarDesc.VarType.SELECTED_ROWS) @@ -102,7 +102,7 @@ class TestVarDesc(unittest.TestCase): self.assertEqual(core.VarDesc.VarType.SELECTED_ROWS, var.type()) def test_data_type(self): - program_desc = core.ProgramDesc.__create_program_desc__() + program_desc = core.ProgramDesc() block = program_desc.block(0) var = block.var('my_var') var.set_type(core.VarDesc.VarType.LOD_TENSOR) @@ -113,7 +113,7 @@ class TestVarDesc(unittest.TestCase): class TestBlockDesc(unittest.TestCase): def test_add_var(self): - prog = core.ProgramDesc.__create_program_desc__() + prog = core.ProgramDesc() self.assertIsNotNone(prog) block = prog.block(0) self.assertIsNotNone(block) @@ -121,19 +121,21 @@ class TestBlockDesc(unittest.TestCase): var2 = block.var("var2") var3 = block.var("var3") all_vars = block.all_vars() - self.assertEqual(set(all_vars), set([var1, var2, var3])) + self.assertEqual(set(all_vars), {var1, var2, var3}) var2_re = block.find_var("var2") self.assertEqual(var2_re, var2) def test_add_op(self): - prog = core.ProgramDesc.__create_program_desc__() + prog = core.ProgramDesc() self.assertIsNotNone(prog) block = prog.block(0) self.assertIsNotNone(block) op1 = block.append_op() op2 = block.append_op() op0 = block.prepend_op() - all_ops = block.all_ops() + all_ops = [] + for idx in xrange(0, block.op_size()): + all_ops.append(block.op(idx)) self.assertEqual(all_ops, [op0, op1, op2]) diff --git a/python/paddle/v2/framework/tests/test_proximal_gd_op.py b/python/paddle/v2/framework/tests/test_proximal_gd_op.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca79ce6b3b710244e4f65db70b305231a9f3fcf --- /dev/null +++ b/python/paddle/v2/framework/tests/test_proximal_gd_op.py @@ -0,0 +1,33 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestProximalGDOp(OpTest): + def setUp(self): + self.op_type = "proximal_gd" + w = np.random.random((102, 105)).astype("float32") + g = np.random.random((102, 105)).astype("float32") + lr = np.array([0.1]).astype("float32") + l1 = 0.1 + l2 = 0.2 + + self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr} + self.attrs = {'l1': l1, 'l2': l2} + prox_param = w - lr * g + param_out = 0.0 + if l1 > 0.0: + x = np.abs(prox_param) - lr * l1 + x[x < 0] = 0 + param_out = np.sign(prox_param) * (x / (1.0 + lr * l2)) + else: + param_out = prox_param / (1.0 + lr * l2) + + self.outputs = {'ParamOut': param_out} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_selected_rows.py b/python/paddle/v2/framework/tests/test_selected_rows.py index 661e81817951f5605ba3ca7fb0cc667074b1e37c..e8a930cb08c42b48f678bdd7bdb7698923535d4f 100644 --- a/python/paddle/v2/framework/tests/test_selected_rows.py +++ b/python/paddle/v2/framework/tests/test_selected_rows.py @@ -8,29 +8,30 @@ class TestSelectedRows(unittest.TestCase): place = core.CPUPlace() height = 10 rows = [0, 4, 7] - row_numel = 10 - selcted_rows = core.SelectedRows(rows, row_numel) - np_array = np.ones((len(rows), height)).astype("float32") + row_numel = 12 + selected_rows = core.SelectedRows(rows, height) + np_array = np.ones((len(rows), row_numel)).astype("float32") np_array[0, 0] = 2.0 np_array[2, 8] = 4.0 - tensor = selcted_rows.get_tensor() + tensor = selected_rows.get_tensor() tensor.set(np_array, place) # compare rows - self.assertEqual(0, selcted_rows.rows()[0]) - self.assertEqual(4, selcted_rows.rows()[1]) - self.assertEqual(7, selcted_rows.rows()[2]) + self.assertEqual(0, selected_rows.rows()[0]) + self.assertEqual(4, selected_rows.rows()[1]) + self.assertEqual(7, selected_rows.rows()[2]) # compare height - self.assertEqual(10, selcted_rows.height()) + self.assertEqual(10, selected_rows.height()) # compare tensor self.assertAlmostEqual(2.0, - selcted_rows.get_tensor().get_float_element(0)) + selected_rows.get_tensor().get_float_element(0)) self.assertAlmostEqual(1.0, - selcted_rows.get_tensor().get_float_element(1)) + selected_rows.get_tensor().get_float_element(1)) self.assertAlmostEqual( - 4.0, selcted_rows.get_tensor().get_float_element(2 * row_numel + 8)) + 4.0, + selected_rows.get_tensor().get_float_element(2 * row_numel + 8)) if __name__ == "__main__": diff --git a/python/paddle/v2/framework/tests/test_sgd_op.py b/python/paddle/v2/framework/tests/test_sgd_op.py index 2dd881e5e107249277a91bd8e3a72567269e1cd4..01262bba4d43adaed179baef88ccab6e69b0884b 100644 --- a/python/paddle/v2/framework/tests/test_sgd_op.py +++ b/python/paddle/v2/framework/tests/test_sgd_op.py @@ -1,5 +1,7 @@ import unittest import numpy as np +import paddle.v2.framework.core as core +from paddle.v2.framework.op import Operator from op_test import OpTest @@ -17,5 +19,70 @@ class TestSGDOp(OpTest): self.check_output() +class TestSparseSGDOp(unittest.TestCase): + def check_with_place(self, place): + scope = core.Scope() + + # create and initialize Grad Variable + height = 10 + rows = [0, 4, 7] + row_numel = 12 + + grad_selected_rows = scope.var('Grad').get_selected_rows() + grad_selected_rows.set_height(height) + grad_selected_rows.set_rows(rows) + np_array = np.ones((len(rows), row_numel)).astype("float32") + np_array[0, 0] = 2.0 + np_array[2, 8] = 4.0 + + grad_tensor = grad_selected_rows.get_tensor() + grad_tensor.set(np_array, place) + + # create and initialize Param Variable + param = scope.var('Param').get_tensor() + param_array = np.full((height, row_numel), 5.0).astype("float32") + param.set(param_array, place) + + # create and initialize LeraningRate Variable + lr = scope.var('LearningRate').get_tensor() + lr_array = np.full((1), 2.0).astype("float32") + lr.set(lr_array, place) + + # create and run sgd operator + sgd_op = Operator( + "sgd", + Param='Param', + Grad='Grad', + ParamOut='Param', + LearningRate='LearningRate') + ctx = core.DeviceContext.create(place) + sgd_op.run(scope, ctx) + + # get and compare result + result_array = np.array(param) + + # rows[0] = 0, 5.0 - 2.0 * 2.0 + self.assertAlmostEqual(1.0, result_array[rows[0], 0]) + # rows[0] = 0, 5.0 - 2.0 * 1.0 + self.assertAlmostEqual(3.0, result_array[rows[0], 2]) + # 5.0 - 2.0 * 0.0 + self.assertAlmostEqual(5.0, result_array[1, 0]) + # rows[1] = 4, 5.0 - 2.0 * 1.0 + self.assertAlmostEqual(3.0, result_array[rows[1], 10]) + # 5.0 - 2.0 * 0.0 + self.assertAlmostEqual(5.0, result_array[5, 8]) + # rows[2] = 7, 5.0 - 2.0 * 1.0 + self.assertAlmostEqual(3.0, result_array[rows[2], 1]) + # rows[2] = 7, 5.0 - 2.0 * 4.0 + self.assertAlmostEqual(-3.0, result_array[rows[2], 8]) + + def test_sparse_sgd(self): + places = [core.CPUPlace()] + if core.is_compile_gpu(): + places.append(core.GPUPlace(0)) + for place in places: + self.check_with_place(place) + + if __name__ == "__main__": unittest.main()