提交 74b283c9 编写于 作者: W wanghaoshuang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into seq_expand_op

# Prune
## Motivation
We want to support running inference, training and checkpointing in one `ProgramDesc`. We implement
`void Prune(const ProgramDesc* input, ProgramDesc* output)` function, which takes a `ProgramDesc`
and generate a pruned `ProgramDesc`.
## Challenge
Pruning need to support both variables and operators being evaluation targets. Consider the following
different situations.
```python
# Case 1: run foward pass.
cost_np = session.run(target=cost)
# Case 2: run backward passing.
opts_np, _ = session.run(target=[cost, opt])
# Case 3: run checkpointing
_ = session.run(target=checkpoint)
```
## Solution
To support evaluation of operators, we add `is_target` field in the `OpDesc`.
```c++
message OpDesc {
required string type = 3;
repeated Var inputs = 1;
repeated Var outputs = 2;
repeated Attr attrs = 4;
optional bool is_target = 5 [ default = false ];
};
```
To support evaluation of variables, we add [fetch_op](https://github.com/PaddlePaddle/Paddle/pull/4599).
For each variable in the `target`, we insert a `fetch_op` into the `ProgramDesc` with `variable` being
`fetch_op`'s input. Then we also set `fetch_op` is a target.
### Algorithm
If an operator needs to be run, it must fall into one of the following cases:
1. It is the target.
2. It is depended by some other ops, meaning its output is some other op's input.
The first case can be checked by `op_desc.is_traget()` . The second case can be implement as
```c++
bool HasDependentVar(const OpDesc& op_desc, const std::set<string>& dependent_vars) {
for (auto& var : op_desc.outputs()) {
for (auto& argu : var.arguments()) {
if (dependent_vars.count(argu) != 0) {
return true;
}
}
}
return false;
}
```
Then the whole algorithm can be implemented as the following [code](https://github.com/tonyyang-svail/Paddle/blob/prune_impl/paddle/framework/prune.cc).
...@@ -177,9 +177,6 @@ REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, grad_op_class) ...@@ -177,9 +177,6 @@ REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, grad_op_class)
REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class)
``` ```
### USE Macros
Make sure the registration process is executed and linked.
--- ---
# Registration Process # Registration Process
1. Write an Op class and its gradient Op class, if required. 1. Write an Op class and its gradient Op class, if required.
...@@ -188,8 +185,6 @@ Make sure the registration process is executed and linked. ...@@ -188,8 +185,6 @@ Make sure the registration process is executed and linked.
1. Call maker class to complete `proto` and `checker` 1. Call maker class to complete `proto` and `checker`
2. Using the completed `proto` and `checker`, it will add a new key-value pair to the `OpInfoMap` 2. Using the completed `proto` and `checker`, it will add a new key-value pair to the `OpInfoMap`
4. Invoke the `USE` macro in which the Op is used to make sure that it is linked.
--- ---
# Backward Module (1/2) # Backward Module (1/2)
### Create Backward Operator ### Create Backward Operator
......
# Regularization in PaddlePaddle
## Introduction to Regularization
A central problem in machine learning is how to design an algorithm that will perform well not just on the training data, but also on new data. Many strategies are used by machine learning practitioners to reduce the test error, possibly at the expense of increased training error. These strategies are collectively known as **regularization**.
### Parameter Norm Penalties
Most common regularization approaches in deep learning are based on limiting the capacity of the models by adding a parameter norm penalty to the objective function `J`. This is given as follows:
<img src="./images/loss_equation.png" align="center"/><br/>
The parameter `alpha` is a hyperparameter that weights the relative contribution of the norm penalty term, `omega`, relative to the standard objective function `J`.
The most commonly used norm penalties are the L2 norm penalty and the L1 norm penalty. These are given as follows:
##### L2 Regularization:
<img src="./images/l2_regularization.png" align="center"/><br/>
##### L1 Regularization
<img src="./images/l1_regularization.png" align="center"/><br/>
A much more detailed mathematical background of reguilarization can be found [here](http://www.deeplearningbook.org/contents/regularization.html).
## How to do Regularization in PaddlePaddle
On surveying existing frameworks like Tensorflow, PyTorch, Caffe, etc, it can be seen that there are 2 common approaches of doing regularization:
1. Making regularization a part of the optimizer using an attribute like `weight_decay` that is used to control the scale of the L2 Penalty. This approach is used in PyTorch as follows:
```python
opt = torch.optim.SGD(params, lr=0.2, weight_decay=0.2)
```
At every optimization step, this code will add the gradient of the L2 Norm of the params to the gradient of the params with respect to the loss function. This can seen in the following code snippet:
```python
if weight_decay != 0:
d_p.add_(weight_decay, p.data)
```
This is a very restyrictive way of doing regularization and does not give the users enough flexibility.
**Advantages**:
- It is easy to implement for us.
- Faster execution of backward. However, it can be done manually by advanced users too.
**Disadvantages**:
- Not flexible for other regularizations such as L1/L0 regularization.
- Does not allow for different regularization coefficient for different parameters. For example, in most models, ony the weight matrices are regularized and the bias vectors are unregularized.
- Tightly coupled optimizer and regularization implementation.
2. Adding regularization ops to the graph through Python API. This approach is used by Tensorflow and Caffe. Using this approach, we manually add regularization ops to the graph and then add the regularization loss to the final loss function before sending them to the optimizer.
**Advantages**:
- Allows for greater flexibility to the users of Paddle. Using this approach, the users can put different regularization to different parameters and also choose parameters that are not a part of regularization.
- Makes it easy for the users to customize and extend the framework.
**Disadvantages**:
- Implementation requires comprehensive design and time.
## Proposal for Regularization in PaddlePaddle
### Low-Level implementation
In the new design, we propose to create new operations for regularization. For now, we can add 2 ops thgat correspond to the most frequently used regularizations:
- L2_regularization_op
- L1_regularization_op
These ops can be like any other ops with their own CPU/GPU implementations either using Eigen or separate Cpu and GPU kernels. As the initial implementation, we can implement their kernels using Eigen following the abstraction pattern implemented for [Activation Ops](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/accuracy_op.h). This abstraction pattern can make it very easy to implement new regularization schemes. other than L1 and L2 norm penalties.
The idea of building ops for regularization is in sync with the refactored Paddle philosophy of using operators to represent any computation unit. The way these ops will be added to the computation graph, will be decided by the [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) in Python API.
### Computation Graph
Below is an example of a really simple feed forward neural network.
<img src="./images/feed_forward.png" align="center"/><br/>
The Python API will modify this computation graph to add regularization operators. The modified computation graph will look as follows:
<img src="./images/feed_forward_regularized.png" align="center"/><br/>
   
### Python API implementation for Regularization
Using the low level ops, `L2_regularization_op` and `L1_regularization_op`, any user can add regularization to their computation graphs. However, this will require a lot of lines of code and we should design Python APIs that support regularization. An example of such an API can be seen in [Keras](https://keras.io/regularizers/). As per the PaddlePaddle [Python API design](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md), the layer functions are responsible for creating operators, operator parameters and variables. Since regularization is a property of parameters, it makes sense to create these in the layer functions.
#### Creation of Regularization ops
There are two possibilities for creating the regularization ops:
1. We create these ops immediately while building the computation graph.
2. We add these ops in a lazy manner, just before the backward, similar to the way the optimization ops are added.
The proposal is to add these ops in a lazy manner just before the backward pass.
#### Storage of Regularization attributes
Since we want to create the regularization ops in a lazy manner, the regularization attributes (type of regularization and weight of regularization penalty) can be stored as attributes of the [`Parameter`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/framework/framework.py#L421) class. This is because regularization is a property of the parameters and storing regularization properties with Parameters also allows for shared parameters.
#### High-level API
In PaddlePaddle Python API, users will primarily rely on [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) to create neural network layers. Hence, we lso need to provide regularization functionality in layer functions. The design of these APIs can be postponed for later right now. A good reference for these APIs can be found in [Keras](https://keras.io/regularizers/) and also by looking at Tensorflow in [`tf.contrib.layers`](https://www.tensorflow.org/api_guides/python/contrib.layers).
...@@ -20,13 +20,14 @@ proto_library(framework_proto SRCS framework.proto) ...@@ -20,13 +20,14 @@ proto_library(framework_proto SRCS framework.proto)
cc_library(attribute SRCS attribute.cc DEPS framework_proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info)
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope proto_desc) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope proto_desc glog)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
py_proto_compile(framework_py_proto SRCS framework.proto) py_proto_compile(framework_py_proto SRCS framework.proto)
...@@ -44,6 +45,9 @@ cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_co ...@@ -44,6 +45,9 @@ cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_co
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward)
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
cc_library(tensor_array SRCS tensor_array.cc DEPS lod_tensor) cc_library(tensor_array SRCS tensor_array.cc DEPS lod_tensor)
cc_test(tensor_array_test SRCS tensor_array_test.cc DEPS tensor_array place) cc_test(tensor_array_test SRCS tensor_array_test.cc DEPS tensor_array place)
......
...@@ -19,19 +19,7 @@ limitations under the License. */ ...@@ -19,19 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
static ProgramDesc* g_program_desc = nullptr; Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* program) {
ProgramDesc& GetProgramDesc() {
if (g_program_desc == nullptr) {
g_program_desc = new ProgramDesc();
auto root_block = g_program_desc->mutable_blocks()->Add();
root_block->set_idx(0);
root_block->set_parent_idx(-1);
}
return *g_program_desc;
}
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) { switch (attr_desc.type()) {
case framework::AttrType::BOOLEAN: { case framework::AttrType::BOOLEAN: {
return attr_desc.b(); return attr_desc.b();
...@@ -74,7 +62,9 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { ...@@ -74,7 +62,9 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
return val; return val;
} }
case framework::AttrType::BLOCK: { case framework::AttrType::BLOCK: {
return GetProgramDesc().mutable_blocks(attr_desc.block_idx()); PADDLE_ENFORCE(program != nullptr,
"Need to specify ProgramDesc when get a block attr");
return program->mutable_blocks(attr_desc.block_idx());
} }
} }
PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !"); PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !");
......
...@@ -26,16 +26,13 @@ limitations under the License. */ ...@@ -26,16 +26,13 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
ProgramDesc& GetProgramDesc();
template <typename T> template <typename T>
inline AttrType AttrTypeID() { inline AttrType AttrTypeID() {
Attribute tmp = T(); Attribute tmp = T();
return static_cast<AttrType>(tmp.which() - 1); return static_cast<AttrType>(tmp.which() - 1);
} }
Attribute GetAttrValue(const OpDesc::Attr& attr_desc); Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* desc);
class AttrReader { class AttrReader {
public: public:
......
...@@ -309,8 +309,7 @@ static void CreateGradVarInBlock( ...@@ -309,8 +309,7 @@ static void CreateGradVarInBlock(
} }
std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
const std::unique_ptr<OpDescBind>& op_desc, const OpDescBind* op_desc, std::unordered_set<std::string>* no_grad_vars,
std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var) { std::unordered_map<std::string, std::string>* grad_to_var) {
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs; std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
// All input gradients of forwarding operator do not need to calculate. // All input gradients of forwarding operator do not need to calculate.
...@@ -357,7 +356,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -357,7 +356,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
std::unordered_set<std::string>* no_grad_vars, std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var) { std::unordered_map<std::string, std::string>* grad_to_var) {
BlockDescBind* cur_block = program_desc.Block(block_idx); BlockDescBind* cur_block = program_desc.Block(block_idx);
std::deque<std::unique_ptr<OpDescBind>>& op_descs = cur_block->ops_; std::vector<OpDescBind*> op_descs = cur_block->AllOps();
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops; std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
size_t grad_desc_idx = 0; size_t grad_desc_idx = 0;
std::vector<std::unique_ptr<OpDescBind>> backward_descs; std::vector<std::unique_ptr<OpDescBind>> backward_descs;
...@@ -375,7 +374,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -375,7 +374,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
program_desc, step_block_idx, no_grad_vars, grad_to_var); program_desc, step_block_idx, no_grad_vars, grad_to_var);
BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block); BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block);
for (auto& ptr : backward_block_op_descs) { for (auto& ptr : backward_block_op_descs) {
backward_block->ops_.push_back(std::move(ptr)); backward_block->AppendAllocatedOp(std::move(ptr));
} }
op_grads[0]->SetBlockAttr("step_block", *backward_block); op_grads[0]->SetBlockAttr("step_block", *backward_block);
} }
...@@ -432,7 +431,6 @@ ParamGradInfoMap AppendBackward( ...@@ -432,7 +431,6 @@ ParamGradInfoMap AppendBackward(
const int root_block_idx = 0; const int root_block_idx = 0;
auto root_block = program_desc.Block(root_block_idx); auto root_block = program_desc.Block(root_block_idx);
auto& all_ops = root_block->ops_;
// insert fill one op for target // insert fill one op for target
// TODO(qiao) add some check to the target. // TODO(qiao) add some check to the target.
...@@ -447,8 +445,8 @@ ParamGradInfoMap AppendBackward( ...@@ -447,8 +445,8 @@ ParamGradInfoMap AppendBackward(
{{"shape", target_shape}, {{"shape", target_shape},
{"value", static_cast<float>(1.0)}, {"value", static_cast<float>(1.0)},
{"data_type", framework::DataType::FP32}})); {"data_type", framework::DataType::FP32}}));
all_ops.push_back(std::move(fill_one_op)); root_block->AppendAllocatedOp(std::move(fill_one_op));
size_t forward_op_num = all_ops.size(); size_t forward_op_num = root_block->OpSize();
size_t forward_block_num = program_desc.Size(); size_t forward_block_num = program_desc.Size();
// Insert backward operators // Insert backward operators
...@@ -457,7 +455,7 @@ ParamGradInfoMap AppendBackward( ...@@ -457,7 +455,7 @@ ParamGradInfoMap AppendBackward(
&no_grad_var_names, &grad_to_var); &no_grad_var_names, &grad_to_var);
for (auto& ptr : backward_op_descs) { for (auto& ptr : backward_op_descs) {
all_ops.push_back(std::move(ptr)); root_block->AppendAllocatedOp(std::move(ptr));
} }
// Create Variable // Create Variable
......
...@@ -495,19 +495,8 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -495,19 +495,8 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
EXPECT_EQ(bwd_net->ops_[2]->Outputs(all).size(), 0UL); EXPECT_EQ(bwd_net->ops_[2]->Outputs(all).size(), 0UL);
} }
// =================================== //
f::ProgramDesc *GetNewProgramDesc() {
auto *program_desc = new f::ProgramDesc();
auto *root_block = program_desc->add_blocks();
root_block->set_idx(0);
root_block->set_parent_idx(-1);
return program_desc;
}
TEST(Backward, simple_single_op) { TEST(Backward, simple_single_op) {
f::ProgramDesc *program_desc = GetNewProgramDesc(); f::ProgramDescBind program;
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op = block->AppendOp(); f::OpDescBind *op = block->AppendOp();
...@@ -543,8 +532,7 @@ TEST(Backward, simple_single_op) { ...@@ -543,8 +532,7 @@ TEST(Backward, simple_single_op) {
} }
TEST(Backward, default_attribute) { TEST(Backward, default_attribute) {
f::ProgramDesc *program_desc = GetNewProgramDesc(); f::ProgramDescBind program;
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op = block->AppendOp(); f::OpDescBind *op = block->AppendOp();
op->SetType("mul"); op->SetType("mul");
...@@ -570,8 +558,7 @@ TEST(Backward, default_attribute) { ...@@ -570,8 +558,7 @@ TEST(Backward, default_attribute) {
} }
TEST(Backward, simple_mult_op) { TEST(Backward, simple_mult_op) {
f::ProgramDesc *program_desc = GetNewProgramDesc(); f::ProgramDescBind program;
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDescBind *op1 = block->AppendOp();
op1->SetType("rowwise_add"); op1->SetType("rowwise_add");
...@@ -654,8 +641,7 @@ TEST(Backward, simple_mult_op) { ...@@ -654,8 +641,7 @@ TEST(Backward, simple_mult_op) {
} }
TEST(Backward, intermedia_var_no_grad) { TEST(Backward, intermedia_var_no_grad) {
f::ProgramDesc *program_desc = GetNewProgramDesc(); f::ProgramDescBind program;
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDescBind *op1 = block->AppendOp();
op1->SetType("rowwise_add"); op1->SetType("rowwise_add");
...@@ -725,8 +711,7 @@ TEST(Backward, intermedia_var_no_grad) { ...@@ -725,8 +711,7 @@ TEST(Backward, intermedia_var_no_grad) {
} }
TEST(Backward, var_no_grad) { TEST(Backward, var_no_grad) {
f::ProgramDesc *program_desc = GetNewProgramDesc(); f::ProgramDescBind program;
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDescBind *op1 = block->AppendOp();
op1->SetType("mult_in_out"); op1->SetType("mult_in_out");
...@@ -802,8 +787,7 @@ TEST(Backward, var_no_grad) { ...@@ -802,8 +787,7 @@ TEST(Backward, var_no_grad) {
} }
TEST(Backward, shared_var) { TEST(Backward, shared_var) {
f::ProgramDesc *program_desc = GetNewProgramDesc(); f::ProgramDescBind program;
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDescBind *op1 = block->AppendOp();
op1->SetType("rowwise_add"); op1->SetType("rowwise_add");
...@@ -893,8 +877,7 @@ TEST(Backward, shared_var) { ...@@ -893,8 +877,7 @@ TEST(Backward, shared_var) {
} }
TEST(Backward, half_backward) { TEST(Backward, half_backward) {
f::ProgramDesc *program_desc = GetNewProgramDesc(); f::ProgramDescBind program;
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.Block(0);
auto *op1 = block->AppendOp(); auto *op1 = block->AppendOp();
op1->SetType("minus"); op1->SetType("minus");
......
...@@ -19,11 +19,11 @@ namespace paddle { ...@@ -19,11 +19,11 @@ namespace paddle {
namespace framework { namespace framework {
VarDescBind *BlockDescBind::Var(const std::string &name) { VarDescBind *BlockDescBind::Var(const std::string &name) {
need_update_ = true;
auto it = vars_.find(name); auto it = vars_.find(name);
if (it != vars_.end()) { if (it != vars_.end()) {
return it->second.get(); return it->second.get();
} }
need_update_ = true;
auto *var = new VarDescBind(name); auto *var = new VarDescBind(name);
vars_[name].reset(var); vars_[name].reset(var);
return var; return var;
...@@ -55,6 +55,11 @@ OpDescBind *BlockDescBind::AppendOp() { ...@@ -55,6 +55,11 @@ OpDescBind *BlockDescBind::AppendOp() {
return ops_.back().get(); return ops_.back().get();
} }
void BlockDescBind::AppendAllocatedOp(std::unique_ptr<OpDescBind> &&op_desc) {
need_update_ = true;
ops_.emplace_back(std::move(op_desc));
}
OpDescBind *BlockDescBind::PrependOp() { OpDescBind *BlockDescBind::PrependOp() {
need_update_ = true; need_update_ = true;
ops_.emplace_front(new OpDescBind()); ops_.emplace_front(new OpDescBind());
...@@ -70,6 +75,10 @@ std::vector<OpDescBind *> BlockDescBind::AllOps() const { ...@@ -70,6 +75,10 @@ std::vector<OpDescBind *> BlockDescBind::AllOps() const {
} }
void BlockDescBind::Flush() { void BlockDescBind::Flush() {
for (auto &op_desc : ops_) {
op_desc->Flush();
}
if (need_update_) { if (need_update_) {
auto &op_field = *this->desc_->mutable_ops(); auto &op_field = *this->desc_->mutable_ops();
this->ClearPBOps(); this->ClearPBOps();
...@@ -98,6 +107,19 @@ BlockDesc *BlockDescBind::Proto() { ...@@ -98,6 +107,19 @@ BlockDesc *BlockDescBind::Proto() {
Flush(); Flush();
return desc_; return desc_;
} }
BlockDescBind::BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
ProgramDescBind *prog)
: prog_(prog), desc_(desc) {
need_update_ = true;
for (auto &op : other.ops_) {
ops_.emplace_back(new OpDescBind(*op));
}
for (auto &it : other.vars_) {
auto *var = new VarDescBind(*it.second);
vars_[it.first].reset(var);
}
}
void BlockDescBind::ClearPBOps() { void BlockDescBind::ClearPBOps() {
auto ops = this->desc_->mutable_ops(); auto ops = this->desc_->mutable_ops();
......
...@@ -16,8 +16,10 @@ limitations under the License. */ ...@@ -16,8 +16,10 @@ limitations under the License. */
#include <deque> #include <deque>
#include <memory> #include <memory>
#include <set>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/framework/op_desc.h" #include "paddle/framework/op_desc.h"
#include "paddle/framework/var_desc.h" #include "paddle/framework/var_desc.h"
#include "paddle/platform/macros.h" #include "paddle/platform/macros.h"
...@@ -36,6 +38,9 @@ class BlockDescBind { ...@@ -36,6 +38,9 @@ class BlockDescBind {
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {} : prog_(prog), desc_(desc), need_update_(false) {}
BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
ProgramDescBind *prog);
~BlockDescBind() { ~BlockDescBind() {
this->ClearPBVars(); this->ClearPBVars();
this->ClearPBOps(); this->ClearPBOps();
...@@ -51,16 +56,30 @@ class BlockDescBind { ...@@ -51,16 +56,30 @@ class BlockDescBind {
bool HasVar(const std::string &var_name) const; bool HasVar(const std::string &var_name) const;
std::set<std::string> LocalVarNames() const {
std::set<std::string> var_names;
for (auto &var : vars_) {
var_names.insert(var.first);
}
return var_names;
}
std::vector<VarDescBind *> AllVars() const; std::vector<VarDescBind *> AllVars() const;
BlockDescBind *ParentBlock() const; BlockDescBind *ParentBlock() const;
OpDescBind *AppendOp(); OpDescBind *AppendOp();
void AppendAllocatedOp(std::unique_ptr<OpDescBind> &&op_desc);
OpDescBind *PrependOp(); OpDescBind *PrependOp();
std::vector<OpDescBind *> AllOps() const; std::vector<OpDescBind *> AllOps() const;
size_t OpSize() const { return ops_.size(); }
OpDescBind *Op(int idx) { return ops_.at(idx).get(); }
void Flush(); void Flush();
BlockDesc *Proto(); BlockDesc *Proto();
...@@ -69,9 +88,7 @@ class BlockDescBind { ...@@ -69,9 +88,7 @@ class BlockDescBind {
void ClearPBOps(); void ClearPBOps();
void ClearPBVars(); void ClearPBVars();
// FIXME(yuyang18): backward will access private data of BlockDesc. private:
// Mark it public temporary. We can fix it later.
public:
ProgramDescBind *prog_; // not_own ProgramDescBind *prog_; // not_own
BlockDesc *desc_; // not_own BlockDesc *desc_; // not_own
bool need_update_; bool need_update_;
......
...@@ -75,7 +75,8 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) { ...@@ -75,7 +75,8 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
} }
for (auto& op_desc : block.ops()) { for (auto& op_desc : block.ops()) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(
op_desc, const_cast<ProgramDesc*>(&pdesc));
op->Run(local_scope, *device); op->Run(local_scope, *device);
} }
......
...@@ -55,6 +55,7 @@ message OpDesc { ...@@ -55,6 +55,7 @@ message OpDesc {
repeated Var inputs = 1; repeated Var inputs = 1;
repeated Var outputs = 2; repeated Var outputs = 2;
repeated Attr attrs = 4; repeated Attr attrs = 4;
optional bool is_target = 5 [ default = false ];
}; };
// OpProto describes a C++ framework::OperatorBase derived class. // OpProto describes a C++ framework::OperatorBase derived class.
......
...@@ -74,12 +74,12 @@ class LoDTensor : public Tensor { ...@@ -74,12 +74,12 @@ class LoDTensor : public Tensor {
LoD lod() const { return lod_; } LoD lod() const { return lod_; }
/* /*
* Get a element from LoD. * Get the start offset and end offset of an element from LoD.
*/ */
size_t lod_element(size_t level, size_t elem) const { std::pair<size_t, size_t> lod_element(size_t level, size_t elem) const {
PADDLE_ENFORCE_LT(level, NumLevels()); PADDLE_ENFORCE_LT(level, NumLevels());
PADDLE_ENFORCE_LT(elem, NumElements(level)); PADDLE_ENFORCE_LT(elem, NumElements(level));
return (lod_)[level][elem]; return std::make_pair((lod_)[level][elem], (lod_)[level][elem + 1]);
} }
/* /*
......
...@@ -36,8 +36,8 @@ TEST(LoDTensor, LoDInGPU) { ...@@ -36,8 +36,8 @@ TEST(LoDTensor, LoDInGPU) {
lod_tensor.mutable_data<float>(place); lod_tensor.mutable_data<float>(place);
lod_tensor.set_lod(src_lod); lod_tensor.set_lod(src_lod);
CHECK_EQ(lod_tensor.lod_element(0, 2), 4UL); CHECK_EQ(lod_tensor.lod_element(0, 2).first, 4UL);
CHECK_EQ(lod_tensor.lod_element(0, 4), 8UL); CHECK_EQ(lod_tensor.lod_element(0, 4).first, 8UL);
auto lod = lod_tensor.lod(); auto lod = lod_tensor.lod();
......
...@@ -43,12 +43,13 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap( ...@@ -43,12 +43,13 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap(
return ret_val; return ret_val;
} }
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) { std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc,
ProgramDesc* program) {
VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
AttributeMap attrs; AttributeMap attrs;
for (auto& attr : op_desc.attrs()) { for (auto& attr : op_desc.attrs()) {
attrs[attr.name()] = GetAttrValue(attr); attrs[attr.name()] = GetAttrValue(attr, program);
} }
return CreateOp(op_desc.type(), inputs, outputs, attrs); return CreateOp(op_desc.type(), inputs, outputs, attrs);
......
...@@ -20,6 +20,8 @@ limitations under the License. */ ...@@ -20,6 +20,8 @@ limitations under the License. */
#include <typeinfo> #include <typeinfo>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "glog/logging.h" // For VLOG()
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/details/op_registry.h" #include "paddle/framework/details/op_registry.h"
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
...@@ -74,7 +76,8 @@ class OpRegistry { ...@@ -74,7 +76,8 @@ class OpRegistry {
const VariableNameMap& outputs, const VariableNameMap& outputs,
AttributeMap attrs); AttributeMap attrs);
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc,
ProgramDesc* program);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
}; };
......
...@@ -74,7 +74,7 @@ TEST(OpRegistry, CreateOp) { ...@@ -74,7 +74,7 @@ TEST(OpRegistry, CreateOp) {
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(scale); attr->set_f(scale);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
paddle::framework::Scope scope; paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
...@@ -95,7 +95,7 @@ TEST(OpRegistry, IllegalAttr) { ...@@ -95,7 +95,7 @@ TEST(OpRegistry, IllegalAttr) {
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
} catch (paddle::platform::EnforceNotMet err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "larger_than check fail"; std::string msg = "larger_than check fail";
...@@ -115,7 +115,7 @@ TEST(OpRegistry, DefaultValue) { ...@@ -115,7 +115,7 @@ TEST(OpRegistry, DefaultValue) {
ASSERT_TRUE(op_desc.IsInitialized()); ASSERT_TRUE(op_desc.IsInitialized());
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
paddle::framework::Scope scope; paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
...@@ -131,7 +131,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -131,7 +131,7 @@ TEST(OpRegistry, CustomChecker) {
// attr 'test_attr' is not set // attr 'test_attr' is not set
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
} catch (paddle::platform::EnforceNotMet err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "Attribute 'test_attr' is required!"; std::string msg = "Attribute 'test_attr' is required!";
...@@ -149,7 +149,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -149,7 +149,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(3); attr->set_i(3);
caught = false; caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
} catch (paddle::platform::EnforceNotMet err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "'test_attr' must be even!"; std::string msg = "'test_attr' must be even!";
...@@ -166,7 +166,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -166,7 +166,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_name("test_attr"); attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT); attr->set_type(paddle::framework::AttrType::INT);
attr->set_i(4); attr->set_i(4);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
paddle::framework::Scope scope; paddle::framework::Scope scope;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
......
...@@ -20,12 +20,13 @@ limitations under the License. */ ...@@ -20,12 +20,13 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "op_info.h" #include "glog/logging.h" // For VLOG
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/block_desc.h" #include "paddle/framework/block_desc.h"
#include "paddle/framework/data_type.h" #include "paddle/framework/data_type.h"
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/framework/shape_inference.h" #include "paddle/framework/shape_inference.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
...@@ -573,6 +574,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -573,6 +574,7 @@ class OperatorWithKernel : public OperatorBase {
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const final { const platform::DeviceContext& dev_ctx) const final {
VLOG(3) << "Running operator " << this->Type();
RuntimeInferShapeContext infer_shape_ctx(*this, scope); RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
......
...@@ -83,7 +83,7 @@ TEST(OperatorBase, all) { ...@@ -83,7 +83,7 @@ TEST(OperatorBase, all) {
paddle::platform::CPUDeviceContext device_context; paddle::platform::CPUDeviceContext device_context;
paddle::framework::Scope scope; paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
scope.Var("OUT1"); scope.Var("OUT1");
ASSERT_EQ(paddle::framework::op_run_num, 0); ASSERT_EQ(paddle::framework::op_run_num, 0);
op->Run(scope, device_context); op->Run(scope, device_context);
...@@ -208,7 +208,7 @@ TEST(OpKernel, all) { ...@@ -208,7 +208,7 @@ TEST(OpKernel, all) {
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUDeviceContext cpu_device_context;
paddle::framework::Scope scope; paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
op->Run(scope, cpu_device_context); op->Run(scope, cpu_device_context);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
...@@ -244,7 +244,7 @@ TEST(OpKernel, multi_inputs) { ...@@ -244,7 +244,7 @@ TEST(OpKernel, multi_inputs) {
scope.Var("y0")->GetMutable<Tensor>(); scope.Var("y0")->GetMutable<Tensor>();
scope.Var("y1")->GetMutable<Tensor>(); scope.Var("y1")->GetMutable<Tensor>();
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
op->Run(scope, cpu_device_context); op->Run(scope, cpu_device_context);
} }
......
...@@ -18,27 +18,10 @@ limitations under the License. */ ...@@ -18,27 +18,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using ProgDescMap =
std::unordered_map<ProgramDesc *, std::unique_ptr<ProgramDescBind>>;
static ProgDescMap *g_bind_map = nullptr;
ProgramDescBind &ProgramDescBind::Instance(ProgramDesc *prog) {
if (g_bind_map == nullptr) {
g_bind_map = new ProgDescMap();
}
auto &map = *g_bind_map;
auto &ptr = map[prog];
if (ptr == nullptr) {
ptr.reset(new ProgramDescBind(prog));
}
return *ptr;
}
BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) { BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
auto *b = prog_->add_blocks(); auto *b = prog_.add_blocks();
b->set_parent_idx(parent.ID()); b->set_parent_idx(parent.ID());
b->set_idx(prog_->blocks_size() - 1); b->set_idx(prog_.blocks_size() - 1);
blocks_.emplace_back(new BlockDescBind(this, b)); blocks_.emplace_back(new BlockDescBind(this, b));
return blocks_.back().get(); return blocks_.back().get();
} }
...@@ -47,13 +30,22 @@ ProgramDesc *ProgramDescBind::Proto() { ...@@ -47,13 +30,22 @@ ProgramDesc *ProgramDescBind::Proto() {
for (auto &block : blocks_) { for (auto &block : blocks_) {
block->Flush(); block->Flush();
} }
return prog_; return &prog_;
}
ProgramDescBind::ProgramDescBind() {
auto *block = prog_.mutable_blocks()->Add();
block->set_idx(0);
block->set_parent_idx(-1);
blocks_.emplace_back(new BlockDescBind(this, block));
} }
ProgramDescBind::ProgramDescBind(ProgramDesc *prog) { ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) {
prog_ = prog; prog_ = o.prog_;
for (auto &block : *prog->mutable_blocks()) {
blocks_.emplace_back(new BlockDescBind(this, &block)); for (int i = 0; i < prog_.blocks_size(); ++i) {
auto *block = prog_.mutable_blocks(i);
blocks_.emplace_back(new BlockDescBind(*o.blocks_[i], block, this));
} }
} }
} // namespace framework } // namespace framework
......
...@@ -26,7 +26,9 @@ class BlockDescBind; ...@@ -26,7 +26,9 @@ class BlockDescBind;
class ProgramDescBind { class ProgramDescBind {
public: public:
static ProgramDescBind &Instance(ProgramDesc *prog); ProgramDescBind();
ProgramDescBind(const ProgramDescBind &o);
BlockDescBind *AppendBlock(const BlockDescBind &parent); BlockDescBind *AppendBlock(const BlockDescBind &parent);
...@@ -37,14 +39,9 @@ class ProgramDescBind { ...@@ -37,14 +39,9 @@ class ProgramDescBind {
ProgramDesc *Proto(); ProgramDesc *Proto();
private: private:
explicit ProgramDescBind(ProgramDesc *prog); ProgramDesc prog_;
// Not owned
ProgramDesc *prog_;
std::vector<std::unique_ptr<BlockDescBind>> blocks_; std::vector<std::unique_ptr<BlockDescBind>> blocks_;
DISABLE_COPY_AND_ASSIGN(ProgramDescBind);
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/program_desc.h"
#include "gtest/gtest.h"
#include "paddle/framework/block_desc.h"
namespace paddle {
namespace framework {
TEST(ProgramDesc, copy_ctor) {
ProgramDescBind program;
auto* global_block = program.Block(0);
auto* x = global_block->Var("X");
x->SetType(VarDesc_VarType_LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(FP32);
x->SetShape({1000, 784});
auto* y = global_block->Var("Y");
y->SetType(VarDesc_VarType_LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(FP32);
y->SetShape({784, 100});
auto* op = global_block->AppendOp();
op->SetType("mul");
op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out");
out->SetType(VarDesc_VarType_LOD_TENSOR);
op->SetOutput("Y", {out->Name()});
ProgramDescBind program_copy(program);
auto* global_block_copy = program_copy.Block(0);
ASSERT_NE(global_block, global_block_copy);
auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) {
ASSERT_TRUE(global_block_copy->HasVar(name));
auto* copy = global_block_copy->Var(name);
ASSERT_NE(copy, var_before);
ASSERT_EQ(copy->Name(), var_before->Name());
ASSERT_EQ(copy->GetType(), var_before->GetType());
ASSERT_EQ(copy->Shape(), var_before->Shape());
ASSERT_EQ(copy->Proto()->SerializeAsString(),
var_before->Proto()->SerializeAsString());
};
ASSERT_EQ(global_block->LocalVarNames(), global_block_copy->LocalVarNames());
ASSERT_EQ(3, global_block_copy->LocalVarNames().size());
assert_same_var("X", x);
assert_same_var("Y", y);
assert_same_var("Out", out);
for (size_t i = 0; i < global_block->OpSize(); ++i) {
auto op_origin = global_block->Op(i);
auto op_copy = global_block->Op(i);
ASSERT_EQ(op_origin->Type(), op_copy->Type());
ASSERT_EQ(op_origin->Inputs(), op_copy->Inputs());
ASSERT_EQ(op_origin->Outputs(), op_copy->Outputs());
ASSERT_EQ(op_copy->Proto()->SerializeAsString(),
op_origin->Proto()->SerializeAsString());
}
// Not check block's protostr are same it because the order of vars could be
// different and it is correct.
}
} // namespace framework
} // namespace paddle
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/prune.h"
#include <algorithm>
#include <set>
#include <string>
#include <vector>
#include <glog/logging.h>
namespace paddle {
namespace framework {
const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch";
bool HasDependentVar(const OpDesc& op_desc,
const std::set<std::string>& dependent_vars) {
for (auto& var : op_desc.outputs()) {
for (auto& argu : var.arguments()) {
if (dependent_vars.count(argu) != 0) {
return true;
}
}
}
return false;
}
bool IsTarget(const OpDesc& op_desc) {
if (op_desc.has_is_target()) {
return op_desc.is_target();
}
return false;
}
void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) {
// TODO(tonyyang-svail):
// - will change to use multiple blocks for RNN op and Cond Op
auto& block = input.blocks(block_id);
auto& ops = block.ops();
bool expect_feed = true;
for (auto& op_desc : ops) {
PADDLE_ENFORCE(op_desc.type() != kFeedOpType || expect_feed,
"All FeedOps are at the beginning of the ProgramDesc");
expect_feed = (op_desc.type() == kFeedOpType);
}
bool expect_fetch = true;
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
auto& op_desc = *op_iter;
PADDLE_ENFORCE(op_desc.type() != kFetchOpType || expect_fetch,
"All FetchOps must at the end of the ProgramDesc");
expect_fetch = (op_desc.type() == kFetchOpType);
}
std::set<std::string> dependent_vars;
std::vector<bool> should_run;
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
auto& op_desc = *op_iter;
if (IsTarget(op_desc) || HasDependentVar(op_desc, dependent_vars)) {
// insert its input to the dependency graph
for (auto& var : op_desc.inputs()) {
for (auto& argu : var.arguments()) {
dependent_vars.insert(argu);
}
}
should_run.push_back(true);
} else {
should_run.push_back(false);
}
}
// since we are traversing the ProgramDesc in reverse order
// we reverse the should_run vector
std::reverse(should_run.begin(), should_run.end());
output = input;
auto* op_field = output.mutable_blocks(block_id)->mutable_ops();
op_field->Clear();
for (size_t i = 0; i < should_run.size(); ++i) {
if (should_run[i]) {
*op_field->Add() = input.blocks(block_id).ops(i);
}
}
}
void Prune(const ProgramDesc& input, ProgramDesc& output) {
prune_impl(input, output, 0);
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/framework.pb.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace framework {
void Prune(const ProgramDesc& input, ProgramDesc& output);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/prune.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/operator.h"
#include "paddle/operators/net_op.h"
#include "paddle/framework/block_desc.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/program_desc.h"
#include <gtest/gtest.h>
namespace f = paddle::framework;
namespace ops = paddle::operators;
void AddOp(const std::string &type, const f::VariableNameMap &inputs,
const f::VariableNameMap &outputs, f::AttributeMap attrs,
paddle::framework::BlockDescBind *block) {
// insert output
for (auto kv : outputs) {
for (auto v : kv.second) {
auto var = block->Var(v);
var->SetDataType(paddle::framework::DataType::FP32);
}
}
// insert op
auto op = block->AppendOp();
op->SetType(type);
for (auto &kv : inputs) {
op->SetInput(kv.first, kv.second);
}
for (auto &kv : outputs) {
op->SetOutput(kv.first, kv.second);
}
op->SetAttrMap(attrs);
}
TEST(Prune, one_operator) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, {}, block);
f::ProgramDesc *pdesc = program.Proto();
f::ProgramDesc pruned;
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0);
pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true);
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1);
}
TEST(Prune, forward) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, {}, block);
AddOp("one_one", {{"input", {"b"}}}, {{"output", {"c"}}}, {}, block);
AddOp("one_one", {{"input", {"c"}}}, {{"output", {"d"}}}, {}, block);
AddOp("one_one", {{"input", {"d"}}}, {{"output", {"e"}}}, {}, block);
f::ProgramDesc *pdesc = program.Proto();
for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) {
f::ProgramDesc pruned;
pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true);
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1);
}
}
TEST(Prune, multi_input_op) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
AddOp("one_one", {{"input", {"a0"}}}, {{"output", {"b0"}}}, {}, block);
AddOp("one_one", {{"input", {"a1"}}}, {{"output", {"b1"}}}, {}, block);
AddOp("one_one", {{"input", {"a2"}}}, {{"output", {"b2"}}}, {}, block);
AddOp("three_one", {{"input", {"b0", "b1", "b2"}}}, {{"output", {"c"}}}, {},
block);
f::ProgramDesc *pdesc = program.Proto();
pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true);
f::ProgramDesc pruned;
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4);
}
TEST(Prune, multi_output_op) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, {}, block);
AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, {}, block);
AddOp("one_one", {{"input", {"c"}}}, {{"output", {"c1"}}}, {}, block);
f::ProgramDesc *pdesc = program.Proto();
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned;
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2);
}
TEST(Prune, multi_target) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, {}, block);
AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, {}, block);
AddOp("one_one", {{"input", {"c"}}}, {{"output", {"c1"}}}, {}, block);
f::ProgramDesc *pdesc = program.Proto();
pdesc->mutable_blocks(0)->mutable_ops(1)->set_is_target(true);
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned;
Prune(*pdesc, pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3);
}
...@@ -79,6 +79,10 @@ class VarDescBind { ...@@ -79,6 +79,10 @@ class VarDescBind {
void SetType(VarDesc::VarType type) { desc_.set_type(type); } void SetType(VarDesc::VarType type) { desc_.set_type(type); }
bool Persistable() const { return desc_.persistable(); }
void SetPersistable(bool persistable) { desc_.set_persistable(persistable); }
private: private:
const TensorDesc &tensor_desc() const; const TensorDesc &tensor_desc() const;
TensorDesc *mutable_tensor_desc(); TensorDesc *mutable_tensor_desc();
......
...@@ -62,7 +62,7 @@ namespace paddle { ...@@ -62,7 +62,7 @@ namespace paddle {
namespace framework { namespace framework {
TEST(InferVarType, sum_op) { TEST(InferVarType, sum_op) {
auto &prog = ProgramDescBind::Instance(&GetProgramDesc()); ProgramDescBind prog;
auto *op = prog.Block(0)->AppendOp(); auto *op = prog.Block(0)->AppendOp();
op->SetType("sum"); op->SetType("sum");
op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetInput("X", {"test_a", "test_b", "test_c"});
...@@ -83,7 +83,7 @@ TEST(InferVarType, sum_op) { ...@@ -83,7 +83,7 @@ TEST(InferVarType, sum_op) {
} }
TEST(InferVarType, sum_op_without_infer_var_type) { TEST(InferVarType, sum_op_without_infer_var_type) {
auto &prog = ProgramDescBind::Instance(&GetProgramDesc()); ProgramDescBind prog;
auto *op = prog.Block(0)->AppendOp(); auto *op = prog.Block(0)->AppendOp();
op->SetType("sum_without_infer_var_type"); op->SetType("sum_without_infer_var_type");
op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
......
# Batch Normalization
## What is batch normalization
Batch normalization is a frequently-used method in deep network training. It adjusts the mean and variance of a layer's output, and make the data distribution easier for next layer's training.
The principle of batch normalization can be summarized into a simple function:
```
y = (x - E[x]) / STD[x]) * scale + bias
```
`x` is a batch of output data of a certain layer. `E[x]` and `STD[x]` is the mean and standard deviation of `x`, respectively。 `scale` and `bias` are two trainable parameters. The training of batch normalization layer equals to the learning of best values of `scale` and `bias`.
In our design, we use a single operator(`batch_norm_op`) to implement the whole batch normalization in C++, and wrap it as a layer in Python.
## Differences with normal operators
`batch_norm_op` is a single operator. However, there are a few differences between `BatchNormOp` and normal operators, which we shall take into consideration in our design.
1. `batch_norm_op` shall behave differently in training and inferencing. For example, during inferencing, there is no batch data and it's impossible to compute `E[x]` and `STD[x]`, so we have to use an `estimated_mean` and an `estimated_variance` instead of them. These require our framework to be able to inform operators current running type (training/inferencing), then operators can switch their behaviors.
2. `batch_norm_op` shall have the ability to maintain `estimated_mean` and `estimated_variance` across mini-batch. In each mini-batch, `estimated_mean` is iterated by the following equations:
```
if batch_id == 0
estimated_mean = E[x]
else
estimated_mean = estimated_mean * momentum + (1.0 - momentum_) * E[x]
```
The iterating of `estimated_variance` is similar. `momentum` is an attribute, which controls estimated_mean updating speed.
## Implementation
Batch normalization is designed as a single operator is C++, and then wrapped as a layer in Python.
### C++
As most C++ operators do, `batch_norm_op` is defined by inputs, outputs, attributes and compute kernels.
#### Inputs
- `x`: The inputs data, which is generated by the previous layer.
- `estimated_mean`: The estimated mean of all previous data batches. It is updated in each forward propagation and will be used in inferencing to take the role of `E[x]`.
- `estimated_var`: The estimated standard deviation of all previous data batches. It is updated in each forward propagation and will be used in inferencing to take the role of `STD[x]`.
- `scale`: trainable parameter 'scale'
- `bias`: trainable parameter 'bias'
#### Outputs
- `y`: The output data.
- `batch_mean`: The mean value of batch data.
- `batch_var`: The standard deviation value of batch data.
- `saved_mean`: Updated `estimated_mean` with current batch data. It's supposed to share the memory with input `estimated_mean`.
- `saved_var`: Updated `estimated_var` with current batch data. It's supposed to share the memory with input `estimated_var`.
#### Attributes
- `is_infer`: *bool*. If true, run `batch_norm_op` in inferencing mode.
- `use_global_est`: *bool*. If true, use `saved_mean` and `saved_var` instead of `E[x]` and `STD[x]` in trainning.
- `epsilon`: *float*. The epsilon value to avoid division by zero.
- `momentum`: *float*. Factor used in `estimated_mean` and `estimated_var` updating. The usage is shown above.
#### Kernels
The following graph showes the training computational process of `batch_norm_op`:
<img src="./images/batch_norm_op_kernel.png" width="800"/>
cudnn provides APIs to finish the whole series of computation, we can use them in our GPU kernel.
### Python
`batch_norm_op` is warpped as a layer in Python:
```python
def batch_norm_layer(net,
input,
output,
scale,
bias,
use_global_est = False,
epsilon = 1e-6,
momentum = 0.99):
mean_cache = scope.new_var(name = 'estimated_mean', trainable = False)
var_cache = scop.new_var(name = 'estimated_var', trainable = False)
batch_mean = scope.new_var(name = 'batch_mean')
batch_var = scope.new_var(name = 'batch_var')
batch_norm_op = Operator('batch_norm_op',
x = input,
estimated_mean = mean_cache,
estimated_mean = var_cache,
scale = scale,
bias = bias,
y = output,
batch_mean = batch_mean,
batch_var = batch_var,
saved_mean = mean_cache,
saved_var = var_cache,
is_infer = False,
use_global_est = use_global_est,
epsilon = epsilon,
momentum = momentum)
net.append_op(batch_norm_op)
return output
```
Because Python API has not been finally decided, the code above can be regarded as pseudo code. There are a few key points we shall note:
1. `estimated_mean` and `estimated_var` are assigned the same variables with `saved_mean` and `saved_var` respectively. So they share same the memories. The output mean and variance values(`saved_mean` and `saved_var`) of a certain batch will be the inputs(`estimated_mean` and `estimated_var`) of the next batch.
2. `is_infer` decided whether `batch_norm_op` will run in training mode or inferencing mode. However, a network may contains both training and inferencing parts. And user may switch `batch_norm_op`'s running mode in Python `for` loop like this:
```python
for pass_id in range(PASS_NUM):
# ...
net.train() # run training model
if pass_id % 100 == 0:
net.infer(test_image) # run inferencing model
# ...
```
`is_infer` is an attribute. Once an operator is created, its attributes can not be changed. It suggests us that we shall maintain two `batch_norm_op` in the model, one's `is_infer` is `True`(we call it `infer_batch_norm_op`) and the other one's is `False`(we call it `train_batch_norm_op`). They share all parameters and variables, but be placed in two different branches. That is to say, if a network contains a `batch_norm_op`, it will fork into two branches, one go through `train_batch_norm_op` and the other one go through `infer_batch_norm_op`:
<div align=center>
<img src="./images/batch_norm_fork.png" width="500"/>
</div>
Just like what is shown in the above graph, the net forks before `batch_norm_op` and will never merge again. All the operators after `batch_norm_op` will duplicate.
When the net runs in training mode, the end of the left branch will be set as the running target, so the dependency tracking process will ignore right branch automatically. When the net runs in inferencing mode, the process is reversed.
How to set a target is related to Python API design, so I will leave it here waiting for more discussions.
...@@ -51,7 +51,7 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test { ...@@ -51,7 +51,7 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test {
CreateGlobalVariables(); CreateGlobalVariables();
auto op_desc = CreateOpDesc(); auto op_desc = CreateOpDesc();
op = paddle::framework::OpRegistry::CreateOp(op_desc); op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
dop = dynamic_cast<DynamicRecurrentOp*>(op.get()); dop = dynamic_cast<DynamicRecurrentOp*>(op.get());
InitCacheManually(); InitCacheManually();
InitStepNet(); InitStepNet();
......
digraph ImageBatchNormForkGragh {
subgraph cluster_before {
Prev [label="...", shape=plaintext];
Rnn [label="rnn_op", shape=box];
BatchNorm [label="batch_norm_op", shape=box];
Fc [label="fc_op", shape=box];
After [label="...", shape=plaintext];
Prev -> Rnn -> BatchNorm -> Fc -> After;
label="original";
}
subgraph cluster_after {
Prev2 [label="...", shape=plaintext];
Rnn2 [label="rnn_op", shape=box];
BatchNorm2_1 [label="train_batch_norm_op", shape=box];
BatchNorm2_2 [label="infer_batch_norm_op", shape=box];
Fc2_1 [label="fc_op", shape=box];
Fc2_2 [label="fc_op", shape=box];
After2_1 [label="...", shape=plaintext];
After2_2 [label="...", shape=plaintext];
Prev2 -> Rnn2 -> BatchNorm2_1 -> Fc2_1 -> After2_1;
Rnn2 -> BatchNorm2_2 ->Fc2_2 ->After2_2
label="forked";
}
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/proximal_gd_op.h"
namespace paddle {
namespace operators {
class ProximalGDOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of ProximalGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of ProximalGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of ProximalGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of ProximalGDOp should not be null.");
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"),
"Two input of ProximalGD Op's dimension must be same.");
auto lr_dim = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1,
"Learning Rate should be a scalar.");
ctx->SetOutputDim("ParamOut", param_dim);
}
};
class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ProximalGDOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter value that has to be updated.");
AddInput("Grad",
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter.");
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"The learning rate should be a tensor of size 1.");
AddOutput("ParamOut", "(Tensor) Output updated parameter value.");
AddAttr<float>("l1",
"(float, default 0.0) "
"L1 regularization strength.")
.SetDefault(0.0f);
AddAttr<float>("l2",
"(float, default 0.0)"
"L2 regularization strength.")
.SetDefault(0.0f);
AddComment(R"DOC(
Optimizer that implements the proximal gradient descent algorithm.
prox_param = param - learning_rate * grad
param = sign(prox_param) / (1 + learning_rate * l2) *
max { |prox_param| - learning_rate * l1 , 0 }
The paper that proposed Proximal Gradient Descent:
(http://papers.nips.cc/paper/3793-efficient-learning-using-forward-backward-splitting.pdf)
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(proximal_gd, ops::ProximalGDOp,
ops::ProximalGDOpMaker);
REGISTER_OP_CPU_KERNEL(
proximal_gd, ops::ProximalGDOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/proximal_gd_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
proximal_gd, ops::ProximalGDOpKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class ProximalGDOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* param_out = ctx.Output<Tensor>("ParamOut");
param_out->mutable_data<T>(ctx.GetPlace());
auto grad = ctx.Input<Tensor>("Grad");
auto l1 = static_cast<T>(ctx.Attr<float>("l1"));
auto l2 = static_cast<T>(ctx.Attr<float>("l2"));
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto g = EigenVector<T>::Flatten(*grad);
auto lr = EigenVector<T>::Flatten(*ctx.Input<Tensor>("LearningRate"));
auto p_out = EigenVector<T>::Flatten(*param_out);
auto place = ctx.GetEigenDevice<Place>();
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
auto prox_param = p - lr.broadcast(grad_dsize) * g;
if (l1 > 0) {
p_out.device(place) =
prox_param.sign() *
(((prox_param.abs() - (lr * l1).broadcast(grad_dsize))
.cwiseMax(T(0.0))) /
(1.0 + (lr * l2).broadcast(grad_dsize)));
} else {
p_out.device(place) =
prox_param / (1.0 + (lr * l2).broadcast(grad_dsize));
}
}
};
} // namespace operators
} // namespace paddle
...@@ -21,7 +21,7 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -21,7 +21,7 @@ class SGDOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"), PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of SGDOp should not be null."); "Input(Param) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"), PADDLE_ENFORCE(ctx->HasInput("Grad"),
...@@ -35,15 +35,15 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -35,15 +35,15 @@ class SGDOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element"); "Learning rate should have 1 element");
auto param_dim = ctx->GetInputDim("Param"); auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"), // TODO(qijun): check dimensions of Param and Grad at complie
"Two input of SGD Op's dimension must be same."); // and run time.
ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
} }
}; };
class SGDOpMaker : public framework::OpProtoAndCheckerMaker { class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) SGDOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "Input parameter"); AddInput("Param", "Input parameter");
AddInput("LearningRate", "Learning rate of SGD"); AddInput("LearningRate", "Learning rate of SGD");
...@@ -58,6 +58,38 @@ param_out = param - learning_rate * grad; ...@@ -58,6 +58,38 @@ param_out = param - learning_rate * grad;
)DOC"); )DOC");
} }
}; };
template <typename T>
struct SparseSGDFunctor<platform::CPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input,
const framework::Tensor& learning_rate,
framework::Tensor* output) {
auto in_height = input.height();
auto out_dims = output->dims();
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
auto& in_value = input.value();
auto& in_rows = input.rows();
int64_t in_row_numel = in_value.numel() / in_rows.size();
PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height);
auto* in_data = in_value.data<T>();
auto* out_data = output->data<T>();
auto* lr = learning_rate.data<T>();
for (size_t i = 0; i < in_rows.size(); i++) {
for (int64_t j = 0; j < in_row_numel; j++) {
out_data[in_rows[i] * in_row_numel + j] -=
lr[0] * in_data[i * in_row_numel + j];
}
}
}
};
template struct SparseSGDFunctor<platform::CPUPlace, float>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -14,6 +14,66 @@ ...@@ -14,6 +14,66 @@
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/sgd_op.h" #include "paddle/operators/sgd_op.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace {
template <typename T>
__global__ void SparseSGDFunctorKernel(const T* selected_rows,
const int64_t* rows,
const T* learning_rate, T* tensor_out,
int64_t row_numel, int block_size) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
selected_rows += ty * row_numel;
tensor_out += rows[ty] * row_numel;
for (int index = tid; index < row_numel; index += block_size) {
// Since index in rows of SelectedRows can be duplicate, we have to use
// Atomic Operation to avoid concurrent write error.
paddle::platform::CudaAtomicAdd(
tensor_out + index, -1.0 * learning_rate[0] * selected_rows[index]);
}
}
} // namespace
template <typename T>
struct SparseSGDFunctor<platform::GPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input,
const framework::Tensor& learning_rate,
framework::Tensor* output) {
auto in_height = input.height();
auto out_dims = output->dims();
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
auto& in_value = input.value();
auto& in_rows = input.rows();
int64_t in_row_numel = in_value.numel() / in_rows.size();
PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height);
auto* in_data = in_value.data<T>();
auto* out_data = output->data<T>();
int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid(1, in_rows.size());
SparseSGDFunctorKernel<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(in_data, in_rows.data(), learning_rate.data<T>(),
out_data, in_row_numel, block_size);
}
};
template struct SparseSGDFunctor<platform::GPUPlace, float>;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sgd, REGISTER_OP_GPU_KERNEL(sgd,
......
...@@ -15,20 +15,32 @@ limitations under the License. */ ...@@ -15,20 +15,32 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/selected_rows.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T>
struct SparseSGDFunctor {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input,
const framework::Tensor& learning_rate,
framework::Tensor* output);
};
template <typename Place, typename T> template <typename Place, typename T>
class SGDOpKernel : public framework::OpKernel<T> { class SGDOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto param = ctx.Input<framework::Tensor>("Param"); auto* param = ctx.Input<framework::Tensor>("Param");
auto grad = ctx.Input<framework::Tensor>("Grad"); auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
auto param_out = ctx.Output<framework::Tensor>("ParamOut"); auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto* grad_var = ctx.InputVar("Grad");
// Actually, all tensors are LoDTensor except SelectedRows.
if (grad_var->IsType<framework::LoDTensor>()) {
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
auto* grad = ctx.Input<framework::Tensor>("Grad");
auto p = framework::EigenVector<T>::Flatten(*param); auto p = framework::EigenVector<T>::Flatten(*param);
auto g = framework::EigenVector<T>::Flatten(*grad); auto g = framework::EigenVector<T>::Flatten(*grad);
...@@ -38,8 +50,18 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -38,8 +50,18 @@ class SGDOpKernel : public framework::OpKernel<T> {
Eigen::DSizes<int, 1> grad_dsize(grad->numel()); Eigen::DSizes<int, 1> grad_dsize(grad->numel());
o.device(place) = p - lr.broadcast(grad_dsize) * g; o.device(place) = p - lr.broadcast(grad_dsize) * g;
} else if (grad_var->IsType<framework::SelectedRows>()) {
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
// This manual optimization brings difficulty to track data dependency.
// It's better to find a more elegant solution.
PADDLE_ENFORCE_EQ(param, param_out);
auto* grad = ctx.Input<framework::SelectedRows>("Grad");
SparseSGDFunctor<Place, T> functor;
functor(ctx.device_context(), *grad, *learning_rate, param_out);
} else {
PADDLE_THROW("Unsupported Variable Type of Grad");
}
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -100,21 +100,11 @@ using namespace paddle::framework; // NOLINT ...@@ -100,21 +100,11 @@ using namespace paddle::framework; // NOLINT
// Bind Methods // Bind Methods
void BindProgramDesc(py::module &m) { void BindProgramDesc(py::module &m) {
py::class_<ProgramDescBind>(m, "ProgramDesc", "") py::class_<ProgramDescBind>(m, "ProgramDesc", "")
.def_static("instance", .def(py::init<>())
[]() -> ProgramDescBind * { .def("__init__",
return &ProgramDescBind::Instance(&GetProgramDesc()); [](ProgramDescBind &self, const ProgramDescBind &other) {
}, new (&self) ProgramDescBind(other);
py::return_value_policy::reference) })
.def_static("__create_program_desc__",
[]() -> ProgramDescBind * {
// Only used for unit-test
auto *prog_desc = new ProgramDesc;
auto *block = prog_desc->mutable_blocks()->Add();
block->set_idx(0);
block->set_parent_idx(-1);
return &ProgramDescBind::Instance(prog_desc);
},
py::return_value_policy::reference)
.def("append_block", &ProgramDescBind::AppendBlock, .def("append_block", &ProgramDescBind::AppendBlock,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("append_backward", .def("append_backward",
...@@ -176,8 +166,8 @@ void BindBlockDesc(py::module &m) { ...@@ -176,8 +166,8 @@ void BindBlockDesc(py::module &m) {
py::return_value_policy::reference) py::return_value_policy::reference)
.def("all_vars", &BlockDescBind::AllVars, .def("all_vars", &BlockDescBind::AllVars,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("all_ops", &BlockDescBind::AllOps, .def("op_size", &BlockDescBind::OpSize)
py::return_value_policy::reference) .def("op", &BlockDescBind::Op, py::return_value_policy::reference)
.def("serialize_to_string", [](BlockDescBind &block_desc) -> py::bytes { .def("serialize_to_string", [](BlockDescBind &block_desc) -> py::bytes {
const BlockDesc *desc = block_desc.Proto(); const BlockDesc *desc = block_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(), PADDLE_ENFORCE(desc->IsInitialized(),
...@@ -216,7 +206,8 @@ void BindVarDsec(py::module &m) { ...@@ -216,7 +206,8 @@ void BindVarDsec(py::module &m) {
.def("set_lod_level", &VarDescBind::SetLoDLevel) .def("set_lod_level", &VarDescBind::SetLoDLevel)
.def("type", &VarDescBind::GetType) .def("type", &VarDescBind::GetType)
.def("set_type", &VarDescBind::SetType) .def("set_type", &VarDescBind::SetType)
.def("serialize_to_string", [](VarDescBind &var_desc) -> py::bytes { .def("serialize_to_string",
[](VarDescBind &var_desc) -> py::bytes {
const VarDesc *desc = var_desc.Proto(); const VarDesc *desc = var_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(), PADDLE_ENFORCE(desc->IsInitialized(),
"VarDesc has not been initialized."); "VarDesc has not been initialized.");
...@@ -225,7 +216,9 @@ void BindVarDsec(py::module &m) { ...@@ -225,7 +216,9 @@ void BindVarDsec(py::module &m) {
desc->SerializeToString(&res), desc->SerializeToString(&res),
"Serialize VarDesc Error. This could be a bug of Paddle."); "Serialize VarDesc Error. This could be a bug of Paddle.");
return res; return res;
}); })
.def("persistable", &VarDescBind::Persistable)
.def("set_persistable", &VarDescBind::SetPersistable);
py::enum_<VarDesc::VarType>(var_desc, "VarType", "") py::enum_<VarDesc::VarType>(var_desc, "VarType", "")
.value("LOD_TENSOR", VarDesc::LOD_TENSOR) .value("LOD_TENSOR", VarDesc::LOD_TENSOR)
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
#include "paddle/framework/executor.h" #include "paddle/framework/executor.h"
#include "paddle/framework/feed_fetch_method.h" #include "paddle/framework/feed_fetch_method.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/selected_rows.h" #include "paddle/framework/selected_rows.h"
#include "paddle/framework/tensor_array.h" #include "paddle/framework/tensor_array.h"
...@@ -153,7 +154,15 @@ PYBIND11_PLUGIN(core) { ...@@ -153,7 +154,15 @@ PYBIND11_PLUGIN(core) {
py::return_value_policy::reference) py::return_value_policy::reference)
.def("set_height", &SelectedRows::set_height) .def("set_height", &SelectedRows::set_height)
.def("height", &SelectedRows::height) .def("height", &SelectedRows::height)
.def("set_rows", &SelectedRows::set_rows) .def("set_rows",
[](SelectedRows &self, std::vector<int64_t> rows) {
#ifndef PADDLE_WITH_CUDA
self.set_rows(rows);
#else
Vector<int64_t> new_rows(rows);
self.set_rows(new_rows);
#endif
})
.def("rows", [](SelectedRows &self) { .def("rows", [](SelectedRows &self) {
#ifndef PADDLE_WITH_CUDA #ifndef PADDLE_WITH_CUDA
return self.rows(); return self.rows();
...@@ -186,6 +195,11 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -186,6 +195,11 @@ All parameter, weight, gradient are variables in Paddle.
return self.GetMutable<LoDTensor>(); return self.GetMutable<LoDTensor>();
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("get_selected_rows",
[](Variable &self) -> SelectedRows * {
return self.GetMutable<SelectedRows>();
},
py::return_value_policy::reference)
.def("get_net", .def("get_net",
[](Variable &self) -> operators::NetOp * { [](Variable &self) -> operators::NetOp * {
return self.GetMutable<operators::NetOp>(); return self.GetMutable<operators::NetOp>();
...@@ -259,7 +273,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -259,7 +273,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
return OpRegistry::CreateOp(desc); return OpRegistry::CreateOp(desc, nullptr);
}) })
.def("backward", .def("backward",
[](const OperatorBase &forwardOp, [](const OperatorBase &forwardOp,
...@@ -363,7 +377,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -363,7 +377,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc); auto rnn_op = OpRegistry::CreateOp(desc, nullptr);
return static_cast<operators::RecurrentOp *>(rnn_op.release()); return static_cast<operators::RecurrentOp *>(rnn_op.release());
}) })
.def("set_stepnet", [](operators::RecurrentOp &self, .def("set_stepnet", [](operators::RecurrentOp &self,
...@@ -381,7 +395,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -381,7 +395,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc); auto rnn_op = OpRegistry::CreateOp(desc, nullptr);
return static_cast<operators::DynamicRecurrentOp *>( return static_cast<operators::DynamicRecurrentOp *>(
rnn_op.release()); rnn_op.release());
}) })
...@@ -408,7 +422,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -408,7 +422,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
auto cond_op = OpRegistry::CreateOp(desc); auto cond_op = OpRegistry::CreateOp(desc, nullptr);
return static_cast<operators::CondOp *>(cond_op.release()); return static_cast<operators::CondOp *>(cond_op.release());
}) })
.def("set_truenet", .def("set_truenet",
......
...@@ -15,6 +15,7 @@ class Variable(object): ...@@ -15,6 +15,7 @@ class Variable(object):
shape=None, shape=None,
dtype=None, dtype=None,
lod_level=None, lod_level=None,
persistable=False,
**kwargs): **kwargs):
self.block = block self.block = block
...@@ -70,6 +71,17 @@ class Variable(object): ...@@ -70,6 +71,17 @@ class Variable(object):
"lod_level is {2}. They are not " "lod_level is {2}. They are not "
"matched".format(self.name, self.lod_level, "matched".format(self.name, self.lod_level,
lod_level)) lod_level))
if persistable is not None:
if is_new_var:
self.desc.set_persistable(persistable)
else:
if persistable != self.persistable:
raise ValueError(
"Variable {0} has been created before."
"The previous persistable is {1}; the new "
"persistable is {2}. They are not matched".format(
self.name, self.persistable, persistable))
self.block.vars[name] = self self.block.vars[name] = self
self.op = None self.op = None
...@@ -80,6 +92,10 @@ class Variable(object): ...@@ -80,6 +92,10 @@ class Variable(object):
__repr__ = __str__ __repr__ = __str__
@property
def persistable(self):
return self.desc.persistable()
@property @property
def name(self): def name(self):
return self.desc.name() return self.desc.name()
...@@ -232,7 +248,7 @@ class Operator(object): ...@@ -232,7 +248,7 @@ class Operator(object):
if attrs is not None: if attrs is not None:
for attr in proto.attrs: for attr in proto.attrs:
attr_name = attr.name attr_name = attr.name
if not attr_name in attrs: if (not attr_name in attrs) or (attrs[attr_name] is None):
continue continue
if not isinstance(attrs[attr_name], Block): if not isinstance(attrs[attr_name], Block):
self.desc.set_attr(attr_name, attrs[attr_name]) self.desc.set_attr(attr_name, attrs[attr_name])
...@@ -344,7 +360,11 @@ class Block(object): ...@@ -344,7 +360,11 @@ class Block(object):
self.create_var(name=var.name(), desc=var, type=var.type()) self.create_var(name=var.name(), desc=var, type=var.type())
# sync operators from cpp # sync operators from cpp
ops_in_cpp = self.desc.all_ops() ops_in_cpp = []
for op_idx in range(0, self.desc.op_size()):
ops_in_cpp.append(self.desc.op(op_idx))
if len(self.ops) != 0:
first_op_in_python = self.ops[0].desc first_op_in_python = self.ops[0].desc
last_op_in_python = self.ops[len(self.ops) - 1].desc last_op_in_python = self.ops[len(self.ops) - 1].desc
start_index = None start_index = None
...@@ -357,6 +377,9 @@ class Block(object): ...@@ -357,6 +377,9 @@ class Block(object):
assert start_index is not None assert start_index is not None
assert end_index is not None assert end_index is not None
assert start_index <= end_index assert start_index <= end_index
else:
start_index = 0
end_index = -1
# sync ops append to the head of cpp_ops # sync ops append to the head of cpp_ops
for index in range((start_index - 1 - 1), -1, -1): for index in range((start_index - 1 - 1), -1, -1):
...@@ -376,18 +399,8 @@ class Block(object): ...@@ -376,18 +399,8 @@ class Block(object):
class Program(object): class Program(object):
@classmethod def __init__(self):
def instance(cls): self.desc = core.ProgramDesc()
# From https://stackoverflow.com/questions/8212053
# Making Program as a Singleton class.
if not hasattr(cls, '_instance'):
cls._instance = cls()
return cls._instance
def __init__(self, desc=None):
if desc is None:
desc = core.ProgramDesc.instance()
self.desc = desc
self.blocks = [Block(self, 0)] self.blocks = [Block(self, 0)]
self.current_block_idx = 0 self.current_block_idx = 0
...@@ -396,7 +409,15 @@ class Program(object): ...@@ -396,7 +409,15 @@ class Program(object):
proto = framework_pb2.ProgramDesc.FromString(str(protostr)) proto = framework_pb2.ProgramDesc.FromString(str(protostr))
return proto.__str__() return proto.__str__()
__repr__ = __str__ def clone(self):
p = Program()
p.desc = core.ProgramDesc(self.desc)
p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())]
p.sync_with_cpp()
return p
def __repr__(self):
return str(self)
def global_block(self): def global_block(self):
return self.blocks[0] return self.blocks[0]
...@@ -444,7 +465,9 @@ class Parameter(Variable): ...@@ -444,7 +465,9 @@ class Parameter(Variable):
if each < 0: if each < 0:
raise ValueError("Parameter shape should not be related with " raise ValueError("Parameter shape should not be related with "
"batch-size") "batch-size")
Variable.__init__(self, block, shape=shape, dtype=dtype, **kwargs)
Variable.__init__(
self, block, persistable=True, shape=shape, dtype=dtype, **kwargs)
self.trainable = kwargs.get('trainable', True) self.trainable = kwargs.get('trainable', True)
self.init_attr = kwargs.get('initialize_attr', { self.init_attr = kwargs.get('initialize_attr', {
'type': 'uniform_random', 'type': 'uniform_random',
...@@ -469,4 +492,4 @@ class Parameter(Variable): ...@@ -469,4 +492,4 @@ class Parameter(Variable):
# program is a global instance. # program is a global instance.
g_program = Program.instance() g_program = Program()
...@@ -66,15 +66,15 @@ class LayerHelper(object): ...@@ -66,15 +66,15 @@ class LayerHelper(object):
actual = self.kwargs.get('param_attr', None) actual = self.kwargs.get('param_attr', None)
return actual if actual is not None else default return actual if actual is not None else default
def bias_attr(self, size, dtype): def bias_attr(self, shape, dtype):
bias_attr = self.kwargs.get('bias_attr', False) bias_attr = self.kwargs.get('bias_attr', None)
if bias_attr is None or bias_attr: if bias_attr is True:
bias_attr = { bias_attr = {
'name': None, 'name': None,
'init_attr': { 'init_attr': {
'type': 'fill_constant', 'type': 'fill_constant',
'value': 0.0, 'value': 0.0,
'shape': [size], 'shape': shape,
'dataType': dtype 'dataType': dtype
} }
} }
...@@ -127,15 +127,13 @@ class LayerHelper(object): ...@@ -127,15 +127,13 @@ class LayerHelper(object):
return self.program.global_block().create_var(*args, **kwargs) return self.program.global_block().create_var(*args, **kwargs)
def append_bias_op(self, input_var): def append_bias_op(self, input_var):
bias_attr = self.bias_attr( size = list(input_var.shape[1:])
self.kwargs['size'], dtype=input_var.data_type) bias_attr = self.bias_attr(size, dtype=input_var.data_type)
if not bias_attr: if not bias_attr:
return input_var return input_var
b = self.create_parameter( b = self.create_parameter(
attr=bias_attr, attr=bias_attr, shape=size, dtype=input_var.data_type, suffix='b')
shape=[self.kwargs['size']],
dtype=input_var.data_type,
suffix='b')
tmp = self.create_tmp_variable(dtype=input_var.data_type) tmp = self.create_tmp_variable(dtype=input_var.data_type)
self.append_op( self.append_op(
type='elementwise_add', type='elementwise_add',
......
...@@ -3,10 +3,10 @@ import paddle.v2.framework.core as core ...@@ -3,10 +3,10 @@ import paddle.v2.framework.core as core
from paddle.v2.framework.framework import OpProtoHolder, Variable from paddle.v2.framework.framework import OpProtoHolder, Variable
import re import re
__all__ = ['fc_layer', 'data_layer', 'cross_entropy'] __all__ = ['fc', 'data', 'cross_entropy', 'conv2d']
def fc_layer(input, def fc(input,
size, size,
param_attr=None, param_attr=None,
bias_attr=True, bias_attr=True,
...@@ -24,6 +24,7 @@ def fc_layer(input, ...@@ -24,6 +24,7 @@ def fc_layer(input,
for input_var, param_attr in helper.iter_inputs_and_params(): for input_var, param_attr in helper.iter_inputs_and_params():
input_shape = input_var.shape input_shape = input_var.shape
param_shape = list(input_shape[num_flatten_dims:]) + [size] param_shape = list(input_shape[num_flatten_dims:]) + [size]
w = helper.create_parameter( w = helper.create_parameter(
attr=param_attr, shape=param_shape, dtype=dtype) attr=param_attr, shape=param_shape, dtype=dtype)
tmp = helper.create_tmp_variable(dtype) tmp = helper.create_tmp_variable(dtype)
...@@ -50,7 +51,7 @@ def fc_layer(input, ...@@ -50,7 +51,7 @@ def fc_layer(input,
return helper.append_activation(pre_activation) return helper.append_activation(pre_activation)
def data_layer(name, def data(name,
shape, shape,
data_type='float32', data_type='float32',
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
...@@ -111,6 +112,7 @@ def _create_op_func_(op_type): ...@@ -111,6 +112,7 @@ def _create_op_func_(op_type):
_create_op_func_('mean') _create_op_func_('mean')
_create_op_func_('pool2d')
def cross_entropy(input, label, **kwargs): def cross_entropy(input, label, **kwargs):
...@@ -141,3 +143,47 @@ def square_error_cost(input, label, **kwargs): ...@@ -141,3 +143,47 @@ def square_error_cost(input, label, **kwargs):
outputs={'Y': [square_out]}, outputs={'Y': [square_out]},
attrs={'factor': 2.0}) attrs={'factor': 2.0})
return square_out return square_out
def conv2d(input,
num_filters,
name=None,
filter_size=[1, 1],
act=None,
groups=None,
stride=[1, 1],
padding=None,
bias_attr=None,
param_attr=None,
program=None):
helper = LayerHelper('conv2d', **locals())
dtype = helper.input_dtype()
num_channels = input.shape[1]
if groups is None:
num_filter_channels = num_channels
else:
if num_channels % groups is not 0:
raise ValueError("num_channels must be divisible by groups.")
num_filter_channels = num_channels / groups
input_shape = input.shape
filter_shape = [num_filters, num_filter_channels] + filter_size
filter = helper.create_parameter(
attr=helper.param_attr, shape=filter_shape, dtype=dtype)
pre_bias = helper.create_tmp_variable(dtype)
helper.append_op(
type='conv2d',
inputs={
'Input': input,
'Filter': filter,
},
outputs={"Output": pre_bias},
attrs={'strides': stride,
'paddings': padding,
'groups': groups})
pre_act = helper.append_bias_op(pre_bias)
return helper.append_activation(pre_act)
...@@ -5,7 +5,7 @@ import paddle.v2.framework.core as core ...@@ -5,7 +5,7 @@ import paddle.v2.framework.core as core
class TestInferShape(unittest.TestCase): class TestInferShape(unittest.TestCase):
def test_sum_op(self): def test_sum_op(self):
prog = core.ProgramDesc.__create_program_desc__() prog = core.ProgramDesc()
self.assertIsNotNone(prog) self.assertIsNotNone(prog)
block = prog.block(0) block = prog.block(0)
self.assertIsNotNone(block) self.assertIsNotNone(block)
...@@ -33,7 +33,7 @@ class TestInferShape(unittest.TestCase): ...@@ -33,7 +33,7 @@ class TestInferShape(unittest.TestCase):
self.assertEqual(out.shape(), shape) self.assertEqual(out.shape(), shape)
def test_mul_op(self): def test_mul_op(self):
prog = core.ProgramDesc.__create_program_desc__() prog = core.ProgramDesc()
self.assertIsNotNone(prog) self.assertIsNotNone(prog)
block = prog.block(0) block = prog.block(0)
self.assertIsNotNone(block) self.assertIsNotNone(block)
......
from paddle.v2.framework.layers import fc_layer, data_layer, cross_entropy, mean, square_error_cost import paddle.v2.framework.layers as layers
from paddle.v2.framework.framework import Program, g_program from paddle.v2.framework.framework import Program, g_program
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
import unittest import unittest
...@@ -6,36 +6,57 @@ import unittest ...@@ -6,36 +6,57 @@ import unittest
class TestBook(unittest.TestCase): class TestBook(unittest.TestCase):
def test_fit_a_line(self): def test_fit_a_line(self):
pd = core.ProgramDesc.__create_program_desc__() program = Program()
program = Program(desc=pd) x = layers.data(
x = data_layer(
name='x', shape=[13], data_type='float32', program=program) name='x', shape=[13], data_type='float32', program=program)
y_predict = fc_layer(input=x, size=1, act=None, program=program) y_predict = layers.fc(input=x, size=1, act=None, program=program)
y = data_layer( y = layers.data(
name='y', shape=[1], data_type='float32', program=program) name='y', shape=[1], data_type='float32', program=program)
cost = square_error_cost(input=y_predict, label=y, program=program) cost = layers.square_error_cost(
input=y_predict, label=y, program=program)
avg_cost = mean(x=cost, program=program) avg_cost = layers.mean(x=cost, program=program)
self.assertIsNotNone(avg_cost) self.assertIsNotNone(avg_cost)
program.append_backward(avg_cost, set())
print str(program) print str(program)
def test_recognize_digits_mlp(self): def test_recognize_digits_mlp(self):
pd = core.ProgramDesc.__create_program_desc__() program = Program()
program = Program(desc=pd)
# Change g_program, so the rest layers use `g_program` # Change g_program, so the rest layers use `g_program`
images = data_layer( images = layers.data(
name='pixel', shape=[784], data_type='float32', program=program) name='pixel', shape=[784], data_type='float32', program=program)
label = data_layer( label = layers.data(
name='label', shape=[1], data_type='int32', program=program) name='label', shape=[1], data_type='int32', program=program)
hidden1 = fc_layer(input=images, size=128, act='relu', program=program) hidden1 = layers.fc(input=images, size=128, act='relu', program=program)
hidden2 = fc_layer(input=hidden1, size=64, act='relu', program=program) hidden2 = layers.fc(input=hidden1, size=64, act='relu', program=program)
predict = fc_layer( predict = layers.fc(input=hidden2,
input=hidden2, size=10, act='softmax', program=program) size=10,
cost = cross_entropy(input=predict, label=label, program=program) act='softmax',
avg_cost = mean(x=cost, program=program) program=program)
cost = layers.cross_entropy(input=predict, label=label, program=program)
avg_cost = layers.mean(x=cost, program=program)
self.assertIsNotNone(avg_cost) self.assertIsNotNone(avg_cost)
# print str(program)
def test_simple_conv2d(self):
pd = core.ProgramDesc.__create_program_desc__()
program = Program(desc=pd)
images = data_layer(
name='pixel', shape=[3, 48, 48], data_type='int32', program=program)
conv2d_layer(
input=images, num_filters=3, filter_size=[4, 4], program=program)
# print str(program)
def test_simple_conv2d(self):
program = Program()
images = layers.data(
name='pixel', shape=[3, 48, 48], data_type='int32', program=program)
layers.conv2d(
input=images, num_filters=3, filter_size=[4, 4], program=program)
print str(program) print str(program)
......
...@@ -34,8 +34,26 @@ class TestProgram(unittest.TestCase): ...@@ -34,8 +34,26 @@ class TestProgram(unittest.TestCase):
self.assertEqual(1, b.idx) self.assertEqual(1, b.idx)
self.assertEqual(0, b.parent_idx) self.assertEqual(0, b.parent_idx)
def test_program_clone(self):
prog = Program()
x = prog.global_block().create_var(
name='X', shape=[1000, 784], dtype='float32')
y = prog.global_block().create_var(
name='Y', shape=[784, 100], dtype='float32')
out = prog.global_block().create_var(name='Out', dtype='float32')
prog.global_block().append_op(
type="mul", inputs={'X': [x],
'Y': [y]}, outputs={'Out': [out]})
# FIXME(yuyang18): We manual compare the output string, since the order
# of variable could be changed.
print prog
print prog.clone()
def test_append_backward(self): def test_append_backward(self):
prog = Program.instance() prog = Program()
block = prog.global_block() block = prog.global_block()
mul_x = block.create_var( mul_x = block.create_var(
......
...@@ -4,7 +4,7 @@ import paddle.v2.framework.core as core ...@@ -4,7 +4,7 @@ import paddle.v2.framework.core as core
class TestOpDesc(unittest.TestCase): class TestOpDesc(unittest.TestCase):
def test_op_desc(self): def test_op_desc(self):
prog = core.ProgramDesc.__create_program_desc__() prog = core.ProgramDesc()
self.assertIsNotNone(prog) self.assertIsNotNone(prog)
block = prog.block(0) block = prog.block(0)
self.assertIsNotNone(block) self.assertIsNotNone(block)
...@@ -64,16 +64,16 @@ class TestOpDesc(unittest.TestCase): ...@@ -64,16 +64,16 @@ class TestOpDesc(unittest.TestCase):
class TestProgramDesc(unittest.TestCase): class TestProgramDesc(unittest.TestCase):
def test_instance(self): def test_instance(self):
program_desc = core.ProgramDesc.__create_program_desc__() program_desc = core.ProgramDesc()
self.assertIsNotNone(program_desc) self.assertIsNotNone(program_desc)
del program_desc del program_desc
program_desc = core.ProgramDesc.instance() program_desc = core.ProgramDesc()
self.assertIsNotNone(program_desc) self.assertIsNotNone(program_desc)
self.assertIsNotNone(program_desc.block(0)) self.assertIsNotNone(program_desc.block(0))
del program_desc del program_desc
def test_append_block(self): def test_append_block(self):
prog_desc = core.ProgramDesc.__create_program_desc__() prog_desc = core.ProgramDesc()
self.assertIsNotNone(prog_desc) self.assertIsNotNone(prog_desc)
block_root = prog_desc.block(0) block_root = prog_desc.block(0)
self.assertIsNotNone(block_root) self.assertIsNotNone(block_root)
...@@ -91,7 +91,7 @@ class TestProgramDesc(unittest.TestCase): ...@@ -91,7 +91,7 @@ class TestProgramDesc(unittest.TestCase):
class TestVarDesc(unittest.TestCase): class TestVarDesc(unittest.TestCase):
def test_shape(self): def test_shape(self):
program_desc = core.ProgramDesc.__create_program_desc__() program_desc = core.ProgramDesc()
block = program_desc.block(0) block = program_desc.block(0)
var = block.var('my_var') var = block.var('my_var')
var.set_type(core.VarDesc.VarType.SELECTED_ROWS) var.set_type(core.VarDesc.VarType.SELECTED_ROWS)
...@@ -102,7 +102,7 @@ class TestVarDesc(unittest.TestCase): ...@@ -102,7 +102,7 @@ class TestVarDesc(unittest.TestCase):
self.assertEqual(core.VarDesc.VarType.SELECTED_ROWS, var.type()) self.assertEqual(core.VarDesc.VarType.SELECTED_ROWS, var.type())
def test_data_type(self): def test_data_type(self):
program_desc = core.ProgramDesc.__create_program_desc__() program_desc = core.ProgramDesc()
block = program_desc.block(0) block = program_desc.block(0)
var = block.var('my_var') var = block.var('my_var')
var.set_type(core.VarDesc.VarType.LOD_TENSOR) var.set_type(core.VarDesc.VarType.LOD_TENSOR)
...@@ -113,7 +113,7 @@ class TestVarDesc(unittest.TestCase): ...@@ -113,7 +113,7 @@ class TestVarDesc(unittest.TestCase):
class TestBlockDesc(unittest.TestCase): class TestBlockDesc(unittest.TestCase):
def test_add_var(self): def test_add_var(self):
prog = core.ProgramDesc.__create_program_desc__() prog = core.ProgramDesc()
self.assertIsNotNone(prog) self.assertIsNotNone(prog)
block = prog.block(0) block = prog.block(0)
self.assertIsNotNone(block) self.assertIsNotNone(block)
...@@ -121,19 +121,21 @@ class TestBlockDesc(unittest.TestCase): ...@@ -121,19 +121,21 @@ class TestBlockDesc(unittest.TestCase):
var2 = block.var("var2") var2 = block.var("var2")
var3 = block.var("var3") var3 = block.var("var3")
all_vars = block.all_vars() all_vars = block.all_vars()
self.assertEqual(set(all_vars), set([var1, var2, var3])) self.assertEqual(set(all_vars), {var1, var2, var3})
var2_re = block.find_var("var2") var2_re = block.find_var("var2")
self.assertEqual(var2_re, var2) self.assertEqual(var2_re, var2)
def test_add_op(self): def test_add_op(self):
prog = core.ProgramDesc.__create_program_desc__() prog = core.ProgramDesc()
self.assertIsNotNone(prog) self.assertIsNotNone(prog)
block = prog.block(0) block = prog.block(0)
self.assertIsNotNone(block) self.assertIsNotNone(block)
op1 = block.append_op() op1 = block.append_op()
op2 = block.append_op() op2 = block.append_op()
op0 = block.prepend_op() op0 = block.prepend_op()
all_ops = block.all_ops() all_ops = []
for idx in xrange(0, block.op_size()):
all_ops.append(block.op(idx))
self.assertEqual(all_ops, [op0, op1, op2]) self.assertEqual(all_ops, [op0, op1, op2])
......
import unittest
import numpy as np
from op_test import OpTest
class TestProximalGDOp(OpTest):
def setUp(self):
self.op_type = "proximal_gd"
w = np.random.random((102, 105)).astype("float32")
g = np.random.random((102, 105)).astype("float32")
lr = np.array([0.1]).astype("float32")
l1 = 0.1
l2 = 0.2
self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr}
self.attrs = {'l1': l1, 'l2': l2}
prox_param = w - lr * g
param_out = 0.0
if l1 > 0.0:
x = np.abs(prox_param) - lr * l1
x[x < 0] = 0
param_out = np.sign(prox_param) * (x / (1.0 + lr * l2))
else:
param_out = prox_param / (1.0 + lr * l2)
self.outputs = {'ParamOut': param_out}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
...@@ -8,29 +8,30 @@ class TestSelectedRows(unittest.TestCase): ...@@ -8,29 +8,30 @@ class TestSelectedRows(unittest.TestCase):
place = core.CPUPlace() place = core.CPUPlace()
height = 10 height = 10
rows = [0, 4, 7] rows = [0, 4, 7]
row_numel = 10 row_numel = 12
selcted_rows = core.SelectedRows(rows, row_numel) selected_rows = core.SelectedRows(rows, height)
np_array = np.ones((len(rows), height)).astype("float32") np_array = np.ones((len(rows), row_numel)).astype("float32")
np_array[0, 0] = 2.0 np_array[0, 0] = 2.0
np_array[2, 8] = 4.0 np_array[2, 8] = 4.0
tensor = selcted_rows.get_tensor() tensor = selected_rows.get_tensor()
tensor.set(np_array, place) tensor.set(np_array, place)
# compare rows # compare rows
self.assertEqual(0, selcted_rows.rows()[0]) self.assertEqual(0, selected_rows.rows()[0])
self.assertEqual(4, selcted_rows.rows()[1]) self.assertEqual(4, selected_rows.rows()[1])
self.assertEqual(7, selcted_rows.rows()[2]) self.assertEqual(7, selected_rows.rows()[2])
# compare height # compare height
self.assertEqual(10, selcted_rows.height()) self.assertEqual(10, selected_rows.height())
# compare tensor # compare tensor
self.assertAlmostEqual(2.0, self.assertAlmostEqual(2.0,
selcted_rows.get_tensor().get_float_element(0)) selected_rows.get_tensor().get_float_element(0))
self.assertAlmostEqual(1.0, self.assertAlmostEqual(1.0,
selcted_rows.get_tensor().get_float_element(1)) selected_rows.get_tensor().get_float_element(1))
self.assertAlmostEqual( self.assertAlmostEqual(
4.0, selcted_rows.get_tensor().get_float_element(2 * row_numel + 8)) 4.0,
selected_rows.get_tensor().get_float_element(2 * row_numel + 8))
if __name__ == "__main__": if __name__ == "__main__":
......
import unittest import unittest
import numpy as np import numpy as np
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
from op_test import OpTest from op_test import OpTest
...@@ -17,5 +19,70 @@ class TestSGDOp(OpTest): ...@@ -17,5 +19,70 @@ class TestSGDOp(OpTest):
self.check_output() self.check_output()
class TestSparseSGDOp(unittest.TestCase):
def check_with_place(self, place):
scope = core.Scope()
# create and initialize Grad Variable
height = 10
rows = [0, 4, 7]
row_numel = 12
grad_selected_rows = scope.var('Grad').get_selected_rows()
grad_selected_rows.set_height(height)
grad_selected_rows.set_rows(rows)
np_array = np.ones((len(rows), row_numel)).astype("float32")
np_array[0, 0] = 2.0
np_array[2, 8] = 4.0
grad_tensor = grad_selected_rows.get_tensor()
grad_tensor.set(np_array, place)
# create and initialize Param Variable
param = scope.var('Param').get_tensor()
param_array = np.full((height, row_numel), 5.0).astype("float32")
param.set(param_array, place)
# create and initialize LeraningRate Variable
lr = scope.var('LearningRate').get_tensor()
lr_array = np.full((1), 2.0).astype("float32")
lr.set(lr_array, place)
# create and run sgd operator
sgd_op = Operator(
"sgd",
Param='Param',
Grad='Grad',
ParamOut='Param',
LearningRate='LearningRate')
ctx = core.DeviceContext.create(place)
sgd_op.run(scope, ctx)
# get and compare result
result_array = np.array(param)
# rows[0] = 0, 5.0 - 2.0 * 2.0
self.assertAlmostEqual(1.0, result_array[rows[0], 0])
# rows[0] = 0, 5.0 - 2.0 * 1.0
self.assertAlmostEqual(3.0, result_array[rows[0], 2])
# 5.0 - 2.0 * 0.0
self.assertAlmostEqual(5.0, result_array[1, 0])
# rows[1] = 4, 5.0 - 2.0 * 1.0
self.assertAlmostEqual(3.0, result_array[rows[1], 10])
# 5.0 - 2.0 * 0.0
self.assertAlmostEqual(5.0, result_array[5, 8])
# rows[2] = 7, 5.0 - 2.0 * 1.0
self.assertAlmostEqual(3.0, result_array[rows[2], 1])
# rows[2] = 7, 5.0 - 2.0 * 4.0
self.assertAlmostEqual(-3.0, result_array[rows[2], 8])
def test_sparse_sgd(self):
places = [core.CPUPlace()]
if core.is_compile_gpu():
places.append(core.GPUPlace(0))
for place in places:
self.check_with_place(place)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册