diff --git a/.gitignore b/.gitignore index 018654c8daedf4f18a4ff8b78219eb90bd5be594..351b8204100dfd71e94cb3efa2e946b44b9e4285 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,7 @@ cmake-build-* # generated while compiling python/paddle/v2/framework/core.so +paddle/pybind/pybind.h CMakeFiles cmake_install.cmake paddle/.timestamp diff --git a/cmake/cpplint.cmake b/cmake/cpplint.cmake index 8d5d533126c9b7fa84c725d614cf3486126d0284..4823dc3e91390002aefac70f7931b4197db05789 100644 --- a/cmake/cpplint.cmake +++ b/cmake/cpplint.cmake @@ -26,9 +26,9 @@ set(IGNORE_PATTERN .*ImportanceSampler.* .*cblas\\.h.* .*\\.pb\\.txt - .*LtrDataProvider.* .*MultiDataProvider.* - .*pb.*) + .*pb.* + .*pybind.h) # add_style_check_target # diff --git a/doc/design/block.md b/doc/design/block.md new file mode 100644 index 0000000000000000000000000000000000000000..be8800122035984df281692fc40009c397565046 --- /dev/null +++ b/doc/design/block.md @@ -0,0 +1,338 @@ +# Design Doc: Block and Scope + +## The Representation of Computation + +Both deep learning systems and programming languages help users describe computation procedures. These systems use various representations of computation: + +- Caffe, Torch, and Paddle: sequences of layers. +- TensorFlow, Caffe2, Mxnet: graphs of operators. +- PaddlePaddle: nested blocks, like C++ and Java programs. + +## Block in Programming Languages and Deep Learning + +In programming languages, a block is a pair of curly braces that includes local variables definitions and a sequence of instructions, or operators. + +Blocks work with control flow structures like `if`, `else`, and `for`, which have equivalents in deep learning: + +| programming languages | PaddlePaddle | +|-----------------------|-----------------------| +| for, while loop | RNN, WhileOp | +| if, if-else, switch | IfElseOp, SwitchOp | +| sequential execution | a sequence of layers | + +A key difference is that a C++ program describes a one pass computation, whereas a deep learning program describes both the forward and backward passes. + +## Stack Frames and the Scope Hierarchy + +The existence of the backward makes the execution of a block of traditional programs and PaddlePaddle different to each other: + +| programming languages | PaddlePaddle | +|-----------------------|-------------------------------| +| stack | scope hierarchy | +| stack frame | scope | +| push at entering block| push at entering block | +| pop at leaving block | destroy at minibatch completes| + +1. In traditional programs: + + - When the execution enters the left curly brace of a block, the runtime pushes a frame into the stack, where it realizes local variables. + - After the execution leaves the right curly brace, the runtime pops the frame. + - The maximum number of frames in the stack is the maximum depth of nested blocks. + +1. In PaddlePaddle + + - When the execution enters a block, PaddlePaddle adds a new scope, where it realizes variables. + - PaddlePaddle doesn't pop a scope after the execution of the block because variables therein are to be used by the backward pass. So it has a stack forest known as a *scope hierarchy*. + - The height of the highest tree is the maximum depth of nested blocks. + - After the process of a minibatch, PaddlePaddle destroys the scope hierarchy. + +## Use Blocks in C++ and PaddlePaddle Programs + +Let us consolidate the discussion by presenting some examples. + +### Blocks with `if-else` and `IfElseOp` + +The following C++ programs shows how blocks are used with the `if-else` structure: + +```c++ +int x = 10; +int y = 20; +int out; +bool cond = false; +if (cond) { + int z = x + y; + out = softmax(z); +} else { + int z = fc(x); + out = z; +} +``` + +An equivalent PaddlePaddle program from the design doc of the [IfElseOp operator](./if_else_op.md) is as follows: + +```python +import paddle as pd + +x = var(10) +y = var(20) +cond = var(false) +ie = pd.create_ifelseop(inputs=[x], output_num=1) +with ie.true_block(): + x = ie.inputs(true, 0) + z = operator.add(x, y) + ie.set_output(true, 0, operator.softmax(z)) +with ie.false_block(): + x = ie.inputs(false, 0) + z = layer.fc(x) + ie.set_output(true, 0, operator.softmax(z)) +out = b(cond) +``` + +In both examples, the left branch computes `softmax(x+y)` and the right branch computes `fc(x)`. + +A difference is that variables in the C++ program contain scalar values, whereas those in the PaddlePaddle programs are mini-batches of instances. The `ie.input(true, 0)` invocation returns instances in the 0-th input, `x`, that corresponds to true values in `cond` as the local variable `x`, where `ie.input(false, 0)` returns instances corresponding to false values. + +### Blocks with `for` and `RNNOp` + +The following RNN model from the [RNN design doc](./rnn.md) + +```python +x = sequence([10, 20, 30]) +m = var(0) +W = tensor() +U = tensor() + +rnn = create_rnn(inputs=[input]) +with rnn.stepnet() as net: + x = net.set_inputs(0) + h = net.add_memory(init=m) + fc_out = pd.matmul(W, x) + hidden_out = pd.matmul(U, h.pre(n=1)) + sum = pd.add_two(fc_out, hidden_out) + act = pd.sigmoid(sum) + h.update(act) # update memory with act + net.set_outputs(0, act, hidden_out) # two outputs + +o1, o2 = rnn() +print o1, o2 +``` + +has its equivalent C++ program as follows + +```c++ +int* x = {10, 20, 30}; +int m = 0; +int W = some_value(); +int U = some_other_value(); + +int mem[sizeof(x) / sizeof(x[0]) + 1]; +int o1[sizeof(x) / sizeof(x[0]) + 1]; +int o2[sizeof(x) / sizeof(x[0]) + 1]; +for (int i = 1; i <= sizeof(x)/sizeof(x[0]); ++i) { + int x = x[i-1]; + if (i == 1) mem[0] = m; + int fc_out = W * x; + int hidden_out = Y * mem[i-1]; + int sum = fc_out + hidden_out; + int act = sigmoid(sum); + mem[i] = act; + o1[i] = act; + o2[i] = hidden_out; +} + +print_array(o1); +print_array(o2); +``` + + +## Compilation and Execution + +Like TensorFlow programs, a PaddlePaddle program is written in Python. The first part describes a neural network as a protobuf message, and the rest part executes the message for training or inference. + +The generation of this protobuf message is like what a compiler generates a binary executable file. The execution of the message that the OS executes the binary file. + +## The "Binary Executable File Format" + +The definition of the protobuf message is as follows: + +```protobuf +message BlockDesc { + repeated VarDesc vars = 1; + repeated OpDesc ops = 2; +} +``` + +The step net in above RNN example would look like + +``` +BlockDesc { + vars = { + VarDesc {...} // x + VarDesc {...} // h + VarDesc {...} // fc_out + VarDesc {...} // hidden_out + VarDesc {...} // sum + VarDesc {...} // act + } + ops = { + OpDesc {...} // matmul + OpDesc {...} // add_two + OpDesc {...} // sigmoid + } +}; +``` + +Also, the RNN operator in above example is serialized into a protobuf message of type `OpDesc` and would look like: + +``` +OpDesc { + inputs = {0} // the index of x + outputs = {5, 3} // indices of act and hidden_out + attrs { + "memories" : {1} // the index of h + "step_net" : + } +}; +``` + +This `OpDesc` value is in the `ops` field of the `BlockDesc` value representing the global block. + + +## The Compilation of Blocks + +During the generation of the Protobuf message, the Block should store VarDesc (the Protobuf message which describes Variable) and OpDesc (the Protobuf message which describes Operator). + +VarDesc in a block should have its name scope to avoid local variables affect parent block's name scope. +Child block's name scopes should inherit the parent's so that OpDesc in child block can reference a VarDesc that stored in parent block. For example + +```python +a = pd.Varaible(shape=[20, 20]) +b = pd.fc(a, params=["fc.w", "fc.b"]) + +rnn = pd.create_rnn() +with rnn.stepnet() as net: + x = net.set_inputs(a) + # reuse fc's parameter + fc_without_b = pd.get_variable("fc.w") + net.set_outputs(fc_without_b) + +out = rnn() +``` +the method `pd.get_variable` can help retrieve a Variable by a name, a Variable may store in a parent block, but might be retrieved in a child block, so block should have a variable scope that supports inheritance. + +In compiler design, the symbol table is a data structure created and maintained by compilers to store information about the occurrence of various entities such as variable names, function names, classes, etc. + +To store the definition of variables and operators, we define a C++ class `SymbolTable`, like the one used in compilers. + +`SymbolTable` can do the following stuff: + +- store the definitions (some names and attributes) of variables and operators, +- to verify if a variable was declared, +- to make it possible to implement type checking (offer Protobuf message pointers to `InferShape` handlers). + + +```c++ +// Information in SymbolTable is enough to trace the dependency graph. So maybe +// the Eval() interface takes a SymbolTable is enough. +class SymbolTable { + public: + SymbolTable(SymbolTable* parent) : parent_(parent) {} + + OpDesc* NewOp(const string& name=""); + + // TODO determine whether name is generated by python or C++ + // currently assume that a unique name will be generated by C++ if the + // argument name left default. + VarDesc* NewVar(const string& name=""); + + // find a VarDesc by name, if recursive true, find parent's SymbolTable + // recursively. + // this interface is introduced to support InferShape, find protobuf messages + // of variables and operators, pass pointers into InferShape. + // operator + // + // NOTE maybe some C++ classes such as VarDescBuilder and OpDescBuilder should + // be proposed and embedded into pybind to enable python operate on C++ pointers. + VarDesc* FindVar(const string& name, bool recursive=true); + + OpDesc* FindOp(const string& name); + + BlockDesc Compile() const; + + private: + SymbolTable* parent_; + + map ops_; + map vars_; +}; +``` + +After all the description of variables and operators is added into SymbolTable, +the block has enough information to run. + +The `Block` class takes a `BlockDesc` as input, and provide `Run` and `InferShape` functions. + + +```c++ +namespace { + +class Block : OperatorBase { +public: + Block(const BlockDesc& desc) desc_(desc) {} + + void InferShape(const framework::Scope& scope) const override { + if (!symbols_ready_) { + CreateVariables(scope); + CreateOperators(); + } + // should run InferShape first. + for (auto& op : runtime_table_.ops()) { + op->InferShape(scope); + } + } + + void Run(const framework::Scope& scope, + const platform::DeviceContext& dev_ctx) const override { + PADDLE_ENFORCE(symbols_ready_, "operators and variables should be created first."); + for (auto& op : runtime_table_.ops()) { + op->Run(scope, dev_ctx); + } + } + + void CreateVariables(const framework::Scope& scope); + void CreateOperators(); + + // some other necessary interfaces of NetOp are list below + // ... + +private: + BlockDesc desc_; + bool symbols_ready_{false}; +}; +``` + +## The Execution of Blocks + +Block inherits from OperatorBase, which has a Run method. +Block's Run method will run its operators sequentially. + +There is another important interface called `Eval`, which take some arguments called targets, and generate a minimal graph which takes targets as the end points and creates a new Block, +after `Run`, `Eval` will get the latest value and return the targets. + +The definition of Eval is as follows: + +```c++ +// clean a block description by targets using the corresponding dependency graph. +// return a new BlockDesc with minimal number of operators. +// NOTE not return a Block but the block's description so that this can be distributed +// to a cluster. +BlockDesc Prune(const BlockDesc& desc, vector targets); + +void Block::Eval(const vector& targets, + const framework::Scope& scope, + const platform::DeviceContext& dev_ctx) { + BlockDesc min_desc = Prune(desc_, targets); + Block min_block(min_desc); + min_block.Run(scope, dev_ctx); +} +``` diff --git a/doc/design/if_else_op.md b/doc/design/if_else_op.md index 7370c2a24fa644a64e738f202bac9b9209642e08..954a19c0733358c235eae3cffe134c23dac94c95 100644 --- a/doc/design/if_else_op.md +++ b/doc/design/if_else_op.md @@ -1,22 +1,4 @@ -IfOp should have only one branch. An IfOp operator takes a `cond` variable whose value must be a vector of N boolean elements. Its return value has M (M<=N) instances, each corresponds to a true element in `cond`. - -```python -import paddle as pd - -x = var() -y = var() -cond = var() - -b = pd.create_ifop(inputs=[x], output_num=1) -with b.true_block(): - x = b.inputs(0) - z = operator.add(x, y) - b.set_output(0, operator.softmax(z)) - -out = b(cond) -``` - -If we want the output still has N instances, we can use IfElseOp with a default value, whose minibatch size must be N: +IfOp should have only one branch. An IfOp operator takes a `cond` variable whose value must be a vector of N boolean elements. Its return value has N instances. If cond[i] == True, input instance input[i] will go through true_block() and generate output[i]; otherwise it will produce output from false_bloack(). ```python import paddle as pd @@ -39,7 +21,7 @@ with b.false_block(): out = b(cond) ``` -If only true_block is set in an IfElseOp, we can have a default value for false as: +If only true_block is set in an IfElseOp, a special case is that we can have a default value for false as: ```python import paddle as pd diff --git a/doc/design/ops/images/2_level_rnn.dot b/doc/design/ops/images/2_level_rnn.dot new file mode 100644 index 0000000000000000000000000000000000000000..a498e882a3d85a33d44dbad7474fa2a340e33976 --- /dev/null +++ b/doc/design/ops/images/2_level_rnn.dot @@ -0,0 +1,56 @@ +digraph G { + + rnn [label="1-th level RNN" shape=box] + + subgraph cluster0 { + label = "time step 0" + + sent0 [label="sentence"] + sent1 [label="sentence"] + + rnn1 [label="2-th level RNN" shape=box] + + sent0 -> rnn1 + sent1 -> rnn1 + } + + subgraph cluster1 { + label = "time step 1" + + sent2 [label="sentence"] + sent3 [label="sentence"] + + rnn2 [label="2-th level RNN" shape=box] + + sent2 -> rnn2 + sent3 -> rnn2 + } + + subgraph cluster2 { + label = "time step 2" + + sent4 [label="sentence"] + sent5 [label="sentence"] + + rnn3 [label="2-th level RNN" shape=box] + + sent4 -> rnn3 + sent5 -> rnn3 + } + + + para0 [label="paragraph info 0"] + para1 [label="paragraph info 1"] + para2 [label="paragraph info 2"] + + rnn1 -> para0 + rnn2 -> para1 + rnn3 -> para2 + + para0 -> rnn + para1 -> rnn + para2 -> rnn + + chapter [label="chapter info"] + rnn -> chapter +} diff --git a/doc/design/ops/images/2_level_rnn.png b/doc/design/ops/images/2_level_rnn.png new file mode 100644 index 0000000000000000000000000000000000000000..0537a75beb175c0c284717421f7aa908da2a5038 Binary files /dev/null and b/doc/design/ops/images/2_level_rnn.png differ diff --git a/doc/design/ops/images/rnn.dot b/doc/design/ops/images/rnn.dot new file mode 100644 index 0000000000000000000000000000000000000000..c1141cd9c981bb3cbf50d8bf7a6ed210280d79a5 --- /dev/null +++ b/doc/design/ops/images/rnn.dot @@ -0,0 +1,87 @@ +digraph G { + label = "simple RNN implementation" + + ranksep=2; + + //graph [nodesep=1, ranksep=1]; + + node[nodesep=1] + + subgraph cluster0 { + label = "global scope" + rankdir = TB + W + boot_memory + input + output + } + + subgraph cluster1 { + label = "step-scope 0" + rankdir = TB + memory0[label="memory"] + prememory0[label="pre-memory"] + step_input0[label="step input"] + step_output0[label="step output"] + } + + subgraph cluster2 { + label = "step-scope 1" + rankdir = TB + memory1[label="memory"] + prememory1[label="pre-memory"] + step_input1[label="step input"] + step_output1[label="step output"] + } + + subgraph cluster3 { + label = "step-scope 2" + rankdir = TB + memory2[label="memory"] + prememory2[label="pre-memory"] + step_input2[label="step input"] + step_output2[label="step output"] + } + + stepnet [shape=box] + stepnet0 [shape=box, style=dashed] + stepnet1 [shape=box, style=dashed] + stepnet2 [shape=box, style=dashed] + + + edge[color=blue] + boot_memory -> prememory0 [label="init" color="blue"] + memory0 -> prememory1 [label="copy/reference" color="blue"] + memory1 -> prememory2 [label="copy/reference" color="blue"] + + edge[color=black] + W -> stepnet0[constraint=false, style=dashed] + W -> stepnet1[constraint=false, style=dashed] + W -> stepnet2[constraint=false, style=dashed] + + memory0 -> stepnet0[style=dashed] + prememory0 -> stepnet0 -> step_output0[style=dashed] + + memory1 -> stepnet1[style=dashed] + prememory1 -> stepnet1 -> step_output1[style=dashed] + + memory2 -> stepnet2[style=dashed] + prememory2 -> stepnet2 -> step_output2[style=dashed] + + input -> step_input0 + input -> step_input1 + input -> step_input2 + + step_input0 -> stepnet0 [style=dashed] + step_input1 -> stepnet1[style=dashed] + step_input2 -> stepnet2[style=dashed] + + step_output0 -> output + step_output1 -> output + step_output2 -> output + + stepnet0 -> stepnet[style=dashed] + stepnet1 -> stepnet[style=dashed] + stepnet2 -> stepnet[style=dashed] + +} diff --git a/doc/design/ops/images/rnn.jpg b/doc/design/ops/images/rnn.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9867e404cf959df0dce6ded5222b466c788fb840 Binary files /dev/null and b/doc/design/ops/images/rnn.jpg differ diff --git a/doc/design/ops/images/rnn.png b/doc/design/ops/images/rnn.png new file mode 100644 index 0000000000000000000000000000000000000000..e139e373fe8396782044cfd936fdde624f8c66fe Binary files /dev/null and b/doc/design/ops/images/rnn.png differ diff --git a/doc/design/ops/images/rnn_2level_data.dot b/doc/design/ops/images/rnn_2level_data.dot new file mode 100644 index 0000000000000000000000000000000000000000..1d85ae2617a915ad0ad8288d848b607cc37ad297 --- /dev/null +++ b/doc/design/ops/images/rnn_2level_data.dot @@ -0,0 +1,75 @@ +digraph G { + chapter [label="chapter"] + + subgraph cluster0 { + label = "paragraph 0" + + top_rnn0[label="top rnn step 0" shape=box] + + p0 [label="paragraph 0"] + p1 [label="paragraph 1"] + } + + subgraph cluster1{ + label = "paragraph 1" + + top_rnn1[label="top rnn step 1" shape=box] + + p2 [label="paragraph 0"] + p3 [label="paragraph 1"] + } + + subgraph cluster_p0 { + label = "sentence 0" + + low_rnn0 [label="low rnn step 0" shape=box] + s00 [label="sentence 0"] + s01 [label="sentence 1"] + + low_rnn0 -> s00 + low_rnn0 -> s01 + } + + subgraph cluster_p1 { + label = "sentence 1" + low_rnn1 [label="low rnn step 1" shape=box] + s10 [label="sentence 0"] + s11 [label="sentence 1"] + low_rnn1 -> s10 + low_rnn1 -> s11 + } + + subgraph cluster_p2 { + label = "sentence 1" + low_rnn2 [label="low rnn step 0" shape=box] + s20 [label="sentence 0"] + s21 [label="sentence 1"] + low_rnn2 -> s20 + low_rnn2 -> s21 + } + + subgraph cluster_p3 { + label = "sentence 1" + low_rnn3 [label="low rnn step 1" shape=box] + s30 [label="sentence 0"] + s31 [label="sentence 1"] + low_rnn3 -> s30 + low_rnn3 -> s31 + } + + + chapter -> top_rnn0 + chapter -> top_rnn1 + + top_rnn0 -> p0 + top_rnn0 -> p1 + top_rnn1 -> p2 + top_rnn1 -> p3 + + + p0 -> low_rnn0 + p1 -> low_rnn1 + p2 -> low_rnn2 + p3 -> low_rnn3 + +} diff --git a/doc/design/ops/images/rnn_2level_data.png b/doc/design/ops/images/rnn_2level_data.png new file mode 100644 index 0000000000000000000000000000000000000000..4be81b2430717a6a506342a09fc26899568574c6 Binary files /dev/null and b/doc/design/ops/images/rnn_2level_data.png differ diff --git a/doc/design/ops/rnn.md b/doc/design/ops/rnn.md new file mode 100644 index 0000000000000000000000000000000000000000..a78eea7d45e9e9553d153170aa31da55ec6e8289 --- /dev/null +++ b/doc/design/ops/rnn.md @@ -0,0 +1,153 @@ +# RNNOp design + +This document is about an RNN operator which requires that instances in a mini-batch have the same length. We will have a more flexible RNN operator. + +## RNN Algorithm Implementation + +

+ +

+ +The above diagram shows an RNN unrolled into a full network. + +There are several important concepts: + +- *step-net*: the sub-graph to run at each step, +- *memory*, $h_t$, the state of the current step, +- *ex-memory*, $h_{t-1}$, the state of the previous step, +- *initial memory value*, the ex-memory of the first step. + +### Step-scope + +There could be local variables defined in step-nets. PaddlePaddle runtime realizes these variables in *step-scopes* -- scopes created for each step. + +

+
+Figure 2 the RNN's data flow +

+ +Please be aware that all steps run the same step-net. Each step + +1. creates the step-scope, +2. realizes local variables, including step-outputs, in the step-scope, and +3. runs the step-net, which could use these variables. + +The RNN operator will compose its output from step outputs in step scopes. + +### Memory and Ex-memory + +Let's give more details about memory and ex-memory via a simply example: + +$$ +h_t = U h_{t-1} + W x_t +$$, + +where $h_t$ and $h_{t-1}$ are the memory and ex-memory of step $t$'s respectively. + +In the implementation, we can make an ex-memory variable either "refers to" the memory variable of the previous step, +or copy the value of the previous memory value to the current ex-memory variable. + +### Usage in Python + +For more information on Block, please refer to the [design doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.md). + +We can define an RNN's step-net using Block: + +```python +import paddle as pd + +X = some_op() # x is some operator's output, and is a LoDTensor +a = some_op() + +# declare parameters +W = pd.Variable(shape=[20, 30]) +U = pd.Variable(shape=[20, 30]) + +rnn = pd.create_rnn_op(output_num=1) +with rnn.stepnet(): + x = rnn.add_input(X) + # declare a memory (rnn's step) + h = rnn.add_memory(init=a) + # h.pre_state() means previous memory of rnn + new_state = pd.add_two( pd.matmul(W, x) + pd.matmul(U, h.pre_state())) + # update current memory + h.update(new_state) + # indicate that h variables in all step scopes should be merged + rnn.add_outputs(h) + +out = rnn() +``` + +Python API functions in above example: + +- `rnn.add_input` indicates the parameter is a variable that will be segmented into step-inputs. +- `rnn.add_memory` creates a variable used as the memory. +- `rnn.add_outputs` mark the variables that will be concatenated across steps into the RNN output. + +### Nested RNN and LoDTensor + +An RNN whose step-net includes other RNN operators is known as an *nested RNN*. + +For example, we could have a 2-level RNN, where the top level corresponds to paragraphs, and the lower level corresponds to sentences. + +The following figure illustrates the feeding of text into the lower level, one sentence each step, and the feeding of step outputs to the top level. The final top level output is about the whole text. + +

+ +

+ +```python +import paddle as pd + +W = pd.Variable(shape=[20, 30]) +U = pd.Variable(shape=[20, 30]) + +W0 = pd.Variable(shape=[20, 30]) +U0 = pd.Variable(shape=[20, 30]) + +# a is output of some op +a = some_op() + +# chapter_data is a set of 128-dim word vectors +# the first level of LoD is sentence +# the second level of LoD is chapter +chapter_data = pd.Variable(shape=[None, 128], type=pd.lod_tensor, level=2) + +def lower_level_rnn(paragraph): + ''' + x: the input + ''' + rnn = pd.create_rnn_op(output_num=1) + with rnn.stepnet(): + sentence = rnn.add_input(paragraph, level=0) + h = rnn.add_memory(shape=[20, 30]) + h.update( + pd.matmul(W, sentence) + pd.matmul(U, h.pre_state())) + # get the last state as sentence's info + rnn.add_outputs(h) + return rnn + +top_level_rnn = pd.create_rnn_op(output_num=1) +with top_level_rnn.stepnet(): + paragraph_data = rnn.add_input(chapter_data, level=1) + low_rnn = lower_level_rnn(paragraph_data) + paragraph_out = low_rnn() + + h = rnn.add_memory(init=a) + h.update( + pd.matmul(W0, paragraph_data) + pd.matmul(U0, h.pre_state())) + top_level_rnn.add_outputs(h) + +# just output the last step +chapter_out = top_level_rnn(output_all_steps=False) +``` + +in above example, the construction of the `top_level_rnn` calls `lower_level_rnn`. The input is a LoD Tensor. The top level RNN segments input text data into paragraphs, and the lower level RNN segments each paragraph into sentences. + +By default, the `RNNOp` will concatenate the outputs from all the time steps, +if the `output_all_steps` set to False, it will only output the final time step. + + +

+ +

diff --git a/doc/howto/dev/new_op_cn.md b/doc/howto/dev/new_op_cn.md index e3892849abe21fc207d2fcbe4adc65184ba771f4..c6570b89aedfaac1aef9b00e889b0b3ed21d8d65 100644 --- a/doc/howto/dev/new_op_cn.md +++ b/doc/howto/dev/new_op_cn.md @@ -34,7 +34,7 @@ Kernel实现 | CPU、GPU共享Kernel实现在`.h`文件中,否则,CPU 注册Op | Op注册实现在`.cc`文件;Kernel注册CPU实现在`.cc`文件中,GPU实现在`.cu`文件中 -实现新的op都添加至目录[paddle/operators](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/operators)下,文件命名以`*_op.h`(如有) 、 `*_op.cc` 、`*_op.cu`(如有)结尾。 +实现新的op都添加至目录[paddle/operators](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/operators)下,文件命名以`*_op.h`(如有) 、 `*_op.cc` 、`*_op.cu`(如有)结尾。**系统会根据文件名自动构建op和其对应的Python扩展。** 下面以矩阵乘操作,即[MulOp](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/mul_op.cc)为例来介绍如何写带Kernel的Operator。 @@ -224,45 +224,15 @@ MulOp(const std::string &type, const framework::VariableNameMap &inputs, ### 5. 编译 -- 简单**无特殊依赖**的OP无需修改CMakeList.txt文件。[paddle/operators/CMakeLists.txt](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/CMakeLists.txt) 会自动将 `paddle/operators` 目录下新增的 `*_op.cc` 文件加入编译。 -- 较为复杂、**有额外依赖** 的operator仍需要修改[paddle/operators/CMakeLists.txt](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/CMakeLists.txt)。如,`mul_op` 依赖 `math_function`,需要在`CMakeLists.txt`中添加如下内容: +运行下面命令可以进行编译: - ``` - op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function) + - ``` - -- 运行下面命令可以进行编译: - - ``` - make mul_op - ``` +``` +make mul_op +``` ## 绑定Python -- 绑定Python - - 在 [`paddle/pybind/pybind.cc -`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/pybind/pybind.cc) 使用`USE_OP`告知编译器需要链接的Op,具体解释参考[代码注释](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/op_registry.h#L81)。 - - ``` - USE_OP(mul); - ``` - 如果只实现了CPU版本,则使用`USE_CPU_ONLY_OP`: - - ``` - USE_CPU_ONLY_OP(gather); - ``` - - 如果OP不带Kernel,则使用`USE_NO_KENREL_OP`: - - ``` - USE_NO_KENREL_OP(recurrent); - ``` - - - - 生成库 - - `paddle/operators` 目录下新增的 `*_op.cc` 文件会被自动添加链接到生成的lib库中。 +系统会对新增的op自动绑定Python,并链接到生成的lib库中。 ## 实现单元测试 @@ -367,3 +337,10 @@ make test ARGS="-R test_mul_op -V" ```bash ctest -R test_mul_op ``` + +## 注意事项 + +- 为每个Op创建单独的`*_op.h`(如有)、`*_op.cc`和`*_op.cu`(如有)。不允许一个文件中包含多个Op,这将会导致编译出错。 +- 注册Op时的类型名,需要和该Op的名字一样。即不允许在`A_op.cc`里面,注册`REGISTER_OP(B, ...)`等,这将会导致单元测试出错。 +- 如果Op没有实现GPU Kernel,请不要创建空的`*_op.cu`,这将会导致单元测试出错。 +- 如果多个Op依赖一些共用的函数,可以创建非`*_op.*`格式的文件来存放,如`gather.h`文件。 diff --git a/paddle/cuda/include/hl_cuda_cudnn.h b/paddle/cuda/include/hl_cuda_cudnn.h index 3f68c62de6d9b3aaadc9180d86159089dc728ea9..b44b071bd1b3b6e9e5539d5dc0c2b155c524fd57 100644 --- a/paddle/cuda/include/hl_cuda_cudnn.h +++ b/paddle/cuda/include/hl_cuda_cudnn.h @@ -22,10 +22,10 @@ limitations under the License. */ */ typedef enum { HL_POOLING_MAX = 0, - // average includes padded values - HL_POOLING_AVERAGE = 1, // average does not include padded values - HL_POOLING_AVERAGE_EXCLUDE_PADDING = 2, + HL_POOLING_AVERAGE = 1, + // average includes padded values + HL_POOLING_AVERAGE_INCLUDE_PADDING = 2, HL_POOLING_END } hl_pooling_mode_t; diff --git a/paddle/cuda/include/hl_tensor_ops.h b/paddle/cuda/include/hl_tensor_ops.h index 93d38b7d2299d994cde0934213668a525bffa80c..b2bf334dab9799153fe1d4fe2c74cce9d57168b9 100644 --- a/paddle/cuda/include/hl_tensor_ops.h +++ b/paddle/cuda/include/hl_tensor_ops.h @@ -461,7 +461,7 @@ class add { public: INLINE float32x4_t operator()(const float32x4_t a, const float32x4_t b) const { - return vmulq_f32(a, b); + return vaddq_f32(a, b); } }; diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index 9ba3d142617537c0160f6dccb86ddca43ada15a5..58674febdc4a094c95ff03701e4586c32729847d 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -211,13 +211,11 @@ __global__ void KeAvgPoolForward(const int nthreads, int hstart = ph * strideH - padH; int wstart = pw * strideW - padW; - int hend = min(hstart + sizeY, height + padH); - int wend = min(wstart + sizeX, width + padW); - int pool_size = (hend - hstart) * (wend - wstart); + int hend = min(hstart + sizeY, height); + int wend = min(wstart + sizeX, width); hstart = max(hstart, 0); wstart = max(wstart, 0); - hend = min(hend, height); - wend = min(wend, width); + int pool_size = (hend - hstart) * (wend - wstart); real aveval = 0; inputData += (frameNum * channels + c) * height * width; @@ -299,12 +297,14 @@ __global__ void KeAvgPoolBackward(const int nthreads, outGrad += (frameNum * outStride + offsetC * pooledH * pooledW); for (int ph = phstart; ph < phend; ++ph) { + int hstart = ph * strideH - padH; + int hend = min(hstart + sizeY, height); + hstart = max(hstart, 0); for (int pw = pwstart; pw < pwend; ++pw) { // figure out the pooling size - int hstart = ph * strideH - padH; int wstart = pw * strideW - padW; - int hend = min(hstart + sizeY, height + padH); - int wend = min(wstart + sizeX, width + padW); + int wend = min(wstart + sizeX, width); + wstart = max(wstart, 0); int poolsize = (hend - hstart) * (wend - wstart); gradient += outGrad[ph * pooledW + pw] / poolsize; } @@ -600,16 +600,13 @@ __global__ void KeAvgPool3DForward(const int nthreads, int dstart = pd * strideD - padD; int hstart = ph * strideH - padH; int wstart = pw * strideW - padW; - int dend = min(dstart + sizeZ, depth + padD); - int hend = min(hstart + sizeY, height + padH); - int wend = min(wstart + sizeX, width + padW); - int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); + int dend = min(dstart + sizeZ, depth); + int hend = min(hstart + sizeY, height); + int wend = min(wstart + sizeX, width); dstart = max(dstart, 0); hstart = max(hstart, 0); wstart = max(wstart, 0); - dend = min(dend, depth); - hend = min(hend, height); - wend = min(wend, width); + int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); real aveval = 0; inputData += (frameNum * channels + c) * depth * height * width; @@ -712,15 +709,18 @@ __global__ void KeAvgPool3DBackward(const int nthreads, outGrad += (frameNum * channels + offsetC) * pooledD * pooledH * pooledW; for (int pd = pdstart; pd < pdend; ++pd) { + int dstart = pd * strideD - padD; + int dend = min(dstart + sizeZ, depth); + dstart = max(dstart, 0); for (int ph = phstart; ph < phend; ++ph) { + int hstart = ph * strideH - padH; + int hend = min(hstart + sizeY, height); + hstart = max(hstart, 0); for (int pw = pwstart; pw < pwend; ++pw) { // figure out the pooling size - int dstart = pd * strideD - padD; - int hstart = ph * strideH - padH; int wstart = pw * strideW - padW; - int dend = min(dstart + sizeZ, depth + padD); - int hend = min(hstart + sizeY, height + padH); - int wend = min(wstart + sizeX, width + padW); + int wend = min(wstart + sizeX, width); + wstart = max(wstart, 0); int poolsize = (dend - dstart) * (hend - hstart) * (wend - wstart); gradient += outGrad[(pd * pooledH + ph) * pooledW + pw] / poolsize; } diff --git a/paddle/cuda/src/hl_cuda_cudnn.cc b/paddle/cuda/src/hl_cuda_cudnn.cc index f38ef692558b908ed65d2c84821bbb7c3b439742..b8caf48f9c06094e85765f7aa5a3f4195d0ca931 100644 --- a/paddle/cuda/src/hl_cuda_cudnn.cc +++ b/paddle/cuda/src/hl_cuda_cudnn.cc @@ -432,11 +432,11 @@ void hl_create_pooling_descriptor(hl_pooling_descriptor* pooling_desc, cudnn_mode = CUDNN_POOLING_MAX; break; case HL_POOLING_AVERAGE: - cudnn_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; - break; - case HL_POOLING_AVERAGE_EXCLUDE_PADDING: cudnn_mode = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; break; + case HL_POOLING_AVERAGE_INCLUDE_PADDING: + cudnn_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + break; default: LOG(FATAL) << "parameter mode error"; } diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index 568f4e89819c8345d8908634f6fa56f09483a763..fac5cd20aa7f9db0792f8102bb442192ab1ad63f 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -51,18 +51,15 @@ bool operator==(const LoD& a, const LoD& b); * LoDTensor (Level of details Tensor) * see https://en.wikipedia.org/wiki/Level_of_details for reference. */ -class LoDTensor { +class LoDTensor : public Tensor { public: LoDTensor() {} - LoDTensor(const LoD& lod, Tensor* t) : lod_(lod), tensor_(t) {} - void set_lod(const LoD& lod) { lod_ = lod; } - - void set_tensor(Tensor* tensor) { tensor_ = tensor; } + explicit LoDTensor(const LoD& lod) : lod_(lod) {} - Tensor& tensor() { return *tensor_; } + void set_lod(const LoD& lod) { lod_ = lod; } - LoD lod() { return lod_; } + LoD lod() const { return lod_; } /* * Get a element from LoD. @@ -104,7 +101,6 @@ class LoDTensor { private: LoD lod_; - Tensor* tensor_; // not owned }; } // namespace framework } // namespace paddle diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc index 1da8553134f377f7a4fbe8008d12fe8d4a0e47f4..7915326b27a22e9280e3f09d9bbfc2a58f46aff7 100644 --- a/paddle/framework/lod_tensor_test.cc +++ b/paddle/framework/lod_tensor_test.cc @@ -36,69 +36,64 @@ class LoDTensorTester : public ::testing::Test { ASSERT_EQ(lod.size(), 3UL); - tensor.Resize({20 /*batch size*/, 128 /*dim*/}); + lod_tensor_.Resize({20 /*batch size*/, 128 /*dim*/}); // malloc memory - tensor.mutable_data(place); + lod_tensor_.mutable_data(place); - lod_tensor.set_lod(lod); - lod_tensor.set_tensor(&tensor); + lod_tensor_.set_lod(lod); } protected: platform::CPUPlace place; - Tensor tensor; - LoDTensor lod_tensor; + LoDTensor lod_tensor_; }; -TEST_F(LoDTensorTester, NumLevels) { ASSERT_EQ(lod_tensor.NumLevels(), 3UL); } +TEST_F(LoDTensorTester, NumLevels) { ASSERT_EQ(lod_tensor_.NumLevels(), 3UL); } TEST_F(LoDTensorTester, NumElements) { - ASSERT_EQ(lod_tensor.NumElements(0), 2UL); - ASSERT_EQ(lod_tensor.NumElements(1), 4UL); - ASSERT_EQ(lod_tensor.NumElements(2), 8UL); + ASSERT_EQ(lod_tensor_.NumElements(0), 2UL); + ASSERT_EQ(lod_tensor_.NumElements(1), 4UL); + ASSERT_EQ(lod_tensor_.NumElements(2), 8UL); } TEST_F(LoDTensorTester, SliceLevels) { // slice 1 level for (size_t level = 0; level < 3UL; ++level) { - LoDTensor new_lod_tensor = lod_tensor; + LoDTensor new_lod_tensor = lod_tensor_; new_lod_tensor.SliceLevels(level, level + 1); ASSERT_EQ(new_lod_tensor.NumLevels(), 1UL); - ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor.NumElements(level)); - ASSERT_EQ(new_lod_tensor.tensor().data(), - lod_tensor.tensor().data()); + ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor_.NumElements(level)); + ASSERT_EQ(new_lod_tensor.data(), lod_tensor_.data()); } // slice 2 level for (size_t level = 0; level < 2UL; ++level) { - LoDTensor new_lod_tensor = lod_tensor; + LoDTensor new_lod_tensor = lod_tensor_; new_lod_tensor.SliceLevels(level, level + 2); ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL); - ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor.NumElements(level)); - ASSERT_EQ(new_lod_tensor.NumElements(1), lod_tensor.NumElements(level + 1)); - ASSERT_EQ(new_lod_tensor.tensor().data(), - lod_tensor.tensor().data()); + ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor_.NumElements(level)); + ASSERT_EQ(new_lod_tensor.NumElements(1), + lod_tensor_.NumElements(level + 1)); + ASSERT_EQ(new_lod_tensor.data(), lod_tensor_.data()); } } TEST_F(LoDTensorTester, SliceInLevel) { size_t level = 0; - LoDTensor new_lod_tensor = lod_tensor; + LoDTensor new_lod_tensor = lod_tensor_; new_lod_tensor.SliceInLevel(level, 0, 2); EXPECT_EQ(new_lod_tensor.NumLevels(), 3UL); EXPECT_EQ(new_lod_tensor.NumElements(0), 2UL); EXPECT_EQ(new_lod_tensor.NumElements(1), 4UL); EXPECT_EQ(new_lod_tensor.NumElements(2), 8UL); - ASSERT_EQ(new_lod_tensor.tensor().data(), - lod_tensor.tensor().data()); + ASSERT_EQ(new_lod_tensor.data(), lod_tensor_.data()); level = 1; - new_lod_tensor = lod_tensor; + new_lod_tensor = lod_tensor_; new_lod_tensor.SliceInLevel(level, 0, 2); ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL); ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL); ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL); - ASSERT_EQ(new_lod_tensor.tensor().data(), - lod_tensor.tensor().data()); + ASSERT_EQ(new_lod_tensor.data(), lod_tensor_.data()); } } // namespace framework diff --git a/paddle/framework/lod_tensor_test.cu b/paddle/framework/lod_tensor_test.cu index 1079a36a2e7b24f6f8a5bcbb296355567305a765..97e69cdb2e5e1e64031c899f5e04020665485ba8 100644 --- a/paddle/framework/lod_tensor_test.cu +++ b/paddle/framework/lod_tensor_test.cu @@ -26,18 +26,16 @@ __global__ void test(size_t* a, int size) { } TEST(LoDTensor, LoDInGPU) { - paddle::framework::Tensor tensor; paddle::framework::LoDTensor lod_tensor; paddle::platform::GPUPlace place(0); paddle::framework::LoD src_lod; src_lod.push_back(std::vector{0, 2, 4, 6, 8, 10, 12, 14}); - tensor.Resize({14, 16}); - tensor.mutable_data(place); + lod_tensor.Resize({14, 16}); + lod_tensor.mutable_data(place); lod_tensor.set_lod(src_lod); - lod_tensor.set_tensor(&tensor); CHECK_EQ(lod_tensor.lod_element(0, 2), 4); CHECK_EQ(lod_tensor.lod_element(0, 4), 8); diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 84fdb592a66e7ed96d2c0b81860fdcd5ceb1cf47..f8a64a786611ef872dbbfced10919e00c4d46715 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -186,6 +186,48 @@ void OperatorBase::GenerateTemporaryNames() { } } +template <> +const Tensor* InferShapeContext::Input(const std::string& name) const { + auto* var = InputVar(name); + return var == nullptr ? nullptr : GetTensorFromVar(var); +} + +template <> +const std::vector InferShapeContext::MultiInput( + const std::string& name) const { + auto names = op().Inputs(name); + std::vector res; + res.reserve(names.size()); + std::transform(names.begin(), names.end(), std::back_inserter(res), + [&](const std::string& sub_name) { + auto var = scope_.FindVar(sub_name); + return var == nullptr ? nullptr : GetTensorFromVar(var); + }); + return res; +} + +template <> +Tensor* ExecutionContext::Output(const std::string& name) const { + auto* var = OutputVar(name); + return var == nullptr ? nullptr : const_cast(GetTensorFromVar(var)); +} + +template <> +std::vector ExecutionContext::MultiOutput( + const std::string& name) const { + auto names = op().Outputs(name); + std::vector res; + res.reserve(names.size()); + std::transform(names.begin(), names.end(), std::back_inserter(res), + [&](const std::string& sub_name) { + auto var = scope().FindVar(sub_name); + return var == nullptr + ? nullptr + : const_cast(GetTensorFromVar(var)); + }); + return res; +} + void OpProtoAndCheckerMaker::Validate() { validated_ = true; CheckNoDuplicatedInOutAttrs(); diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index f14d901f17226032766f58a3e8c580dd6b753b70..b7c9c39402d57daf0aec97d98535ac8a8d9c0150 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -22,6 +22,7 @@ limitations under the License. */ #include "op_info.h" #include "paddle/framework/attribute.h" #include "paddle/framework/framework.pb.h" +#include "paddle/framework/lod_tensor.h" #include "paddle/framework/scope.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" @@ -326,11 +327,27 @@ class InferShapeContext { return res; } + const Tensor* GetTensorFromVar(const Variable* var) const { + if (var->IsType()) { + return &var->Get(); + } + PADDLE_ENFORCE(var->IsType(), + "The Input(%s) must be LoDTensor or Tensor."); + return &var->Get(); + } + private: const OperatorBase& op_; const Scope& scope_; }; +template <> +const Tensor* InferShapeContext::Input(const std::string& name) const; + +template <> +const std::vector InferShapeContext::MultiInput( + const std::string& name) const; + template struct EigenDeviceConverter; @@ -363,10 +380,38 @@ class ExecutionContext : public InferShapeContext { return device_context_; } + // redefine Output function, + // use Variable::Get instead of Variable::GetMutable + template + T* Output(const std::string& name) const { + auto var = OutputVar(name); + return var == nullptr ? nullptr : const_cast(&var->Get()); + } + + // redefine MultiOutput function. + // use Variable::Get instead of Variable::GetMutable + template + std::vector MultiOutput(const std::string& name) const { + auto names = op().Outputs(name); + std::vector res; + res.reserve(names.size()); + std::transform( + names.begin(), names.end(), std::back_inserter(res), + [&](const std::string& sub_name) { return Output(sub_name); }); + return res; + } + private: const platform::DeviceContext& device_context_; }; +template <> +Tensor* ExecutionContext::Output(const std::string& name) const; + +template <> +std::vector ExecutionContext::MultiOutput( + const std::string& name) const; + class OpKernel { public: /** diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 642b53efc7095d25712ca324638f5fe9b8316c0c..ed166935f76be9d25062b5e69536c7b7ac19045d 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -22,7 +22,7 @@ namespace framework { template inline void Tensor::check_memory_size() const { PADDLE_ENFORCE_NOT_NULL( - holder_, "Tenosr holds no memory. Call Tensor::mutable_data first."); + holder_, "Tensor holds no memory. Call Tensor::mutable_data first."); PADDLE_ENFORCE_GE( holder_->size(), numel() * sizeof(T) + offset_, "Tensor's dims_ is out of bound. Call Tensor::mutable_data " diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index 55302ea47120f420e952b26830c8ea4cbcce6435..e2ec738de35c90c6a06c9a46b062d4cce55f5eda 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -36,7 +36,7 @@ TEST(Tensor, DataAssert) { } catch (paddle::platform::EnforceNotMet err) { caught = true; std::string msg = - "holder_ should not be null\nTenosr holds no memory. Call " + "holder_ should not be null\nTensor holds no memory. Call " "Tensor::mutable_data first."; const char* what = err.what(); for (size_t i = 0; i < msg.length(); ++i) { @@ -112,7 +112,7 @@ TEST(Tensor, ShareDataWith) { } catch (paddle::platform::EnforceNotMet err) { caught = true; std::string msg = - "holder_ should not be null\nTenosr holds no memory. Call " + "holder_ should not be null\nTensor holds no memory. Call " "Tensor::mutable_data first."; const char* what = err.what(); for (size_t i = 0; i < msg.length(); ++i) { @@ -274,4 +274,4 @@ TEST(Tensor, ReshapeToMatrix) { Tensor res = ReshapeToMatrix(src, 2); ASSERT_EQ(res.dims()[0], 2 * 3); ASSERT_EQ(res.dims()[1], 4 * 9); -} \ No newline at end of file +} diff --git a/paddle/gserver/layers/CudnnPoolLayer.cpp b/paddle/gserver/layers/CudnnPoolLayer.cpp index 4adb2d4709e585a6fec052435c33714d6e3a3f0e..810a1af2d09c63c3787a1ac225c2c7de4238d609 100644 --- a/paddle/gserver/layers/CudnnPoolLayer.cpp +++ b/paddle/gserver/layers/CudnnPoolLayer.cpp @@ -29,9 +29,9 @@ bool CudnnPoolLayer::typeCheck(const std::string &poolType, if (mode) { *mode = HL_POOLING_AVERAGE; } - } else if (poolType == "cudnn-avg-excl-pad-pool") { + } else if (poolType == "cudnn-avg-incl-pad-pool") { if (mode) { - *mode = HL_POOLING_AVERAGE_EXCLUDE_PADDING; + *mode = HL_POOLING_AVERAGE_INCLUDE_PADDING; } } else { return false; diff --git a/paddle/gserver/layers/ExpandConvBaseLayer.cpp b/paddle/gserver/layers/ExpandConvBaseLayer.cpp deleted file mode 100644 index 2b7bef0a757d7c706be3815c539b036b094596cf..0000000000000000000000000000000000000000 --- a/paddle/gserver/layers/ExpandConvBaseLayer.cpp +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "ExpandConvBaseLayer.h" - -#include "paddle/utils/Logging.h" -namespace paddle { - -bool ExpandConvBaseLayer::init(const LayerMap &layerMap, - const ParameterMap ¶meterMap) { - /* Initialize the basic convolutional parent class */ - ConvBaseLayer::init(layerMap, parameterMap); - - int index = 0; - for (auto &inputConfig : config_.inputs()) { - const ConvConfig &conf = inputConfig.conv_conf(); - /* Consistent caffe mode for multiple input */ - caffeMode_ = conf.caffe_mode(); - - // create a new weight - size_t height, width; - height = filterPixels_[index] * filterChannels_[index]; - width = (!isDeconv_) ? numFilters_ : channels_[index]; - CHECK_EQ(parameters_[index]->getSize(), width * height); - Weight *w = new Weight(height, width, parameters_[index]); - weights_.emplace_back(w); - index++; - } - if (biasParameter_.get()) { - if (sharedBiases_) { - CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); - biases_ = - std::unique_ptr(new Weight(numFilters_, 1, biasParameter_)); - } else { - biases_ = - std::unique_ptr(new Weight(getSize(), 1, biasParameter_)); - } - } - getOutputSize(); - - return true; -} - -size_t ExpandConvBaseLayer::getOutputSize() { - CHECK_NE(inputLayers_.size(), 0UL); - size_t layerSize = ConvBaseLayer::calOutputSize(); - return layerSize; -} - -void ExpandConvBaseLayer::addSharedBias() { - size_t mapW = getOutputSize() / numFilters_; - size_t mapH = getOutputValue()->getElementCnt() / mapW; - MatrixPtr out = - Matrix::create(getOutputValue()->getData(), mapH, mapW, false, useGpu_); - - Matrix::resizeOrCreate(transOutValue_, mapW, mapH, false, useGpu_); - - out->transpose(transOutValue_, false); // false means no memory allocation - transOutValue_->reshape(transOutValue_->getElementCnt() / numFilters_, - numFilters_); - - MatrixPtr bias = Matrix::create(biases_->getW()->getData(), - 1, - biases_->getW()->getElementCnt(), - false, - useGpu_); - transOutValue_->addBias(*bias, 1.0f); - - transOutValue_->reshape(mapW, mapH); - transOutValue_->transpose(out, false); // false means no memory allocation - - out->clear(); - bias->clear(); -} - -void ExpandConvBaseLayer::addUnsharedBias() { - MatrixPtr outValue = getOutputValue(); - MatrixPtr bias = Matrix::create(biases_->getW()->getData(), - 1, - biases_->getW()->getElementCnt(), - false, - useGpu_); - outValue->addBias(*bias, 1.0f); -} - -void ExpandConvBaseLayer::bpropSharedBias(MatrixPtr biases, MatrixPtr v) { - size_t mapW = getOutputSize() / numFilters_; - size_t mapH = v->getElementCnt() / mapW; - MatrixPtr vTmp = Matrix::create(v->getData(), mapH, mapW, false, useGpu_); - - Matrix::resizeOrCreate(transOutValue_, mapW, mapH, false, useGpu_); - - vTmp->transpose(transOutValue_, false); // false means no memory allocation - transOutValue_->reshape(transOutValue_->getElementCnt() / numFilters_, - numFilters_); - biases->collectBias(*transOutValue_, 1.0f); -} - -void ExpandConvBaseLayer::bpropBiases(MatrixPtr v) { - MatrixPtr biases = Matrix::create(biases_->getWGrad()->getData(), - 1, - biases_->getWGrad()->getElementCnt(), - false, - useGpu_); - if (sharedBiases_) { - bpropSharedBias(biases, v); - } else { - biases->collectBias(*v, 1.0f); - } - biases->clear(); -} - -} // namespace paddle diff --git a/paddle/gserver/layers/ExpandConvBaseLayer.h b/paddle/gserver/layers/ExpandConvBaseLayer.h deleted file mode 100644 index 01c699d2344443a1887ec0b5005125f617cbe279..0000000000000000000000000000000000000000 --- a/paddle/gserver/layers/ExpandConvBaseLayer.h +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include "ConvBaseLayer.h" -#include "paddle/math/Matrix.h" - -namespace paddle { - -/** - * @brief A subclass of ConvBaseLayer that is a superclass of both - * ExpandConvLayer and ExpandConvTransLayer - */ -class ExpandConvBaseLayer : public ConvBaseLayer { -protected: - /// The transpose of output, which is an auxiliary matrix. - MatrixPtr transOutValue_; - -public: - explicit ExpandConvBaseLayer(const LayerConfig& config) - : ConvBaseLayer(config) {} - - ~ExpandConvBaseLayer() {} - - bool init(const LayerMap& layerMap, - const ParameterMap& parameterMap) override; - - size_t getOutputSize(); - - /** - * Add shared bias. - */ - void addSharedBias(); - - /** - * Add unshared bias. - */ - void addUnsharedBias(); - - void bpropSharedBias(MatrixPtr biases, MatrixPtr v); - void bpropBiases(MatrixPtr v); -}; - -} // namespace paddle diff --git a/paddle/gserver/layers/ExpandConvLayer.cpp b/paddle/gserver/layers/ExpandConvLayer.cpp index 20de475fc3f6b6f3c05ac26bea8363daff0cf110..48dfcb49a4c2c46891bb5236fc1f8e644c03f327 100644 --- a/paddle/gserver/layers/ExpandConvLayer.cpp +++ b/paddle/gserver/layers/ExpandConvLayer.cpp @@ -36,7 +36,36 @@ inline bool isDepthwiseConv(int channels, int groups) { bool ExpandConvLayer::init(const LayerMap &layerMap, const ParameterMap ¶meterMap) { /* Initialize the basic convolutional parent class */ - ExpandConvBaseLayer::init(layerMap, parameterMap); + ConvBaseLayer::init(layerMap, parameterMap); + + int index = 0; + for (auto &inputConfig : config_.inputs()) { + const ConvConfig &conf = inputConfig.conv_conf(); + /* Consistent caffe mode for multiple input */ + caffeMode_ = conf.caffe_mode(); + + // create a new weight + size_t height, width; + height = filterPixels_[index] * filterChannels_[index]; + width = (!isDeconv_) ? numFilters_ : channels_[index]; + CHECK_EQ(parameters_[index]->getSize(), width * height); + Weight *w = new Weight(height, width, parameters_[index]); + weights_.emplace_back(w); + index++; + } + + if (biasParameter_.get()) { + if (sharedBiases_) { + CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); + biases_ = std::unique_ptr( + new Weight(1, numFilters_, biasParameter_, 0)); + } else { + biases_ = + std::unique_ptr(new Weight(1, getSize(), biasParameter_, 0)); + } + } + + getOutputSize(); size_t numInputs = config_.inputs_size(); inputShape_.resize(numInputs); @@ -108,6 +137,12 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, return true; } +size_t ExpandConvLayer::getOutputSize() { + CHECK_NE(inputLayers_.size(), 0UL); + size_t layerSize = ConvBaseLayer::calOutputSize(); + return layerSize; +} + // i is the index of input layers #define BACKWARD_INPUT(i, inputs, outputs) \ backward_[2 * i]->calc(inputs, outputs) @@ -155,11 +190,7 @@ void ExpandConvLayer::forward(PassType passType) { /* add the bias-vector */ if (biases_.get()) { - if (sharedBiases_) { - addSharedBias(); - } else { - addUnsharedBias(); - } + output_.value->addBias(*biases_->getW(), 1.0, sharedBiases_); } /* activation */ @@ -171,7 +202,7 @@ void ExpandConvLayer::backward(const UpdateCallback &callback) { MatrixPtr outGrad = getOutputGrad(); if (biases_ && biases_->getWGrad()) { - bpropBiases(outGrad); + biases_->getWGrad()->collectBias(*getOutputGrad(), 1, sharedBiases_); /* Increasing the number of gradient */ biases_->getParameterPtr()->incUpdate(callback); } diff --git a/paddle/gserver/layers/ExpandConvLayer.h b/paddle/gserver/layers/ExpandConvLayer.h index a1f943d1521547af0f82cec7da8a4efe9037cd71..a0873de19253f2496bc0c2fba550b3199dfc7486 100644 --- a/paddle/gserver/layers/ExpandConvLayer.h +++ b/paddle/gserver/layers/ExpandConvLayer.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include -#include "ExpandConvBaseLayer.h" +#include "ConvBaseLayer.h" #include "paddle/math/Matrix.h" namespace paddle { @@ -28,10 +28,9 @@ namespace paddle { * The config file api is img_conv_layer. */ -class ExpandConvLayer : public ExpandConvBaseLayer { +class ExpandConvLayer : public ConvBaseLayer { public: - explicit ExpandConvLayer(const LayerConfig& config) - : ExpandConvBaseLayer(config) {} + explicit ExpandConvLayer(const LayerConfig& config) : ConvBaseLayer(config) {} ~ExpandConvLayer() {} @@ -41,6 +40,8 @@ public: void forward(PassType passType) override; void backward(const UpdateCallback& callback) override; + size_t getOutputSize(); + protected: std::vector inputShape_; std::vector filterShape_; diff --git a/paddle/gserver/layers/MKLDNNConvLayer.cpp b/paddle/gserver/layers/MKLDNNConvLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9088744beebd25ac105737fe3b012de143c66a7c --- /dev/null +++ b/paddle/gserver/layers/MKLDNNConvLayer.cpp @@ -0,0 +1,544 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "MKLDNNConvLayer.h" +#include "paddle/math/MathUtils.h" +#include "paddle/utils/Logging.h" + +using namespace mkldnn; // NOLINT +typedef memory::format format; + +namespace paddle { + +REGISTER_LAYER(mkldnn_conv, MKLDNNConvLayer); + +bool MKLDNNConvLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + if (!MKLDNNLayer::init(layerMap, parameterMap)) { + return false; + } + CHECK_EQ(inputLayers_.size(), 1) << "Only support one input layer yet"; + CHECK_EQ(inputLayers_.size(), parameters_.size()); + CHECK(config_.shared_biases()) << "Only support shared biases yet"; + + oc_ = config_.num_filters(); + const ConvConfig& conf = config_.inputs(0).conv_conf(); + ic_ = conf.channels(); + fw_ = conf.filter_size(); + fh_ = conf.filter_size_y(); + pw_ = conf.padding(); + ph_ = conf.padding_y(); + dw_ = conf.dilation(); + dh_ = conf.dilation_y(); + sw_ = conf.stride(); + sh_ = conf.stride_y(); + gp_ = conf.groups(); + oh_ = conf.output_y(); + ow_ = conf.output_x(); + ih_ = conf.img_size_y(); + iw_ = conf.img_size(); + caffeMode_ = conf.caffe_mode(); + CHECK(caffeMode_) << "Only support caffe mode yet"; + CHECK(dh_ == 1 && dw_ == 1) << "Only support dilation 1 yet"; + // check group setting + CHECK_EQ((oc_ / gp_) * gp_, oc_) << "group is indivisible for oc"; + CHECK_EQ((ic_ / gp_) * gp_, ic_) << "group is indivisible for ic"; + + // create weight + size_t height = oc_ / gp_; + size_t width = ic_ * fh_ * fw_; + CHECK_EQ(parameters_[0]->getSize(), height * width); + weight_ = + std::unique_ptr(new Weight(height, width, parameters_[0], 0)); + + // create biases + if (biasParameter_.get() != NULL) { + biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_)); + } + return true; +} + +void MKLDNNConvLayer::convertWeightsFromPaddle() { + if (hasInitedWgt_) { + return; + } + + CHECK(wgtVal_) << "should have been initialized"; + // the paddle weight format is oihw or goihw + auto targetDim = wgtVal_->getDims(); + auto srcFmt = (gp_ == 1) ? memory::format::oihw : memory::format::goihw; + wgtVal_->reorderDataFrom(wgtVal_, srcFmt, targetDim); + hasInitedWgt_ = true; +} + +void MKLDNNConvLayer::convertWeightsToPaddle() { + CHECK(wgtVal_) << "should have been initialized"; + auto targetDim = wgtVal_->getDims(); + auto dstFmt = (gp_ == 1) ? memory::format::oihw : memory::format::goihw; + wgtVal_->reorderDataTo(wgtVal_, dstFmt, targetDim); +} + +void MKLDNNConvLayer::reshape( + int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) { + reshapeInput(bs, ih, iw); + + // cal output sizes + // oc can not be changed + int fh = (fh_ - 1) * dh_ + 1; + int fw = (fw_ - 1) * dw_ + 1; + oh = outputSize(ih, fh, ph_, sh_, caffeMode_); + ow = outputSize(iw, fw, pw_, sw_, caffeMode_); + + reshapeOutput(oh, ow); + resizeOutput(bs, oc * oh * ow); + + printSizeInfo(); +} + +void MKLDNNConvLayer::resetFwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + resetFwdPD(fwdPD_); + + resetFwdBuffers(fwdPD_, in, wgt, bias, out); + + resetFwdPipeline(pipeline, fwdPD_, in, wgt, bias, out); + + printValueFormatFlow(); +} + +void MKLDNNConvLayer::resetBwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + std::shared_ptr bwdWgtPD; + std::shared_ptr bwdDataPD; + + resetBwdWgtPD(bwdWgtPD); + + resetBwdDataPD(bwdDataPD); + + resetBwdBuffers(bwdWgtPD, bwdDataPD, in, wgt, bias, out); + + resetBwdPipeline(pipeline, bwdWgtPD, bwdDataPD, in, wgt, bias, out); + + printGradFormatFlow(); +} + +void MKLDNNConvLayer::updateInputData() { + cpuInVal_->setData(getInputValue(0, CPU_DEVICE)->getData()); +} + +void MKLDNNConvLayer::updateWeights(const UpdateCallback& callback) { + weight_->getParameterPtr()->incUpdate(callback); + if (biases_ && biases_->getWGrad()) { + biases_->getParameterPtr()->incUpdate(callback); + } +} + +void MKLDNNConvLayer::loadConvSettings(memory::dims& wgt, + memory::dims& bias, + memory::dims& stride, + memory::dims& dilation, + memory::dims& padL, + memory::dims& padR) { + wgt = (gp_ == 1) ? memory::dims{oc_, ic_, fh_, fw_} + : memory::dims{gp_, oc_ / gp_, ic_ / gp_, fh_, fw_}; + bias = memory::dims{oc_}; + stride = memory::dims{sh_, sw_}; + padL = memory::dims{ph_, pw_}; + padR = getPaddingR(); + // note: mkldnn dilation start from 0 + dilation = memory::dims{dh_ - 1, dw_ - 1}; +} + +void MKLDNNConvLayer::resetFwdPD( + std::shared_ptr& pd) { + // dims for conv + memory::dims inDims = memory::dims{bs_, ic_, ih_, iw_}; + memory::dims outDims = memory::dims{bs_, oc_, oh_, ow_}; + memory::dims wgtDims, biasDims, strides, dilations, padL, padR; + loadConvSettings(wgtDims, biasDims, strides, dilations, padL, padR); + + prop_kind pk = passType_ == PASS_TEST ? prop_kind::forward_scoring + : prop_kind::forward_training; + algorithm algo = algorithm::convolution_direct; + padding_kind padKind = padding_kind::zero; + conv_fwd::desc fwdDesc = + biases_ && biases_->getW() + ? conv_fwd::desc(pk, + algo, + MKLDNNMatrix::createMemoryDesc(inDims), + MKLDNNMatrix::createMemoryDesc(wgtDims), + MKLDNNMatrix::createMemoryDesc(biasDims), + MKLDNNMatrix::createMemoryDesc(outDims), + strides, + dilations, + padL, + padR, + padKind) + : conv_fwd::desc(pk, + algo, + MKLDNNMatrix::createMemoryDesc(inDims), + MKLDNNMatrix::createMemoryDesc(wgtDims), + MKLDNNMatrix::createMemoryDesc(outDims), + strides, + dilations, + padL, + padR, + padKind); + pd.reset(new conv_fwd::primitive_desc(fwdDesc, engine_)); +} + +void MKLDNNConvLayer::resetFwdBuffers( + std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + CHECK(pd); + resetInValue(pd, in); + + resetWgtBiasValue(pd, wgt, bias); + + resetOutValue(pd, out); +} + +void MKLDNNConvLayer::resetFwdPipeline( + std::vector& pipeline, + std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + pipeline.clear(); + + if (cvtInVal_) { + pipeline.push_back(*cvtInVal_); + } + + if (bias) { + fwd_.reset(new conv_fwd(*pd, *in, *wgt, *bias, *out)); + } else { + fwd_.reset(new conv_fwd(*pd, *in, *wgt, *out)); + } + pipeline.push_back(*fwd_); + + if (cvtOutVal_) { + pipeline.push_back(*cvtOutVal_); + } +} + +void MKLDNNConvLayer::resetInValue( + std::shared_ptr& pd, MKLDNNMatrixPtr& in) { + const MatrixPtr& inMat = inputLayers_[0]->getOutput().value; + in = MKLDNNMatrix::create(inMat, pd->src_primitive_desc()); + + // create buffer and reorder if input value do not match + cpuInVal_ = nullptr; + cvtInVal_ = nullptr; + if (inputIsOnlyMKLDNN()) { + MKLDNNMatrixPtr dnnIn = std::dynamic_pointer_cast(inMat); + CHECK(dnnIn) << "Input should be MKLDNNMatrix"; + if (dnnIn->getPrimitiveDesc() != in->getPrimitiveDesc()) { + CHECK_EQ(dnnIn->getFormat(), format::nc); + CHECK(ih_ == 1 && iw_ == 1) << "when input is nc format"; + // create a new one with nchw format and same data + memory::dims inDims = memory::dims{bs_, ic_, 1, 1}; + dnnIn = MKLDNNMatrix::create(inMat, inDims, format::nchw, engine_); + CHECK(dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()); + } + in = dnnIn; + } else { + const MatrixPtr& cpuIn = getInputValue(0, CPU_DEVICE); + memory::dims inDims = memory::dims{bs_, ic_, ih_, iw_}; + cpuInVal_ = MKLDNNMatrix::create(cpuIn, inDims, format::nchw, engine_); + if (cpuInVal_->getPrimitiveDesc() != in->getPrimitiveDesc()) { + // create new mkldnn matrix + in = MKLDNNMatrix::create(nullptr, pd->src_primitive_desc()); + cvtInVal_ = MKLDNNMatrix::createReorder(cpuInVal_, in); + CHECK(cvtInVal_) << "should not be emptry"; + } else { + in = cpuInVal_; + } + } +} + +void MKLDNNConvLayer::resetWgtBiasValue( + std::shared_ptr& pd, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias) { + wgt = MKLDNNMatrix::create(weight_->getW(), pd->weights_primitive_desc()); + VLOG(MKLDNN_FMTS) << "Weight value format: " << wgt->getFormat(); + + bias = (biases_ && biases_->getW()) + ? MKLDNNMatrix::create(biases_->getW(), pd->bias_primitive_desc()) + : nullptr; +} + +void MKLDNNConvLayer::resetOutValue( + std::shared_ptr& pd, MKLDNNMatrixPtr& out) { + out = MKLDNNMatrix::create(output_.value, pd->dst_primitive_desc()); + + // change original output value from cpu matrix to mkldnn matrix + output_.value = std::dynamic_pointer_cast(out); + + // create reorder if output value has cpu device and pd do not match + cpuOutVal_ = nullptr; + cpuOutVal_ = nullptr; + if (!outputIsOnlyMKLDNN()) { + const MatrixPtr& cpuOut = getOutput(CPU_DEVICE).value; + memory::dims outDims = memory::dims{bs_, oc_, oh_, ow_}; + cpuOutVal_ = MKLDNNMatrix::create(cpuOut, outDims, format::nchw, engine_); + if (cpuOutVal_->getPrimitiveDesc() != out->getPrimitiveDesc()) { + cvtOutVal_ = MKLDNNMatrix::createReorder(out, cpuOutVal_); + CHECK(cvtOutVal_) << "should not be emptry"; + } else { + // CPU output share the same data of MKLDNN output + cpuOut->setData(out->getData()); + cpuOutVal_ = out; + } + } +} + +void MKLDNNConvLayer::resetBwdWgtPD( + std::shared_ptr& pd) { + memory::dims wgtDims, biasDims, strides, dilations, padL, padR; + loadConvSettings(wgtDims, biasDims, strides, dilations, padL, padR); + + // create backward weight using input, output and weight value memory desc + CHECK(inVal_) << "Should have input value"; + CHECK(outVal_) << "Should have output value"; + CHECK(wgtVal_) << "Should have weight value"; + algorithm algo = algorithm::convolution_direct; + padding_kind padKind = padding_kind::zero; + auto bwdWgtDesc = biasVal_ != nullptr + ? conv_bwdWgt::desc(algo, + inVal_->getMemoryDesc(), + wgtVal_->getMemoryDesc(), + biasVal_->getMemoryDesc(), + outVal_->getMemoryDesc(), + strides, + padL, + padR, + padKind) + : conv_bwdWgt::desc(algo, + inVal_->getMemoryDesc(), + wgtVal_->getMemoryDesc(), + outVal_->getMemoryDesc(), + strides, + padL, + padR, + padKind); + pd.reset(new conv_bwdWgt::primitive_desc(bwdWgtDesc, engine_, *fwdPD_)); + CHECK(pd->src_primitive_desc() == inVal_->getPrimitiveDesc()) + << "primitive desc of in value should equal"; + CHECK(pd->diff_dst_primitive_desc() == outVal_->getPrimitiveDesc()) + << "primitive desc of out grad should equal the out value"; + CHECK(pd->diff_weights_primitive_desc() == wgtVal_->getPrimitiveDesc()) + << "primitive desc of weight grad should equal the weight value"; +} + +void MKLDNNConvLayer::resetBwdDataPD( + std::shared_ptr& pd) { + pd = nullptr; + if (inputLayers_[0]->getOutput().grad == nullptr) { + return; + } + + memory::dims wgtDims, biasDims, strides, dilations, padL, padR; + loadConvSettings(wgtDims, biasDims, strides, dilations, padL, padR); + CHECK(inVal_) << "Should have input value"; + CHECK(outVal_) << "Should have output value"; + // create backward data using input and output value memory desc + // but using weight memory desc with any format + auto bwdDataDesc = conv_bwdData::desc(algorithm::convolution_direct, + inVal_->getMemoryDesc(), + MKLDNNMatrix::createMemoryDesc(wgtDims), + outVal_->getMemoryDesc(), + strides, + padL, + padR, + padding_kind::zero); + pd.reset(new conv_bwdData::primitive_desc(bwdDataDesc, engine_, *fwdPD_)); + CHECK(pd->diff_src_primitive_desc() == inVal_->getPrimitiveDesc()) + << "primitive desc of in grad should equal the in value"; + CHECK(pd->diff_dst_primitive_desc() == outVal_->getPrimitiveDesc()) + << "primitive desc of out grad should equal"; +} + +void MKLDNNConvLayer::resetBwdBuffers( + std::shared_ptr& wgtPD, + std::shared_ptr& dataPD, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + CHECK(wgtPD); + resetOutGrad(wgtPD, out); + + resetWgtBiasGrad(wgtPD, wgt, bias); + + resetInGrad(dataPD, in); + + resetWgtValBwdData(dataPD, wgtValBwdData_); +} + +void MKLDNNConvLayer::resetBwdPipeline( + std::vector& pipeline, + std::shared_ptr& wgtPD, + std::shared_ptr& dataPD, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + pipeline.clear(); + + if (cvtOutGrad_) { + pipeline.push_back(*cvtOutGrad_); + } + + // add bwdWgt handle + if (bias) { + bwdWgt_.reset(new conv_bwdWgt(*wgtPD, *inVal_, *out, *wgt, *bias)); + } else { + bwdWgt_.reset(new conv_bwdWgt(*wgtPD, *inVal_, *out, *wgt)); + } + pipeline.push_back(*bwdWgt_); + + if (dataPD == nullptr) { + return; + } + + if (cvtWgtVal_) { + pipeline.push_back(*cvtWgtVal_); + } + + // add bwdData handle + CHECK(wgtValBwdData_) << "Should have weight memory"; + bwdData_.reset(new conv_bwdData(*dataPD, *out, *wgtValBwdData_, *in)); + pipeline.push_back(*bwdData_); + + if (cvtInGrad_) { + pipeline.push_back(*cvtInGrad_); + } +} + +void MKLDNNConvLayer::resetOutGrad( + std::shared_ptr& wgtPD, MKLDNNMatrixPtr& out) { + const MatrixPtr& outMat = output_.grad; + out = MKLDNNMatrix::create(outMat, wgtPD->diff_dst_primitive_desc()); + CHECK(outVal_ != nullptr && + out->getPrimitiveDesc() == outVal_->getPrimitiveDesc()) + << "primitive desc of out grad and value should be equal"; + + // TODO(TJ): merge outgrad + // create reorder if has output grad does not match + cpuOutGrad_ = nullptr; + cvtOutGrad_ = nullptr; + if (!outputIsOnlyMKLDNN()) { + const MatrixPtr& cpuOut = getOutput(CPU_DEVICE).grad; + // same PrimitiveDesc with cpuInVal_ + CHECK(cpuOutVal_); + cpuOutGrad_ = MKLDNNMatrix::create(cpuOut, cpuOutVal_->getPrimitiveDesc()); + if (cpuOutGrad_->getPrimitiveDesc() == out->getPrimitiveDesc()) { + outMat->setData(cpuOut->getData()); + out = cpuOutGrad_; + } else { + cvtOutGrad_ = MKLDNNMatrix::createReorder(cpuOutGrad_, out); + CHECK(cvtOutGrad_); + } + } +} + +void MKLDNNConvLayer::resetWgtBiasGrad( + std::shared_ptr& wgtPD, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias) { + wgt = MKLDNNMatrix::create(weight_->getWGrad(), + wgtPD->diff_weights_primitive_desc()); + CHECK(nullptr != wgtVal_ && + wgt->getPrimitiveDesc() == wgtVal_->getPrimitiveDesc()) + << "primitive desc of weight grad and value should be equal"; + VLOG(MKLDNN_FMTS) << "weight grad format: " << wgt->getFormat(); + + bias = nullptr; + if (biasVal_ == nullptr) { + return; + } + bias = MKLDNNMatrix::create(biases_->getWGrad(), + wgtPD->diff_bias_primitive_desc()); + CHECK(bias->getPrimitiveDesc() == biasVal_->getPrimitiveDesc()) + << "primitive desc of bias grad should equal the bias value"; +} + +void MKLDNNConvLayer::resetInGrad( + std::shared_ptr& dataPD, + MKLDNNMatrixPtr& in) { + if (dataPD == nullptr) { + return; + } + + // TODO(TJ): use outputMaps_ ways to get the inGrad_ when merge outgrad done + in = MKLDNNMatrix::create(inputLayers_[0]->getOutput().grad, + dataPD->diff_src_primitive_desc()); + CHECK(nullptr != inVal_ && + in->getPrimitiveDesc() == inVal_->getPrimitiveDesc()) + << "primitive desc of input grad and value should be equal"; + + // create reorder if has output grad does not match + cpuInGrad_ = nullptr; + cvtInGrad_ = nullptr; + if (!inputIsOnlyMKLDNN()) { + const MatrixPtr& cpuIn = getInputGrad(0, CPU_DEVICE); + // same PrimitiveDesc with cpuInVal_ + CHECK(cpuInVal_); + cpuInGrad_ = MKLDNNMatrix::create(cpuIn, cpuInVal_->getPrimitiveDesc()); + if (cpuInGrad_->getPrimitiveDesc() != in->getPrimitiveDesc()) { + const MatrixPtr& dnnIn = getInputGrad(0, MKLDNN_DEVICE); + in = MKLDNNMatrix::create(dnnIn, in->getPrimitiveDesc()); + cvtInGrad_ = MKLDNNMatrix::createReorder(in, cpuInGrad_); + CHECK(cvtInGrad_); + } else { + in = cpuInGrad_; + } + } +} + +void MKLDNNConvLayer::resetWgtValBwdData( + std::shared_ptr& dataPD, + MKLDNNMatrixPtr& wgt) { + if (dataPD == nullptr) { + return; + } + + // create new weight value for backward data, and create reorder if necessary + // since the primitive_desc would be different with wgtVal_ + CHECK(wgtVal_) << "should have weight value"; + if (dataPD->weights_primitive_desc() != wgtVal_->getPrimitiveDesc()) { + wgtValBwdData_ = + MKLDNNMatrix::create(nullptr, dataPD->weights_primitive_desc()); + cvtWgtVal_ = MKLDNNMatrix::createReorder(wgtVal_, wgtValBwdData_); + CHECK(cvtWgtVal_); + } else { + wgtValBwdData_ = wgtVal_; + } + VLOG(MKLDNN_FMTS) << "weight value format for backward data" + << wgtValBwdData_->getFormat(); +} + +} // namespace paddle diff --git a/paddle/gserver/layers/MKLDNNConvLayer.h b/paddle/gserver/layers/MKLDNNConvLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..f84f2f737c47a1b8adc2b83360a0396ffbc6ae24 --- /dev/null +++ b/paddle/gserver/layers/MKLDNNConvLayer.h @@ -0,0 +1,253 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "MKLDNNLayer.h" +#include "mkldnn.hpp" + +namespace paddle { +typedef mkldnn::convolution_forward conv_fwd; +typedef mkldnn::convolution_backward_weights conv_bwdWgt; +typedef mkldnn::convolution_backward_data conv_bwdData; + +/** + * @brief A subclass of MKLDNNLayer conv layer. + * + * The config file api is mkldnn_conv + */ +class MKLDNNConvLayer : public MKLDNNLayer { +protected: + // padding height and width + int ph_, pw_; + // stride height and width + int sh_, sw_; + // dilation height and width + int dh_, dw_; + // filter(kenerl) height and width + int fh_, fw_; + // group number + int gp_; + + // in resetBwdData, the format of wgtValBwdData_ is different with wgtVal_ + MKLDNNMatrixPtr wgtValBwdData_; + // convert handle from wgtVal_ to wgtValBwdData_ + std::shared_ptr cvtWgtVal_; + + // save forward primitive_desc, which can be used backward + std::shared_ptr fwdPD_; + + // MKLDNNMatrixPtr which should be created from CPU Device + MKLDNNMatrixPtr cpuInVal_; + MKLDNNMatrixPtr cpuInGrad_; + MKLDNNMatrixPtr cpuOutVal_; + MKLDNNMatrixPtr cpuOutGrad_; + // convert handle between CPU device and MKLDNN device + std::shared_ptr cvtInVal_; + std::shared_ptr cvtInGrad_; + std::shared_ptr cvtOutVal_; + std::shared_ptr cvtOutGrad_; + + // whether the weight has been init + bool hasInitedWgt_; + + // true by default, which impact the calculation of output image size. + // details can refer to mathUtil.h + bool caffeMode_; + + // weight and bias + std::unique_ptr weight_; + std::unique_ptr biases_; + +public: + explicit MKLDNNConvLayer(const LayerConfig& config) + : MKLDNNLayer(config), hasInitedWgt_(false), caffeMode_(true) {} + + ~MKLDNNConvLayer() {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void reshape( + int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override; + + void resetFwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) override; + + void resetBwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) override; + + void updateInputData() override; + + void updateWeights(const UpdateCallback& callback) override; + + void convertWeightsFromPaddle() override; + + void convertWeightsToPaddle() override; + + void printSizeInfo() override { + MKLDNNLayer::printSizeInfo(); + VLOG(MKLDNN_SIZES) << getName() << ": fh: " << fh_ << ", fw: " << fw_ + << ": ph: " << ph_ << ", pw: " << pw_ << ", sh: " << sh_ + << ", sw: " << sw_ << ", dh: " << dh_ << ", dw: " << dw_; + } + + void printValueFormatFlow() override { + if (cpuInVal_) { + VLOG(MKLDNN_FMTS) << cpuInVal_->getFormat() << " >>>"; + } + MKLDNNLayer::printValueFormatFlow(); + if (cpuOutVal_) { + VLOG(MKLDNN_FMTS) << " >>> " << cpuOutVal_->getFormat(); + } + } + + void printGradFormatFlow() override { + if (cpuInGrad_) { + VLOG(MKLDNN_FMTS) << cpuInGrad_->getFormat() << " <<<"; + } + MKLDNNLayer::printGradFormatFlow(); + if (cpuOutGrad_) { + VLOG(MKLDNN_FMTS) << " <<< " << cpuOutGrad_->getFormat(); + } + } + +protected: + /** + * load the dims settings of this conv + */ + void loadConvSettings(mkldnn::memory::dims& wgt, + mkldnn::memory::dims& bias, + mkldnn::memory::dims& stride, + mkldnn::memory::dims& dilation, + mkldnn::memory::dims& padL, + mkldnn::memory::dims& padR); + + /** + * reset the forward primitive descriptor. + */ + void resetFwdPD(std::shared_ptr& pd); + /** + * reset the MKLDNNMatrix buffers used in forward. + */ + void resetFwdBuffers(std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out); + /** + * reset the forward pipeline. + */ + void resetFwdPipeline(std::vector& pipeline, + std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out); + + /** + * reset MKLDNNMatrix of input value + */ + void resetInValue(std::shared_ptr& pd, + MKLDNNMatrixPtr& in); + /** + * reset MKLDNNMatrix of weight and bias value + */ + void resetWgtBiasValue(std::shared_ptr& pd, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias); + /** + * reset MKLDNNMatrix of output value + */ + void resetOutValue(std::shared_ptr& pd, + MKLDNNMatrixPtr& out); + + /** + * reset the backward weight primitive descriptor. + */ + void resetBwdWgtPD(std::shared_ptr& pd); + /** + * reset the backward data primitive descriptor. + */ + void resetBwdDataPD(std::shared_ptr& pd); + /** + * reset the MKLDNNMatrix buffers used in backward. + */ + void resetBwdBuffers(std::shared_ptr& wgtPD, + std::shared_ptr& dataPD, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out); + /** + * reset the backward pipeline. + */ + void resetBwdPipeline(std::vector& pipeline, + std::shared_ptr& wgtPD, + std::shared_ptr& dataPD, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out); + + /** + * reset MKLDNNMatrix of output grad + */ + void resetOutGrad(std::shared_ptr& wgtPD, + MKLDNNMatrixPtr& out); + /** + * reset MKLDNNMatrix of weight and bias grad + */ + void resetWgtBiasGrad(std::shared_ptr& wgtPD, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias); + /** + * reset MKLDNNMatrix of input grad + */ + void resetInGrad(std::shared_ptr& dataPD, + MKLDNNMatrixPtr& in); + /** + * reset MKLDNNMatrix of weight value for backward data + * since the primitive_desc would be different with wgtVal_ + */ + void resetWgtValBwdData(std::shared_ptr& dataPD, + MKLDNNMatrixPtr& wgt); + + /** + * get padding_r according to + * https://github.com/01org/mkl-dnn/blob/master/tests/gtests/ + * test_convolution_forward_common.hpp + * @note: mkldnn dilation start from 0 while paddle start from 1 + */ + mkldnn::memory::dims getPaddingR() const { + mkldnn::memory::dims padR = {ph_, pw_}; + for (int i = 0; i < 2; ++i) { + if ((ih_ - ((fh_ - 1) * dh_ + 1) + ph_ + padR[0]) / sh_ + 1 != oh_) { + ++padR[0]; + } + if ((iw_ - ((fw_ - 1) * dw_ + 1) + pw_ + padR[1]) / sw_ + 1 != ow_) { + ++padR[1]; + } + } + return padR; + } +}; + +} // namespace paddle diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp index f70343251ad4fbb99f9614618f6d1bff1174f15e..f60e221a6ec2ff513789a24e9f59bb25aef437b5 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.cpp +++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp @@ -17,9 +17,6 @@ limitations under the License. */ using namespace mkldnn; // NOLINT typedef memory::format format; -typedef inner_product_forward fc_fwd; -typedef inner_product_backward_weights fc_bwdWgt; -typedef inner_product_backward_data fc_bwdData; namespace paddle { @@ -93,35 +90,88 @@ void MKLDNNFcLayer::reshape( printSizeInfo(); } -void MKLDNNFcLayer::resetFwd(std::vector& pipeline, +void MKLDNNFcLayer::resetFwd(std::vector& pipeline, MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& out) { - pipeline.clear(); - bool hasBias = biases_ && biases_->getW(); - const MatrixPtr& wgtVal = weight_->getW(); - const MatrixPtr& biasVal = hasBias ? biases_->getW() : nullptr; - const MatrixPtr& outVal = output_.value; + resetFwdBuffers(in, wgt, bias, out); + + resetFwdPD(fwdPD_, in, wgt, bias, out); + + resetFwdPipeline(pipeline, fwdPD_, in, wgt, bias, out); + + printValueFormatFlow(); +} + +void MKLDNNFcLayer::resetBwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + std::shared_ptr bwdWgtPD; + std::shared_ptr bwdDataPD; + + resetBwdBuffers(in, wgt, bias, out); + + resetBwdWgtPD(bwdWgtPD, wgt, bias, out); + + resetBwdDataPD(bwdDataPD, in, out); + + resetBwdPipeline(pipeline, bwdWgtPD, bwdDataPD, in, wgt, bias, out); + + printGradFormatFlow(); +} + +void MKLDNNFcLayer::updateInputData() { + inVal_->setData(getInputValue(0, CPU_DEVICE)->getData()); +} +void MKLDNNFcLayer::updateWeights(const UpdateCallback& callback) { + weight_->getParameterPtr()->incUpdate(callback); + if (biases_ && biases_->getWGrad()) { + biases_->getParameterPtr()->incUpdate(callback); + } +} + +void MKLDNNFcLayer::resetFwdBuffers(MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + resetInValue(in); + + resetWgtBiasValue(wgt, bias); + + resetOutValue(out); +} + +void MKLDNNFcLayer::resetInValue(MKLDNNMatrixPtr& in) { if (inputIsOnlyMKLDNN()) { - const MatrixPtr& inVal = getInputValue(0); - in = std::dynamic_pointer_cast(inVal); + const MatrixPtr& dnnIn = getInputValue(0); + in = std::dynamic_pointer_cast(dnnIn); CHECK(in) << "Input should be MKLDNNMatrix"; } else { CHECK_EQ(getPrev(0)->getDeviceId(), CPU_DEVICE) << "Only support CPU yet"; - const MatrixPtr& inVal = getInputValue(0, CPU_DEVICE); + const MatrixPtr& cpuIn = getInputValue(0, CPU_DEVICE); in = MKLDNNMatrix::create( - inVal, memory::dims{bs_, ic_, ih_, iw_}, format::nchw, engine_); + cpuIn, {bs_, ic_, ih_, iw_}, format::nchw, engine_); } in->downSpatial(); +} + +void MKLDNNFcLayer::resetWgtBiasValue(MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias) { wgt = MKLDNNMatrix::create( - wgtVal, memory::dims{oc_, ic_, ih_, iw_}, format::oihw, engine_); + weight_->getW(), {oc_, ic_, ih_, iw_}, format::oihw, engine_); wgt->downSpatial(); - bias = hasBias ? MKLDNNMatrix::create(biasVal, {oc_}, format::x, engine_) - : nullptr; - out = MKLDNNMatrix::create(outVal, {bs_, oc_}, format::nc, engine_); + bias = (biases_ && biases_->getW()) + ? MKLDNNMatrix::create(biases_->getW(), {oc_}, format::x, engine_) + : nullptr; +} + +void MKLDNNFcLayer::resetOutValue(MKLDNNMatrixPtr& out) { + out = MKLDNNMatrix::create(output_.value, {bs_, oc_}, format::nc, engine_); // change original output value to mkldnn output value output_.value = std::dynamic_pointer_cast(out); if (!outputIsOnlyMKLDNN()) { @@ -129,46 +179,59 @@ void MKLDNNFcLayer::resetFwd(std::vector& pipeline, // just share point getOutput(CPU_DEVICE).value->setData(output_.value->getData()); } +} - // create forward handle +void MKLDNNFcLayer::resetFwdPD(std::shared_ptr& pd, + MKLDNNMatrixPtr in, + MKLDNNMatrixPtr wgt, + MKLDNNMatrixPtr bias, + MKLDNNMatrixPtr out) { + CHECK(in); + CHECK(wgt); + CHECK(out); prop_kind pk = prop_kind::forward; - fc_fwd::desc fwdDesc = hasBias ? fc_fwd::desc(pk, - in->getMemoryDesc(), - wgt->getMemoryDesc(), - bias->getMemoryDesc(), - out->getMemoryDesc()) - : fc_fwd::desc(pk, - in->getMemoryDesc(), - wgt->getMemoryDesc(), - out->getMemoryDesc()); - fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_); - if (hasBias) { - fwd_.reset(new fc_fwd(fwdPD, *in, *wgt, *bias, *out)); + fc_fwd::desc fwdDesc = bias != nullptr ? fc_fwd::desc(pk, + in->getMemoryDesc(), + wgt->getMemoryDesc(), + bias->getMemoryDesc(), + out->getMemoryDesc()) + : fc_fwd::desc(pk, + in->getMemoryDesc(), + wgt->getMemoryDesc(), + out->getMemoryDesc()); + pd.reset(new fc_fwd::primitive_desc(fwdDesc, engine_)); +} + +void MKLDNNFcLayer::resetFwdPipeline( + std::vector& pipeline, + std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + pipeline.clear(); + + if (bias) { + fwd_.reset(new fc_fwd(*pd, *in, *wgt, *bias, *out)); } else { - fwd_.reset(new fc_fwd(fwdPD, *in, *wgt, *out)); + fwd_.reset(new fc_fwd(*pd, *in, *wgt, *out)); } - printValueFormatFlow(); pipeline.push_back(*fwd_); } -void MKLDNNFcLayer::resetBwd(std::vector& pipeline, - MKLDNNMatrixPtr& in, - MKLDNNMatrixPtr& wgt, - MKLDNNMatrixPtr& bias, - MKLDNNMatrixPtr& out) { - pipeline.clear(); - if (!needResetBwd_) { - return; - } - needResetBwd_ = false; - bool hasBias = biases_ && biases_->getWGrad(); +void MKLDNNFcLayer::resetBwdBuffers(MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + resetOutGrad(out); + + resetWgtBiasGrad(wgt, bias); - /// backward weight - CHECK(inVal_) << "Should have input value"; - const MatrixPtr& wgtGrad = weight_->getWGrad(); - const MatrixPtr& biasGrad = hasBias ? biases_->getWGrad() : nullptr; + resetInGrad(in); +} +void MKLDNNFcLayer::resetOutGrad(MKLDNNMatrixPtr& out) { // TODO(TJ): merge outgrad int device = outputIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE; // for MKLDNN device: @@ -178,66 +241,88 @@ void MKLDNNFcLayer::resetBwd(std::vector& pipeline, // for CPU device: // fc do not need to convert from cpu device since output is always nc format // only need create from cpu device - const MatrixPtr& outGrad = getOutput(device).grad; - out = MKLDNNMatrix::create(outGrad, outVal_->getPrimitiveDesc()); - wgt = MKLDNNMatrix::create(wgtGrad, wgtVal_->getPrimitiveDesc()); - bias = hasBias ? MKLDNNMatrix::create(biasGrad, biasVal_->getPrimitiveDesc()) - : nullptr; - - // create memory primitive desc - fc_fwd::desc fwdDesc = fc_fwd::desc(prop_kind::forward, - inVal_->getMemoryDesc(), - wgt->getMemoryDesc(), - out->getMemoryDesc()); - fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_); - fc_bwdWgt::desc bwdWgtDesc = hasBias - ? fc_bwdWgt::desc(inVal_->getMemoryDesc(), - wgt->getMemoryDesc(), - bias->getMemoryDesc(), - out->getMemoryDesc()) - : fc_bwdWgt::desc(inVal_->getMemoryDesc(), - wgt->getMemoryDesc(), - out->getMemoryDesc()); - fc_bwdWgt::primitive_desc bwdWgtPD = - fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, fwdPD); - - if (hasBias) { - bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, *inVal_, *out, *wgt, *bias)); - } else { - bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, *inVal_, *out, *wgt)); + CHECK(outVal_); + out = + MKLDNNMatrix::create(getOutput(device).grad, outVal_->getPrimitiveDesc()); +} + +void MKLDNNFcLayer::resetWgtBiasGrad(MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias) { + CHECK(wgtVal_); + wgt = MKLDNNMatrix::create(weight_->getWGrad(), wgtVal_->getPrimitiveDesc()); + + bias = nullptr; + if (biasVal_ == nullptr) { + return; } - pipeline.push_back(*bwdWgt_); + bias = + MKLDNNMatrix::create(biases_->getWGrad(), biasVal_->getPrimitiveDesc()); +} - /// backward data +void MKLDNNFcLayer::resetInGrad(MKLDNNMatrixPtr& in) { + in = nullptr; const MatrixPtr& inGrad = inputLayers_[0]->getOutput().grad; if (inGrad == nullptr) { return; } - if (getInput(0, MKLDNN_DEVICE).getAllCount() > 1) { - // TODO(TJ): use outputMaps_ ways to get the inGrad_ when merge outgrad done - } else { - in = MKLDNNMatrix::create(inGrad, inVal_->getPrimitiveDesc()); - } - - fc_bwdData::desc bwdDataDesc = fc_bwdData::desc( - inVal_->getMemoryDesc(), wgt->getMemoryDesc(), out->getMemoryDesc()); - fc_bwdData::primitive_desc bwdDataPD = - fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD); + // TODO(TJ): use outputMaps_ ways to get the inGrad_ when merge outgrad done + CHECK(inVal_); + in = MKLDNNMatrix::create(inGrad, inVal_->getPrimitiveDesc()); +} - CHECK(wgtVal_) << "Should have weight memory"; - bwdData_.reset(new fc_bwdData(bwdDataPD, *out, *wgtVal_, *in)); - printGradFormatFlow(); - pipeline.push_back(*bwdData_); +void MKLDNNFcLayer::resetBwdWgtPD( + std::shared_ptr& pd, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + CHECK(inVal_); + fc_bwdWgt::desc bwdWgtDesc = bias ? fc_bwdWgt::desc(inVal_->getMemoryDesc(), + wgt->getMemoryDesc(), + bias->getMemoryDesc(), + out->getMemoryDesc()) + : fc_bwdWgt::desc(inVal_->getMemoryDesc(), + wgt->getMemoryDesc(), + out->getMemoryDesc()); + pd.reset(new fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, *fwdPD_)); } -void MKLDNNFcLayer::updateInputData() { - inVal_->setData(getInputValue(0, CPU_DEVICE)->getData()); +void MKLDNNFcLayer::resetBwdDataPD( + std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& out) { + pd = nullptr; + if (in == nullptr) { + return; + } + CHECK(wgtVal_); + fc_bwdData::desc bwdDataDesc = fc_bwdData::desc( + in->getMemoryDesc(), wgtVal_->getMemoryDesc(), out->getMemoryDesc()); + pd.reset(new fc_bwdData::primitive_desc(bwdDataDesc, engine_, *fwdPD_)); } -void MKLDNNFcLayer::updateWeights(const UpdateCallback& callback) { - weight_->getParameterPtr()->incUpdate(callback); - if (biases_ && biases_->getWGrad()) { - biases_->getParameterPtr()->incUpdate(callback); +void MKLDNNFcLayer::resetBwdPipeline( + std::vector& pipeline, + std::shared_ptr& bwdWgtPD, + std::shared_ptr& bwdDataPD, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + pipeline.clear(); + CHECK(inVal_); + if (bias) { + bwdWgt_.reset(new fc_bwdWgt(*bwdWgtPD, *inVal_, *out, *wgt, *bias)); + } else { + bwdWgt_.reset(new fc_bwdWgt(*bwdWgtPD, *inVal_, *out, *wgt)); + } + pipeline.push_back(*bwdWgt_); + + if (bwdDataPD == nullptr) { + return; } + CHECK(wgtVal_) << "Should have weight memory"; + bwdData_.reset(new fc_bwdData(*bwdDataPD, *out, *wgtVal_, *in)); + pipeline.push_back(*bwdData_); } + } // namespace paddle diff --git a/paddle/gserver/layers/MKLDNNFcLayer.h b/paddle/gserver/layers/MKLDNNFcLayer.h index 3119f863496df092da13c08bf733f13c42e53780..c76878aafab7e986d2bf478eaba02f2f0aced293 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.h +++ b/paddle/gserver/layers/MKLDNNFcLayer.h @@ -18,6 +18,9 @@ limitations under the License. */ #include "mkldnn.hpp" namespace paddle { +typedef mkldnn::inner_product_forward fc_fwd; +typedef mkldnn::inner_product_backward_weights fc_bwdWgt; +typedef mkldnn::inner_product_backward_data fc_bwdData; /** * @brief A subclass of MKLDNNLayer fc layer. @@ -32,6 +35,9 @@ protected: // if has already init the weight bool hasInitedWgt_; + // save forward primitive_desc, which can be used backward + std::shared_ptr fwdPD_; + // fc weight and bias std::unique_ptr weight_; std::unique_ptr biases_; @@ -67,6 +73,59 @@ public: void convertWeightsFromPaddle() override; void convertWeightsToPaddle() override; + +protected: + /** + * Forward functions: reset buffers(input, output, weight and bias), + * reset primitive descriptor, + * reset pipeline. + */ + void resetFwdBuffers(MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out); + void resetInValue(MKLDNNMatrixPtr& in); + void resetWgtBiasValue(MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias); + void resetOutValue(MKLDNNMatrixPtr& out); + void resetFwdPD(std::shared_ptr& pd, + MKLDNNMatrixPtr in, + MKLDNNMatrixPtr wgt, + MKLDNNMatrixPtr bias, + MKLDNNMatrixPtr out); + void resetFwdPipeline(std::vector& pipeline, + std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out); + + /** + * Backward functions: reset buffers(input, output, weight and bias), + * reset primitive descriptor for backward weight, + * reset primitive descriptor for backward data, + * reset pipeline. + */ + void resetBwdBuffers(MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out); + void resetOutGrad(MKLDNNMatrixPtr& out); + void resetWgtBiasGrad(MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias); + void resetInGrad(MKLDNNMatrixPtr& in); + void resetBwdWgtPD(std::shared_ptr& pd, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out); + void resetBwdDataPD(std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& out); + void resetBwdPipeline(std::vector& pipeline, + std::shared_ptr& bwdWgtPD, + std::shared_ptr& bwdDataPD, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out); }; } // namespace paddle diff --git a/paddle/gserver/tests/test_MKLDNN.cpp b/paddle/gserver/tests/test_MKLDNN.cpp index e1d2270df24331914f3a51acc90a518084b3ce4e..e70802881e3f22160a87b7a4babda07ffbcf9d6f 100644 --- a/paddle/gserver/tests/test_MKLDNN.cpp +++ b/paddle/gserver/tests/test_MKLDNN.cpp @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "MKLDNNTester.h" #include "ModelConfig.pb.h" +#include "paddle/math/MathUtils.h" using namespace paddle; // NOLINT @@ -63,6 +64,83 @@ TEST(MKLDNNLayer, FcLayer) { testFcLayer({/*bs*/ 15, /*ic*/ 3, /*oc*/ 6, /*ih*/ 16, /*iw*/ 16}); } +struct testConvDesc { + int bs, gp; + int ic, ih, iw; + int oc, oh, ow; + int fh, fw; + int ph, pw; + int sh, sw; + int dh, dw; +}; + +void testConvLayer(const testConvDesc& pm) { + const std::string compareTypes[] = {"mkldnn_conv", "exconv"}; + TestConfig cfg; + cfg.layerConfig.set_type(compareTypes[0]); + cfg.layerConfig.set_num_filters(pm.oc); + cfg.layerConfig.set_size(pm.oc * pm.oh * pm.ow); + // cfg.layerConfig.set_partial_sum(1); // TODO: check it + cfg.layerConfig.set_shared_biases(true); + cfg.inputDefs.push_back( + {INPUT_DATA, + "layer_0", + /* size of input layer= */ size_t(pm.ic * pm.ih * pm.iw), + /* size of weight= */ size_t(pm.oc * pm.ic * pm.fh * pm.fw / pm.gp)}); + LayerInputConfig* input = cfg.layerConfig.add_inputs(); + ConvConfig* conv = input->mutable_conv_conf(); + conv->set_groups(pm.gp); + conv->set_img_size(pm.iw); + conv->set_img_size_y(pm.ih); + conv->set_output_x(pm.ow); + conv->set_output_y(pm.oh); + conv->set_filter_size(pm.fw); + conv->set_filter_size_y(pm.fh); + conv->set_channels(pm.ic); + conv->set_padding(pm.pw); + conv->set_padding_y(pm.ph); + conv->set_stride(pm.sw); + conv->set_stride_y(pm.sh); + conv->set_dilation(pm.dw); + conv->set_dilation_y(pm.dh); + conv->set_caffe_mode(true); + conv->set_filter_channels(conv->channels() / conv->groups()); + CHECK_EQ(conv->filter_channels() * pm.gp, conv->channels()) + << "it is indivisible"; + + int fh = (pm.fh - 1) * pm.dh + 1; + int fw = (pm.fw - 1) * pm.dw + 1; + int ow = outputSize(pm.iw, fw, pm.pw, pm.sw, true); + int oh = outputSize(pm.ih, fh, pm.ph, pm.sh, true); + CHECK_EQ(ow, pm.ow) << "output size check failed"; + CHECK_EQ(oh, pm.oh) << "output size check failed"; + + MKLDNNTester tester; + for (auto biasSize : {pm.oc, 0}) { + cfg.biasSize = biasSize; + TestConfig ref = cfg; + ref.layerConfig.set_type(compareTypes[1]); + for (auto bs : {pm.bs, 1}) { + tester.run(cfg, ref, bs, pm.ih, pm.iw); + } + } +} + +TEST(MKLDNNLayer, ConvLayer) { + /* bs, gp, ic, ih, iw, oc, oh, ow, fh, fw, ph, pw, sh, sw, dh, dw */ + testConvLayer({2, 1, 3, 32, 32, 16, 32, 32, 3, 3, 1, 1, 1, 1, 1, 1}); + testConvLayer({2, 1, 8, 16, 16, 8, 16, 16, 3, 3, 1, 1, 1, 1, 1, 1}); + testConvLayer({3, 1, 16, 32, 32, 3, 32, 32, 3, 3, 1, 1, 1, 1, 1, 1}); + testConvLayer({8, 1, 16, 18, 18, 32, 18, 18, 3, 3, 1, 1, 1, 1, 1, 1}); + testConvLayer({16, 1, 1, 42, 31, 32, 23, 11, 4, 5, 3, 2, 2, 3, 1, 1}); + testConvLayer({2, 1, 8, 16, 16, 8, 8, 8, 3, 3, 1, 1, 2, 2, 1, 1}); + testConvLayer({3, 1, 8, 13, 13, 8, 7, 7, 3, 3, 1, 1, 2, 2, 1, 1}); + // with groups + testConvLayer({2, 2, 4, 5, 5, 8, 5, 5, 3, 3, 1, 1, 1, 1, 1, 1}); + testConvLayer({2, 3, 3, 5, 5, 3, 5, 5, 3, 3, 1, 1, 1, 1, 1, 1}); + testConvLayer({4, 4, 16, 3, 3, 16, 3, 3, 3, 3, 1, 1, 1, 1, 1, 1}); +} + // TODO(TJ): add branch test int main(int argc, char** argv) { diff --git a/paddle/math/BaseMatrix.cu b/paddle/math/BaseMatrix.cu index 5435808fb7f70fdf1ac98815f7fe8890fb85527c..53dd5383601782231e6e742784007d1c9154dc6b 100644 --- a/paddle/math/BaseMatrix.cu +++ b/paddle/math/BaseMatrix.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "BaseMatrix.h" #include "MathFunctions.h" +#include "NEONFunctions.h" #include "SIMDFunctions.h" #include "hl_matrix_apply.cuh" #include "hl_matrix_base.cuh" @@ -666,6 +667,13 @@ void BaseMatrixT::relu(BaseMatrixT& b) { applyBinary(binary::Relu(), b); } +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +template <> +void BaseMatrixT::relu(BaseMatrixT& b) { + neon::relu(data_, b.data_, height_ * width_); +} +#endif + DEFINE_MATRIX_BINARY_OP(ReluDerivative, a *= (b > 0.0f ? 1.0f : 0.0f)); template void BaseMatrixT::reluDerivative(BaseMatrixT& b) { diff --git a/paddle/math/MKLDNNMatrix.cpp b/paddle/math/MKLDNNMatrix.cpp index c4063e5069854242d9f93886b66580385557ca73..0778bb63b7b3bca9b3d2647ca43dad72d783950a 100644 --- a/paddle/math/MKLDNNMatrix.cpp +++ b/paddle/math/MKLDNNMatrix.cpp @@ -49,6 +49,27 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, return create(m, memory::primitive_desc(memory::desc(dims, dtype, fmt), eg)); } +std::shared_ptr MKLDNNMatrix::createReorder(const MKLDNNMatrixPtr& src, + const MKLDNNMatrixPtr& dst, + bool checkData) { + if (src == dst || src->getPrimitiveDesc() == dst->getPrimitiveDesc()) { + return nullptr; + } + + if (checkData && (src->getData() == dst->getData())) { + LOG(FATAL) << "can not create reorder with inplace data"; + return nullptr; + } + + memory::dims srcDims = src->getDims(); + memory::dims dstDims = dst->getDims(); + CHECK_EQ(srcDims.size(), dstDims.size()); + for (size_t i = 0; i < srcDims.size(); ++i) { + CHECK_EQ(srcDims[i], dstDims[i]); + } + return std::make_shared(*src, *dst); +} + void MKLDNNMatrix::reorderDataFrom(const MKLDNNMatrixPtr& m, memory::format srcFmt, memory::dims targetDim) { diff --git a/paddle/math/MKLDNNMatrix.h b/paddle/math/MKLDNNMatrix.h index eef3b429e6fa0087aeac3f5aed9dff983b06e826..c843115eb9a5be50d6ff873f1510844228c9d89f 100644 --- a/paddle/math/MKLDNNMatrix.h +++ b/paddle/math/MKLDNNMatrix.h @@ -52,6 +52,32 @@ public: mkldnn::engine& eg, mkldnn::memory::data_type dtype = mkldnn::memory::data_type::f32); + /** + * Create Memory descriptor. + * default with any format and f32 dtype + */ + static mkldnn::memory::desc createMemoryDesc( + const mkldnn::memory::dims& dims, + const mkldnn::memory::format& fmt = mkldnn::memory::format::any, + const mkldnn::memory::data_type& dtype = mkldnn::memory::data_type::f32) { + return mkldnn::memory::desc(dims, dtype, fmt); + } + + /** + * Create reorder primitive. + * Create a mkldnn::reorder handle for converting src MKLDNNMatrix to dst. + * checkData: whether to check the data handle of src and dst. + * if true, it will check the data and do not allow them equal; + * otherwise, it will not check them, then the reorder created + * may have inplace buffer. + * Do not set false, if you can not guarantee the inplace logical + * would work with your reorder. + */ + static std::shared_ptr createReorder( + const MKLDNNMatrixPtr& src, + const MKLDNNMatrixPtr& dst, + bool checkData = true); + public: /** * Reorder this MKLDNNMatrix from other format. diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 4a2132c8d1bfa329ced575f9b78052bdbfe3e4d5..0023b4d0f5da500f380ecb836b7c54e050b13d67 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1033,17 +1033,15 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat, real* inputData = inputMat.getData(); size_t frameNum = inputMat.getHeight(); - size_t width = imgSizeW; - size_t height = imgSizeH; - CHECK(height * width * channels == inputMat.getWidth()); + CHECK(imgSizeH * imgSizeW * channels == inputMat.getWidth()); CHECK(height_ == inputMat.getHeight()); CHECK(width_ == outputH * outputW * channels); hl_maxpool_forward(frameNum, inputData, channels, - height, - width, + imgSizeH, + imgSizeW, outputH, outputW, sizeX, @@ -1080,11 +1078,8 @@ void GpuMatrix::maxPoolBackward(Matrix& inputMat, real* outDiff = outGrad.getData(); size_t frameNum = inputMat.getHeight(); size_t channels = outV.getWidth() / outputH / outputW; - size_t width = imgSizeW; - size_t height = imgSizeH; - CHECK(height * width * channels == inputMat.getWidth()); + CHECK(imgSizeH * imgSizeW * channels == inputMat.getWidth()); CHECK(height_ == inputMat.getHeight()); - CHECK(width_ == width * height * channels); CHECK(outGrad.getHeight() == outV.getHeight() && outGrad.getWidth() == outV.getWidth()); @@ -1093,8 +1088,8 @@ void GpuMatrix::maxPoolBackward(Matrix& inputMat, outData, outDiff, channels, - height, - width, + imgSizeH, + imgSizeW, outputH, outputW, sizeX, @@ -1125,17 +1120,15 @@ void GpuMatrix::avgPoolForward(Matrix& inputMat, real* inputData = inputMat.getData(); size_t frameNum = inputMat.getHeight(); - size_t height = imgSizeH; - size_t width = imgSizeW; - CHECK(height * width * channels == inputMat.getWidth()); + CHECK(imgSizeH * imgSizeW * channels == inputMat.getWidth()); CHECK(height_ == inputMat.getHeight()); CHECK(width_ == outputH * outputW * channels); hl_avgpool_forward(frameNum, inputData, channels, - height, - width, + imgSizeH, + imgSizeW, outputH, outputW, sizeX, @@ -1166,17 +1159,15 @@ void GpuMatrix::avgPoolBackward(Matrix& outGrad, real* outDiff = outGrad.getData(); size_t frameNum = outGrad.getHeight(); size_t channels = outGrad.getWidth() / outputH / outputW; - size_t height = imgSizeH; - size_t width = imgSizeW; - CHECK(height * width * channels == width_); + CHECK(imgSizeH * imgSizeW * channels == width_); CHECK(height_ == outGrad.getHeight()); CHECK(outGrad.getWidth() == outputH * outputW * channels); hl_avgpool_backward(frameNum, outDiff, channels, - height, - width, + imgSizeH, + imgSizeW, outputH, outputW, sizeX, @@ -1214,19 +1205,16 @@ void GpuMatrix::maxPool3DForward(Matrix& inputMat, real* inputData = inputMat.getData(); real* maxPoolIdxData = maxPoolIdx.getData(); size_t num = inputMat.getHeight(); - size_t width = imgSizeW; - size_t height = imgSizeH; - size_t depth = imgSizeD; - CHECK(depth * height * width * channels == inputMat.getWidth()); + CHECK(imgSizeD * imgSizeH * imgSizeW * channels == inputMat.getWidth()); CHECK(height_ == inputMat.getHeight()); CHECK(width_ == outputD * outputH * outputW * channels); hl_maxpool3D_forward(num, inputData, channels, - depth, - height, - width, + imgSizeD, + imgSizeH, + imgSizeW, outputD, outputH, outputW, @@ -1269,20 +1257,16 @@ void GpuMatrix::maxPool3DBackward(Matrix& outGrad, real* maxPoolIdxData = maxPoolIdx.getData(); size_t frameNum = getHeight(); size_t channels = outGrad.getWidth() / outputD / outputH / outputW; - size_t width = imgSizeW; - size_t height = imgSizeH; - size_t depth = imgSizeD; - CHECK(depth * height * width * channels == getWidth()); - CHECK(width_ == depth * width * height * channels); + CHECK(imgSizeD * imgSizeH * imgSizeW * channels == getWidth()); CHECK(outGrad.getHeight() == maxPoolIdx.getHeight() && outGrad.getWidth() == maxPoolIdx.getWidth()); hl_maxpool3D_backward(frameNum, outDiff, channels, - depth, - height, - width, + imgSizeD, + imgSizeH, + imgSizeW, outputD, outputH, outputW, @@ -1323,19 +1307,16 @@ void GpuMatrix::avgPool3DForward(Matrix& inputMat, real* inputData = inputMat.getData(); size_t frameNum = inputMat.getHeight(); - size_t height = imgSizeH; - size_t width = imgSizeW; - size_t depth = imgSizeD; - CHECK(depth * height * width * channels == inputMat.getWidth()); + CHECK(imgSizeD * imgSizeH * imgSizeW * channels == inputMat.getWidth()); CHECK(height_ == inputMat.getHeight()); CHECK(width_ == outputD * outputH * outputW * channels); hl_avgpool3D_forward(frameNum, inputData, channels, - depth, - height, - width, + imgSizeD, + imgSizeH, + imgSizeW, outputD, outputH, outputW, @@ -1375,19 +1356,16 @@ void GpuMatrix::avgPool3DBackward(Matrix& outGrad, real* outDiff = outGrad.getData(); size_t frameNum = outGrad.getHeight(); size_t channels = outGrad.getWidth() / outputD / outputH / outputW; - size_t height = imgSizeH; - size_t width = imgSizeW; - size_t depth = imgSizeD; - CHECK(depth * height * width * channels == width_); + CHECK(imgSizeD * imgSizeH * imgSizeW * channels == width_); CHECK(height_ == outGrad.getHeight()); CHECK(outGrad.getWidth() == outputD * outputH * outputW * channels); hl_avgpool3D_backward(frameNum, outDiff, channels, - depth, - height, - width, + imgSizeD, + imgSizeH, + imgSizeW, outputD, outputH, outputW, @@ -1999,11 +1977,11 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, real* inputData = inputMat.getData(); real* outData = data_; size_t num = inputMat.getHeight(); - size_t inWidth = imgSizeW; - size_t inHeight = imgSizeH; - CHECK(inHeight * inWidth == inputMat.getWidth() / channels); + size_t inLength = imgSizeH * imgSizeW; + size_t outLength = outputH * outputW; + CHECK(inLength == inputMat.getWidth() / channels); CHECK_EQ(num, this->getHeight()); - CHECK_EQ(channels * outputH * outputW, this->getWidth()); + CHECK_EQ(channels * outLength, this->getWidth()); size_t outStride = getStride(); /* initialize the data_ */ @@ -2020,24 +1998,24 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, } for (size_t c = 0; c < channels; ++c) { // channel by channel for (size_t ph = 0; ph < outputH; ++ph) { + int hstart = ph * strideH - paddingH; + int hend = std::min(hstart + sizeY, imgSizeH); + hstart = std::max(hstart, 0); for (size_t pw = 0; pw < outputW; ++pw) { - int hstart = ph * strideH - paddingH; int wstart = pw * strideW - paddingW; - int hend = std::min(hstart + sizeY, inHeight); - int wend = std::min(wstart + sizeX, inWidth); - hstart = std::max(hstart, 0); + int wend = std::min(wstart + sizeX, imgSizeW); wstart = std::max(wstart, 0); for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - outData[ph * outputW + pw] = std::max(outData[ph * outputW + pw], - inputData[h * inWidth + w]); + outData[ph * outputW + pw] = std::max( + outData[ph * outputW + pw], inputData[h * imgSizeW + w]); } } } } // compute offset - inputData += inHeight * inWidth; - outData += outputH * outputW; + inputData += inLength; + outData += outLength; } } } @@ -2058,8 +2036,10 @@ void CpuMatrix::maxPoolBackward(Matrix& image, size_t paddingH, size_t paddingW) { size_t num = image.getHeight(); - size_t channels = size_t(width_ / imgSizeH / imgSizeW); - CHECK(image.getWidth() == imgSizeH * imgSizeW * channels); + size_t inLength = imgSizeH * imgSizeW; + size_t outLength = outputH * outputW; + size_t channels = size_t(width_ / inLength); + CHECK(image.getWidth() == inLength * channels); CHECK(image.getHeight() == height_ && image.getWidth() == width_); CHECK(outV.getHeight() == outGrad.getHeight() && outV.getWidth() == outGrad.getWidth()); @@ -2080,12 +2060,12 @@ void CpuMatrix::maxPoolBackward(Matrix& image, } for (size_t c = 0; c < channels; ++c) { for (size_t ph = 0; ph < outputH; ++ph) { + int hstart = ph * strideH - paddingH; + int hend = std::min(hstart + sizeY, imgSizeH); + hstart = std::max(hstart, 0); for (size_t pw = 0; pw < outputW; ++pw) { - int hstart = ph * strideH - paddingH; int wstart = pw * strideW - paddingW; - int hend = std::min(hstart + sizeY, imgSizeH); int wend = std::min(wstart + sizeX, imgSizeW); - hstart = std::max(hstart, 0); wstart = std::max(wstart, 0); for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { @@ -2098,10 +2078,10 @@ void CpuMatrix::maxPoolBackward(Matrix& image, } } // offset - inData += imgSizeH * imgSizeW; - tgtGrad += imgSizeH * imgSizeW; - otData += outputH * outputW; - otGrad += outputH * outputW; + inData += inLength; + tgtGrad += inLength; + otData += outLength; + otGrad += outLength; } } } @@ -2120,10 +2100,10 @@ void CpuMatrix::avgPoolForward(Matrix& input, size_t paddingW) { // The main loop size_t num = input.getHeight(); - size_t inHeight = imgSizeH; - size_t inWidth = imgSizeW; - CHECK(inHeight * inWidth * channels == input.getWidth()); - CHECK(outputH * outputW * channels * num == height_ * width_); + size_t inLength = imgSizeH * imgSizeW; + size_t outLength = outputH * outputW; + CHECK(inLength * channels == input.getWidth()); + CHECK(outLength * channels * num == height_ * width_); real* tgtData = data_; real* inData = input.getData(); @@ -2133,30 +2113,27 @@ void CpuMatrix::avgPoolForward(Matrix& input, } for (size_t c = 0; c < channels; ++c) { for (size_t ph = 0; ph < outputH; ++ph) { + int hstart = ph * strideH - paddingH; + int hend = std::min(hstart + sizeY, imgSizeH); + hstart = std::max(hstart, 0); for (size_t pw = 0; pw < outputW; ++pw) { - int hstart = ph * strideH - paddingH; int wstart = pw * strideW - paddingW; - int hend = std::min(hstart + sizeY, inHeight + paddingH); - int wend = std::min(wstart + sizeX, inWidth + paddingW); - int poolSize = (hend - hstart) * (wend - wstart); - hstart = std::max(hstart, 0); + int wend = std::min(wstart + sizeX, imgSizeW); wstart = std::max(wstart, 0); - hend = std::min(hend, static_cast(inHeight)); - wend = std::min(wend, static_cast(inWidth)); - - CHECK(poolSize); tgtData[ph * outputW + pw] = 0; // clear for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - tgtData[ph * outputW + pw] += inData[h * inWidth + w]; + tgtData[ph * outputW + pw] += inData[h * imgSizeW + w]; } } + int poolSize = (hend - hstart) * (wend - wstart); + CHECK(poolSize); tgtData[ph * outputW + pw] /= poolSize; } } // compute offset - inData += inHeight * inWidth; - tgtData += outputH * outputW; + inData += inLength; + tgtData += outLength; } } } @@ -2176,7 +2153,9 @@ void CpuMatrix::avgPoolBackward(Matrix& input, size_t paddingW) { size_t num = input.getHeight(); size_t channels = input.getWidth() / outputH / outputW; - CHECK(imgSizeH * imgSizeW * channels == getWidth()); + size_t inLength = imgSizeH * imgSizeW; + size_t outLength = outputH * outputW; + CHECK(inLength * channels == getWidth()); real* inData = input.getData(); real* outData = getData(); @@ -2186,16 +2165,14 @@ void CpuMatrix::avgPoolBackward(Matrix& input, } for (size_t c = 0; c < channels; ++c) { for (size_t ph = 0; ph < outputH; ++ph) { + int hstart = ph * strideH - paddingH; + int hend = std::min(hstart + sizeY, imgSizeH); + hstart = std::max(hstart, 0); for (size_t pw = 0; pw < outputW; ++pw) { - int hstart = ph * strideH - paddingH; int wstart = pw * strideW - paddingW; - int hend = std::min(hstart + sizeY, imgSizeH + paddingH); - int wend = std::min(wstart + sizeX, imgSizeW + paddingW); - int poolSize = (hend - hstart) * (wend - wstart); - hstart = std::max(hstart, 0); + int wend = std::min(wstart + sizeX, imgSizeW); wstart = std::max(wstart, 0); - hend = std::min(hend, static_cast(imgSizeH)); - wend = std::min(wend, static_cast(imgSizeW)); + int poolSize = (hend - hstart) * (wend - wstart); CHECK(poolSize); for (int h = hstart; h < hend; ++h) { @@ -2206,8 +2183,8 @@ void CpuMatrix::avgPoolBackward(Matrix& input, } } // offset - outData += imgSizeH * imgSizeW; - inData += outputH * outputW; + outData += inLength; + inData += outLength; } } } @@ -2234,12 +2211,11 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat, real* outData = getData(); real* maxPoolIdxData = maxPoolIdx.getData(); size_t num = inputMat.getHeight(); - size_t inWidth = imgSizeW; - size_t inHeight = imgSizeH; - size_t inDepth = imgSizeD; - CHECK(inHeight * inWidth * inDepth == inputMat.getWidth() / channels); + size_t inLength = imgSizeH * imgSizeW * imgSizeD; + size_t outLength = outputH * outputW * outputD; + CHECK(inLength == inputMat.getWidth() / channels); CHECK_EQ(num, this->getHeight()); - CHECK_EQ(channels * outputH * outputW * outputD, this->getWidth()); + CHECK_EQ(channels * outLength, this->getWidth()); size_t outStride = getStride(); /* initialize the data_ */ @@ -2258,16 +2234,16 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat, } for (size_t c = 0; c < channels; ++c) { // channel by channel for (size_t pd = 0; pd < outputD; ++pd) { + int dstart = pd * strideD - paddingD; + int dend = std::min(dstart + sizeZ, imgSizeD); + dstart = std::max(dstart, 0); for (size_t ph = 0; ph < outputH; ++ph) { + int hstart = ph * strideH - paddingH; + int hend = std::min(hstart + sizeY, imgSizeH); + hstart = std::max(hstart, 0); for (size_t pw = 0; pw < outputW; ++pw) { - int dstart = pd * strideD - paddingD; - int hstart = ph * strideH - paddingH; int wstart = pw * strideW - paddingW; - int dend = std::min(dstart + sizeZ, inDepth); - int hend = std::min(hstart + sizeY, inHeight); - int wend = std::min(wstart + sizeX, inWidth); - dstart = std::max(dstart, 0); - hstart = std::max(hstart, 0); + int wend = std::min(wstart + sizeX, imgSizeW); wstart = std::max(wstart, 0); int maxIdx = -1; real maxOutData = outData[(pd * outputH + ph) * outputW + pw]; @@ -2275,9 +2251,9 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat, for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { if (maxOutData < - inputData[(d * inHeight + h) * inWidth + w]) { - maxOutData = inputData[(d * inHeight + h) * inWidth + w]; - maxIdx = (d * inHeight + h) * inWidth + w; + inputData[(d * imgSizeH + h) * imgSizeW + w]) { + maxOutData = inputData[(d * imgSizeH + h) * imgSizeW + w]; + maxIdx = (d * imgSizeH + h) * imgSizeW + w; } } } @@ -2288,9 +2264,9 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat, } } // compute offset - inputData += inDepth * inHeight * inWidth; - outData += outputD * outputH * outputW; - maxPoolIdxData += outputD * outputH * outputW; + inputData += inLength; + outData += outLength; + maxPoolIdxData += outLength; } } } @@ -2315,7 +2291,9 @@ void CpuMatrix::maxPool3DBackward(Matrix& outGrad, real scaleTargets, real scaleOutput) { size_t num = getHeight(); - size_t channels = size_t(width_ / imgSizeD / imgSizeH / imgSizeW); + size_t inLength = imgSizeH * imgSizeW * imgSizeD; + size_t outLength = outputH * outputW * outputD; + size_t channels = size_t(width_ / inLength); CHECK(maxPoolIdx.getHeight() == outGrad.getHeight() && maxPoolIdx.getWidth() == outGrad.getWidth()); @@ -2341,9 +2319,9 @@ void CpuMatrix::maxPool3DBackward(Matrix& outGrad, } } // offset - tgtGrad += imgSizeD * imgSizeH * imgSizeW; - otGrad += outputD * outputH * outputW; - maxPoolIdxData += outputD * outputH * outputW; + tgtGrad += inLength; + otGrad += outLength; + maxPoolIdxData += outLength; } } } @@ -2367,11 +2345,10 @@ void CpuMatrix::avgPool3DForward(Matrix& input, size_t paddingW) { // The main loop size_t num = input.getHeight(); - size_t inDepth = imgSizeD; - size_t inHeight = imgSizeH; - size_t inWidth = imgSizeW; - CHECK(inDepth * inHeight * inWidth * channels == input.getWidth()); - CHECK(outputD * outputH * outputW * channels * num == height_ * width_); + size_t inLength = imgSizeH * imgSizeW * imgSizeD; + size_t outLength = outputH * outputW * outputD; + CHECK(inLength * channels == input.getWidth()); + CHECK(outLength * channels * num == height_ * width_); real* tgtData = getData(); real* inData = input.getData(); @@ -2381,39 +2358,36 @@ void CpuMatrix::avgPool3DForward(Matrix& input, } for (size_t c = 0; c < channels; ++c) { for (size_t pd = 0; pd < outputD; ++pd) { + int dstart = pd * strideD - paddingD; + int dend = std::min(dstart + sizeZ, imgSizeD); + dstart = std::max(dstart, 0); for (size_t ph = 0; ph < outputH; ++ph) { + int hstart = ph * strideH - paddingH; + int hend = std::min(hstart + sizeY, imgSizeH); + hstart = std::max(hstart, 0); for (size_t pw = 0; pw < outputW; ++pw) { - int dstart = pd * strideD - paddingD; - int hstart = ph * strideH - paddingH; int wstart = pw * strideW - paddingW; - int dend = std::min(dstart + sizeZ, inDepth + paddingD); - int hend = std::min(hstart + sizeY, inHeight + paddingH); - int wend = std::min(wstart + sizeX, inWidth + paddingW); - int poolSize = (dend - dstart) * (hend - hstart) * (wend - wstart); - dstart = std::max(dstart, 0); - hstart = std::max(hstart, 0); + int wend = std::min(wstart + sizeX, imgSizeW); wstart = std::max(wstart, 0); - dend = std::min(dend, static_cast(inDepth)); - hend = std::min(hend, static_cast(inHeight)); - wend = std::min(wend, static_cast(inWidth)); - CHECK(poolSize); tgtData[(pd * outputH + ph) * outputW + pw] = 0; // clear for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { tgtData[(pd * outputH + ph) * outputW + pw] += - inData[(d * inHeight + h) * inWidth + w]; + inData[(d * imgSizeH + h) * imgSizeW + w]; } } } + int poolSize = (dend - dstart) * (hend - hstart) * (wend - wstart); + CHECK(poolSize); tgtData[(pd * outputH + ph) * outputW + pw] /= poolSize; } } } // compute offset - inData += inDepth * inHeight * inWidth; - tgtData += outputD * outputH * outputW; + inData += inLength; + tgtData += outLength; } } } @@ -2437,8 +2411,10 @@ void CpuMatrix::avgPool3DBackward(Matrix& input, real scaleTargets, real scaleOutput) { size_t num = input.getHeight(); - size_t channels = input.getWidth() / outputD / outputH / outputW; - CHECK(imgSizeD * imgSizeH * imgSizeW * channels == getWidth()); + size_t inLength = imgSizeH * imgSizeW * imgSizeD; + size_t outLength = outputH * outputW * outputD; + size_t channels = input.getWidth() / outLength; + CHECK(inLength * channels == getWidth()); real* inData = input.getData(); real* outData = getData(); @@ -2448,21 +2424,18 @@ void CpuMatrix::avgPool3DBackward(Matrix& input, } for (size_t c = 0; c < channels; ++c) { for (size_t pd = 0; pd < outputD; ++pd) { + int dstart = pd * strideD - paddingD; + int dend = std::min(dstart + sizeZ, imgSizeD); + dstart = std::max(dstart, 0); for (size_t ph = 0; ph < outputH; ++ph) { + int hstart = ph * strideH - paddingH; + int hend = std::min(hstart + sizeY, imgSizeH); + hstart = std::max(hstart, 0); for (size_t pw = 0; pw < outputW; ++pw) { - int dstart = pd * strideD - paddingD; - int hstart = ph * strideH - paddingH; int wstart = pw * strideW - paddingW; - int dend = std::min(dstart + sizeZ, imgSizeD + paddingD); - int hend = std::min(hstart + sizeY, imgSizeH + paddingH); - int wend = std::min(wstart + sizeX, imgSizeW + paddingW); - int poolSize = (dend - dstart) * (hend - hstart) * (wend - wstart); - dstart = std::max(dstart, 0); - hstart = std::max(hstart, 0); + int wend = std::min(wstart + sizeX, imgSizeW); wstart = std::max(wstart, 0); - dend = std::min(dend, static_cast(imgSizeD)); - hend = std::min(hend, static_cast(imgSizeH)); - wend = std::min(wend, static_cast(imgSizeW)); + int poolSize = (dend - dstart) * (hend - hstart) * (wend - wstart); CHECK(poolSize); for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { @@ -2476,8 +2449,8 @@ void CpuMatrix::avgPool3DBackward(Matrix& input, } } // offset - outData += imgSizeD * imgSizeH * imgSizeW; - inData += outputD * outputH * outputW; + outData += inLength; + inData += outLength; } } } diff --git a/paddle/math/NEONFunctions.cpp b/paddle/math/NEONFunctions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3bf47901f1069ac228fa1b877e29848d8cc130e8 --- /dev/null +++ b/paddle/math/NEONFunctions.cpp @@ -0,0 +1,55 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + +#include "NEONFunctions.h" +#include + +namespace paddle { +namespace neon { + +// b[i] = a[i] > 0.0f ? a[i] : 0.0f +void relu(const float* a, float* b, int len) { + int offset = len % 16; + float32x4_t ma0, ma1, ma2, ma3; + float32x4_t mb0, mb1, mb2, mb3; + + float32x4_t zero = vdupq_n_f32(0.f); + for (int k = 0; k < len / 16; k++, a += 16, b += 16) { + ma0 = vld1q_f32(a); + ma1 = vld1q_f32(a + 4); + ma2 = vld1q_f32(a + 8); + ma3 = vld1q_f32(a + 12); + + mb0 = vmaxq_f32(ma0, zero); + mb1 = vmaxq_f32(ma1, zero); + mb2 = vmaxq_f32(ma2, zero); + mb3 = vmaxq_f32(ma3, zero); + + vst1q_f32(b, mb0); + vst1q_f32(b + 4, mb1); + vst1q_f32(b + 8, mb2); + vst1q_f32(b + 12, mb3); + } + + for (int i = 0; i < offset; i++) { + b[i] = a[i] > 0.0f ? a[i] : 0.0f; + } +} + +} // namespace neon +} // namespace paddle + +#endif diff --git a/paddle/operators/concat_op.cu b/paddle/math/NEONFunctions.h similarity index 74% rename from paddle/operators/concat_op.cu rename to paddle/math/NEONFunctions.h index 38fee7473dbb2ba97fe95b6632db7a1749cf3bbe..69085e333547a31a341fbfde247f1e30adb957ee 100644 --- a/paddle/operators/concat_op.cu +++ b/paddle/math/NEONFunctions.h @@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -12,8 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU -#include "paddle/operators/concat_op.h" +#pragma once -namespace ops = paddle::operators; -// TODO(Yancey1989) Add GPU kernel +namespace paddle { +namespace neon { + +void relu(const float* a, float* b, int len); + +} // namespace neon +} // namespace paddle diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 103f06acc57d7a23f019f5e713f6cacf2179e9e0..061fb22e3fd744d9d9895fd1008089e4a6ce6a0f 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -825,9 +825,8 @@ void testMaxPoolFwdBwd(int numSamples, int strideW, int padH, int padW) { - int outH = 0, outW = 0; - outH = (imgSizeH - ksizeH + 2 * padH + strideH - 1) / strideH + 1; - outW = (imgSizeW - ksizeW + 2 * padW + strideW - 1) / strideW + 1; + int outH = outputSize(imgSizeH, ksizeH, padH, strideH, true); + int outW = outputSize(imgSizeW, ksizeW, padW, strideW, true); int inWidth = imgSizeH * imgSizeW * channels; MatrixPtr input = CpuMatrix::create(numSamples, inWidth, false, false); @@ -927,9 +926,8 @@ void testAvgPoolFwdBwd(int numSamples, int strideW, int padH, int padW) { - int outH = 0, outW = 0; - outH = (imgSizeH - ksizeH + 2 * padH + strideH - 1) / strideH + 1; - outW = (imgSizeW - ksizeW + 2 * padW + strideW - 1) / strideW + 1; + int outH = outputSize(imgSizeH, ksizeH, padH, strideH, true); + int outW = outputSize(imgSizeW, ksizeW, padW, strideW, true); int inWidth = imgSizeH * imgSizeW * channels; MatrixPtr input = CpuMatrix::create(numSamples, inWidth, false, false); diff --git a/paddle/memory/memcpy.cc b/paddle/memory/memcpy.cc index a19a3e3675e3e2e7cc0c3594f21191f932d6379f..19ec9ba9b26f5919796181a19a048b7edb508bdd 100644 --- a/paddle/memory/memcpy.cc +++ b/paddle/memory/memcpy.cc @@ -62,6 +62,24 @@ void Copy(platform::GPUPlace dst_place, } } +template <> +void Copy(platform::CPUPlace dst_place, + void* dst, + platform::GPUPlace src_place, + const void* src, size_t num) { + platform::SetDeviceId(src_place.device); + platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost); +} + +template <> +void Copy(platform::GPUPlace dst_place, + void* dst, + platform::CPUPlace src_place, + const void* src, size_t num) { + platform::SetDeviceId(dst_place.device); + platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice); +} + #endif // PADDLE_ONLY_CPU } // namespace memory diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index f9ea25ab045a02be5ab9ed81ef9c679126d3a188..e3e934bcccd1a5f34d88a2f33f3708a46ddabe05 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -1,5 +1,7 @@ file(GLOB GENERAL_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc") string(REPLACE ".cc" "" GENERAL_OPS "${GENERAL_OPS}") +set(pybind_file ${PADDLE_SOURCE_DIR}/paddle/pybind/pybind.h) +file(WRITE ${pybind_file} "// Generated by the paddle/operator/CMakeLists.txt. DO NOT EDIT!\n\n") function(op_library TARGET) # op_library is a function to create op library. The interface is same as # cc_library. But it handle split GPU/CPU code and link some common library @@ -7,10 +9,11 @@ function(op_library TARGET) set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} PARENT_SCOPE) set(cc_srcs) set(cu_srcs) - set(op_common_deps operator op_registry) + set(op_common_deps operator op_registry math_function) set(options "") set(oneValueArgs "") set(multiValueArgs SRCS DEPS) + set(pybind_flag 0) cmake_parse_arguments(op_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -46,22 +49,42 @@ function(op_library TARGET) cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) endif() + + # net_op doesn't need pybind + if ("${TARGET}" STREQUAL "net_op") + set(pybind_flag 1) + endif() + + # pybind USE_NO_KERNEL_OP + file(READ ${TARGET}.cc TARGET_CONTENT) + string(REGEX MATCH "OperatorWithKernel" regex_result "${TARGET_CONTENT}") + string(REPLACE "_op" "" TARGET "${TARGET}") + if (${pybind_flag} EQUAL 0 AND regex_result STREQUAL "") + file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(${TARGET});\n") + set(pybind_flag 1) + endif() + + # pybind USE_CPU_ONLY_OP + list(LENGTH cu_srcs cu_srcs_len) + if (${pybind_flag} EQUAL 0 AND ${cu_srcs_len} EQUAL 0) + file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") + set(pybind_flag 1) + endif() + + # pybind USE_OP + if (${pybind_flag} EQUAL 0) + file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") + endif() endfunction() add_subdirectory(math) set(DEPS_OPS - identity_op - minus_op - mul_op recurrent_op - scale_op) -op_library(identity_op DEPS scale_op) -op_library(minus_op DEPS scale_op) -op_library(mul_op DEPS math_function) + cond_op) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc - DEPS framework_proto tensor operator net_op) -op_library(scale_op DEPS net_op) + DEPS framework_proto tensor net_op) +op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/operators/accuracy_op.cc b/paddle/operators/accuracy_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0c813748b2989a8f0c00a359345747242dd21dd8 --- /dev/null +++ b/paddle/operators/accuracy_op.cc @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/accuracy_op.h" + +namespace paddle { +namespace operators { + +class AccuracyOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL( + ctx.InputVar("Inference"), + "Input(Inference) of AccuracyOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), + "Input(Label) of AccuracyOp should not be null."); + PADDLE_ENFORCE_NOT_NULL( + ctx.OutputVar("Accuracy"), + "Output(Accuracy) of AccuracyOp should not be null."); + + auto *inference = ctx.Input("Inference"); + auto *label = ctx.Input("Label"); + + PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label must be a vector"); + PADDLE_ENFORCE_EQ(inference->dims()[0], label->dims()[0], + "inference size must be the same as label size"); + + ctx.Output("Accuracy")->Resize({1}); + } +}; + +class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker { + public: + AccuracyOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + // TODO(typhoonzero): support both inference value and indices. + AddInput("Inference", "topk(indices) the network output"); + AddInput("Label", "Label of the training data"); + // TODO(typhoonzero): AddInput("Weight", ... + AddOutput("Accuracy", "The accuracy of current batch"); + + AddComment( + R"DOC(Accuracy. It will print accuracy rate for classification. +The accuracy is: +.. math:: +accuracy = \\frac{NumOfCorrectPredicts}{NumOfAllSamples})DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(accuracy, ops::AccuracyOp, ops::AccuracyOpMaker); +REGISTER_OP_CPU_KERNEL(accuracy, + ops::AccuracyKernel); diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..0a6a0fd15c73330902552f7a9aa6339de24c1a18 --- /dev/null +++ b/paddle/operators/accuracy_op.cu @@ -0,0 +1,81 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/operators/accuracy_op.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +using platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void AccuracyCudaKernel(const int N, const int D, const int* Xdata, + const int* labeldata, float* accuracy) { + int count = 0; + __shared__ int total[BlockSize]; + + // support only 1 block + for (int i = threadIdx.x; i < (N); i += BlockSize) { + for (int j = 0; j < D; ++j) { + if (Xdata[i * D + j] == labeldata[i]) { + ++count; + break; + } + } + } + total[threadIdx.x] = count; + __syncthreads(); + + // reduce the count with init value 0, and output accuracy. + int result = thrust::reduce(thrust::device, total, total + BlockSize, 0); + if (threadIdx.x == 0) { + *accuracy = static_cast(result) / static_cast(N); + } +} + +template +class AccuracyOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + auto* inference = ctx.Input("Inference"); + auto* label = ctx.Input("Label"); + auto* accuracy = ctx.Output("Accuracy"); + // FIXME(typhoonzero): only support indices currently + // if add support for output values, how to detect the data type? + const int* inference_data = inference->data(); + const int* label_data = label->data(); + float* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); + + size_t num_samples = inference->dims()[0]; + size_t infer_width = inference->dims()[1]; + cudaMemset((void**)&accuracy_data, 0, sizeof(float)); + + if (num_samples == 0) { + return; + } + + AccuracyCudaKernel<<<1, PADDLE_CUDA_NUM_THREADS>>>( + num_samples, infer_width, inference_data, label_data, accuracy_data); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_GPU_KERNEL(accuracy, + paddle::operators::AccuracyOpCUDAKernel); diff --git a/paddle/operators/accuracy_op.h b/paddle/operators/accuracy_op.h new file mode 100644 index 0000000000000000000000000000000000000000..fe704efe1c979f4fc6a5a37184e51b416f5e517f --- /dev/null +++ b/paddle/operators/accuracy_op.h @@ -0,0 +1,77 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenMatrix = framework::EigenMatrix; + +template +using EigenVector = framework::EigenVector; + +template +using EigenScalar = framework::EigenScalar; + +template +class AccuracyKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* inference = ctx.Input("Inference"); + auto* label = ctx.Input("Label"); + auto* accuracy = ctx.Output("Accuracy"); + + float* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); + + const T* inference_data = inference->data(); + const T* label_data = label->data(); + + size_t num_samples = inference->dims()[0]; + size_t class_dim = inference->dims()[1]; + *accuracy_data = 0.0f; + + if (num_samples == 0) { + return; + } + + int num_correct = 0; + // assume inference is already the topk of the output + for (size_t i = 0; i < num_samples; ++i) { + PADDLE_ENFORCE_GE(label_data[i], 0, "label must >= 0"); + for (size_t j = 0; j < class_dim; ++j) { + if (inference_data[i * class_dim + j] == label_data[i]) { + ++num_correct; + break; + } + } + } + + // FIXME(typhoonzero): we don't accumulate the accuracy for now. + *accuracy_data = + static_cast(num_correct) / static_cast(num_samples); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 8dbd47cf0dfbc265032a9966343eed5c7bd8692e..e83c1efeaf897889d18a37a6bd2ca2f8f012db25 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -23,10 +23,18 @@ class AddOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of AddOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), + "Input(Y) of AddOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of AddOp should not be null."); + PADDLE_ENFORCE_EQ(ctx.Input("X")->dims(), ctx.Input("Y")->dims(), "Two input of Add Op's dimension must be same."); - ctx.Output("Out")->Resize(ctx.Input("X")->dims()); + ctx.Output("Out")->Resize( + ctx.Input("X")->dims()); } }; diff --git a/paddle/operators/concat_op.cc b/paddle/operators/concat_op.cc index 0ebefbab26ec8fdf316f852fbb7f6d9f3bbc48eb..223bb0ffe6e75ce71919eb5f4cca06bedbb00764 100644 --- a/paddle/operators/concat_op.cc +++ b/paddle/operators/concat_op.cc @@ -25,8 +25,11 @@ class ConcatOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of ConcatOp should not be null."); + auto ins = ctx.MultiInput("X"); - auto *out = ctx.Output("Out"); + auto *out = ctx.Output("Out"); size_t axis = static_cast(ctx.Attr("axis")); size_t n = ins.size(); diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8262a7a5c8c13c86c5f6c123a14fa89696358c57 --- /dev/null +++ b/paddle/operators/cond_op.cc @@ -0,0 +1,229 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/cond_op.h" + +#include +#include + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/gather.h" +#include "paddle/operators/net_op.h" +#include "paddle/operators/scatter.h" + +namespace paddle { +namespace operators { + +using Scope = framework::Scope; +using Variable = framework::Variable; +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using DDim = framework::DDim; + +void CondOp::CreateScope(const Scope& scope) const { + auto sub_scopes_var = scope.FindVar("SubScopes"); + PADDLE_ENFORCE_NOT_NULL(sub_scopes_var, + "Output(SubScopes) of CondOp should not be null."); + auto sub_scopes = sub_scopes_var->GetMutable>(); + auto& sub_scope = scope.NewScope(); + sub_scopes->push_back(&sub_scope); +} + +void CondOp::CreateIndexTensor(const Scope& scope) const { + auto index_tensors_var = scope.FindVar("IndexTensors"); + PADDLE_ENFORCE_NOT_NULL(index_tensors_var, + "Output(IndexTensors) of CondOp should not be null."); + auto& index_tensors = + *index_tensors_var->GetMutable>(); + index_tensors.push_back(LoDTensor()); +} + +void CondOp::InferShape(const Scope& scope) const { + auto sub_scopes_var = scope.FindVar("SubScopes"); + PADDLE_ENFORCE_NOT_NULL(sub_scopes_var, + "Output(SubScopes) of CondOp should not be null."); + auto& sub_scopes = *sub_scopes_var->GetMutable>(); + + for (int i = 0; i < 2; ++i) { + // Create two sub scopes for true and false branches + // sub_scopes[0] for the true branch and sub_scopes[1] for the false + // branch + CreateScope(scope); + + // Create two tensors for true and false indices + // index_tensors[0] for the true branch and index_tensors[1] for the false + // branch + CreateIndexTensor(scope); + + PADDLE_ENFORCE(!Inputs("Xs").empty(), + "Inputs(Xs) of CondOp can't be empty."); + for (auto& input : Inputs("Xs")) { + // Create a new tensor in sub-scope for input-type tensor + Variable* v = sub_scopes[i]->NewVar(input); + LoDTensor* sub_input = v->GetMutable(); + sub_input->Resize(scope.FindVar(input)->GetMutable()->dims()); + } + + for (auto& output : (*sub_net_op_[i]).Outputs()) { + for (auto& var_name : output.second) { + sub_scopes[i]->NewVar(var_name); + } + } + + // each net calls InferShape + sub_net_op_[i]->InferShape(*sub_scopes[i]); + } + + for (auto& output : Outputs("Outs")) { + LoDTensor* tensor_t_out = + sub_scopes[0]->FindVar(output)->GetMutable(); + PADDLE_ENFORCE_NOT_NULL(tensor_t_out, "True output should not be NULL"); + LoDTensor* tensor_f_out = + sub_scopes[1]->FindVar(output)->GetMutable(); + PADDLE_ENFORCE_NOT_NULL(tensor_f_out, "False output should not be NULL"); + + auto* tensor_out_var = scope.FindVar(output); + PADDLE_ENFORCE_NOT_NULL(tensor_out_var, "Output not found"); + LoDTensor* tensor_out = tensor_out_var->GetMutable(); + PADDLE_ENFORCE_NOT_NULL(tensor_t_out, + "True output tensor should not be NULL"); + + // check output size should be same + PADDLE_ENFORCE_EQ(tensor_t_out->dims(), tensor_f_out->dims(), + "Outputs not of the same shape"); + tensor_out->Resize(tensor_t_out->dims()); + // tensor_out->mutable_data(tensor_out->dims(), + // platform::CPUPlace()); + tensor_out->mutable_data(platform::CPUPlace()); + } +} + +void CondOp::Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const { + auto* sub_scopes_var = scope.FindVar("SubScopes"); + PADDLE_ENFORCE_NOT_NULL(sub_scopes_var, + "Output(SubScopes) of CondOp should not be null."); + auto sub_scopes = sub_scopes_var->Get>(); + auto* index_tensors_var = scope.FindVar("IndexTensors"); + PADDLE_ENFORCE_NOT_NULL(index_tensors_var, + "Output(IndexTensors) of CondOp should not be null."); + auto index_tensors = index_tensors_var->Get>(); + + std::string cond_name = Input("Cond"); + Variable* cond_var = scope.FindVar(cond_name); + PADDLE_ENFORCE_NOT_NULL(cond_var, + "Input(Cond) of CondOp should not be null."); + const LoDTensor* cond = cond_var->GetMutable(); + + // Step 1: get the true/false index at runtime + // index_[0]: vector, contains all index for cond[i] == true + // index_[1]: vector, contains all index for cond[i] == false + for (int i = 0; i < 2; ++i) index_[i].clear(); + + const int* cond_data = cond->data(); + for (int i = 0; i < cond->dims()[0]; ++i) { + if (cond_data[i]) + index_[0].push_back(i); + else + index_[1].push_back(i); + } + + // put index_[0] and index_[1] into two tensors: + // index_tensor_[0] and index_tensor_[1] + DDim dim = paddle::framework::make_ddim({0}); + for (int i = 0; i < 2; ++i) { + dim[0] = index_[i].size(); + int* tmp_ptr = + index_tensors[i].mutable_data(dim, platform::CPUPlace()); + index_tensors[i].Resize(dim); + memcpy(tmp_ptr, index_[i].data(), dim[0] * sizeof(int)); + } + + // Step 2: collect data by calling gather + for (int i = 0; i < 2; ++i) { + // i= 0/i for True and False branches respectively + for (auto& input : Inputs("Xs")) { + // find Tensor + Variable* v = scope.FindVar(input); + PADDLE_ENFORCE_NOT_NULL(v); + LoDTensor* tensor_parent = v->GetMutable(); + + v = sub_scopes[i]->FindVar(input); + PADDLE_ENFORCE_NOT_NULL(v); + LoDTensor* tensor_child = v->GetMutable(); + + // Resize child + DDim dim = tensor_child->dims(); + dim[0] = index_[i].size(); + tensor_child->Resize(dim); + tensor_child->mutable_data(dim, platform::CPUPlace()); + + Gather(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i], + tensor_child); + } + } + + // Step 3: run + for (int i = 0; i < 2; ++i) { + sub_net_op_[i]->Run(*sub_scopes[i], dev_ctx); + } + + // Step 4: merge output results + PADDLE_ENFORCE(!Outputs("Outs").empty(), + "Outputs(Outs) of CondOp can't be empty."); + for (int i = 0; i < 2; ++i) { + // i= 0/i for True and False branches respectively + for (auto& output : Outputs("Outs")) { + // find Tensor + Variable* v = scope.FindVar(output); + PADDLE_ENFORCE_NOT_NULL(v); + LoDTensor* tensor_parent = v->GetMutable(); + + v = sub_scopes[i]->FindVar(output); + PADDLE_ENFORCE_NOT_NULL(v); + LoDTensor* tensor_child = v->GetMutable(); + + ScatterUpdate(dev_ctx.GetPlace(), tensor_child, &index_tensors[i], + tensor_parent); + } + } +} + +class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker { + public: + CondOpProtoAndCheckerMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Cond", "The condition, which is a bool vector"); + AddInput("Xs", "Inputs of Subnets").AsDuplicable(); + AddOutput("Outs", "Outputs of Cond_Op after merge").AsDuplicable(); + + AddOutput("SubScopes", "sub scopes for true and false branches"); + AddOutput("IndexTensors", "Index Tensors contains indices for true/false"); + + AddComment(R"DOC( +Sample dependent Cond Operator: +Given Cond[i] as a 1/0 vector to indicate true/false +The equation is: +Out[i] = subnet_t[i], if Cond[i] == true +Out[i] = subnet_t[i], if Cond[i] == false +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_WITHOUT_GRADIENT(cond, paddle::operators::CondOp, + paddle::operators::CondOpProtoAndCheckerMaker); diff --git a/paddle/operators/cond_op.h b/paddle/operators/cond_op.h new file mode 100644 index 0000000000000000000000000000000000000000..b09e32331e66c53555c88c06d7b1456276050eaa --- /dev/null +++ b/paddle/operators/cond_op.h @@ -0,0 +1,91 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "glog/logging.h" +#include "paddle/framework/ddim.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/tensor.h" +#include "paddle/operators/net_op.h" + +namespace paddle { +namespace operators { + +/* + * @brief CondOp is a dynamic if-else Operator + * + * It has a input tensor named cond indicating which netop each instance will + * run. + * + * if cond == 1, it will run true_net, which is a NetOp. + * + * if cond == 0, it will run false_net, which is another NetOp. + */ +class CondOp : public framework::OperatorBase { + public: + CondOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) { + index_.resize(2); + sub_net_op_.resize(2); + } + + CondOp(const CondOp& o) + : framework::OperatorBase( + static_cast(o)) { + // TODO(yuyang18): Implement copy ctor well. + PADDLE_THROW("Not implemented"); + } + + void CreateScope(const framework::Scope& scope) const; + + void CreateIndexTensor(const framework::Scope& scope) const; + + /* + * InferShape must be called before Run. + */ + void InferShape(const framework::Scope& scope) const override; + + /* + * Set True Block + */ + void set_truenet(std::unique_ptr&& net) { + sub_net_op_[0] = std::move(net); + } + + /* + * Set False Block + */ + void set_falsenet(std::unique_ptr&& net) { + sub_net_op_[1] = std::move(net); + } + + void Run(const framework::Scope& scope, + const platform::DeviceContext& dev_ctx) const override; + + private: + // sub_net_op_[0]: subnet_t + // sub_net_op_[1]: subnet_f + std::vector> sub_net_op_; + + // index_[0]: True_index; + // index_[1]: False_index; + mutable std::vector> index_; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/cos_sim_op.cc b/paddle/operators/cos_sim_op.cc index c033af3b741ae26ad9d37b2164f87aa6e8651c6e..72c446493684246959656dc048e7f0e761665423 100644 --- a/paddle/operators/cos_sim_op.cc +++ b/paddle/operators/cos_sim_op.cc @@ -25,16 +25,38 @@ class CosSimOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must not be null."); - PADDLE_ENFORCE_EQ(ctx.Input("X")->dims(), - ctx.Input("Y")->dims(), - "Dimensions of Input(X) and Input(Y) must be the same."); - - auto dims = ctx.Input("X")->dims(); - ctx.Output("Out")->Resize({dims[0], 1}); - ctx.Output("XNorm")->Resize({dims[0], 1}); - ctx.Output("YNorm")->Resize({dims[0], 1}); + // notnull check + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of CosSimOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), + "Input(Y) of CosSimOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of CosSimOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("XNorm"), + "Output(XNorm) of CosSimOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("YNorm"), + "Output(YNorm) of CosSimOp should not be null."); + + // shape check + auto x_dims = ctx.Input("X")->dims(); + auto y_dims = ctx.Input("Y")->dims(); + + PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(), + "Ranks of Input(X) and Input(Y) must be equal."); + PADDLE_ENFORCE_GE(x_dims.size(), 2, + "Rank of Input(X) must not be less than 2."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 1, x_dims.size()), + framework::slice_ddim(y_dims, 1, y_dims.size()), + "All dimensions except the 1st of Input(X) and Input(Y) " + "must be equal."); + PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1, + "The 1st dimension of Input(Y) must be equal to Input(X) or" + " just 1 (which will be broadcasted to match Input(X))."); + + // resize tensor + ctx.Output("Out")->Resize({x_dims[0], 1}); + ctx.Output("XNorm")->Resize({x_dims[0], 1}); + ctx.Output("YNorm")->Resize({y_dims[0], 1}); } }; @@ -42,16 +64,27 @@ class CosSimOpMaker : public framework::OpProtoAndCheckerMaker { public: CosSimOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The first input of cos_sim op."); - AddInput("Y", "The second input of cos_sim op."); + AddInput("X", "The 1st input of cos_sim op."); + AddInput("Y", "The 2nd input of cos_sim op."); AddOutput("Out", "The output of cos_sim op."); - AddOutput("XNorm", "Row norm of the first input.").AsIntermediate(); - AddOutput("YNorm", "Row norm of the second input.").AsIntermediate(); + AddOutput("XNorm", + "Norm of the first input, reduced along the 1st " + "dimension.") + .AsIntermediate(); + AddOutput("YNorm", + "Norm of the second input, reduced along the 1st " + "dimension.") + .AsIntermediate(); AddComment(R"DOC( Cosine Similarity Operator. -The equation is: Out = X^T * Y / (sqrt(X^T * X) * sqrt(Y^T * Y)) +The equation is: Out = X^T * Y / (sqrt(X^T * X) * sqrt(Y^T * Y)). + +Input(X) and Input(Y) must have the same shape, except that the 1st dimension +of Input(Y) could be just 1 (different from Input(X)), which will be +broadcasted to match the shape of Input(X) before computing their cosine +similarity. )DOC"); } }; @@ -62,34 +95,54 @@ class CosSimOpGrad : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + // notnull check PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("XNorm"), "Input(XNorm) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("YNorm"), "Input(YNorm) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Out"), + "Input(Out) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Input(Out@GRAD) must not be null."); + // shape check auto x_dims = ctx.Input("X")->dims(); auto y_dims = ctx.Input("Y")->dims(); auto xnorm_dims = ctx.Input("XNorm")->dims(); auto ynorm_dims = ctx.Input("YNorm")->dims(); - auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); - PADDLE_ENFORCE_EQ(x_dims, y_dims, - "Dimensions of Input(X) and Input(Y) must be the same."); - PADDLE_ENFORCE_EQ(xnorm_dims[0], x_dims[0], - "1st dimension of XNorm must equal that of Input(X)."); - PADDLE_ENFORCE_EQ(xnorm_dims[1], 1, "2st dimension of XNorm must be one."); - PADDLE_ENFORCE_EQ(ynorm_dims[0], y_dims[0], - "1st dimension of YNorm must equal that of Input(Y)."); - PADDLE_ENFORCE_EQ(ynorm_dims[1], 1, "2st dimension of YNorm must be one."); - PADDLE_ENFORCE_EQ(out_dims[0], x_dims[0], - "1st dimension of Out@GRAD must equal that of Input(X)"); - PADDLE_ENFORCE_EQ(out_dims[1], 1, "1st dimension of Out@GRAD must be one."); - - auto *x_grad = ctx.Output(framework::GradVarName("X")); - auto *y_grad = ctx.Output(framework::GradVarName("Y")); + auto out_dims = ctx.Input("Out")->dims(); + auto out_grad_dims = + ctx.Input(framework::GradVarName("Out"))->dims(); + + PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), + "Ranks of Input(X) and Input(Y) must be equal."); + PADDLE_ENFORCE_GE(x_dims.size(), 2, + "Rank of Input(X) must not be less than 2."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 1, x_dims.size()), + framework::slice_ddim(y_dims, 1, y_dims.size()), + "All dimensions except the 1st of Input(X) and Input(Y) " + "must be equal."); + PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1, + "The 1st dimension of Input(Y) must be equal to Input(X) or" + " just 1 (which will be broadcasted to match Input(X))."); + auto target_xnorm_dims = framework::make_ddim({x_dims[0], 1}); + auto target_ynorm_dims = framework::make_ddim({y_dims[0], 1}); + PADDLE_ENFORCE_EQ(xnorm_dims, target_xnorm_dims, + "Shape of Input(XNorm) must be [X.Dim(0), 1]."); + PADDLE_ENFORCE_EQ(ynorm_dims, target_ynorm_dims, + "Shape of Input(YNorm) must be [Y.Dim(0), 1]."); + PADDLE_ENFORCE_EQ(out_dims, target_xnorm_dims, + "Shape of Input(Out) must be [X.Dim(0), 1]."); + PADDLE_ENFORCE_EQ(out_grad_dims, target_xnorm_dims, + "Shape of Input(Out@Grad) must be [X.Dim(0), 1]."); + + // resize tensor + auto *x_grad = + ctx.Output(framework::GradVarName("X")); + auto *y_grad = + ctx.Output(framework::GradVarName("Y")); if (x_grad) x_grad->Resize(x_dims); if (y_grad) y_grad->Resize(y_dims); } diff --git a/paddle/operators/cos_sim_op.h b/paddle/operators/cos_sim_op.h index 0dc509952578497671a128374f77ce616a520909..bcf6f758cae561a2e22f5be6c7a242647ef1c144 100644 --- a/paddle/operators/cos_sim_op.h +++ b/paddle/operators/cos_sim_op.h @@ -31,30 +31,38 @@ template class CosSimKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* input_x = context.Input("X"); - auto* input_y = context.Input("Y"); - auto* output_z = context.Output("Out"); - auto* output_x_norm = context.Output("XNorm"); - auto* output_y_norm = context.Output("YNorm"); + // get Tensor + auto* in_x = context.Input("X"); + auto* in_y = context.Input("Y"); + auto* out_z = context.Output("Out"); + auto* out_x_norm = context.Output("XNorm"); + auto* out_y_norm = context.Output("YNorm"); + out_z->mutable_data(context.GetPlace()); + out_x_norm->mutable_data(context.GetPlace()); + out_y_norm->mutable_data(context.GetPlace()); - output_z->mutable_data(context.GetPlace()); - output_x_norm->mutable_data(context.GetPlace()); - output_y_norm->mutable_data(context.GetPlace()); - - auto dims = input_x->dims(); - int64_t size = input_x->numel(); - auto new_dims = framework::make_ddim({dims[0], size / dims[0]}); - auto x = EigenMatrix::From(*input_x, new_dims); - auto y = EigenMatrix::From(*input_y, new_dims); - auto z = EigenVector::Flatten(*output_z); - auto x_norm = EigenVector::Flatten(*output_x_norm); - auto y_norm = EigenVector::Flatten(*output_y_norm); + // convert Tensor to Eigen Tensor + int rows_x = in_x->dims()[0]; + int rows_y = in_y->dims()[0]; + auto x = EigenMatrix::Reshape(*in_x, 1); + auto y = EigenMatrix::Reshape(*in_y, 1); + auto z = EigenVector::Flatten(*out_z); + auto x_norm = EigenVector::Flatten(*out_x_norm); + auto y_norm = EigenVector::Flatten(*out_y_norm); + // compute auto place = context.GetEigenDevice(); - auto xy = (x * y).sum(Eigen::array({{1}})); - x_norm.device(place) = x.square().sum(Eigen::array({{1}})).sqrt(); - y_norm.device(place) = y.square().sum(Eigen::array({{1}})).sqrt(); - z.device(place) = xy / x_norm / y_norm; + auto row_along = Eigen::array({{1}}); + x_norm.device(place) = x.square().sum(row_along).sqrt(); + y_norm.device(place) = y.square().sum(row_along).sqrt(); + if (rows_x == rows_y) { + auto xy = (x * y).sum(Eigen::array({{1}})); + z.device(place) = xy / x_norm / y_norm; + } else { + Eigen::DSizes bcast(rows_x, 1); + auto xy = (x * y.broadcast(bcast)).sum(row_along); + z.device(place) = xy / x_norm / y_norm.broadcast(bcast); + } } }; @@ -62,43 +70,72 @@ template class CosSimGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* input_x = context.Input("X"); - auto* input_y = context.Input("Y"); - auto* input_z = context.Input("Out"); - auto* input_x_norm = context.Input("XNorm"); - auto* input_y_norm = context.Input("YNorm"); - auto* output_grad_x = context.Output(framework::GradVarName("X")); - auto* output_grad_y = context.Output(framework::GradVarName("Y")); - auto* input_grad_z = context.Input(framework::GradVarName("Out")); + // get Tensor + auto* in_x = context.Input("X"); + auto* in_y = context.Input("Y"); + auto* in_z = context.Input("Out"); + auto* in_x_norm = context.Input("XNorm"); + auto* in_y_norm = context.Input("YNorm"); + auto* out_grad_x = context.Output(framework::GradVarName("X")); + auto* out_grad_y = context.Output(framework::GradVarName("Y")); + auto* in_grad_z = context.Input(framework::GradVarName("Out")); - auto dims = input_x->dims(); - int64_t size = input_x->numel(); - auto new_dims = framework::make_ddim({dims[0], size / dims[0]}); - auto x = EigenMatrix::From(*input_x, new_dims); - auto y = EigenMatrix::From(*input_y, new_dims); - auto z = EigenMatrix::From(*input_z); - auto x_norm = EigenMatrix::From(*input_x_norm); - auto y_norm = EigenMatrix::From(*input_y_norm); - auto dz = EigenMatrix::From(*input_grad_z); + // convert Tensor to Eigen Tensor + auto x = EigenMatrix::Reshape(*in_x, 1); + auto y = EigenMatrix::Reshape(*in_y, 1); + auto z = EigenMatrix::Reshape(*in_z, 1); + auto x_norm = EigenMatrix::Reshape(*in_x_norm, 1); + auto y_norm = EigenMatrix::Reshape(*in_y_norm, 1); + auto dz = EigenMatrix::Reshape(*in_grad_z, 1); - Eigen::DSizes bcast(1, new_dims[1]); - auto z_bcast = z.broadcast(bcast); - auto dz_bcast = dz.broadcast(bcast); + // compute gradident + int rows_x = in_x->dims()[0]; + int rows_y = in_y->dims()[0]; + int cols = framework::product(in_x->dims()) / rows_x; + Eigen::DSizes bcast_cols(1, cols); + auto z_bcast = z.broadcast(bcast_cols); + auto dz_bcast = dz.broadcast(bcast_cols); + auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast_cols); auto place = context.GetEigenDevice(); - auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast); - auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast); - auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast); - if (output_grad_x) { - output_grad_x->mutable_data(context.GetPlace()); - auto dx = EigenMatrix::From(*output_grad_x, new_dims); - dx.device(place) = - dz_bcast * (y / norm_prod_bcast - z_bcast * x / x_snorm_bcast); - } - if (output_grad_y) { - output_grad_y->mutable_data(context.GetPlace()); - auto dy = EigenMatrix::From(*output_grad_y, new_dims); - dy.device(place) = - dz_bcast * (x / norm_prod_bcast - z_bcast * y / y_snorm_bcast); + if (rows_x == rows_y) { + auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_cols); + auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast_cols); + // compute dx + if (out_grad_x) { + out_grad_x->mutable_data(context.GetPlace()); + auto dx = EigenMatrix::Reshape(*out_grad_x, 1); + auto grad = y / norm_prod_bcast - z_bcast * x / x_snorm_bcast; + dx.device(place) = dz_bcast * grad; + } + // compute dy + if (out_grad_y) { + out_grad_y->mutable_data(context.GetPlace()); + auto dy = EigenMatrix::Reshape(*out_grad_y, 1); + auto grad = x / norm_prod_bcast - z_bcast * y / y_snorm_bcast; + dy.device(place) = dz_bcast * grad; + } + } else { + Eigen::DSizes bcast_rows(rows_x, 1); + Eigen::DSizes bcast_rows_cols(rows_x, cols); + auto y_bcast = y.broadcast(bcast_rows); + auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_rows_cols); + auto norm_prod_bcast = (x_norm * y_norm.eval().broadcast(bcast_rows)) + .eval() + .broadcast(bcast_cols); + // compute dx + if (out_grad_x) { + out_grad_x->mutable_data(context.GetPlace()); + auto dx = EigenMatrix::Reshape(*out_grad_x, 1); + auto grad = y_bcast / norm_prod_bcast - z_bcast * x / x_snorm_bcast; + dx.device(place) = dz_bcast * grad; + } + // compute dy + if (out_grad_y) { + out_grad_y->mutable_data(context.GetPlace()); + auto dy = EigenMatrix::Reshape(*out_grad_y, 1); + auto grad = x / norm_prod_bcast - z_bcast * y_bcast / y_snorm_bcast; + dy.device(place) = (dz_bcast * grad).sum(Eigen::array({{0}})); + } } } }; diff --git a/paddle/operators/elementwise_mul_op.cc b/paddle/operators/elementwise_mul_op.cc index 1742925545d29df5d7df719faaea3b754680ab61..ee6e975b443691bf71cec904565ced20406f3fba 100644 --- a/paddle/operators/elementwise_mul_op.cc +++ b/paddle/operators/elementwise_mul_op.cc @@ -25,13 +25,19 @@ class ElementWiseMulOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of ElementWiseMulOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), + "Input(Y) of ElementWiseMulOp should not be null."); + PADDLE_ENFORCE_NOT_NULL( + ctx.OutputVar("Out"), + "Output(Out) of ElementWiseMulOp should not be null."); + auto x_dim = ctx.Input("X")->dims(); auto y_dim = ctx.Input("Y")->dims(); PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), "Rank of first input must >= rank of second input.") - ctx.Output("Out")->Resize(x_dim); + ctx.Output("Out")->Resize(x_dim); } }; @@ -80,8 +86,10 @@ class ElementWiseMulOpGrad : public framework::OperatorWithKernel { auto x_dims = ctx.Input("X")->dims(); auto y_dims = ctx.Input("Y")->dims(); auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); - auto *x_grad = ctx.Output(framework::GradVarName("X")); - auto *y_grad = ctx.Output(framework::GradVarName("Y")); + auto *x_grad = + ctx.Output(framework::GradVarName("X")); + auto *y_grad = + ctx.Output(framework::GradVarName("Y")); PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), "Rank of first input must >= rank of second input.") diff --git a/paddle/operators/elementwise_mul_op.h b/paddle/operators/elementwise_mul_op.h index e9ed6791799240039f9af42c1a4339be7126ee65..6d58da580b81b9e0a8ae170eec1a73638b190df8 100644 --- a/paddle/operators/elementwise_mul_op.h +++ b/paddle/operators/elementwise_mul_op.h @@ -13,10 +13,8 @@ limitations under the License. */ #pragma once -#include #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { diff --git a/paddle/operators/fc_op.cc b/paddle/operators/fc_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e5d0f3c3724262a60a463ef3beadd9906d3ebaf6 --- /dev/null +++ b/paddle/operators/fc_op.cc @@ -0,0 +1,197 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/net_op.h" + +namespace paddle { +namespace operators { + +class FCOp : public NetOp { + public: + FCOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : NetOp(type, inputs, outputs, attrs) { + PADDLE_ENFORCE(!Inputs("X").empty(), + "Inputs(X) of FCOp should not be null."); + PADDLE_ENFORCE(!Inputs("W").empty(), + "Inputs(W) of FCOp should not be null."); + PADDLE_ENFORCE(!Outputs("MulOut").empty(), + "Outputs(MulOut) of FCOp should not be null."); + PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName, + "Output(Out) of FCOp should not be null."); + + auto x = Inputs("X"); + auto w = Inputs("W"); + auto mul_out = Outputs("MulOut"); + PADDLE_ENFORCE_EQ( + x.size(), w.size(), + "The size of inputs X(%d) should be the same as that of weights W(%d).", + x.size(), w.size()); + PADDLE_ENFORCE_EQ(mul_out.size(), x.size(), + "The size of intermediate mul_out(%d) should be the same " + "as that of inputs X(%d).", + mul_out.size(), x.size()); + + size_t n = x.size(); + PADDLE_ENFORCE_GE(n, static_cast(1), + "The size of inputs X(%d) should be no less than 1.", n); + + auto x_num_col_dims = Attr>("xNumColDims"); + + // Set all values or set no values (use the default value) + if (!x_num_col_dims.empty()) { + PADDLE_ENFORCE_EQ(x_num_col_dims.size(), n, + "The size of attribute xNumColDims(%d) should be the " + "same as that of inputs X(%d).", + x_num_col_dims.size(), n); + } else { + x_num_col_dims.resize(n); + for (size_t i = 0; i < n; i++) { + x_num_col_dims[i] = 1; + } + } + + // mul_out[i] = X[i] * W[i] + for (size_t i = 0; i < n; i++) { + framework::AttributeMap mul_attr; + mul_attr["x_num_col_dims"] = static_cast(x_num_col_dims[i]); + mul_attr["y_num_col_dims"] = static_cast(1); + AppendOp( + framework::OpRegistry::CreateOp("mul", {{"X", {x[i]}}, {"Y", {w[i]}}}, + {{"Out", {mul_out[i]}}}, mul_attr)); + } + + // sum_out = X[0] * W[0] + ... + X[n-1] * W[n-1] + auto sum_out = mul_out[0]; + if (n > 1) { + PADDLE_ENFORCE_NE(Output("SumOut"), framework::kEmptyVarName, + "Output(SumOut) of FCOp should not be null when the " + "size of Inputs(X) > 1."); + + sum_out = Output("SumOut"); + AppendOp(framework::OpRegistry::CreateOp("sum", {{"X", {mul_out}}}, + {{"Out", {sum_out}}}, {})); + } else { + if (Output("SumOut") != framework::kEmptyVarName) { + this->Rename(Output("SumOut"), framework::kEmptyVarName); + } + } + + // add_out = sum_out + b + auto b = Input("B"); + auto add_out = sum_out; + if (b != framework::kEmptyVarName) { + PADDLE_ENFORCE_NE( + Output("AddOut"), framework::kEmptyVarName, + "Output(AddOut) of FCOp should not be null when Input(B) is set."); + + add_out = Output("AddOut"); + AppendOp(framework::OpRegistry::CreateOp( + "rowwise_add", {{"X", {sum_out}}, {"b", {Input("B")}}}, + {{"Out", {add_out}}}, {})); + } else { + if (Output("AddOut") != framework::kEmptyVarName) { + this->Rename(Output("AddOut"), framework::kEmptyVarName); + } + } + + auto activation = Attr("activation"); + AppendOp(framework::OpRegistry::CreateOp(activation, {{"X", {add_out}}}, + {{"Y", {Output("Out")}}}, {})); + CompleteAddOp(false); + } +}; + +class FCOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FCOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(A vector of Tensors) each input Tensor can be of arbitrary " + "dimension, and will be reshaped to a 2-D matrix of size " + "(minibatch, number_of_input_features) according to attribute " + "xNumColDims.") + .AsDuplicable(); + AddInput("W", + "(A vector of Tensors) the weights of FC operator, a " + "vector of 2-D matrix of size " + "(number_of_input_features, number_of_neurons).") + .AsDuplicable(); + AddInput("B", + "(Tensor) the bias of FC operator, a 1-D vector of size " + "number_of_neurons."); + + AddOutput("Out", + "(Tensor) the activated output matrix of FC operator, a 2-D " + "matrix of size (minibatch, number_of_neurons)."); + AddOutput("MulOut", + "(A vector of Tensors) the intermediate outputs of FC operator, " + "each Tensor saving the product of X_i * W_i.") + .AsIntermediate() + .AsDuplicable(); + AddOutput( + "SumOut", + "(Tensor) the intermediate output of FC operator, " + "saving the sum of the products of X and W, that is sum{X_i * W_i}.") + .AsIntermediate(); + AddOutput("AddOut", + "(Tensor) the non-actived output of FC operator, " + "saving sum{X_i * W_i} + B.") + .AsIntermediate(); + AddAttr( + "activation", + "(string, default identity) the activation type of FC operator.") + .SetDefault("identity") + .InEnum({"identity", "sigmoid", "softmax"}); + AddAttr>( + "xNumColDims", + "(std::vector) The inputs Tensors of FC operator can be of " + "more than 2 dimensions. In that case, each input Tensor `X_i` will be " + "reshaped to a 2-D matrix. The matrix's first dimension " + "(the length of column) will be the product of `X_i`'s last " + "`xNumColDims_i` dimensions, that is " + "`X_i.dims[0] x ... x X_i.dims[xNumColDims_i - 1]`. " + "The matrix's second dimension (the length of row) will be the product " + "of `X_i`'s first `rank - xNumColDims_i` dimensions, that is " + "`X_i.dims[xNumColDims_i] x ... x X_i.dims[rank - 1]`)") + .SetDefault(std::vector{}); + + AddComment(R"DOC( +Fully Connected Operator, known as Fully Connected Layer or Inner Product Layer +in Convolutional Neural Networks. Neurons in a fully connected layer have +full connections to all activations in the previous layer. +It computes an inner product of a set of +learned weights with a matrix multiplication followed by a bias offset +(optionally). + +Equation: + Out = Act(sum_n{X_i * W_i} + B) + +where X_i is Tensor that will be reshaped to a 2-D matrix of size (M x K), +usually M is the minibatch size and K is the number of input features. +W_i is a 2-D matrix of size (K x N), where N means the number of neurons +in the fully connected layer. B is a 1-D vector of size N. +Thus, the output Out is a 2-D matrix of size (M x N). +Activation type can be set to `identity` (default), `sigmoid` or `softmax`. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(fc, ops::FCOp, ops::FCOpMaker); diff --git a/paddle/operators/fill_zeros_like_op.cc b/paddle/operators/fill_zeros_like_op.cc index 9d51f6e3a16fe96125599bb440d40237aeb9a028..ba7857cc65f6860a6156674c6addc2bfdce21a99 100644 --- a/paddle/operators/fill_zeros_like_op.cc +++ b/paddle/operators/fill_zeros_like_op.cc @@ -23,7 +23,14 @@ class FillZerosLikeOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output("Dst")->Resize( + PADDLE_ENFORCE_NOT_NULL( + ctx.InputVar("Src"), + "Input(Src) of FillZerosLikeOp should not be null."); + PADDLE_ENFORCE_NOT_NULL( + ctx.OutputVar("Dst"), + "Output(Dst) of FillZerosLikeOp should not be null."); + + ctx.Output("Dst")->Resize( ctx.Input("Src")->dims()); } }; diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index 123bed296c462c30bddd3bfbd530098fdbfe4856..d445b61c1657356f2cdcf1e98d756607de2bd042 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -24,11 +24,18 @@ class GatherOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of GatherOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Index"), + "Input(Index) of GatherOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of GatherOp should not be null."); + int batch_size = ctx.Input("Index")->dims()[0]; PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0"); framework::DDim output_dims(ctx.Input("X")->dims()); output_dims[0] = batch_size; - ctx.Output("Out")->Resize(output_dims); + ctx.Output("Out")->Resize(output_dims); } }; @@ -38,7 +45,7 @@ class GatherGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto X_grad = ctx.Output(framework::GradVarName("X")); + auto X_grad = ctx.Output(framework::GradVarName("X")); auto X = ctx.Input("X"); X_grad->Resize(X->dims()); diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index 3d76516405960c502a46997108049b2db5cab6bf..c0e161bbc0c5486eb10408e43e6388f1b287abf8 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -43,8 +43,12 @@ class GaussianRandomOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& context) const override { - auto* tensor = context.Output("Out"); + void InferShape(const framework::InferShapeContext& ctx) const override { + PADDLE_ENFORCE_NOT_NULL( + ctx.OutputVar("Out"), + "Output(Out) of GaussianRandomOp should not be null."); + + auto* tensor = ctx.Output("Out"); auto dims = Attr>("dims"); std::vector temp; temp.reserve(dims.size()); diff --git a/paddle/operators/identity_op.cc b/paddle/operators/identity_op.cc index 7d9d4fa519d1c690feacbadc5175aeab49082282..2cc632205e63abbe412b09af4b894420ac512ec5 100644 --- a/paddle/operators/identity_op.cc +++ b/paddle/operators/identity_op.cc @@ -27,7 +27,7 @@ class IdentityOpMaker : public framework::OpProtoAndCheckerMaker { framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input tensor of identity operator."); - AddOutput("Out", "The output tensor of identity operator."); + AddOutput("Y", "The output tensor of identity operator."); AddComment(R"DOC( The identity operator is an alias of the scale operator with the attribute scale fixed to 1.0. @@ -42,9 +42,15 @@ class IdentityOp : public NetOp { const framework::VariableNameMap &outputs, const framework::AttributeMap &attrs) : NetOp(type, inputs, outputs, attrs) { + PADDLE_ENFORCE_NE(Input("X"), framework::kEmptyVarName, + "Input(X) of IdentityOp should not be null."); + PADDLE_ENFORCE_NE(Output("Y"), framework::kEmptyVarName, + "Output(Y) of IdentityOp should not be null."); + AppendOp(framework::OpRegistry::CreateOp( - "scale", {{"X", {Input("X")}}}, {{"Out", {Output("Out")}}}, + "scale", {{"X", {Input("X")}}}, {{"Out", {Output("Y")}}}, {{"scale", static_cast(1)}})); + CompleteAddOp(false); } }; diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc index 94d40890a765413e88a35a6ad995ca97ac84dcda..07f6dfabca5879e3de6004e59d2e87f7fa68d66c 100644 --- a/paddle/operators/lookup_table_op.cc +++ b/paddle/operators/lookup_table_op.cc @@ -22,10 +22,17 @@ class LookupTableOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &context) const override { - auto table_t = context.Input("W"); - auto ids_t = context.Input("Ids"); - auto output_t = context.Output("Out"); + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("W"), + "Input(W) of LookupTableOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ids"), + "Input(Ids) of LookupTableOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of LookupTableOp should not be null."); + + auto table_t = ctx.Input("W"); + auto ids_t = ctx.Input("Ids"); + auto output_t = ctx.Output("Out"); output_t->Resize({ids_t->dims()[0], table_t->dims()[1]}); } @@ -56,7 +63,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &context) const override { auto table = context.Input("W"); - auto d_table = context.Output(framework::GradVarName("W")); + auto d_table = + context.Output(framework::GradVarName("W")); d_table->Resize(table->dims()); } }; diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index d3d0e55a674587fb04f43f24d0790de4358f035a..7d7eeb59a23435036dc33c1e4fe6dd1c4a1a2f62 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -24,8 +24,10 @@ class MeanOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input of MeanOp must be initialized."); - ctx.Output("Out")->Resize({1}); + "Input(X) of MeanOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of MeanOp should not be null."); + ctx.Output("Out")->Resize({1}); } }; @@ -45,7 +47,7 @@ class MeanGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output(framework::GradVarName("X")) + ctx.Output(framework::GradVarName("X")) ->Resize(ctx.Input("X")->dims()); } }; diff --git a/paddle/operators/minus_op.cc b/paddle/operators/minus_op.cc index a4876feb2edf77bd422fa2a7687b0fa7d55dae47..a97bbecdca1779df330d1053cf359bb658aa75c2 100644 --- a/paddle/operators/minus_op.cc +++ b/paddle/operators/minus_op.cc @@ -27,13 +27,20 @@ class MinusOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of MinusOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), + "Input(Y) of MinusOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of MinusOp should not be null."); + auto *left_tensor = ctx.Input("X"); auto *right_tensor = ctx.Input("Y"); PADDLE_ENFORCE_EQ( left_tensor->numel(), right_tensor->numel(), "Minus operator must take two tensor with same num of elements"); - ctx.Output("Out")->Resize(left_tensor->dims()); + ctx.Output("Out")->Resize(left_tensor->dims()); } }; @@ -64,7 +71,7 @@ class MinusGradOp : public NetOp { // x_grad = out_grad AppendOp(framework::OpRegistry::CreateOp("identity", {{"X", {out_grad}}}, - {{"Out", {x_grad}}}, {})); + {{"Y", {x_grad}}}, {})); framework::AttributeMap scale_attr; scale_attr["scale"] = static_cast(-1); @@ -77,8 +84,6 @@ class MinusGradOp : public NetOp { } // namespace operators } // namespace paddle -USE_OP(scale); -USE_NO_KERNEL_OP(identity); namespace ops = paddle::operators; REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, minus_grad, ops::MinusGradOp); diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 710a56a0e8e2d17162d7d000df226f1537104eb9..b6d320b415e02549e85cb36ab517b0b5433887d5 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -18,6 +18,7 @@ namespace paddle { namespace operators { using framework::Tensor; +using framework::LoDTensor; class MulOp : public framework::OperatorWithKernel { public: @@ -25,6 +26,13 @@ class MulOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of MulOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), + "Input(Y) of MulOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of MulOp should not be null."); + auto x_dims = ctx.Input("X")->dims(); auto y_dims = ctx.Input("Y")->dims(); int x_num_col_dims = Attr("x_num_col_dims"); @@ -45,7 +53,8 @@ class MulOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ( x_mat_dims[1], y_mat_dims[0], "First matrix's width must be equal with second matrix's height."); - ctx.Output("Out")->Resize({x_mat_dims[0], y_mat_dims[1]}); + ctx.Output("Out")->Resize( + {x_mat_dims[0], y_mat_dims[1]}); } }; @@ -94,8 +103,10 @@ class MulOpGrad : public framework::OperatorWithKernel { auto x_dims = ctx.Input("X")->dims(); auto y_dims = ctx.Input("Y")->dims(); auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); - auto *x_grad = ctx.Output(framework::GradVarName("X")); - auto *y_grad = ctx.Output(framework::GradVarName("Y")); + auto *x_grad = + ctx.Output(framework::GradVarName("X")); + auto *y_grad = + ctx.Output(framework::GradVarName("Y")); auto x_mat_dims = framework::flatten_to_2d(x_dims, Attr("x_num_col_dims")); diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/onehot_cross_entropy_op.cc similarity index 81% rename from paddle/operators/cross_entropy_op.cc rename to paddle/operators/onehot_cross_entropy_op.cc index ab1e1c101a10e09a81f7785d2f1514822e3bdf15..f38be3549f3c5d2443f61739fc32cdca74197649 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/onehot_cross_entropy_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/cross_entropy_op.h" +#include "paddle/operators/onehot_cross_entropy_op.h" namespace paddle { namespace operators { @@ -23,13 +23,23 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL( + ctx.InputVar("X"), + "Input(X) of OnehotCrossEntropyOp should not be null."); + PADDLE_ENFORCE_NOT_NULL( + ctx.InputVar("label"), + "Input(label) of OnehotCrossEntropyOp should not be null."); + PADDLE_ENFORCE_NOT_NULL( + ctx.OutputVar("Y"), + "Output(Y) of OnehotCrossEntropyOp should not be null."); + auto *X = ctx.Input("X"); auto *label = ctx.Input("label"); PADDLE_ENFORCE_EQ(X->dims().size(), 2, "X's dimension must be 2."); PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label's dimension must be 1."); PADDLE_ENFORCE_EQ(X->dims()[0], label->dims()[0]); - ctx.Output("Y")->Resize({X->dims()[0]}); + ctx.Output("Y")->Resize({X->dims()[0], 1}); } }; @@ -39,7 +49,7 @@ class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto dX = ctx.Output(framework::GradVarName("X")); + auto dX = ctx.Output(framework::GradVarName("X")); auto X = ctx.Input("X"); dX->Resize(X->dims()); diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/onehot_cross_entropy_op.cu similarity index 100% rename from paddle/operators/cross_entropy_op.cu rename to paddle/operators/onehot_cross_entropy_op.cu diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/onehot_cross_entropy_op.h similarity index 100% rename from paddle/operators/cross_entropy_op.h rename to paddle/operators/onehot_cross_entropy_op.h diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index 7e78b6ec133981494a65b5e16316ae8fdbd61a60..a0b1c6b631d97a40d774f7d2ff9550fda9c32db4 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -25,6 +25,11 @@ class PadOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of PadOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of PadOp should not be null."); + auto x_dim = ctx.Input("X")->dims(); auto paddings = Attr>("paddings"); PADDLE_ENFORCE_EQ(x_dim.size() * 2, int64_t(paddings.size()), @@ -34,7 +39,8 @@ class PadOp : public framework::OperatorWithKernel { for (int i = 0; i < x_dim.size(); ++i) { out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1]; } - ctx.Output("Out")->Resize(framework::make_ddim(out_dims)); + ctx.Output("Out")->Resize( + framework::make_ddim(out_dims)); } }; @@ -95,9 +101,9 @@ class PadOpGrad : public framework::OperatorWithKernel { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); auto x_dims = ctx.Input("X")->dims(); - auto *x_grad = ctx.Output(framework::GradVarName("X")); - if (x_grad != nullptr) { - x_grad->Resize(x_dims); + auto *x_g = ctx.Output(framework::GradVarName("X")); + if (x_g != nullptr) { + x_g->Resize(x_dims); } } }; diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index e826703c60ca82e1fe690eb78c3d4f92981ef3a2..d3413d7cb9305732e9ddf3cb1bc267f7203097f3 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -26,10 +26,11 @@ namespace operators { using Scope = framework::Scope; using Variable = framework::Variable; using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; void RecurrentAlgorithm::InferShape(const Scope& scope) const { seq_len_ = scope.FindVar((arg_->inlinks[0]).external) - ->GetMutable() + ->GetMutable() ->dims()[0]; CreateScopes(scope); auto step_scopes = GetStepScopes(scope); @@ -88,7 +89,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { // the weight are located in parent scope for (auto& var_name : input.second) { if (!step_scope.FindVar(var_name)) { - step_scope.NewVar(var_name)->GetMutable(); + step_scope.NewVar(var_name)->GetMutable(); } } } @@ -106,11 +107,12 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { void RecurrentAlgorithm::InitMemories(Scope* step_scope, bool infer_shape_mode) const { for (auto& attr : arg_->memories) { - Tensor* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable(); + auto* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable(); PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr, "memory [%s]'s boot variable [%s] not exists", attr.var, attr.boot_var); - Tensor* boot_mem = step_scope->FindVar(attr.boot_var)->GetMutable(); + auto* boot_mem = + step_scope->FindVar(attr.boot_var)->GetMutable(); if (infer_shape_mode) { pre_mem->Resize(boot_mem->dims()); PADDLE_ENFORCE_EQ(pre_mem->dims().size(), 2); @@ -192,9 +194,9 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients( "memory variable [%s] does not exists", attr.var); PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr, "boot variable [%s] does not exists", attr.boot_var); - Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable(); - Tensor* boot_mem_grad = - step_scope->NewVar(attr.boot_var)->GetMutable(); + auto* mem_grad = step_scope->NewVar(attr.var)->GetMutable(); + auto* boot_mem_grad = + step_scope->NewVar(attr.boot_var)->GetMutable(); if (infer_shape_mode) { boot_mem_grad->Resize(mem_grad->dims()); } else { @@ -205,7 +207,7 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients( void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { seq_len_ = scope.FindVar((arg_->inlinks[0]).external) - ->GetMutable() + ->GetMutable() ->dims()[0]; auto step_scopes = GetStepScopes(scope); rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index b7061153d2bf13982f14f233e87a87daeeebf5fd..0d05e344148c68f5625dd819ec59c5991892e4ce 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -28,7 +28,11 @@ class ReshapeOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { // input check - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) shouldn't be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of ReshapeOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of ReshapeOp should not be null."); + auto shape = ctx.Attr>("shape"); PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty."); for (auto dim : shape) { @@ -46,7 +50,7 @@ class ReshapeOp : public framework::OperatorWithKernel { std::transform(shape.begin(), shape.end(), shape_int64.begin(), [](int a) { return static_cast(a); }); auto out_dims = framework::make_ddim(shape_int64); - ctx.Output("Out")->Resize(out_dims); + ctx.Output("Out")->Resize(out_dims); } }; @@ -90,7 +94,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Input(Out@GRAD) shouldn't be null."); auto dims = ctx.Input("X")->dims(); - auto *d_in = ctx.Output(framework::GradVarName("X")); + auto *d_in = ctx.Output(framework::GradVarName("X")); d_in->Resize(dims); } }; diff --git a/paddle/operators/rnn/recurrent_op_utils.cc b/paddle/operators/rnn/recurrent_op_utils.cc index 97872c67ac99fbf6c9c177d52f1d4069163e8548..6c082cb1825e04accb09019fef28eb2ec6523a5b 100644 --- a/paddle/operators/rnn/recurrent_op_utils.cc +++ b/paddle/operators/rnn/recurrent_op_utils.cc @@ -21,6 +21,7 @@ namespace rnn { namespace f = paddle::framework; using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; void SegmentInputs(const std::vector& step_scopes, const std::vector& inlinks, const size_t seq_len, @@ -31,7 +32,7 @@ void SegmentInputs(const std::vector& step_scopes, PADDLE_ENFORCE(input_var != nullptr, "input link [%s] is not in scope.", inlinks[i].external); - Tensor* input = input_var->GetMutable(); + LoDTensor* input = input_var->GetMutable(); f::DDim dims = input->dims(); PADDLE_ENFORCE(static_cast(dims[0]) == seq_len, "all the inlinks must have same length"); @@ -40,6 +41,8 @@ void SegmentInputs(const std::vector& step_scopes, Tensor* step_input = step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable(); if (!infer_shape_mode) { + // The input of operators of each step is Tensor here. + // Maybe need to modify Slice function. *step_input = input->Slice(j, j + 1); } step_input->Resize(step_dims); @@ -54,21 +57,23 @@ void ConcatOutputs(const std::vector& step_scopes, auto output_var = step_scopes[0]->FindVar(outlinks[i].external); PADDLE_ENFORCE(output_var != nullptr, "output link [%s] is not in scope.", outlinks[i].external); - Tensor* output = output_var->GetMutable(); + LoDTensor* output = output_var->GetMutable(); if (infer_shape_mode) { auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal); PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope", outlinks[i].internal); - f::DDim step_dims = step_scope_var->template GetMutable()->dims(); + f::DDim step_dims = + step_scope_var->template GetMutable()->dims(); std::vector dims_vec = vectorize(step_dims); dims_vec.insert(dims_vec.begin(), seq_len); output->Resize(f::make_ddim(dims_vec)); } else { output->mutable_data(platform::CPUPlace()); for (size_t j = 0; j < seq_len; j++) { - Tensor* step_output = - step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable(); + LoDTensor* step_output = step_scopes[j] + ->FindVar(outlinks[i].internal) + ->GetMutable(); // TODO(luotao02) data type and platform::DeviceContext() should set // correctly (output->Slice(j, j + 1)) @@ -94,8 +99,8 @@ void LinkMemories(const std::vector& scopes, auto scope = scopes[step_id]; auto linked_scope = scopes[step_id + offset]; for (auto& attr : memories) { - auto mem = scope->FindVar(attr.pre_var)->GetMutable(); - auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable(); + auto mem = scope->FindVar(attr.pre_var)->GetMutable(); + auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable(); if (infer_shape_mode) { mem->Resize(linked_mem->dims()); } else { diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index fa8f0ff1a858143af427b51025279c726f1628e0..2a3fd3be941d91aaa6b014df91d3025f07767577 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -25,6 +25,13 @@ class RowwiseAddOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of RowwiseAddOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), + "Input(b) of RowwiseAddOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of RowwiseAddOp should not be null."); + auto x_dims = ctx.Input("X")->dims(); auto b_dims = ctx.Input("b")->dims(); PADDLE_ENFORCE_GT( @@ -37,7 +44,7 @@ class RowwiseAddOp : public framework::OperatorWithKernel { framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims, "The width of two operands must be same"); PADDLE_ENFORCE_EQ(ctx.OutputSize("Out"), 1, "The output size must be 1"); - ctx.Output("Out")->Resize(x_dims); + ctx.Output("Out")->Resize(x_dims); } }; @@ -76,8 +83,8 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ( framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims, "The width of two operands must be same"); - auto *dx = ctx.Output(framework::GradVarName("X")); - auto *db = ctx.Output(framework::GradVarName("b")); + auto *dx = ctx.Output(framework::GradVarName("X")); + auto *db = ctx.Output(framework::GradVarName("b")); if (dx) dx->Resize(x_dims); if (db) db->Resize(b_dims); } diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc index ea991f683d841b3dc4624a0d8aa3c88367fd3c6d..d1f42e8662537d35e17429f9d436fdc0e5a1dc11 100644 --- a/paddle/operators/scale_op.cc +++ b/paddle/operators/scale_op.cc @@ -27,8 +27,13 @@ class ScaleOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of ScaleOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of ScaleOp should not be null."); + auto *in = ctx.Input("X"); - auto *out = ctx.Output("Out"); + auto *out = ctx.Output("Out"); out->Resize(in->dims()); } }; diff --git a/paddle/operators/scatter_op.cc b/paddle/operators/scatter_op.cc index f901edefa22dc9a252e87116df756d04767a7162..8820262732327306f4f807702751708bd1e2aa36 100644 --- a/paddle/operators/scatter_op.cc +++ b/paddle/operators/scatter_op.cc @@ -24,6 +24,15 @@ class ScatterOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ref"), + "Input(Ref) of ScatterOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Index"), + "Input(Index) of ScatterOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Updates"), + "Input(Updates) of ScatterOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of ScatterOp should not be null."); + PADDLE_ENFORCE_EQ(ctx.Input("Index")->dims().size(), 1, "Update Index should be 1-D."); PADDLE_ENFORCE_EQ(ctx.Input("Ref")->dims().size(), @@ -35,7 +44,8 @@ class ScatterOp : public framework::OperatorWithKernel { framework::DDim data_dim(ctx.Input("Updates")->dims()); for (int i = 1; i < data_dim.size(); ++i) PADDLE_ENFORCE_EQ(data_dim[i], ctx.Input("Updates")->dims()[i]); - ctx.Output("Out")->Resize(ctx.Input("Ref")->dims()); + ctx.Output("Out")->Resize( + ctx.Input("Ref")->dims()); } }; @@ -45,9 +55,11 @@ class ScatterGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); + auto *dUpdates = + ctx.Output(framework::GradVarName("Updates")); auto *Updates = ctx.Input("Updates"); - auto *dRef = ctx.Output(framework::GradVarName("Ref")); + auto *dRef = + ctx.Output(framework::GradVarName("Ref")); auto *Ref = ctx.Input("Ref"); dRef->Resize(Ref->dims()); diff --git a/paddle/operators/sequence_avg_pool_op.cc b/paddle/operators/sequence_avg_pool_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9815b8f3a8d813959949bbfedc79f404721a8216 --- /dev/null +++ b/paddle/operators/sequence_avg_pool_op.cc @@ -0,0 +1,95 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_avg_pool_op.h" + +namespace paddle { +namespace operators { + +class SequenceAvgPoolOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext& ctx) const override { + PADDLE_ENFORCE_NOT_NULL( + ctx.InputVar("X"), "Input(X) of SequenceAvgPoolOp should not be null."); + PADDLE_ENFORCE_NOT_NULL( + ctx.OutputVar("Out"), + "Output(Out) of SequenceAvgPoolOp should not be null."); + + auto* x = ctx.Input("X"); + auto dims = x->dims(); + auto lod = x->lod(); + PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + PADDLE_ENFORCE_GE( + dims[0], + /*batch size = */ static_cast(lod[0].size() - 1), + "The first dimension of Input(X) must be large than batch size."); + dims[0] = lod[0].size() - 1; + ctx.Output("Out")->Resize({dims}); + } +}; + +class SequenceAvgPoolOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SequenceAvgPoolOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of SequenceAvgPoolOp."); + AddOutput("Out", "The output of SequenceAvgPoolOp."); + AddComment(R"DOC( + SequenceAvgPoolOp averages features of all time-steps of each instance. + More detailed comments will be added later. + )DOC"); + } +}; + +class SequenceAvgPoolGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext& ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Gradient of Out should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "The input X should not be null."); + auto og_dims = + ctx.Input(framework::GradVarName("Out"))->dims(); + auto x_dims = ctx.Input("X")->dims(); + PADDLE_ENFORCE_EQ(og_dims.size(), x_dims.size(), + "The rank of output grad must equal to Input(X)."); + for (int64_t i = 1; i < og_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(og_dims[i], x_dims[i], "The dimension mismatch."); + } + auto* x_grad = + ctx.Output(framework::GradVarName("X")); + x_grad->Resize(x_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(sequence_avg_pool, ops::SequenceAvgPoolOp, + ops::SequenceAvgPoolOpMaker, sequence_avg_pool_grad, + ops::SequenceAvgPoolGradOp); +REGISTER_OP_CPU_KERNEL( + sequence_avg_pool, + ops::SequenceAvgPoolKernel); +REGISTER_OP_CPU_KERNEL( + sequence_avg_pool_grad, + ops::SequenceAvgPoolGradKernel); diff --git a/paddle/operators/sequence_avg_pool_op.cu b/paddle/operators/sequence_avg_pool_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..bc9d1611fccd17c99b914b6ef59995288a9ebbd6 --- /dev/null +++ b/paddle/operators/sequence_avg_pool_op.cu @@ -0,0 +1,25 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU + +#include "paddle/operators/sequence_avg_pool_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + sequence_avg_pool, + ops::SequenceAvgPoolKernel); +REGISTER_OP_GPU_KERNEL( + sequence_avg_pool_grad, + ops::SequenceAvgPoolGradKernel); diff --git a/paddle/operators/sequence_avg_pool_op.h b/paddle/operators/sequence_avg_pool_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ebe0956344eb71d0fb2836f1b4a989ac546d9f78 --- /dev/null +++ b/paddle/operators/sequence_avg_pool_op.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +template +using EigenVector = framework::EigenVector; +template +using EigenMatrix = framework::EigenMatrix; + +template +class SequenceAvgPoolKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + + auto dims = in->dims(); + auto lod = in->lod(); + int64_t w = in->numel() / dims[0]; + + out->mutable_data(context.GetPlace()); + auto place = context.GetEigenDevice(); + for (int i = 0; i < static_cast(lod[0].size()) - 1; ++i) { + Tensor in_t = in->Slice(static_cast(lod[0][i]), + static_cast(lod[0][i + 1])); + Tensor out_t = out->Slice(i, i + 1); + int64_t h = static_cast(lod[0][i + 1] - lod[0][i]); + auto in_e = EigenMatrix::From(in_t, framework::make_ddim({h, w})); + auto out_e = EigenVector::Flatten(out_t); + out_e.device(place) = in_e.mean(Eigen::array({{0}})); + } + } +}; + +template +class SequenceAvgPoolGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out_g = context.Input(framework::GradVarName("Out")); + auto* in_g = context.Output(framework::GradVarName("X")); + + auto dims = in->dims(); + auto lod = in->lod(); + int64_t w = in->numel() / dims[0]; + + in_g->mutable_data(context.GetPlace()); + auto place = context.GetEigenDevice(); + for (int i = 0; i < static_cast(lod[0].size()) - 1; ++i) { + auto in_g_t = in_g->Slice(static_cast(lod[0][i]), + static_cast(lod[0][i + 1])); + auto out_g_t = out_g->Slice(i, i + 1); + int64_t h = static_cast(lod[0][i + 1] - lod[0][i]); + auto in_g_e = EigenMatrix::From(in_g_t, {h, w}); + auto out_g_e = EigenMatrix::From(out_g_t, {1, w}); + Eigen::DSizes bcast(h, 1); + in_g_e.device(place) = (out_g_e / static_cast(h)).broadcast(bcast); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index ad267e7f087943ff3b8326a7baf2ce3955fa51c2..1232e64c7f0132b9ea19b3d7e1ebe9531e1e25a5 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -23,10 +23,18 @@ class SGDOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE( - ctx.Input("param")->dims() == ctx.Input("grad")->dims(), - "Two input of SGD Op's dimension must be same."); - ctx.Output("param_out")->Resize(ctx.Input("param")->dims()); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("param"), + "Input(param) of SGDOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("grad"), + "Input(grad) of SGDOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("param_out"), + "Output(param_out) of SGDOp should not be null."); + + PADDLE_ENFORCE_EQ(ctx.Input("param")->dims(), + ctx.Input("grad")->dims(), + "Two input of SGD Op's dimension must be same."); + ctx.Output("param_out") + ->Resize(ctx.Input("param")->dims()); } }; diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index 761c6de8d4d2150b30b97b58da95da3d5f33db63..992b19965e0ca9ce7dba1b8b3c5b7780af06eb45 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -23,7 +23,13 @@ class SigmoidOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output("Y")->Resize(ctx.Input("X")->dims()); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of SigmoidOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), + "Output(Y) of SigmoidOp should not be null."); + + ctx.Output("Y")->Resize( + ctx.Input("X")->dims()); } }; @@ -44,7 +50,7 @@ class SigmoidOpGrad : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output(framework::GradVarName("X")) + ctx.Output(framework::GradVarName("X")) ->Resize(ctx.Input("Y")->dims()); } }; diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 7166b2f60be8a6088ab3a81686f7bed1b7181d97..c67eb028c882ed82ca4e6a4dd70cdea9f69cdc24 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -23,9 +23,15 @@ class SoftmaxOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of SoftmaxOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), + "Output(Y) of SoftmaxOp should not be null."); + PADDLE_ENFORCE(ctx.Input("X")->dims().size() == 2UL, "The input of softmax op must be a matrix."); - ctx.Output("Y")->Resize(ctx.Input("X")->dims()); + ctx.Output("Y")->Resize( + ctx.Input("X")->dims()); } }; @@ -71,7 +77,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ctx.Input(framework::GradVarName("Y"))->dims(), "Input(Y) and its gradients should have a same shape."); - ctx.Output(framework::GradVarName("X")) + ctx.Output(framework::GradVarName("X")) ->Resize(ctx.Input("X")->dims()); } }; diff --git a/paddle/operators/split_op.cc b/paddle/operators/split_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..61296f5c8122fdce7083e9a91dc313482875c805 --- /dev/null +++ b/paddle/operators/split_op.cc @@ -0,0 +1,118 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/split_op.h" +#include "paddle/operators/net_op.h" + +namespace paddle { +namespace operators { +using framework::Tensor; + +class SplitOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + // infershape + auto *in = ctx.Input("X"); + auto outs = ctx.MultiOutput("Out"); + size_t axis = static_cast(ctx.Attr("axis")); + size_t num = static_cast(ctx.Attr("num")); + std::vector sections = + static_cast>(ctx.Attr>("sections")); + const size_t n = outs.size(); + + if (num > 0) { + int64_t in_axis_dim = in->dims()[axis]; + PADDLE_ENFORCE_EQ(in_axis_dim % num, 0, + "tensor split does not result" + " in an equal division"); + size_t out_axis_dim = in_axis_dim / num; + for (size_t i = 0; i < n; ++i) { + auto dim = in->dims(); + dim[axis] = out_axis_dim; + outs[i]->Resize(dim); + } + } else if (sections.size() > 0) { + PADDLE_ENFORCE_EQ(sections.size(), n, + "tensor split sections size" + "should be equal to output size."); + for (size_t i = 0; i < n; ++i) { + auto dim = in->dims(); + dim[axis] = sections[i]; + outs[i]->Resize(dim); + } + } else { + PADDLE_ENFORCE_NOT_NULL(nullptr, "split operator should", + " specify indices or sections."); + } + } +}; + +class SplitOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SplitOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "the input tensor of split operator."); + AddOutput("Out", "the output tensors of split operator.").AsDuplicable(); + AddComment(R"DOC( + Split the input tensor into multiple sub-tensors. + Example: + Input = [[1,2], + [3,4], + [5,6]] + sections = [2,1] + axis = 0 + Output[0] = [[1,2], + [3,4]] + Output[1] = [[5,6]] + + )DOC"); + AddAttr>("sections", + "the length for each" + "output along with the specify axis.") + .SetDefault(std::vector{}); + AddAttr("num", + "number of the sub-tensors, it must evenly divide " + "Input.dims()[axis]") + .SetDefault(0); + AddAttr("axis", "The axis which the input will be splited on.") + .SetDefault(0); + } +}; + +class SplitOpGrad : public NetOp { + public: + SplitOpGrad(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : NetOp(type, inputs, outputs, attrs) { + auto out_grad = Inputs(framework::GradVarName("Out")); + auto x_grad = Output(framework::GradVarName("X")); + AppendOp(framework::OpRegistry::CreateOp("concat", {{"X", out_grad}}, + {{"Out", {x_grad}}}, attrs)); + CompleteAddOp(false); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +USE_CPU_ONLY_OP(concat); +REGISTER_OP(split, ops::SplitOp, ops::SplitOpMaker, split_grad, + ops::SplitOpGrad); +REGISTER_OP_CPU_KERNEL(split, + ops::SplitKernel); diff --git a/paddle/operators/split_op.h b/paddle/operators/split_op.h new file mode 100644 index 0000000000000000000000000000000000000000..860690ee895075fda9ddef08776a2102642efff9 --- /dev/null +++ b/paddle/operators/split_op.h @@ -0,0 +1,62 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class SplitKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto outs = ctx.MultiOutput("Out"); + int64_t axis = static_cast(ctx.Attr("axis")); + size_t before = 1, after = 1; + const size_t n = outs.size(); + size_t input_axis_dim = in->dims()[axis]; + + for (int64_t i = 0; i < in->dims().size(); ++i) { + if (i == axis) { + continue; + } + if (i < axis) { + before *= in->dims()[i]; + } else { + after *= in->dims()[i]; + } + } + size_t input_offset = 0; + for (size_t i = 0; i < n; i++) { + auto& out = outs[i]; + size_t axis_dim = out->dims()[axis]; + for (size_t j = 0; j < before; j++) { + size_t len = axis_dim * after * sizeof(T); + T* dest = + out->mutable_data(platform::CPUPlace()) + axis_dim * after * j; + const T* src = + in->data() + input_offset + input_axis_dim * after * j; + memcpy(dest, src, len); + } + input_offset += axis_dim * after; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/squared_l2_distance_op.cc b/paddle/operators/squared_l2_distance_op.cc index 9f51d3efa8ecba894a1023b9de2df451ca85916c..39f4305877de20d451bc35fe698a0eabf9758d57 100644 --- a/paddle/operators/squared_l2_distance_op.cc +++ b/paddle/operators/squared_l2_distance_op.cc @@ -23,12 +23,18 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input of SquaredL2DistanceOp " - "must be initialized."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), - "Target of SquaredL2DistanceOp " - "must be initialized."); + PADDLE_ENFORCE_NOT_NULL( + ctx.InputVar("X"), + "Input(X) of SquaredL2DistanceOp should not be null."); + PADDLE_ENFORCE_NOT_NULL( + ctx.InputVar("Y"), + "Input(Y) of SquaredL2DistanceOp should not be null."); + PADDLE_ENFORCE_NOT_NULL( + ctx.OutputVar("sub_result"), + "Output(sub_result) of SquaredL2DistanceOp should not be null."); + PADDLE_ENFORCE_NOT_NULL( + ctx.OutputVar("Out"), + "Output(Out) of SquaredL2DistanceOp should not be null."); auto* x = ctx.Input("X"); auto x_dims = x->dims(); @@ -48,9 +54,9 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { "First dimension of target must be equal to input " "or to 1."); - ctx.Output("sub_result") + ctx.Output("sub_result") ->Resize({x_dims[0], x->numel() / x_dims[0]}); - ctx.Output("Out")->Resize({x_dims[0], 1}); + ctx.Output("Out")->Resize({x_dims[0], 1}); } }; @@ -94,8 +100,10 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(out_dims[1], 1, "Second dimension of output gradient " "must be 1."); - auto* x_grad = ctx.Output(framework::GradVarName("X")); - auto* y_grad = ctx.Output(framework::GradVarName("Y")); + auto* x_grad = + ctx.Output(framework::GradVarName("X")); + auto* y_grad = + ctx.Output(framework::GradVarName("Y")); if (x_grad) x_grad->Resize(x_dims); if (y_grad) y_grad->Resize(y_dims); } diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc index 5805826ee8a555ca6dfc1ca81feaadffea9e1012..41e05c27f9029b2664685d3979fadcfd2bf6dbce 100644 --- a/paddle/operators/sum_op.cc +++ b/paddle/operators/sum_op.cc @@ -22,8 +22,13 @@ class SumOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), + "Input(X) of SumOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of SumOp should not be null."); + auto ins = ctx.MultiInput("X"); - auto *out = ctx.Output("Out"); + auto *out = ctx.Output("Out"); int N = ins.size(); auto in_dim = ins[0]->dims(); @@ -55,7 +60,8 @@ class SumGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto outputs = ctx.MultiOutput(framework::GradVarName("X")); + auto outputs = + ctx.MultiOutput(framework::GradVarName("X")); auto dims = ctx.Input(framework::GradVarName("Out"))->dims(); for (auto output : outputs) { output->Resize(dims); diff --git a/paddle/operators/top_k_op.cc b/paddle/operators/top_k_op.cc index 38d2f0a09aec751734864947a2f3cfa20107e22f..169b815feffd86f9ff04c129ccc997230ce03a8c 100644 --- a/paddle/operators/top_k_op.cc +++ b/paddle/operators/top_k_op.cc @@ -24,7 +24,12 @@ class TopkOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input of TopkOP must be initialized."); + "Input(X) of TopkOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of TopkOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Indices"), + "Output(Indices) of TopkOp should not be null."); + auto *input = ctx.Input("X"); const int k = static_cast(ctx.Attr("k")); @@ -35,8 +40,8 @@ class TopkOp : public framework::OperatorWithKernel { framework::DDim dims = input->dims(); dims[dims.size() - 1] = k; - ctx.Output("Out")->Resize(dims); - ctx.Output("Indices")->Resize(dims); + ctx.Output("Out")->Resize(dims); + ctx.Output("Indices")->Resize(dims); } }; diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index b8fbc9b52aecdb5c8d985b5de9bcd7cb85835b60..184bcbc29c0d26a214345506f126f9cc0d406b07 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -48,9 +48,13 @@ class UniformRandomOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext& ctx) const override { + PADDLE_ENFORCE_NOT_NULL( + ctx.OutputVar("Out"), + "Output(Out) of UniformRandomOp should not be null."); + PADDLE_ENFORCE(Attr("min") < Attr("max"), "uniform_random's min must less then max"); - auto* tensor = ctx.Output("Out"); + auto* tensor = ctx.Output("Out"); auto dims = Attr>("dims"); std::vector temp; temp.reserve(dims.size()); diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 17bdac8749e31565b119b2cb84aed199fac0f441..8b605e51c3f4ea38fc358ce054bb36fcc82063c4 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -24,3 +24,4 @@ cc_library(device_context SRCS device_context.cc DEPS memory buddy_allocator nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info) nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda) +nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place) diff --git a/paddle/platform/cuda_helper.h b/paddle/platform/cuda_helper.h index 6feec0d7f8bd5d32d9e5eedee962fcbeff655f1c..a7d99cde106a0a66f122a8c43f49717c03e60dec 100644 --- a/paddle/platform/cuda_helper.h +++ b/paddle/platform/cuda_helper.h @@ -24,6 +24,11 @@ namespace platform { #define USE_CUDA_ATOMIC(op, T) \ CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); } +// Default thread count per block(or block size). +// TODO(typhoonzero): need to benchmark against setting this value +// to 1024. +constexpr int PADDLE_CUDA_NUM_THREADS = 512; + // For atomicAdd. USE_CUDA_ATOMIC(Add, float); diff --git a/paddle/platform/details/device_ptr_cast.h b/paddle/platform/details/device_ptr_cast.h new file mode 100644 index 0000000000000000000000000000000000000000..4015491fcdc3554029aa771ab7da1b2f3424321f --- /dev/null +++ b/paddle/platform/details/device_ptr_cast.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifndef __NVCC__ +#error device_ptr_cast must be include by .cu file +#endif + +#include + +namespace paddle { +namespace platform { +namespace details { +template +struct DevicePtrCast; + +template +struct DevicePtrCast { + using ELEM = typename std::remove_pointer::type; + using RTYPE = thrust::device_ptr; + + inline thrust::device_ptr operator()(ELEM* ele) const { + return thrust::device_pointer_cast(ele); + } +}; + +template +struct DevicePtrCast { + using RTYPE = T; + inline RTYPE operator()(RTYPE it) const { return it; } +}; + +// Cast T to thrust::device_ptr if T is a pointer. +// Otherwise, e.g., T is a iterator, return T itself. +template +auto DevPtrCast(T t) -> + typename DevicePtrCast::value>::RTYPE { + DevicePtrCast::value> cast; + return cast(t); +} + +} // namespace details +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 64fcbd93b6c4d5d9b36f2636c3ef4f7327f08d25..df5f71ed760952ed042d7ffa40a4319a73fb93bf 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -25,6 +25,10 @@ limitations under the License. */ #include "paddle/string/printf.h" #include "paddle/string/to_string.h" +#ifdef __GNUC__ +#include // for __cxa_demangle +#endif + #ifndef PADDLE_ONLY_CPU #include "paddle/platform/dynload/cublas.h" @@ -42,6 +46,19 @@ limitations under the License. */ namespace paddle { namespace platform { +namespace { +#ifdef __GNUC__ +inline std::string demangle(std::string name) { + int status = -4; // some arbitrary value to eliminate the compiler warning + std::unique_ptr res{ + abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free}; + return (status == 0) ? res.get() : name; +} +#else +inline std::string demangle(std::string name) { return name; } +#endif +} + struct EnforceNotMet : public std::exception { std::exception_ptr exp_; std::string err_str_; @@ -61,8 +78,8 @@ struct EnforceNotMet : public std::exception { Dl_info info; for (int i = 0; i < size; ++i) { - if (dladdr(call_stack[i], &info)) { - auto demangled = info.dli_sname; + if (dladdr(call_stack[i], &info) && info.dli_sname) { + auto demangled = demangle(info.dli_sname); auto addr_offset = static_cast(call_stack[i]) - static_cast(info.dli_saddr); sout << string::Sprintf("%-3d %*0p %s + %zd\n", i, diff --git a/paddle/platform/transform.h b/paddle/platform/transform.h new file mode 100644 index 0000000000000000000000000000000000000000..3ee4acd29660f201d318ce6d39baa6f3999ae274 --- /dev/null +++ b/paddle/platform/transform.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/platform/enforce.h" +#include "paddle/platform/hostdevice.h" +#include "paddle/platform/place.h" + +#include +#include +#ifdef __NVCC__ +#include +#include "paddle/platform/details/device_ptr_cast.h" +#endif + +namespace paddle { +namespace platform { +// Transform on host or device. It provides the same API in std library. +template +void Transform(Place place, InputIter first, InputIter last, OutputIter result, + UnaryOperation op) { + if (is_cpu_place(place)) { + std::transform(first, last, result, op); + } else { +#ifdef __NVCC__ + using namespace details; + thrust::transform(DevPtrCast(first), DevPtrCast(last), DevPtrCast(result), + op); +#else + PADDLE_THROW("Do not invoke `Transform` in .cc file"); +#endif + } +} + +template +void Transform(Place place, InputIter1 first1, InputIter1 last1, + InputIter2 first2, OutputIter result, BinaryOperation op) { + if (is_cpu_place(place)) { + std::transform(first1, last1, first2, result, op); + } else { +#ifdef __NVCC__ + using namespace details; + thrust::transform(DevPtrCast(first1), DevPtrCast(last1), DevPtrCast(first2), + DevPtrCast(result), op); +#else + PADDLE_THROW("Do not invoke `Transform` in .cc file"); +#endif + } +}; + +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/transform_test.cu b/paddle/platform/transform_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..600fed8f45077a6fee91f295aa854153c9cf9c01 --- /dev/null +++ b/paddle/platform/transform_test.cu @@ -0,0 +1,84 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include "paddle/memory/memcpy.h" +#include "paddle/memory/memory.h" +#include "paddle/platform/transform.h" + +template +class Scale { + public: + explicit Scale(const T& scale) : scale_(scale) {} + + HOSTDEVICE T operator()(const T& a) const { return a * scale_; } + + private: + T scale_; +}; + +template +class Multiply { + public: + HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; } +}; + +TEST(Transform, CPUUnary) { + using namespace paddle::platform; + float buf[4] = {0.1, 0.2, 0.3, 0.4}; + Transform(CPUPlace(), buf, buf + 4, buf, Scale(10)); + for (int i = 0; i < 4; ++i) { + ASSERT_NEAR(buf[i], static_cast(i + 1), 1e-5); + } +} + +TEST(Transform, GPUUnary) { + using namespace paddle::platform; + using namespace paddle::memory; + GPUPlace gpu0(0); + float cpu_buf[4] = {0.1, 0.2, 0.3, 0.4}; + float* gpu_buf = static_cast(Alloc(gpu0, sizeof(float) * 4)); + Copy(gpu0, gpu_buf, CPUPlace(), cpu_buf, sizeof(cpu_buf)); + Transform(gpu0, gpu_buf, gpu_buf + 4, gpu_buf, Scale(10)); + Copy(CPUPlace(), cpu_buf, gpu0, gpu_buf, sizeof(cpu_buf)); + Free(gpu0, gpu_buf); + for (int i = 0; i < 4; ++i) { + ASSERT_NEAR(cpu_buf[i], static_cast(i + 1), 1e-5); + } +} + +TEST(Transform, CPUBinary) { + using namespace paddle::platform; + using namespace paddle::memory; + int buf[4] = {1, 2, 3, 4}; + Transform(CPUPlace(), buf, buf + 4, buf, buf, Multiply()); + for (int i = 0; i < 4; ++i) { + ASSERT_EQ((i + 1) * (i + 1), buf[i]); + } +} + +TEST(Transform, GPUBinary) { + using namespace paddle::platform; + using namespace paddle::memory; + int buf[4] = {1, 2, 3, 4}; + GPUPlace gpu0(0); + int* gpu_buf = static_cast(Alloc(gpu0, sizeof(buf))); + Copy(gpu0, gpu_buf, CPUPlace(), buf, sizeof(buf)); + Transform(gpu0, gpu_buf, gpu_buf + 4, gpu_buf, gpu_buf, Multiply()); + Copy(CPUPlace(), buf, gpu0, gpu_buf, sizeof(buf)); + Free(gpu0, gpu_buf); + for (int i = 0; i < 4; ++i) { + ASSERT_EQ((i + 1) * (i + 1), buf[i]); + } +} \ No newline at end of file diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 00030050700bfb2cee224124d090b0027d456ba0..4f05406c7f74113d8fb10aa6914166e553858338 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,5 +1,5 @@ if(WITH_PYTHON) -cc_library(paddle_pybind SHARED + cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python backward ${GLOB_OP_LIB}) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index fe1e50927a425018e471ac8065c0349474ee93a5..c7009a604f60cda11434ad33b6c7d7caee1befdd 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -19,10 +19,12 @@ limitations under the License. */ #include "paddle/framework/backward.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/cond_op.h" #include "paddle/operators/net_op.h" #include "paddle/operators/recurrent_op.h" #include "paddle/platform/enforce.h" #include "paddle/platform/place.h" +#include "paddle/pybind/pybind.h" #include "paddle/pybind/tensor_py.h" #include "paddle/string/to_string.h" #include "pybind11/numpy.h" @@ -31,33 +33,6 @@ limitations under the License. */ namespace py = pybind11; -USE_OP(add); -USE_OP(onehot_cross_entropy); -USE_OP(sgd); -USE_OP(mul); -USE_OP(elementwise_mul); -USE_OP(mean); -USE_OP(sigmoid); -USE_OP(softmax); -USE_OP(rowwise_add); -USE_OP(fill_zeros_like); -USE_NO_KERNEL_OP(recurrent); -USE_OP(gaussian_random); -USE_OP(uniform_random); -USE_OP(lookup_table); -USE_OP(scale); -USE_NO_KERNEL_OP(identity); -USE_OP(minus); -USE_OP(cos_sim); -USE_CPU_ONLY_OP(gather); -USE_OP(pad); -USE_CPU_ONLY_OP(scatter); -USE_CPU_ONLY_OP(concat); -USE_OP(top_k); -USE_OP(squared_l2_distance); -USE_OP(sum); -USE_OP(reshape); - namespace paddle { namespace framework { @@ -123,27 +98,21 @@ PYBIND11_PLUGIN(core) { return self.data()[offset]; }); - py::class_(m, "LoDTensor", R"DOC(LoD(Leval of Ddetails) Tensor. - -The tensor and LoD info should be created before creating the LoDTensor, then -call the set_tensor and set_lod functions to set them. - -)DOC") - .def("__init__", - [](LoDTensor &instance, - const std::vector> &lod, - Tensor *t) { + py::class_(m, "LoDTensor") + .def_buffer( + [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); }) + .def( + "__init__", + [](LoDTensor &instance, const std::vector> &lod) { #ifdef PADDLE_ONLY_CPU - new (&instance) LoDTensor(lod, t); + new (&instance) LoDTensor(lod); #else paddle::framework::LoD new_lod; new_lod.reserve(lod.size()); std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod)); - new (&instance) LoDTensor(new_lod, t); + new (&instance) LoDTensor(new_lod); #endif - }) - .def("set_tensor", - [](LoDTensor &self, Tensor *tensor) { self.set_tensor(tensor); }) + }) .def("set_lod", [](LoDTensor &self, const std::vector> &lod) { #ifdef PADDLE_ONLY_CPU @@ -155,9 +124,6 @@ call the set_tensor and set_lod functions to set them. self.set_lod(new_lod); #endif }) - .def("tensor", - [](LoDTensor &self) -> Tensor & { return self.tensor(); }, - py::return_value_policy::reference) .def("lod", [](LoDTensor &self) -> std::vector> { #ifdef PADDLE_ONLY_CPU return self.lod(); @@ -186,9 +152,6 @@ All parameter, weight, gradient are variables in Paddle. [](Variable &var, int val) -> void { *var.GetMutable() = val; }) .def("get_int", [](const Variable &var) -> int { return var.Get(); }) .def("get_tensor", - [](Variable &self) -> Tensor * { return self.GetMutable(); }, - py::return_value_policy::reference) - .def("get_lod_tensor", [](Variable &self) -> LoDTensor * { return self.GetMutable(); }, @@ -326,6 +289,28 @@ All parameter, weight, gradient are variables in Paddle. [](operators::RecurrentOp &self, const operators::NetOp &net) -> void { self.set_stepnet(net.Clone()); }); + // cond_op + py::class_(m, "CondOp") + .def_static("create", + [](py::bytes protobin) -> operators::CondOp * { + OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + auto cond_op = OpRegistry::CreateOp(desc); + return static_cast(cond_op.release()); + }) + .def("set_truenet", + [](operators::CondOp &self, const operators::NetOp &net) -> void { + self.set_truenet(net.Clone()); + }) + .def("set_falsenet", + [](operators::CondOp &self, const operators::NetOp &net) -> void { + self.set_falsenet(net.Clone()); + }); + m.def("unique_integer", UniqueIntegerGenerator); m.def("is_compile_gpu", IsCompileGPU); diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 4f68a8953446ffa0510df65c5b214d09b913cff8..a9e1d6d2e06d56f837690ec95fa8f8d41a90725f 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -2055,20 +2055,26 @@ class ConvLayerBase(LayerBase): if num_filters is not None: self.config.num_filters = num_filters + use_mkldnn = int(g_command_config_args.get("use_mkldnn", 0)) use_gpu = int(g_command_config_args.get("use_gpu", 0)) parallel_nn = int(g_command_config_args.get("parallel_nn", 0)) - # Automatically select cudnn_type for GPU and exconv for CPU + # Automatically select cudnn_type for GPU, exconv for CPU + # and mkldnn_conv for MKLDNN # if set type=conv, but still reserve the way user specify - # exconv or cudnn_conv manually. + # exconv, mkldnn_conv or cudnn_conv manually. if self.layer_type == "cudnn_conv": config_assert(use_gpu, "cudnn_conv only support GPU") + if self.layer_type == "mkldnn_conv": + config_assert(use_mkldnn, "mkldnn_conv only support MKLDNN") + if (use_gpu == 1 and self.layer_type != "exconv" and + self.layer_type != "mkldnn_conv" and (parallel_nn == 0 or self.config.device > -1)): self.layer_type = "cudnn_conv" else: - self.layer_type = "exconv" + self.layer_type = "mkldnn_conv" if use_mkldnn else "exconv" # need to specify layer in config self.config.type = self.layer_type @@ -2100,6 +2106,11 @@ class ConvLayer(ConvLayerBase): layer_type = 'exconv' +@config_layer('mkldnn_conv') +class ConvLayer(ConvLayerBase): + layer_type = 'mkldnn_conv' + + @config_layer('cudnn_conv') class ConvLayer(ConvLayerBase): layer_type = 'cudnn_conv' diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py index 34be203ee254584027c79cf93fe54f404b7235db..93e8ac173e721d9623fce91f30ac4642d273caba 100644 --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -11,10 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -""" -# from activations import * + from activations import LinearActivation, ReluActivation, SoftmaxActivation, \ IdentityActivation, TanhActivation, SequenceSoftmaxActivation from attrs import ExtraAttr @@ -55,49 +53,49 @@ def sequence_conv_pool(input, context_attr=None, pool_attr=None): """ - Text convolution pooling layers helper. + Text convolution pooling group. Text input => Context Projection => FC Layer => Pooling => Output. - :param name: name of output layer(pooling layer name) + :param name: group name. :type name: basestring - :param input: name of input layer + :param input: input layer. :type input: LayerOutput :param context_len: context projection length. See context_projection's document. :type context_len: int :param hidden_size: FC Layer size. :type hidden_size: int - :param context_start: context projection length. See + :param context_start: context start position. See context_projection's context_start. - :type context_start: int or None + :type context_start: int|None :param pool_type: pooling layer type. See pooling_layer's document. - :type pool_type: BasePoolingType. + :type pool_type: BasePoolingType :param context_proj_layer_name: context projection layer name. None if user don't care. :type context_proj_layer_name: basestring - :param context_proj_param_attr: context projection parameter attribute. - None if user don't care. - :type context_proj_param_attr: ParameterAttribute or None. + :param context_proj_param_attr: padding parameter attribute of context projection layer. + If false, it means padding always be zero. + :type context_proj_param_attr: ParameterAttribute|None :param fc_layer_name: fc layer name. None if user don't care. :type fc_layer_name: basestring :param fc_param_attr: fc layer parameter attribute. None if user don't care. - :type fc_param_attr: ParameterAttribute or None + :type fc_param_attr: ParameterAttribute|None :param fc_bias_attr: fc bias parameter attribute. False if no bias, None if user don't care. - :type fc_bias_attr: ParameterAttribute or None - :param fc_act: fc layer activation type. None means tanh + :type fc_bias_attr: ParameterAttribute|False|None + :param fc_act: fc layer activation type. None means tanh. :type fc_act: BaseActivation - :param pool_bias_attr: pooling layer bias attr. None if don't care. - False if no bias. - :type pool_bias_attr: ParameterAttribute or None. + :param pool_bias_attr: pooling layer bias attr. False if no bias. + None if user don't care. + :type pool_bias_attr: ParameterAttribute|False|None :param fc_attr: fc layer extra attribute. :type fc_attr: ExtraLayerAttribute :param context_attr: context projection layer extra attribute. :type context_attr: ExtraLayerAttribute :param pool_attr: pooling layer extra attribute. :type pool_attr: ExtraLayerAttribute - :return: output layer name. + :return: layer's output. :rtype: LayerOutput """ # Set Default Value to param @@ -163,45 +161,45 @@ def simple_img_conv_pool(input, """ Simple image convolution and pooling group. - Input => conv => pooling + Img input => Conv => Pooling => Output. - :param name: group name + :param name: group name. :type name: basestring - :param input: input layer name. + :param input: input layer. :type input: LayerOutput - :param filter_size: see img_conv_layer for details + :param filter_size: see img_conv_layer for details. :type filter_size: int - :param num_filters: see img_conv_layer for details + :param num_filters: see img_conv_layer for details. :type num_filters: int - :param pool_size: see img_pool_layer for details + :param pool_size: see img_pool_layer for details. :type pool_size: int - :param pool_type: see img_pool_layer for details + :param pool_type: see img_pool_layer for details. :type pool_type: BasePoolingType - :param act: see img_conv_layer for details + :param act: see img_conv_layer for details. :type act: BaseActivation - :param groups: see img_conv_layer for details + :param groups: see img_conv_layer for details. :type groups: int - :param conv_stride: see img_conv_layer for details + :param conv_stride: see img_conv_layer for details. :type conv_stride: int - :param conv_padding: see img_conv_layer for details + :param conv_padding: see img_conv_layer for details. :type conv_padding: int - :param bias_attr: see img_conv_layer for details + :param bias_attr: see img_conv_layer for details. :type bias_attr: ParameterAttribute - :param num_channel: see img_conv_layer for details + :param num_channel: see img_conv_layer for details. :type num_channel: int - :param param_attr: see img_conv_layer for details + :param param_attr: see img_conv_layer for details. :type param_attr: ParameterAttribute - :param shared_bias: see img_conv_layer for details + :param shared_bias: see img_conv_layer for details. :type shared_bias: bool - :param conv_layer_attr: see img_conv_layer for details + :param conv_layer_attr: see img_conv_layer for details. :type conv_layer_attr: ExtraLayerAttribute - :param pool_stride: see img_pool_layer for details + :param pool_stride: see img_pool_layer for details. :type pool_stride: int - :param pool_padding: see img_pool_layer for details + :param pool_padding: see img_pool_layer for details. :type pool_padding: int - :param pool_layer_attr: see img_pool_layer for details + :param pool_layer_attr: see img_pool_layer for details. :type pool_layer_attr: ExtraLayerAttribute - :return: Layer's output + :return: layer's output :rtype: LayerOutput """ _conv_ = img_conv_layer( @@ -252,48 +250,52 @@ def img_conv_bn_pool(input, pool_layer_attr=None): """ Convolution, batch normalization, pooling group. + + Img input => Conv => BN => Pooling => Output. - :param name: group name + :param name: group name. :type name: basestring - :param input: layer's input - :type input: LayerOutput - :param filter_size: see img_conv_layer's document + :param input: input layer. + :type input: LayerOutput + :param filter_size: see img_conv_layer for details. :type filter_size: int - :param num_filters: see img_conv_layer's document + :param num_filters: see img_conv_layer for details. :type num_filters: int - :param pool_size: see img_pool_layer's document. + :param pool_size: see img_pool_layer for details. :type pool_size: int - :param pool_type: see img_pool_layer's document. + :param pool_type: see img_pool_layer for details. :type pool_type: BasePoolingType - :param act: see batch_norm_layer's document. + :param act: see batch_norm_layer for details. :type act: BaseActivation - :param groups: see img_conv_layer's document + :param groups: see img_conv_layer for details. :type groups: int - :param conv_stride: see img_conv_layer's document. + :param conv_stride: see img_conv_layer for details. :type conv_stride: int - :param conv_padding: see img_conv_layer's document. + :param conv_padding: see img_conv_layer for details. :type conv_padding: int - :param conv_bias_attr: see img_conv_layer's document. + :param conv_bias_attr: see img_conv_layer for details. :type conv_bias_attr: ParameterAttribute - :param num_channel: see img_conv_layer's document. + :param num_channel: see img_conv_layer for details. :type num_channel: int - :param conv_param_attr: see img_conv_layer's document. + :param conv_param_attr: see img_conv_layer for details. :type conv_param_attr: ParameterAttribute - :param shared_bias: see img_conv_layer's document. + :param shared_bias: see img_conv_layer for details. :type shared_bias: bool - :param conv_layer_attr: see img_conv_layer's document. + :param conv_layer_attr: see img_conv_layer for details. :type conv_layer_attr: ExtraLayerOutput - :param bn_param_attr: see batch_norm_layer's document. - :type bn_param_attr: ParameterAttribute. - :param bn_bias_attr: see batch_norm_layer's document. - :param bn_layer_attr: ParameterAttribute. - :param pool_stride: see img_pool_layer's document. + :param bn_param_attr: see batch_norm_layer for details. + :type bn_param_attr: ParameterAttribute + :param bn_bias_attr: see batch_norm_layer for details. + :type bn_bias_attr: ParameterAttribute + :param bn_layer_attr: see batch_norm_layer for details. + :type bn_layer_attr: ExtraLayerAttribute + :param pool_stride: see img_pool_layer for details. :type pool_stride: int - :param pool_padding: see img_pool_layer's document. + :param pool_padding: see img_pool_layer for details. :type pool_padding: int - :param pool_layer_attr: see img_pool_layer's document. + :param pool_layer_attr: see img_pool_layer for details. :type pool_layer_attr: ExtraLayerAttribute - :return: Layer groups output + :return: layer's output :rtype: LayerOutput """ __conv__ = img_conv_layer( @@ -348,10 +350,10 @@ def img_conv_group(input, :param conv_batchnorm_drop_rate: if conv_with_batchnorm[i] is true, conv_batchnorm_drop_rate[i] represents the drop rate of each batch norm. :type conv_batchnorm_drop_rate: list - :param input: layer's input. + :param input: input layer. :type input: LayerOutput - :param conv_num_filter: output channels num. - :type conv_num_filter: int + :param conv_num_filter: list of output channels num. + :type conv_num_filter: list|tuple :param pool_size: pooling filter size. :type pool_size: int :param num_channels: input channels num. @@ -362,18 +364,18 @@ def img_conv_group(input, :type conv_filter_size: int :param conv_act: activation funciton after convolution. :type conv_act: BaseActivation - :param conv_with_batchnorm: conv_with_batchnorm[i] represents - if there is a batch normalization after each convolution. + :param conv_with_batchnorm: if conv_with_batchnorm[i] is true, + there is a batch normalization operation after each convolution. :type conv_with_batchnorm: list :param pool_stride: pooling stride size. :type pool_stride: int :param pool_type: pooling type. :type pool_type: BasePoolingType - :param param_attr: Convolution param attribute. - None means default attribute. + :param param_attr: param attribute of convolution layer, + None means default attribute. :type param_attr: ParameterAttribute - :return: Layer's output - :type: LayerOutput + :return: layer's output + :rtype: LayerOutput """ tmp = input @@ -466,12 +468,14 @@ def vgg_16_network(input_image, num_channels, num_classes=1000): """ Same model from https://gist.github.com/ksimonyan/211839e770f7b538e2d8 - :param num_classes: - :param input_image: + :param num_classes: number of class. + :type num_classes: int + :param input_image: input layer. :type input_image: LayerOutput - :param num_channels: + :param num_channels: input channels num. :type num_channels: int - :return: + :return: layer's output + :rtype: LayerOutput """ tmp = img_conv_group( @@ -560,8 +564,8 @@ def simple_lstm(input, """ Simple LSTM Cell. - It just combine a mixed layer with fully_matrix_projection and a lstmemory - layer. The simple lstm cell was implemented as follow equations. + It just combines a mixed layer with fully_matrix_projection and a lstmemory + layer. The simple lstm cell was implemented with follow equations. .. math:: @@ -575,37 +579,37 @@ def simple_lstm(input, h_t & = o_t tanh(c_t) - Please refer **Generating Sequences With Recurrent Neural Networks** if you - want to know what lstm is. Link_ is here. + Please refer to **Generating Sequences With Recurrent Neural Networks** for more + details about lstm. Link_ is here. .. _Link: http://arxiv.org/abs/1308.0850 :param name: lstm layer name. :type name: basestring - :param input: input layer name. + :param input: layer's input. :type input: LayerOutput :param size: lstm layer size. :type size: int - :param reverse: whether to process the input data in a reverse order + :param reverse: process the input in a reverse order or not. :type reverse: bool - :param mat_param_attr: mixed layer's matrix projection parameter attribute. + :param mat_param_attr: parameter attribute of matrix projection in mixed layer. :type mat_param_attr: ParameterAttribute :param bias_param_attr: bias parameter attribute. False means no bias, None means default bias. :type bias_param_attr: ParameterAttribute|False - :param inner_param_attr: lstm cell parameter attribute. + :param inner_param_attr: parameter attribute of lstm cell. :type inner_param_attr: ParameterAttribute - :param act: lstm final activiation type + :param act: last activiation type of lstm. :type act: BaseActivation - :param gate_act: lstm gate activiation type + :param gate_act: gate activiation type of lstm. :type gate_act: BaseActivation - :param state_act: lstm state activiation type. + :param state_act: state activiation type of lstm. :type state_act: BaseActivation - :param mixed_layer_attr: mixed layer's extra attribute. + :param mixed_layer_attr: extra attribute of mixed layer. :type mixed_layer_attr: ExtraLayerAttribute - :param lstm_cell_attr: lstm layer's extra attribute. + :param lstm_cell_attr: extra attribute of lstm. :type lstm_cell_attr: ExtraLayerAttribute - :return: lstm layer name. + :return: layer's output. :rtype: LayerOutput """ fc_name = 'lstm_transform_%s' % name @@ -643,9 +647,9 @@ def lstmemory_unit(input, lstm_bias_attr=None, lstm_layer_attr=None): """ - Define calculations that a LSTM unit performs during a single time step. - This function itself is not a recurrent layer, so it can not be - directly used to process sequence inputs. This function is always used in + lstmemory_unit defines the caculation process of a LSTM unit during a + single time step. This function is not a recurrent layer, so it can not be + directly used to process sequence input. This function is always used in recurrent_group (see layers.py for more details) to implement attention mechanism. @@ -676,7 +680,7 @@ def lstmemory_unit(input, state_act=TanhActivation()) - :param input: input layer name. + :param input: input layer. :type input: LayerOutput :param out_memory: output of previous time step :type out_memory: LayerOutput | None @@ -684,15 +688,15 @@ def lstmemory_unit(input, :type name: basestring :param size: lstmemory unit size. :type size: int - :param param_attr: Parameter config, None if use default. + :param param_attr: parameter attribute, None means default attribute. :type param_attr: ParameterAttribute - :param act: lstm final activiation type + :param act: last activiation type of lstm. :type act: BaseActivation - :param gate_act: lstm gate activiation type + :param gate_act: gate activiation type of lstm. :type gate_act: BaseActivation - :param state_act: lstm state activiation type. + :param state_act: state activiation type of lstm. :type state_act: BaseActivation - :param input_proj_bias_attr: bias attribute for input-to-hidden projection. + :param input_proj_bias_attr: bias attribute for input to hidden projection. False means no bias, None means default bias. :type input_proj_bias_attr: ParameterAttribute|False|None :param input_proj_layer_attr: extra layer attribute for input to hidden @@ -700,8 +704,8 @@ def lstmemory_unit(input, :type input_proj_layer_attr: ExtraLayerAttribute :param lstm_bias_attr: bias parameter attribute of lstm layer. False means no bias, None means default bias. - :type lstm_bias_attr: ParameterAttribute|False - :param lstm_layer_attr: lstm layer's extra attribute. + :type lstm_bias_attr: ParameterAttribute|False|None + :param lstm_layer_attr: extra attribute of lstm layer. :type lstm_layer_attr: ExtraLayerAttribute :return: lstmemory unit name. :rtype: LayerOutput @@ -758,9 +762,9 @@ def lstmemory_group(input, lstm_group is a recurrent_group version of Long Short Term Memory. It does exactly the same calculation as the lstmemory layer (see lstmemory in layers.py for the maths) does. A promising benefit is that LSTM memory - cell states, or hidden states in every time step are accessible to the + cell states(or hidden states) in every time step are accessible to the user. This is especially useful in attention model. If you do not need to - access the internal states of the lstm, but merely use its outputs, + access the internal states of the lstm and merely use its outputs, it is recommended to use the lstmemory, which is relatively faster than lstmemory_group. @@ -781,28 +785,28 @@ def lstmemory_group(input, gate_act=SigmoidActivation(), state_act=TanhActivation()) - :param input: input layer name. + :param input: input layer. :type input: LayerOutput :param size: lstmemory group size. :type size: int - :param name: name of the lstmemory group. + :param name: name of lstmemory group. :type name: basestring - :param out_memory: output of previous time step + :param out_memory: output of previous time step. :type out_memory: LayerOutput | None - :param reverse: is lstm reversed + :param reverse: process the input in a reverse order or not. :type reverse: bool - :param param_attr: Parameter config, None if use default. + :param param_attr: parameter attribute, None means default attribute. :type param_attr: ParameterAttribute - :param act: lstm final activiation type + :param act: last activiation type of lstm. :type act: BaseActivation - :param gate_act: lstm gate activiation type + :param gate_act: gate activiation type of lstm. :type gate_act: BaseActivation - :param state_act: lstm state activiation type. + :param state_act: state activiation type of lstm. :type state_act: BaseActivation :param lstm_bias_attr: bias parameter attribute of lstm layer. False means no bias, None means default bias. - :type lstm_bias_attr: ParameterAttribute|False - :param input_proj_bias_attr: bias attribute for input-to-hidden projection. + :type lstm_bias_attr: ParameterAttribute|False|None + :param input_proj_bias_attr: bias attribute for input to hidden projection. False means no bias, None means default bias. :type input_proj_bias_attr: ParameterAttribute|False|None :param input_proj_layer_attr: extra layer attribute for input to hidden @@ -848,15 +852,15 @@ def gru_unit(input, gru_layer_attr=None, naive=False): """ - Define calculations that a gated recurrent unit performs in a single time - step. This function itself is not a recurrent layer, so it can not be - directly used to process sequence inputs. This function is always used in + gru_unit defines the calculation process of a gated recurrent unit during a single + time step. This function is not a recurrent layer, so it can not be + directly used to process sequence input. This function is always used in the recurrent_group (see layers.py for more details) to implement attention mechanism. Please see grumemory in layers.py for the details about the maths. - :param input: input layer name. + :param input: input layer. :type input: LayerOutput :param memory_boot: the initialization state of the LSTM cell. :type memory_boot: LayerOutput | None @@ -864,12 +868,12 @@ def gru_unit(input, :type name: basestring :param size: hidden size of the gru. :type size: int - :param act: type of the activation + :param act: activation type of gru :type act: BaseActivation - :param gate_act: type of the gate activation + :param gate_act: gate activation type or gru :type gate_act: BaseActivation - :param gru_layer_attr: Extra parameter attribute of the gru layer. - :type gru_layer_attr: ParameterAttribute|False + :param gru_layer_attr: Extra attribute of the gru layer. + :type gru_layer_attr: ExtraLayerAttribute :return: the gru output layer. :rtype: LayerOutput """ @@ -915,7 +919,7 @@ def gru_group(input, does exactly the same calculation as the grumemory layer does. A promising benefit is that gru hidden states are accessible to the user. This is especially useful in attention model. If you do not need to access - any internal state, but merely use the outputs of a GRU, it is recommended + any internal state and merely use the outputs of a GRU, it is recommended to use the grumemory, which is relatively faster. Please see grumemory in layers.py for more detail about the maths. @@ -924,12 +928,12 @@ def gru_group(input, .. code-block:: python - gru = gur_group(input=[layer1], + gru = gru_group(input=[layer1], size=256, act=TanhActivation(), gate_act=SigmoidActivation()) - :param input: input layer name. + :param input: input layer. :type input: LayerOutput :param memory_boot: the initialization state of the LSTM cell. :type memory_boot: LayerOutput | None @@ -937,16 +941,17 @@ def gru_group(input, :type name: basestring :param size: hidden size of the gru. :type size: int - :param reverse: whether to process the input data in a reverse order + :param reverse: process the input in a reverse order or not. :type reverse: bool - :param act: type of the activiation + :param act: activiation type of gru :type act: BaseActivation - :param gate_act: type of the gate activiation + :param gate_act: gate activiation type of gru :type gate_act: BaseActivation - :param gru_bias_attr: bias. False means no bias, None means default bias. - :type gru_bias_attr: ParameterAttribute|False - :param gru_layer_attr: Extra parameter attribute of the gru layer. - :type gru_layer_attr: ParameterAttribute|False + :param gru_bias_attr: bias parameter attribute of gru layer, + False means no bias, None means default bias. + :type gru_bias_attr: ParameterAttribute|False|None + :param gru_layer_attr: Extra attribute of the gru layer. + :type gru_layer_attr: ExtraLayerAttribute :return: the gru group. :rtype: LayerOutput """ @@ -986,11 +991,11 @@ def simple_gru(input, gru_layer_attr=None, naive=False): """ - You maybe see gru_step_layer, grumemory in layers.py, gru_unit, gru_group, + You may see gru_step_layer, grumemory in layers.py, gru_unit, gru_group, simple_gru in network.py. The reason why there are so many interfaces is that we have two ways to implement recurrent neural network. One way is to use one complete layer to implement rnn (including simple rnn, gru and lstm) - with multiple time steps, such as recurrent_layer, lstmemory, grumemory. But, + with multiple time steps, such as recurrent_layer, lstmemory, grumemory. But the multiplication operation :math:`W x_t` is not computed in these layers. See details in their interfaces in layers.py. The other implementation is to use an recurrent group which can ensemble a @@ -1018,22 +1023,23 @@ def simple_gru(input, gru = simple_gru(input=[layer1], size=256) - :param input: input layer name. + :param input: input layer. :type input: LayerOutput :param name: name of the gru group. :type name: basestring :param size: hidden size of the gru. :type size: int - :param reverse: whether to process the input data in a reverse order + :param reverse: process the input in a reverse order or not. :type reverse: bool - :param act: type of the activiation + :param act: activiation type of gru :type act: BaseActivation - :param gate_act: type of the gate activiation + :param gate_act: gate activiation type of gru :type gate_act: BaseActivation - :param gru_bias_attr: bias. False means no bias, None means default bias. - :type gru_bias_attr: ParameterAttribute|False - :param gru_layer_attr: Extra parameter attribute of the gru layer. - :type gru_layer_attr: ParameterAttribute|False + :param gru_bias_attr: bias parameter attribute of gru layer, + False means no bias, None means default bias. + :type gru_bias_attr: ParameterAttribute|False|None + :param gru_layer_attr: Extra attribute of the gru layer. + :type gru_layer_attr: ExtraLayerAttribute :return: the gru group. :rtype: LayerOutput """ @@ -1071,8 +1077,8 @@ def simple_gru2(input, mixed_layer_attr=None, gru_cell_attr=None): """ - simple_gru2 is the same with simple_gru, but using grumemory instead - Please see grumemory in layers.py for more detail about the maths. + simple_gru2 is the same with simple_gru, but using grumemory instead. + Please refer to grumemory in layers.py for more detail about the math. simple_gru2 is faster than simple_gru. The example usage is: @@ -1081,22 +1087,23 @@ def simple_gru2(input, gru = simple_gru2(input=[layer1], size=256) - :param input: input layer name. + :param input: input layer. :type input: LayerOutput :param name: name of the gru group. :type name: basestring :param size: hidden size of the gru. :type size: int - :param reverse: whether to process the input data in a reverse order + :param reverse: process the input in a reverse order or not. :type reverse: bool - :param act: type of the activiation + :param act: activiation type of gru :type act: BaseActivation - :param gate_act: type of the gate activiation + :param gate_act: gate activiation type of gru :type gate_act: BaseActivation - :param gru_bias_attr: bias. False means no bias, None means default bias. - :type gru_bias_attr: ParameterAttribute|False - :param gru_layer_attr: Extra parameter attribute of the gru layer. - :type gru_layer_attr: ParameterAttribute|False + :param gru_bias_attr: bias parameter attribute of gru layer, + False means no bias, None means default bias. + :type gru_bias_attr: ParameterAttribute|False|None + :param gru_layer_attr: Extra attribute of the gru layer. + :type gru_layer_attr: ExtraLayerAttribute :return: the gru group. :rtype: LayerOutput """ @@ -1145,7 +1152,7 @@ def bidirectional_gru(input, concat_act=None): """ A bidirectional_gru is a recurrent unit that iterates over the input - sequence both in forward and bardward orders, and then concatenate two + sequence both in forward and backward orders, and then concatenate two outputs to form a final output. However, concatenation of two outputs is not the only way to form the final output, you can also, for example, just add them together. @@ -1162,11 +1169,10 @@ def bidirectional_gru(input, :type input: LayerOutput :param size: gru layer size. :type size: int - :param return_seq: If set False, outputs of the last time step are - concatenated and returned. - If set True, the entire output sequences that are - processed in forward and backward directions are + :param return_seq: If set False, the last time step of output are concatenated and returned. + If set True, the entire output sequences in forward + and backward directions are concatenated and returned. :type return_seq: bool :return: LayerOutput object. :rtype: LayerOutput @@ -1230,8 +1236,8 @@ def bidirectional_lstm(input, concat_act=None): """ A bidirectional_lstm is a recurrent unit that iterates over the input - sequence both in forward and bardward orders, and then concatenate two - outputs form a final output. However, concatenation of two outputs + sequence both in forward and backward orders, and then concatenate two + outputs to form a final output. However, concatenation of two outputs is not the only way to form the final output, you can also, for example, just add them together. @@ -1252,13 +1258,12 @@ def bidirectional_lstm(input, :type input: LayerOutput :param size: lstm layer size. :type size: int - :param return_seq: If set False, outputs of the last time step are - concatenated and returned. - If set True, the entire output sequences that are - processed in forward and backward directions are + :param return_seq: If set False, the last time step of output are concatenated and returned. + If set True, the entire output sequences in forward + and backward directions are concatenated and returned. :type return_seq: bool - :return: LayerOutput object accroding to the return_seq. + :return: LayerOutput object. :rtype: LayerOutput """ args = locals() @@ -1303,7 +1308,7 @@ def simple_attention(encoded_sequence, weight_act=None, name=None): """ - Calculate and then return a context vector by attention machanism. + Calculate and return a context vector with attention mechanism. Size of the context vector equals to size of the encoded_sequence. .. math:: @@ -1336,10 +1341,10 @@ def simple_attention(encoded_sequence, :param name: name of the attention model. :type name: basestring :param softmax_param_attr: parameter attribute of sequence softmax - that is used to produce attention weight + that is used to produce attention weight. :type softmax_param_attr: ParameterAttribute - :param weight_act: activation of the attention model - :type weight_act: Activation + :param weight_act: activation of the attention model. + :type weight_act: BaseActivation :param encoded_sequence: output of the encoder :type encoded_sequence: LayerOutput :param encoded_proj: attention weight is computed by a feed forward neural @@ -1411,7 +1416,7 @@ def inputs(layers, *args): def outputs(layers, *args): """ - Declare the outputs of network. If user have not defined the inputs of + Declare the outputs of network. If user has not defined the inputs of network, this method will calculate the input order by dfs travel. :param layers: Output layers. diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py index 15e0d125c495fbc0688d8dc4e66881cb9ab95a90..6cca41e43b38b8cccb65ff9b347ef226dddecd4d 100644 --- a/python/paddle/v2/framework/op.py +++ b/python/paddle/v2/framework/op.py @@ -215,5 +215,27 @@ class __RecurrentOp__(object): return core.RecurrentOp.create(proto.SerializeToString()) +class __CondOp__(object): + __proto__ = None + type = "cond" + + def __init__(self): + # cache recurrent_op's proto + if self.__proto__ is None: + for op_proto in get_all_op_protos(): + if op_proto.type == self.type: + self.__proto__ = op_proto + + def __call__(self, *args, **kwargs): + if self.type not in args and "type" not in kwargs: + kwargs["type"] = self.type + # create proto + create_method = OpDescCreationMethod(self.__proto__) + proto = create_method(*args, **kwargs) + # create condop + return core.CondOp.create(proto.SerializeToString()) + + Operator = OperatorFactory() # The default global factory RecurrentOp = __RecurrentOp__() +CondOp = __CondOp__() diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index 9936fd76baf3e64aed01b8ae1d54e50b39793925..6bbea22c5f147c8314c5d607f8e6953b470b5bd1 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -28,10 +28,10 @@ def create_op(scope, op_type, inputs, outputs, attrs): if out_name in outputs: kwargs[out_name] = [] if out_dup: - sub_in = outputs[out_name] - for sub_in_name, _ in sub_in: - var = scope.new_var(sub_in_name) - kwargs[out_name].append(sub_in_name) + sub_out = outputs[out_name] + for sub_out_name, _ in sub_out: + var = scope.new_var(sub_out_name) + kwargs[out_name].append(sub_out_name) else: var = scope.new_var(out_name) kwargs[out_name].append(out_name) @@ -39,6 +39,7 @@ def create_op(scope, op_type, inputs, outputs, attrs): for attr_name in Operator.get_op_attr_names(op_type): if attr_name in attrs: kwargs[attr_name] = attrs[attr_name] + return Operator(op_type, **kwargs) @@ -47,17 +48,24 @@ def set_input(scope, op, inputs, place): if in_name in inputs: if in_dup: sub_in = inputs[in_name] - for sub_in_name, sub_in_array in sub_in: + for sub_in_name, sub_in_val in sub_in: var = scope.find_var(sub_in_name) tensor = var.get_tensor() + sub_in_array = sub_in_val[0] \ + if isinstance(sub_in_val, tuple) else sub_in_val tensor.set_dims(sub_in_array.shape) tensor.set(sub_in_array, place) + if isinstance(sub_in_val, tuple): + tensor.set_lod(sub_in_val[1]) else: var = scope.find_var(in_name) tensor = var.get_tensor() - arr = inputs[in_name] - tensor.set_dims(arr.shape) - tensor.set(arr, place) + in_val = inputs[in_name] + in_array = in_val[0] if isinstance(in_val, tuple) else in_val + tensor.set_dims(in_array.shape) + tensor.set(in_array, place) + if isinstance(in_val, tuple): + tensor.set_lod(in_val[1]) def set_output_grad(scope, op, outputs, place): @@ -172,8 +180,9 @@ class OpTest(unittest.TestCase): def check_output_with_place(self, place): self.scope = core.Scope() op_inputs = self.inputs if hasattr(self, "inputs") else dict() + op_outputs = self.outputs if hasattr(self, "outputs") else dict() op_attrs = self.attrs if hasattr(self, "attrs") else dict() - self.op = create_op(self.scope, self.op_type, op_inputs, self.outputs, + self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs, op_attrs) if isinstance(place, core.GPUPlace) and not self.op.support_gpu(): return @@ -185,21 +194,26 @@ class OpTest(unittest.TestCase): for out_name, out_dup in Operator.get_op_outputs(self.op.type()): if out_dup: sub_out = self.outputs[out_name] - for sub_out_name in sub_out: + if not isinstance(sub_out, list): + raise AssertionError("sub_out type %s is not list", + type(sub_out)) + + for sub_out_name, expect in sub_out: actual = np.array( self.scope.find_var(sub_out_name).get_tensor()) - expect = sub_out[sub_out_name] self.assertTrue( np.allclose( actual, expect, atol=1e-05), - "output name: " + out_name + "has diff") + "output name: " + out_name + " has diff") else: - actual = np.array(self.scope.find_var(out_name).get_tensor()) - expect = self.outputs[out_name] - self.assertTrue( - np.allclose( - actual, expect, atol=1e-05), - "output name: " + out_name + "has diff") + var = self.scope.find_var(out_name) + if var is not None: + actual = np.array(var.get_tensor()) + expect = self.outputs[out_name] + self.assertTrue( + np.allclose( + actual, expect, atol=1e-05), + "output name: " + out_name + " has diff") def check_output(self): places = [core.CPUPlace()] @@ -234,8 +248,9 @@ class OpTest(unittest.TestCase): max_relative_error=0.005): self.scope = core.Scope() op_inputs = self.inputs if hasattr(self, "inputs") else dict() + op_outputs = self.outputs if hasattr(self, "outputs") else dict() op_attrs = self.attrs if hasattr(self, "attrs") else dict() - self.op = create_op(self.scope, self.op_type, op_inputs, self.outputs, + self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs, op_attrs) if no_grad_set is None: no_grad_set = set() diff --git a/python/paddle/v2/framework/tests/test_accuracy_op.py b/python/paddle/v2/framework/tests/test_accuracy_op.py new file mode 100644 index 0000000000000000000000000000000000000000..b6f3a35d6f58ba90b39e3f6296ae635220a2e965 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_accuracy_op.py @@ -0,0 +1,26 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestAccuracyOp(OpTest): + def setUp(self): + self.op_type = "accuracy" + n = 8192 + infer = np.random.randint(0, 2, (n, 1)).astype("int") + label = np.random.randint(0, 2, (n, )).astype("int") + self.inputs = {'Inference': infer, "Label": label} + num_correct = 0 + for rowid in xrange(n): + for ele in infer[rowid]: + if ele == label[rowid]: + num_correct += 1 + break + self.outputs = {'Accuracy': [num_correct / float(n)]} + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_add_two_op.py b/python/paddle/v2/framework/tests/test_add_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_add_two_op.py rename to python/paddle/v2/framework/tests/test_add_op.py diff --git a/python/paddle/v2/framework/tests/test_cond_op.py b/python/paddle/v2/framework/tests/test_cond_op.py new file mode 100644 index 0000000000000000000000000000000000000000..37177ae0b2482517c4183969c8ef0670f2b3de89 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_cond_op.py @@ -0,0 +1,116 @@ +import logging +import paddle.v2.framework.core as core +import unittest +import numpy as np +from paddle.v2.framework.op import Operator, CondOp + + +class PySimpleCond(object): + ''' + A simple implementation of dynamic if-else based on numpy + ''' + + def __init__(self): + array = [1] * 10 + for i in range(1, 10, 2): + array[i] = 0 + self.cond = np.array(array) + self.x = np.ones(shape=(10, 1)) + + def forward(self): + self.index_t = np.where(self.cond == 1) + self.index_f = np.where(self.cond == 0) + y_t = self.x[self.index_t] + y_f = self.x[self.index_f] + y_t = y_t * 2. + y_f = y_f * (-2.) + output = np.zeros(shape=(10, 1)) + output[self.index_t] = y_t + output[self.index_f] = y_f + return output + + +class PySimpleCondTest(unittest.TestCase): + def setUp(self): + self.condnn = PySimpleCond() + + def test_forward(self): + output = self.condnn.forward() + + +def create_tensor(scope, name, shape, np_data): + tensor = scope.new_var(name).get_tensor() + tensor.set_dims(shape) + tensor.set(np_data, core.CPUPlace()) + return tensor + + +class TestCondOp(unittest.TestCase): + ''' + Test CondOp + + equation: + cond = [True, False, True, False, ...] + y[index_t] = x[index_t] * 2. + y[index_f] = x[index_f] * -2. + outputs: + y + ''' + + def setUp(self): + self.py_cond = PySimpleCond() + + def forward(self): + self.scope = core.Scope() + self.create_global_variables() + self.create_cond_op() + self.create_sub_net() + ctx = core.DeviceContext.create(core.CPUPlace()) + self.condop.infer_shape(self.scope) + self.condop.run(self.scope, ctx) + return np.array(self.scope.find_var("Out").get_tensor()) + + def create_global_variables(self): + x_np_data = self.py_cond.x + create_tensor(self.scope, "X", [10, 1], x_np_data) + cond_np_data = self.py_cond.cond.astype("int32") + create_tensor(self.scope, "cond", [10, 1], cond_np_data) + self.scope.new_var("SubScopes") + self.scope.new_var("IndexTensors") + self.scope.new_var("Out") + + def create_cond_op(self): + self.condop = CondOp( + Cond="cond", + Xs=["X"], + Outs=["Out"], + SubScopes="SubScopes", + IndexTensors="IndexTensors") + + def create_sub_net(self): + truenet = core.Net.create() + scale_op_t = Operator("scale", X='X', Out='Out', scale=2.) + truenet.append_op(scale_op_t) + truenet.complete_add_op(True) + self.condop.set_truenet(truenet) + + falsenet = core.Net.create() + scale_op_t = Operator("scale", X='X', Out='Out', scale=-2.) + falsenet.append_op(scale_op_t) + falsenet.complete_add_op(True) + self.condop.set_falsenet(falsenet) + + def test_forward(self): + print 'test cond op forward' + pd_output = self.forward() + py_output = self.py_cond.forward() + print 'pd_output', pd_output + print + print 'py_output', py_output + self.assertEqual(pd_output.shape, py_output.shape) + print 'test passed' + return 0 + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_cos_sim_op.py b/python/paddle/v2/framework/tests/test_cos_sim_op.py index 797cbd8cc5cf7f73d58ca713d02667731d5c8a0e..d314ce391ea2f10a8bd77c24e84fa3e1eebb6c73 100644 --- a/python/paddle/v2/framework/tests/test_cos_sim_op.py +++ b/python/paddle/v2/framework/tests/test_cos_sim_op.py @@ -7,8 +7,8 @@ class TestCosSimOp(OpTest): def setUp(self): self.op_type = "cos_sim" self.inputs = { - 'X': np.random.random((10, 5)).astype("float32"), - 'Y': np.random.random((10, 5)).astype("float32") + 'X': np.random.random((6, 5)).astype("float32"), + 'Y': np.random.random((6, 5)).astype("float32") } expect_x_norm = np.linalg.norm(self.inputs['X'], axis=1) expect_y_norm = np.linalg.norm(self.inputs['Y'], axis=1) @@ -28,12 +28,66 @@ class TestCosSimOp(OpTest): def test_check_grad_ingore_x(self): self.check_grad( - ['Y'], 'Out', max_relative_error=0.05, no_grad_set=set('X')) + ['Y'], 'Out', max_relative_error=0.05, no_grad_set=set("X")) - def test_check_grad_ignore_y(self): + def test_check_grad_ingore_y(self): self.check_grad( ['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Y')) -if __name__ == "__main__": +class TestCosSimOp2(TestCosSimOp): + def setUp(self): + self.op_type = "cos_sim" + self.inputs = { + 'X': np.random.random((6, 5)).astype("float32"), + 'Y': np.random.random((1, 5)).astype("float32") + } + expect_x_norm = np.linalg.norm(self.inputs['X'], axis=1) + expect_y_norm = np.linalg.norm(self.inputs['Y'], axis=1) + expect_out = (self.inputs['X'] * self.inputs['Y']).sum(axis=1) / \ + expect_x_norm / expect_y_norm + self.outputs = { + 'XNorm': np.expand_dims(expect_x_norm, 1), + 'YNorm': np.expand_dims(expect_y_norm, 1), + 'Out': np.expand_dims(expect_out, 1) + } + + +class TestCosSimOp3(TestCosSimOp): + def setUp(self): + self.op_type = "cos_sim" + self.inputs = { + 'X': np.random.random((6, 5, 2)).astype("float32"), + 'Y': np.random.random((6, 5, 2)).astype("float32") + } + expect_x_norm = np.linalg.norm(self.inputs['X'], axis=(1, 2)) + expect_y_norm = np.linalg.norm(self.inputs['Y'], axis=(1, 2)) + expect_out = (self.inputs['X'] * self.inputs['Y']).sum(axis=(1, 2)) / \ + expect_x_norm / expect_y_norm + self.outputs = { + 'XNorm': np.expand_dims(expect_x_norm, 1), + 'YNorm': np.expand_dims(expect_y_norm, 1), + 'Out': np.expand_dims(expect_out, 1) + } + + +class TestCosSimOp4(TestCosSimOp): + def setUp(self): + self.op_type = "cos_sim" + self.inputs = { + 'X': np.random.random((6, 5, 2)).astype("float32"), + 'Y': np.random.random((1, 5, 2)).astype("float32") + } + expect_x_norm = np.linalg.norm(self.inputs['X'], axis=(1, 2)) + expect_y_norm = np.linalg.norm(self.inputs['Y'], axis=(1, 2)) + expect_out = (self.inputs['X'] * self.inputs['Y']).sum(axis=(1, 2)) / \ + expect_x_norm / expect_y_norm + self.outputs = { + 'XNorm': np.expand_dims(expect_x_norm, 1), + 'YNorm': np.expand_dims(expect_y_norm, 1), + 'Out': np.expand_dims(expect_out, 1) + } + + +if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_fc_op.py b/python/paddle/v2/framework/tests/test_fc_op.py new file mode 100644 index 0000000000000000000000000000000000000000..9f56fe5049c66aa5fce40ce815105e7871ebc3b2 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_fc_op.py @@ -0,0 +1,62 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestFCOp1(OpTest): + def setUp(self): + x0 = np.random.random((16, 32)).astype("float32") + w0 = np.random.random((32, 10)).astype("float32") + + mul_out0 = np.dot(x0, w0) + identity_out = mul_out0 + + self.op_type = "fc" + self.inputs = {"X": [("X0", x0)], "W": [("W0", w0)]} + self.outputs = {"MulOut": [("MulOut0", mul_out0)], "Out": identity_out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X0", "W0"], "Out", max_relative_error=0.01) + + +class TestFCOp2(OpTest): + def setUp(self): + x0 = np.random.random((16, 4, 8)).astype("float32") + x1 = np.random.random((4, 4, 32)).astype("float32") + w0 = np.random.random((32, 10)).astype("float32") + w1 = np.random.random((32, 10)).astype("float32") + b = np.random.random(10).astype("float32") + + mul_out0 = np.dot(x0.reshape(16, 4 * 8), w0) + mul_out1 = np.dot(x1.reshape(4 * 4, 32), w1) + sum_out = mul_out0 + mul_out1 + add_out = np.add(sum_out, b) + sigmoid_out = 1 / (1 + np.exp(-add_out)) + + self.op_type = "fc" + self.inputs = { + "X": [("X0", x0), ("X1", x1)], + "W": [("W0", w0), ("W1", w1)], + "B": b + } + self.attrs = {"xNumColDims": [1, 2], "activation": "sigmoid"} + self.outputs = { + "MulOut": [("MulOut0", mul_out0), ("MulOut1", mul_out1)], + "SumOut": sum_out, + "AddOut": add_out, + "Out": sigmoid_out + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ["X0", "X1", "W0", "W1", "B"], "Out", max_relative_error=0.01) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_gaussian_random_op.py b/python/paddle/v2/framework/tests/test_gaussian_random_op.py index 1f9e4db783c9907a22db72c8a6ff06c7ca0735da..1888ee28f92c66496ce756d8a4a33d3e9ba57d7b 100644 --- a/python/paddle/v2/framework/tests/test_gaussian_random_op.py +++ b/python/paddle/v2/framework/tests/test_gaussian_random_op.py @@ -4,7 +4,7 @@ from paddle.v2.framework.op import Operator import numpy -class GaussianRandomTest(unittest.TestCase): +class TestGaussianRandomOp(unittest.TestCase): def test_cpu(self): self.gaussian_random_test(place=core.CPUPlace()) diff --git a/python/paddle/v2/framework/tests/test_identity_op.py b/python/paddle/v2/framework/tests/test_identity_op.py new file mode 100644 index 0000000000000000000000000000000000000000..26cec1fcc3ad003281c9c41571d475b55bd30026 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_identity_op.py @@ -0,0 +1,20 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestIdentityOp(OpTest): + def setUp(self): + self.op_type = "identity" + self.inputs = {'X': np.random.random((10, 10)).astype("float32")} + self.outputs = {'Y': self.inputs['X']} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_lookup_table.py b/python/paddle/v2/framework/tests/test_lookup_table_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_lookup_table.py rename to python/paddle/v2/framework/tests/test_lookup_table_op.py diff --git a/python/paddle/v2/framework/tests/test_minus_op.py b/python/paddle/v2/framework/tests/test_minus_op.py index dea797a1fea34265d0a32e097f413f421abf2521..c56d7cb548706880dd482bad750f2989c0e9a710 100644 --- a/python/paddle/v2/framework/tests/test_minus_op.py +++ b/python/paddle/v2/framework/tests/test_minus_op.py @@ -3,7 +3,7 @@ import numpy as np from op_test import OpTest -class MinusOpTest(OpTest): +class TestMinusOp(OpTest): def setUp(self): self.op_type = "minus" self.inputs = { diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_onehot_cross_entropy_op.py similarity index 52% rename from python/paddle/v2/framework/tests/test_cross_entropy_op.py rename to python/paddle/v2/framework/tests/test_onehot_cross_entropy_op.py index c2fc102a8b8de82da5c3fc5fee273790325908f8..fd3cbdb80374865ccf113768856096bf49dce643 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_onehot_cross_entropy_op.py @@ -3,25 +3,27 @@ import numpy from op_test import OpTest -class TestCrossEntropy(OpTest): +class TestOnehotCrossEntropyOp(OpTest): def setUp(self): self.op_type = "onehot_cross_entropy" batch_size = 30 class_num = 10 + X = numpy.random.uniform(0.1, 1.0, [batch_size, class_num]).astype("float32") - label = (class_num / 2) * numpy.ones(batch_size).astype("int32") - self.inputs = {'X': X, 'label': label} - Y = [] - for i in range(0, batch_size): - Y.append(-numpy.log(X[i][label[i]])) - self.outputs = {'Y': numpy.array(Y).astype("float32")} + labels = numpy.random.randint(0, class_num, batch_size, dtype="int32") + + cross_entropy = numpy.asmatrix( + [[-numpy.log(X[i][labels[i]])] for i in range(X.shape[0])], + dtype="float32") + self.inputs = {"X": X, "label": labels} + self.outputs = {"Y": cross_entropy} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y') + self.check_grad(["X"], "Y") if __name__ == "__main__": diff --git a/python/paddle/v2/framework/tests/test_pad_op.py b/python/paddle/v2/framework/tests/test_pad_op.py index 456b765e331fc4c80e6fd817c88d7ec533158ecb..9052e63b5683801da7c73be4de23013c949add98 100644 --- a/python/paddle/v2/framework/tests/test_pad_op.py +++ b/python/paddle/v2/framework/tests/test_pad_op.py @@ -22,7 +22,7 @@ class TestPadOp(OpTest): self.check_output() def test_check_grad_normal(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', max_relative_error=0.006) def initTestCase(self): self.shape = (16, 16) diff --git a/python/paddle/v2/framework/tests/test_scale_and_identity_op.py b/python/paddle/v2/framework/tests/test_scale_op.py similarity index 56% rename from python/paddle/v2/framework/tests/test_scale_and_identity_op.py rename to python/paddle/v2/framework/tests/test_scale_op.py index 05d76d428299c8176d1a6adf6da15a203fa7502a..2ea1e185470280730ae8c8c0ea9568bbeb43eaf5 100644 --- a/python/paddle/v2/framework/tests/test_scale_and_identity_op.py +++ b/python/paddle/v2/framework/tests/test_scale_op.py @@ -3,20 +3,7 @@ import numpy as np from op_test import OpTest -class IdentityTest(OpTest): - def setUp(self): - self.op_type = "identity" - self.inputs = {'X': np.random.random((10, 10)).astype("float32")} - self.outputs = {'Out': self.inputs['X']} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out') - - -class ScaleTest(OpTest): +class TestScaleOp(OpTest): def setUp(self): self.op_type = "scale" self.inputs = {'X': np.random.random((10, 10)).astype("float32")} diff --git a/python/paddle/v2/framework/tests/test_seq_pool.py b/python/paddle/v2/framework/tests/test_seq_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..cf864936af6361da1f16df3cfb759b468214b970 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_seq_pool.py @@ -0,0 +1,51 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestSeqAvgPool1D(OpTest): + def setUp(self): + self.op_type = 'sequence_avg_pool' + # one level, batch size is 4 + x = np.random.uniform(0.1, 1, [11, 23]).astype('float32') + lod = [[0, 4, 5, 8, 11]] + + out = np.zeros((4, 23)).astype('float32') + for i in range(4): + sub_x = x[lod[0][i]:lod[0][i + 1], :] + out[i] = sub_x.mean(axis=0) + + self.inputs = {'X': (x, lod)} + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestSeqAvgPool2D(OpTest): + def setUp(self): + self.op_type = 'sequence_avg_pool' + # one level, batch size is 4 + x = np.random.uniform(0.1, 1, [13, 3, 17]).astype('float32') + lod = [[0, 4, 5, 8, 13]] + + out = np.zeros((4, 3, 17)).astype('float32') + for i in range(4): + sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) + out[i] = np.reshape(sub_x.mean(axis=0), (3, 17)) + + self.inputs = {'X': (x, lod)} + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_sgd_op.py b/python/paddle/v2/framework/tests/test_sgd_op.py index 557cf15ace63e336462c7dcdbbc10f30aeedc6f4..64e54d1500c1bc134cc1efe33d41a16dbc08f2d4 100644 --- a/python/paddle/v2/framework/tests/test_sgd_op.py +++ b/python/paddle/v2/framework/tests/test_sgd_op.py @@ -3,7 +3,7 @@ import numpy as np from op_test import OpTest -class TestSGD(OpTest): +class TestSGDOp(OpTest): def setUp(self): self.op_type = "sgd" w = np.random.random((102, 105)).astype("float32") diff --git a/python/paddle/v2/framework/tests/test_sigmoid_op.py b/python/paddle/v2/framework/tests/test_sigmoid_op.py index 2316e49eff7bb1cdb53acb3889a6ef05060b59f3..d65d887db4af58c40e4e78fdbfd8e8ee668b7ee3 100644 --- a/python/paddle/v2/framework/tests/test_sigmoid_op.py +++ b/python/paddle/v2/framework/tests/test_sigmoid_op.py @@ -3,7 +3,7 @@ import numpy as np from op_test import OpTest -class TestSigmoid(OpTest): +class TestSigmoidOp(OpTest): def setUp(self): self.op_type = "sigmoid" self.inputs = { diff --git a/python/paddle/v2/framework/tests/test_split_op.py b/python/paddle/v2/framework/tests/test_split_op.py new file mode 100644 index 0000000000000000000000000000000000000000..b4420db9d71b99556e305104ac17ef5e4b4bd0f2 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_split_op.py @@ -0,0 +1,26 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestSplitOp(OpTest): + def setUp(self): + self.op_type = "split" + axis = 0 + num = 2 + x = np.random.random((4, 2)).astype('float32') + out = np.split(x, num, axis) + self.inputs = {'X': x} + self.attrs = {'axis': axis, 'num': num} + self.outputs = {'Out': [('out%d' % i, out[i]) \ + for i in xrange(len(out))]} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], ['out0', 'out1']) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_tensor.py b/python/paddle/v2/framework/tests/test_tensor.py index f26ed4964c521be1cd839b39d7244f96c653cb1a..8cd93b35d7d1cb7d3b4a19e0e402ef576f1c0982 100644 --- a/python/paddle/v2/framework/tests/test_tensor.py +++ b/python/paddle/v2/framework/tests/test_tensor.py @@ -44,79 +44,66 @@ class TestTensor(unittest.TestCase): self.assertAlmostEqual(2.0, tensor_array_2[19, 11]) def test_int_lod_tensor(self): - places = [core.CPUPlace(), core.GPUPlace(0)] - for place in places: - scope = core.Scope() - var = scope.new_var("test_tensor") - var_lod = scope.new_var("test_lod_tensor") - - tensor = var.get_tensor() - lod_tensor = var_lod.get_lod_tensor() - - tensor.set_dims([4, 4, 6]) - tensor.alloc_int(place) - array = numpy.array(tensor) - array[0, 0, 0] = 3 - array[3, 3, 5] = 10 - tensor.set(array, place) + place = core.CPUPlace() + scope = core.Scope() + var_lod = scope.new_var("test_lod_tensor") + lod_tensor = var_lod.get_tensor() - lod_tensor.set_tensor(tensor) - lod_tensor.set_lod([[0, 2, 4]]) + lod_tensor.set_dims([4, 4, 6]) + lod_tensor.alloc_int(place) + array = numpy.array(lod_tensor) + array[0, 0, 0] = 3 + array[3, 3, 5] = 10 + lod_tensor.set(array, place) + lod_tensor.set_lod([[0, 2, 4]]) - lod_v = numpy.array(lod_tensor.tensor()) - self.assertTrue(numpy.alltrue(array == lod_v)) + lod_v = numpy.array(lod_tensor) + self.assertTrue(numpy.alltrue(array == lod_v)) - lod = lod_tensor.lod() - self.assertEqual(0, lod[0][0]) - self.assertEqual(2, lod[0][1]) - self.assertEqual(4, lod[0][2]) + lod = lod_tensor.lod() + self.assertEqual(0, lod[0][0]) + self.assertEqual(2, lod[0][1]) + self.assertEqual(4, lod[0][2]) def test_float_lod_tensor(self): - places = [core.CPUPlace(), core.GPUPlace(0)] - for place in places: - scope = core.Scope() - var = scope.new_var("test_tensor") - var_lod = scope.new_var("test_lod_tensor") - - tensor = var.get_tensor() - lod_tensor = var_lod.get_lod_tensor() - - tensor.set_dims([5, 2, 3, 4]) - tensor.alloc_float(place) + place = core.CPUPlace() + scope = core.Scope() + var_lod = scope.new_var("test_lod_tensor") - tensor_array = numpy.array(tensor) - self.assertEqual((5, 2, 3, 4), tensor_array.shape) - tensor_array[0, 0, 0, 0] = 1.0 - tensor_array[0, 0, 0, 1] = 2.0 - tensor.set(tensor_array, place) + lod_tensor = var_lod.get_tensor() + lod_tensor.set_dims([5, 2, 3, 4]) + lod_tensor.alloc_float(place) - lod_tensor.set_tensor(tensor) + tensor_array = numpy.array(lod_tensor) + self.assertEqual((5, 2, 3, 4), tensor_array.shape) + tensor_array[0, 0, 0, 0] = 1.0 + tensor_array[0, 0, 0, 1] = 2.0 + lod_tensor.set(tensor_array, place) - lod_v = numpy.array(lod_tensor.tensor()) - self.assertAlmostEqual(1.0, lod_v[0, 0, 0, 0]) - self.assertAlmostEqual(2.0, lod_v[0, 0, 0, 1]) - self.assertEqual(len(lod_tensor.lod()), 0) + lod_v = numpy.array(lod_tensor) + self.assertAlmostEqual(1.0, lod_v[0, 0, 0, 0]) + self.assertAlmostEqual(2.0, lod_v[0, 0, 0, 1]) + self.assertEqual(len(lod_tensor.lod()), 0) - lod_py = [[0, 2, 5], [0, 2, 4, 5]] - lod_tensor.set_lod(lod_py) - lod = lod_tensor.lod() - self.assertListEqual(lod_py, lod) + lod_py = [[0, 2, 5], [0, 2, 4, 5]] + lod_tensor.set_lod(lod_py) + lod = lod_tensor.lod() + self.assertListEqual(lod_py, lod) def test_lod_tensor_init(self): scope = core.Scope() - var = scope.new_var("test_tensor") place = core.CPUPlace() - tensor = var.get_tensor() - tensor.set_dims([5, 2, 3, 4]) - tensor.alloc_float(place) - tensor_array = numpy.array(tensor) + lod_py = [[0, 2, 5], [0, 2, 4, 5]] + lod_tensor = core.LoDTensor(lod_py) + + lod_tensor.set_dims([5, 2, 3, 4]) + lod_tensor.alloc_float(place) + tensor_array = numpy.array(lod_tensor) tensor_array[0, 0, 0, 0] = 1.0 tensor_array[0, 0, 0, 1] = 2.0 - tensor.set(tensor_array, place) - lod_py = [[0, 2, 5], [0, 2, 4, 5]] + lod_tensor.set(tensor_array, place) - lod_tensor = core.LoDTensor(lod_py, tensor) - lod_v = numpy.array(lod_tensor.tensor()) + lod_v = numpy.array(lod_tensor) self.assertAlmostEqual(1.0, lod_v[0, 0, 0, 0]) self.assertAlmostEqual(2.0, lod_v[0, 0, 0, 1]) self.assertListEqual(lod_py, lod_tensor.lod()) diff --git a/python/paddle/v2/framework/tests/test_top_k_op.py b/python/paddle/v2/framework/tests/test_top_k_op.py index cab799256d791889c295aa7f9048080f5caaf2dc..694f37d612d4c46e673dc894b05a0a446190732c 100644 --- a/python/paddle/v2/framework/tests/test_top_k_op.py +++ b/python/paddle/v2/framework/tests/test_top_k_op.py @@ -21,6 +21,9 @@ class TestTopkOp(OpTest): self.outputs = {'Out': output, 'Indices': indices} + def test_check_output(self): + self.check_output() + class TestTopkOp3d(OpTest): def setUp(self): @@ -42,6 +45,9 @@ class TestTopkOp3d(OpTest): self.outputs = {'Out': output, 'Indices': indices} + def test_check_output(self): + self.check_output() + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/v2/framework/tests/test_uniform_random_op.py b/python/paddle/v2/framework/tests/test_uniform_random_op.py index 76a5e36e56ab08230bdc2597d209fcf5d1d2acb0..9e8898fb5920defdfaa361bf45def7666a88beea 100644 --- a/python/paddle/v2/framework/tests/test_uniform_random_op.py +++ b/python/paddle/v2/framework/tests/test_uniform_random_op.py @@ -4,7 +4,7 @@ import paddle.v2.framework.core as core import numpy -class UniformRandomTest(unittest.TestCase): +class TestUniformRandomOp(unittest.TestCase): def test_uniform_random_cpu(self): self.uniform_random_test(place=core.CPUPlace())