diff --git a/cmake/configure.cmake b/cmake/configure.cmake index c1c93e17fd82ea048ba27b127b1527d9a8c9da41..db8f5ab0456792f903093b9cf20e2541f00add5c 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -24,6 +24,10 @@ if(WITH_DOUBLE) add_definitions(-DPADDLE_TYPE_DOUBLE) endif(WITH_DOUBLE) +if(WITH_TESTING) + add_definitions(-DPADDLE_WITH_TESTING) +endif(WITH_TESTING) + if(NOT WITH_TIMER) add_definitions(-DPADDLE_DISABLE_TIMER) endif(NOT WITH_TIMER) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index ff9868fc4e0d970b11e4763d2e0c8581f4f85907..c311783aa3187678c31c27ddbbd074790ca444f3 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -389,13 +389,60 @@ function(go_test TARGET_NAME) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) endfunction(go_test) +# Modification of standard 'protobuf_generate_cpp()' with protobuf-lite support +# Usage: +# paddle_protobuf_generate_cpp( ) + +function(paddle_protobuf_generate_cpp SRCS HDRS) + if(NOT ARGN) + message(SEND_ERROR "Error: paddle_protobuf_generate_cpp() called without any proto files") + return() + endif() + + set(${SRCS}) + set(${HDRS}) + + if (MOBILE_INFERENCE) + set(EXTRA_FLAG "lite:") + else() + set(EXTRA_FLAG "") + endif() + + foreach(FIL ${ARGN}) + get_filename_component(ABS_FIL ${FIL} ABSOLUTE) + get_filename_component(FIL_WE ${FIL} NAME_WE) + + set(_protobuf_protoc_src "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.cc") + set(_protobuf_protoc_hdr "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.h") + list(APPEND ${SRCS} "${_protobuf_protoc_src}") + list(APPEND ${HDRS} "${_protobuf_protoc_hdr}") + + add_custom_command( + OUTPUT "${_protobuf_protoc_src}" + "${_protobuf_protoc_hdr}" + + COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_CURRENT_BINARY_DIR}" + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} + -I${CMAKE_CURRENT_SOURCE_DIR} + --cpp_out "${EXTRA_FLAG}${CMAKE_CURRENT_BINARY_DIR}" ${ABS_FIL} + DEPENDS ${ABS_FIL} protoc + COMMENT "Running C++ protocol buffer compiler on ${FIL}" + VERBATIM ) + endforeach() + + set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE) + set(${SRCS} ${${SRCS}} PARENT_SCOPE) + set(${HDRS} ${${HDRS}} PARENT_SCOPE) +endfunction() + + function(proto_library TARGET_NAME) set(oneValueArgs "") set(multiValueArgs SRCS DEPS) cmake_parse_arguments(proto_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) set(proto_srcs) set(proto_hdrs) - protobuf_generate_cpp(proto_srcs proto_hdrs ${proto_library_SRCS}) + paddle_protobuf_generate_cpp(proto_srcs proto_hdrs ${proto_library_SRCS}) cc_library(${TARGET_NAME} SRCS ${proto_srcs} DEPS ${proto_library_DEPS} protobuf) endfunction() diff --git a/doc/design/block.md b/doc/design/block.md index 4d5dd4ba95a686d18b2339c69f0316c340681909..9c812732d6ead76eb3aa2d1b617449c96807f21a 100644 --- a/doc/design/block.md +++ b/doc/design/block.md @@ -5,12 +5,12 @@ Both deep learning systems and programming languages help users describe computation procedures. These systems use various representations of computation: - Caffe, Torch, and Paddle: sequences of layers. -- TensorFlow, Caffe2, Mxnet: graphs of operators. +- TensorFlow, Caffe2, Mxnet: graph of operators. - PaddlePaddle: nested blocks, like C++ and Java programs. ## Block in Programming Languages and Deep Learning -In programming languages, a block is a pair of curly braces that includes local variables definitions and a sequence of instructions, or operators. +In programming languages, a block is a pair of curly braces that includes local variables definitions and a sequence of instructions or operators. Blocks work with control flow structures like `if`, `else`, and `for`, which have equivalents in deep learning: @@ -24,14 +24,14 @@ A key difference is that a C++ program describes a one pass computation, whereas ## Stack Frames and the Scope Hierarchy -The existence of the backward makes the execution of a block of traditional programs and PaddlePaddle different to each other: +The existence of the backward pass makes the execution of a block of PaddlePaddle different from traditional programs: -| programming languages | PaddlePaddle | -|-----------------------|-------------------------------| -| stack | scope hierarchy | -| stack frame | scope | -| push at entering block| push at entering block | -| pop at leaving block | destroy at minibatch completes| +| programming languages | PaddlePaddle | +|-----------------------|---------------------------------| +| stack | scope hierarchy | +| stack frame | scope | +| push at entering block| push at entering block | +| pop at leaving block | destroy when minibatch completes| 1. In traditional programs: @@ -42,9 +42,9 @@ The existence of the backward makes the execution of a block of traditional prog 1. In PaddlePaddle - When the execution enters a block, PaddlePaddle adds a new scope, where it realizes variables. - - PaddlePaddle doesn't pop a scope after the execution of the block because variables therein are to be used by the backward pass. So it has a stack forest known as a *scope hierarchy*. + - PaddlePaddle doesn't pop a scope after the execution of the block because variables therein are used by the backward pass. So it has a stack forest known as a *scope hierarchy*. - The height of the highest tree is the maximum depth of nested blocks. - - After the process of a minibatch, PaddlePaddle destroys the scope hierarchy. + - After the processing of a minibatch, PaddlePaddle destroys the scope hierarchy. ## Use Blocks in C++ and PaddlePaddle Programs @@ -94,14 +94,14 @@ with ie.false_block(): o1, o2 = ie(cond) ``` -In both examples, the left branch computes `x+y` and `softmax(x+y)`, the right branch computes `x+1` and `fc(x)`. +In both examples, the left branch computes `x+y` and `softmax(x+y)`, the right branch computes `fc(x)` and `x+1` . -A difference is that variables in the C++ program contain scalar values, whereas those in the PaddlePaddle programs are mini-batches of instances. The `ie.input(true, 0)` invocation returns instances in the 0-th input, `x`, that corresponds to true values in `cond` as the local variable `x`, where `ie.input(false, 0)` returns instances corresponding to false values. +The difference is that variables in the C++ program contain scalar values, whereas those in the PaddlePaddle programs are mini-batches of instances. ### Blocks with `for` and `RNNOp` -The following RNN model from the [RNN design doc](./rnn.md) +The following RNN model in PaddlePaddle from the [RNN design doc](./rnn.md) : ```python x = sequence([10, 20, 30]) # shape=[None, 1] @@ -112,9 +112,9 @@ U = var(0.375, param=true) # shape=[1] rnn = pd.rnn() with rnn.step(): h = rnn.memory(init = m) - hh = rnn.previous_memory(h) + h_prev = rnn.previous_memory(h) a = layer.fc(W, x) - b = layer.fc(U, hh) + b = layer.fc(U, h_prev) s = pd.add(a, b) act = pd.sigmoid(s) rnn.update_memory(h, act) @@ -147,9 +147,9 @@ for (int i = 1; i <= sizeof(x)/sizeof(x[0]); ++i) { ## Compilation and Execution -Like TensorFlow programs, a PaddlePaddle program is written in Python. The first part describes a neural network as a protobuf message, and the rest part executes the message for training or inference. +Like TensorFlow, a PaddlePaddle program is written in Python. The first part describes a neural network as a protobuf message, and the rest executes the message for training or inference. -The generation of this protobuf message is like what a compiler generates a binary executable file. The execution of the message that the OS executes the binary file. +The generation of this protobuf message is similar to how a compiler generates a binary executable file. The execution of the message is similar to how the OS executes the binary file. ## The "Binary Executable File Format" @@ -186,8 +186,8 @@ Also, the RNN operator in above example is serialized into a protobuf message of ``` OpDesc { - inputs = {0} // the index of x - outputs = {5, 3} // indices of act and hidden_out + inputs = {0} // the index of x in vars of BlockDesc above + outputs = {5, 3} // indices of act and hidden_out in vars of BlockDesc above attrs { "memories" : {1} // the index of h "step_net" : @@ -203,14 +203,14 @@ This `OpDesc` value is in the `ops` field of the `BlockDesc` value representing During the generation of the Protobuf message, the Block should store VarDesc (the Protobuf message which describes Variable) and OpDesc (the Protobuf message which describes Operator). VarDesc in a block should have its name scope to avoid local variables affect parent block's name scope. -Child block's name scopes should inherit the parent's so that OpDesc in child block can reference a VarDesc that stored in parent block. For example +Child block's name scopes should inherit the parent's so that OpDesc in child block can reference a VarDesc that stored in parent block. For example: ```python -a = pd.Varaible(shape=[20, 20]) +a = pd.Variable(shape=[20, 20]) b = pd.fc(a, params=["fc.w", "fc.b"]) rnn = pd.create_rnn() -with rnn.stepnet() +with rnn.stepnet(): x = a.as_step_input() # reuse fc's parameter fc_without_b = pd.get_variable("fc.w") @@ -218,17 +218,17 @@ with rnn.stepnet() out = rnn() ``` -the method `pd.get_variable` can help retrieve a Variable by a name, a Variable may store in a parent block, but might be retrieved in a child block, so block should have a variable scope that supports inheritance. +The method `pd.get_variable` can help retrieve a Variable by the name. The Variable may be stored in a parent block, but might be retrieved in a child block, so block should have a variable scope that supports inheritance. In compiler design, the symbol table is a data structure created and maintained by compilers to store information about the occurrence of various entities such as variable names, function names, classes, etc. To store the definition of variables and operators, we define a C++ class `SymbolTable`, like the one used in compilers. -`SymbolTable` can do the following stuff: +`SymbolTable` can do the following: - store the definitions (some names and attributes) of variables and operators, -- to verify if a variable was declared, -- to make it possible to implement type checking (offer Protobuf message pointers to `InferShape` handlers). +- verify if a variable was declared, +- make it possible to implement type checking (offer Protobuf message pointers to `InferShape` handlers). ```c++ @@ -240,19 +240,18 @@ class SymbolTable { OpDesc* NewOp(const string& name=""); - // TODO determine whether name is generated by python or C++ - // currently assume that a unique name will be generated by C++ if the - // argument name left default. + // TODO determine whether name is generated by python or C++. + // Currently assume that a unique name will be generated by C++ if the + // argument name is left default. VarDesc* NewVar(const string& name=""); - // find a VarDesc by name, if recursive true, find parent's SymbolTable + // find a VarDesc by name, if recursive is true, find parent's SymbolTable // recursively. // this interface is introduced to support InferShape, find protobuf messages // of variables and operators, pass pointers into InferShape. - // operator // // NOTE maybe some C++ classes such as VarDescBuilder and OpDescBuilder should - // be proposed and embedded into pybind to enable python operate on C++ pointers. + // be proposed and embedded into pybind to enable python operation on C++ pointers. VarDesc* FindVar(const string& name, bool recursive=true); OpDesc* FindOp(const string& name); @@ -270,7 +269,7 @@ class SymbolTable { After all the description of variables and operators is added into SymbolTable, the block has enough information to run. -The `Block` class takes a `BlockDesc` as input, and provide `Run` and `InferShape` functions. +The `Block` class takes a `BlockDesc` as input, and provides `Run` and `InferShape` functions. ```c++ @@ -302,7 +301,7 @@ public: void CreateVariables(const framework::Scope& scope); void CreateOperators(); - // some other necessary interfaces of NetOp are list below + // some other necessary interfaces of NetOp are listed below // ... private: @@ -316,15 +315,14 @@ private: Block inherits from OperatorBase, which has a Run method. Block's Run method will run its operators sequentially. -There is another important interface called `Eval`, which take some arguments called targets, and generate a minimal graph which takes targets as the end points and creates a new Block, -after `Run`, `Eval` will get the latest value and return the targets. +There is another important interface called `Eval`, which takes some arguments called targets and generates a minimal graph which treats targets as the end points and creates a new Block. After `Run`, `Eval` will get the latest value and return the targets. The definition of Eval is as follows: ```c++ // clean a block description by targets using the corresponding dependency graph. // return a new BlockDesc with minimal number of operators. -// NOTE not return a Block but the block's description so that this can be distributed +// NOTE: The return type is not a Block but the block's description so that this can be distributed // to a cluster. BlockDesc Prune(const BlockDesc& desc, vector targets); diff --git a/doc/design/dcgan.png b/doc/design/dcgan.png new file mode 100644 index 0000000000000000000000000000000000000000..15e8e290a111ff43900934341365cb4360d87d28 Binary files /dev/null and b/doc/design/dcgan.png differ diff --git a/doc/design/gan_api.md b/doc/design/gan_api.md new file mode 100644 index 0000000000000000000000000000000000000000..fb41df8615f73d9fd4c32995eab265833eac1a55 --- /dev/null +++ b/doc/design/gan_api.md @@ -0,0 +1,253 @@ +# Design for GAN + +GAN (General Adversarial Net [https://arxiv.org/abs/1406.2661]) is an important model for unsupervised learning and widely used in many areas. + +It applies several important concepts in machine learning system design, including building and running subgraphs, dependency tracing, different optimizers in one executor and so forth. + +In our GAN design, we wrap it as a user-friendly easily customized python API to design different models. We take the conditional DC-GAN (Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks [https://arxiv.org/abs/1511.06434]) as an example due to its good performance on image generation. + +

+
+Figure 1. The overall running logic of GAN. The black solid arrows indicate the forward pass; the green dashed arrows indicate the backward pass of generator training; the red dashed arrows indicate the backward pass of the discriminator training. The BP pass of the green (red) arrow should only update the parameters in the green (red) boxes. The diamonds indicate the data providers. d\_loss and g\_loss marked in red and green are the two targets we would like to run. +

+ +The operators, layers and functions required/optional to build a GAN demo is summarized in https://github.com/PaddlePaddle/Paddle/issues/4563. + +

+
+Figure 2. Photo borrowed from the original DC-GAN paper. +

+ +## The Conditional-GAN might be a class. +This design we adopt the popular open source design in https://github.com/carpedm20/DCGAN-tensorflow and https://github.com/rajathkmp/DCGAN. It contains following data structure: + +- DCGAN(object): which contains everything required to build a GAN model. It provides following member functions methods as API: + +- __init__(...): Initialize hyper-parameters (like conv dimension and so forth), and declare model parameters of discriminator and generator as well. + +- generator(z, y=None): Generate a fake image from input noise z. If the label y is provided, the conditional GAN model will be chosen. +Returns a generated image. + +- discriminator(image): +Given an image, decide if it is from a real source or a fake one. +Returns a 0/1 binary label. + +- build_model(self): +build the whole GAN model, define training loss for both generator and discrimator. + +## Discussion on Engine Functions required to build GAN +- Trace the tensor and variable dependency in the engine executor. (Very critical, otherwise GAN can'be be trained correctly) +- Different optimizers responsible for optimizing different loss. + +To be more detailed, we introduce our design of DCGAN as following: + +### Class member Function: Initializer +- Set up hyper-parameters, including condtional dimension, noise dimension, batch size and so forth. +- Declare and define all the model variables. All the discriminator parameters are included in the list self.theta_D and all the generator parameters are included in the list self.theta_G. +```python +class DCGAN(object): + def __init__(self, y_dim=None): + + # hyper parameters + self.y_dim = y_dim # conditional gan or not + self.batch_size = 100 + self.z_dim = z_dim # input noise dimension + + # define parameters of discriminators + self.D_W0 = pd.Variable(shape=[3,3, 1, 128], data=pd.gaussian_normal_randomizer()) + self.D_b0 = pd.Variable(np.zeros(128)) # variable also support initialization using a numpy data + self.D_W1 = pd.Variable(shape=[784, 128], data=pd.gaussian_normal_randomizer()) + self.D_b1 = pd.Variable(np.zeros(128)) # variable also support initialization using a numpy data + self.D_W2 = pd.Varialble(np.random.rand(128, 1)) + self.D_b2 = pd.Variable(np.zeros(128)) + self.theta_D = [self.D_W0, self.D_b0, self.D_W1, self.D_b1, self.D_W2, self.D_b2] + + # define parameters of generators + self.G_W0 = pd.Variable(shape=[784, 128], data=pd.gaussian_normal_randomizer()) + self.G_b0 = pd.Variable(np.zeros(128)) # variable also support initialization using a numpy data + self.G_W1 = pd.Variable(shape=[784, 128], data=pd.gaussian_normal_randomizer()) + self.G_b1 = pd.Variable(np.zeros(128)) # variable also support initialization using a numpy data + self.G_W2 = pd.Varialble(np.random.rand(128, 1)) + self.G_b2 = pd.Variable(np.zeros(128)) + self.theta_G = [self.G_W0, self.G_b0, self.G_W1, self.G_b1, self.G_W2, self.G_b2] +``` + +### Class member Function: Generator +- Given a noisy input z, returns a fake image. +- Concatenation, batch-norm, FC operations required; +- Deconv layer required, which is missing now... +```python +class DCGAN(object): + def generator(self, z, y = None): + # input z: the random noise + # input y: input data label (optional) + # output G_im: generated fake images + + if not self.y_dim: + z = pd.layer.concat(1, [z, y]) + + G_h0 = pd.layer.fc(z, self.G_w0, self.G_b0) + G_h0_bn = pd.layer.batch_norm(G_h0) + G_h0_relu = pd.layer.relu(G_h0_bn) + + G_h1 = pd.layer.deconv(G_h0_relu, self.G_w1, self.G_b1) + G_h1_bn = pd.layer.batch_norm(G_h1) + G_h1_relu = pd.layer.relu(G_h1_bn) + + G_h2 = pd.layer.deconv(G_h1_relu, self.G_W2, self.G_b2)) + G_im = pd.layer.tanh(G_im) + return G_im +``` + +### Class member function: Discriminator +- Given a noisy input z, returns a fake image. +- Concatenation, Convolution, batch-norm, FC, Leaky-ReLU operations required; +```python +class DCGAN(object): + def discriminator(self, image): + # input image: either generated images or real ones + # output D_h2: binary logit of the label + + D_h0 = pd.layer.conv2d(image, w=self.D_w0, b=self.D_b0) + D_h0_bn = pd.layer.batchnorm(h0) + D_h0_relu = pd.layer.lrelu(h0_bn) + + D_h1 = pd.layer.conv2d(D_h0_relu, w=self.D_w1, b=self.D_b1) + D_h1_bn = pd.layer.batchnorm(D_h1) + D_h1_relu = pd.layer.lrelu(D_h1_bn) + + D_h2 = pd.layer.fc(D_h1_relu, w=self.D_w2, b=self.D_b2) + return D_h2 +``` + +### Class member function: Build the model +- Define data readers as placeholders to hold the data; +- Build generator and discriminators; +- Define two training losses for discriminator and generator, respectively. +If we have execution dependency engine to back-trace all tensors, the module building our GAN model will be like this: +```python +class DCGAN(object): + def build_model(self): + if self.y_dim: + self.y = pd.data(pd.float32, [self.batch_size, self.y_dim]) + self.images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size]) + self.faked_images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size]) + self.z = pd.data(tf.float32, [None, self.z_size]) + + # step 1: generate images by generator, classify real/fake images with discriminator + if self.y_dim: # if conditional GAN, includes label + self.G = self.generator(self.z, self.y) + self.D_t = self.discriminator(self.images) + # generated fake images + self.sampled = self.sampler(self.z, self.y) + self.D_f = self.discriminator(self.G) + else: # original version of GAN + self.G = self.generator(self.z) + self.D_t = self.discriminator(self.images) + # generate fake images + self.sampled = self.sampler(self.z) + self.D_f = self.discriminator(self.images) + + # step 2: define the two losses + self.d_loss_real = pd.reduce_mean(pd.cross_entropy(self.D_t, np.ones(self.batch_size)) + self.d_loss_fake = pd.reduce_mean(pd.cross_entropy(self.D_f, np.zeros(self.batch_size)) + self.d_loss = self.d_loss_real + self.d_loss_fake + + self.g_loss = pd.reduce_mean(pd.cross_entropy(self.D_f, np.ones(self.batch_szie)) +``` + +If we do not have dependency engine but blocks, the module building our GAN model will be like this: +```python +class DCGAN(object): + def build_model(self, default_block): + # input data in the default block + if self.y_dim: + self.y = pd.data(pd.float32, [self.batch_size, self.y_dim]) + self.images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size]) + # self.faked_images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size]) + self.z = pd.data(tf.float32, [None, self.z_size]) + + # step 1: generate images by generator, classify real/fake images with discriminator + with pd.default_block().g_block(): + if self.y_dim: # if conditional GAN, includes label + self.G = self.generator(self.z, self.y) + self.D_g = self.discriminator(self.G, self.y) + else: # original version of GAN + self.G = self.generator(self.z) + self.D_g = self.discriminator(self.G, self.y) + self.g_loss = pd.reduce_mean(pd.cross_entropy(self.D_g, np.ones(self.batch_szie)) + + with pd.default_block().d_block(): + if self.y_dim: # if conditional GAN, includes label + self.D_t = self.discriminator(self.images, self.y) + self.D_f = self.discriminator(self.G, self.y) + else: # original version of GAN + self.D_t = self.discriminator(self.images) + self.D_f = self.discriminator(self.G) + + # step 2: define the two losses + self.d_loss_real = pd.reduce_mean(pd.cross_entropy(self.D_t, np.ones(self.batch_size)) + self.d_loss_fake = pd.reduce_mean(pd.cross_entropy(self.D_f, np.zeros(self.batch_size)) + self.d_loss = self.d_loss_real + self.d_loss_fake +``` +Some small confusion and problems with this design: +- D\_g and D\_f are actually the same thing, but has to be written twice; i.e., if we want to run two sub-graphs conceptually, the same codes have to be written twice if they are shared by the graph. +- Requires ability to create a block anytime, rather than in if-else or rnn only; + +## Main function for the demo: +Generally, the user of GAN just need to the following things: +- Define an object as DCGAN class; +- Build the DCGAN model; +- Specify two optimizers for two different losses with respect to different parameters. +```python +# pd for short, should be more concise. +from paddle.v2 as pd +import numpy as np +import logging + +if __name__ == "__main__": + # dcgan class in the default graph/block + # if we use dependency engine as tensorflow + # the codes, will be slightly different like: + # dcgan = DCGAN() + # dcgan.build_model() + with pd.block() as def_block: + dcgan = DCGAN() + dcgan.build_model(def_block) + + # load mnist data + data_X, data_y = self.load_mnist() + + # Two subgraphs required!!! + with pd.block().d_block(): + d_optim = pd.train.Adam(lr = .001, beta= .1) + d_step = d_optim.minimize(dcgan.d_loss, dcgan.theta_D) + with pd.block.g_block(): + g_optim = pd.train.Adam(lr = .001, beta= .1) + g_step = pd.minimize(dcgan.g_loss, dcgan.theta_G) + + # executor + sess = pd.executor() + + # training + for epoch in xrange(10000): + for batch_id in range(N / batch_size): + idx = ... + # sample a batch + batch_im, batch_label = data_X[idx:idx+batch_size], data_y[idx:idx+batch_size] + # sample z + batch_z = np.random.uniform(-1., 1., [batch_size, z_dim]) + + if batch_id % 2 == 0: + sess.run(d_step, + feed_dict = {dcgan.images: batch_im, + dcgan.y: batch_label, + dcgan.z: batch_z}) + else: + sess.run(g_step, + feed_dict = {dcgan.z: batch_z}) +``` + +# More thinking about dependency engine v.s. block design: +- What if we just want to run an intermediate result? Do we need to run the whole block/graph? +- Should we call eval() to get the fake images in the first stage? And then train the discriminator in the second stage? diff --git a/doc/design/optimizer.md b/doc/design/optimizer.md new file mode 100644 index 0000000000000000000000000000000000000000..17440fae5028cfac5d58fc079ca2096d0be3a0f9 --- /dev/null +++ b/doc/design/optimizer.md @@ -0,0 +1,105 @@ +## Optimizer Design + +### The Problem + +A PaddlePaddle program, or a block, is a sequence of operators operating variables. A training program needs to do three kinds of works: + +1. the forward pass, which computes intermediate results and the cost(s), +1. the backward pass, which derives gradients from intermediate results and costs, and +1. the optimization pass, which update model parameters to optimize the cost(s). + +These works rely on three kinds of operators: + +1. forward operators, +1. gradient operators, and +1. optimization operators. + +It's true that users should be able to create all these operators manually by calling some low-level API, but it would be much more convenient if they could only describe the forward pass and let PaddlePaddle create the backward and optimization operators automatically. + +In this design, we propose a high-level API that automatically derives the optimisation pass and operators from the forward pass. + + +### High-level Python API to describe the training process + +1. User write code to describe the network: + + ```python + images = layer.data("images") + labels = layer.data("labels") + w1 = pd.var("w1") + b1 = pd.var("b1") + hidden = layer.fc(images, w=w1, b=b1) + cost = layer.mse(hidden, labels) + ``` + + The above code snippet will create forward operators in [Block](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.md). + + +2. Users create a certain kind of Optimizer with some argument. + + ```python + optimizer = AdagradOptimizer(learing_rate=0.001) + ``` + +3. Users use the optimizer to `minimize` a certain `cost` through updating parameters in parameter_list. + + ```python + opt_op_list = optimizer.minimize(cost, parameter_list=[w1, b1]) + ``` + The above code snippet will create gradient and optimization operators in Block. The return value of `minimize()` is list of optimization operators that will be run by session. + +4. Users use Session/Executor to run this opt_op_list as target to do training. + + ```python + sess.run(target= opt_op_list, ...) + ``` + +#### Optimizer Python interface: + +```python +class Optimizer(object): + """Optimizer Base class. + + """ + + def __init__(self): + pass + + def create_backward_pass(self, loss, parameter_list=None): + """ + create and add gradient Operators in BlockDesc to Compute gradients of `loss` + for parameters in parameter_list + + Args: + loss: an variable generated by cost function. + parameter_list: parameters that need to compute gradient and update to optimize the lost. + + Returns: + list of (parameters, gradients) pair. + """ + return None + + def create_optimization_pass(self, parameters_and_grads): + """Add optimization operators to update gradients to variables. + + Args: + parameters_and_grads: a list of (variable, gradient) pair to update. + + Returns: + optmization_op_list: a list of optimization operator that will update parameter using gradient. + """ + return None + + def minimize(self, loss, parameter_list): + """Add operations to minimize `loss` by updating `parameter_list`. + + This method combines interface `create_backward_pass()` and + `create_optimization_pass()` into one. + """ + params_grads = self.create_backward_pass(loss, parameter_list) + update_ops = self.create_optimization_pass(params_grads) + return update_ops + +``` + +Users can inherit the Optimizer above to create their own Optimizer with some special logic, such as AdagradOptimizer. diff --git a/doc/design/python_api.md b/doc/design/python_api.md index 6213da65c8c5931bc16e42574b8628b676424873..56ae1d925a96622b5576013f38e33e5f92cbbb90 100644 --- a/doc/design/python_api.md +++ b/doc/design/python_api.md @@ -22,7 +22,7 @@ Whenever we create a block, we need to set its parent block to the current block ```python class Program(objects): def __init__(self): - self.proto = core.NewProgram() # a C++ ProgramDesc pointer. + self.desc = core.NewProgram() # a C++ ProgramDesc pointer. self.blocks = vector() self.blocks.append(Block(self, -1)) # the global block self.current_block = 0 # initialized to the global block @@ -57,7 +57,7 @@ A [Block](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.m ```python class Block(objects): def __init__(self, program, parent_idx): - self.proto = core.NewBlock(program.proto) + self.desc = core.NewBlock(program.desc) self.program = program self.vars = map() self.ops = vector() @@ -98,11 +98,11 @@ class Operator(object): outputs,# dict attrs # dict ): - self.proto = core.NewOpDesc(block.proto, type, inputs, outputs, attrs) - core.infer_shape(self.proto, inputs, outputs) + self.desc = core.NewOpDesc(block.desc, type, inputs, outputs, attrs) + core.infer_shape(self.desc, inputs, outputs) def type(self): - return self.proto.type() + return self.desc.type() ``` `Operator` creates the `OpDesc` message in C++ space, so that it can call the `InferShape` function, which is in C++. @@ -124,7 +124,7 @@ class Variable(object): name = unique_name_generator() self.name = name self.block = block - self.proto = core.NewVarDesc(block.proto, name, shape, lod_level) + self.desc = core.NewVarDesc(block.desc, name, shape, lod_level) self.writer = None ``` @@ -214,3 +214,7 @@ def fc_layer(input, size, ...): out.writer = op return out ``` + +## Optimizer + +[Optimizer Design Doc](./optimizer.md) diff --git a/doc/design/refactorization.md b/doc/design/refactorization.md index 629422e7743af666b42fd69fbff442ce15bef596..ec51aa1a0ec667175ff7215dcd359023e296769f 100644 --- a/doc/design/refactorization.md +++ b/doc/design/refactorization.md @@ -17,22 +17,22 @@ The goals of refactoring include: 1. A graph is composed of *variables* and *operators*. -1. The description of graphs must be capable of being serialized/deserialized, so that: +1. The description of graphs must be serializable/deserializable, so that: - 1. It can to be sent to the cloud for distributed execution, and + 1. It can be sent to the cloud for distributed execution, and 1. It can be sent to clients for mobile or enterprise deployment. -1. The Python program does the following steps +1. The Python program does two things - 1. *compilation*: run a Python program to generate a protobuf message representation of the graph and send it to + 1. *Compilation* runs a Python program to generate a protobuf message representation of the graph and send it to 1. the C++ library `libpaddle.so` for local execution, 1. the master process of a distributed training job for training, or 1. the server process of a Kubernetes serving job for distributed serving. - 1. *execution*: execute the graph by constructing instances of class [`Variable`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/variable.h#L24) and [`OperatorBase`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h#L70), according to the protobuf message. + 1. *Execution* executes the graph by constructing instances of class [`Variable`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/variable.h#L24) and [`OperatorBase`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h#L70), according to the protobuf message. ## Description and Realization of Computation Graph -At compile time, the Python program generates a protobuf message representation of the graph, or the description of the graph. +At compile time, the Python program generates a protobuf message representation of the graph, or a description of the graph. At runtime, the C++ program realizes the graph and runs it. @@ -42,11 +42,11 @@ At runtime, the C++ program realizes the graph and runs it. |Operation|[OpDesc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L35)|[Operator](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h#L64)| |Block|BlockDesc|Block| -The word *graph* is interchangeable with *block* in this document. A graph represents computation steps and local variables similar to a C++/Java program block, or a pair of parentheses(`{` and `}`). +The word *graph* is interchangeable with *block* in this document. A graph consists of computation steps and local variables similar to a C++/Java program block, or a pair of parentheses(`{` and `}`). ## Compilation and Execution -1. Run an application Python program to describe the graph. In particular, the Python application program does the following: +1. Run a Python program to describe the graph. In particular, the Python application program does the following: 1. Create `VarDesc` to represent local/intermediate variables, 1. Create operators and set attributes, @@ -54,10 +54,10 @@ The word *graph* is interchangeable with *block* in this document. A graph repr 1. Infer the type and the shape of variables, 1. Plan memory-reuse for variables, 1. Generate the backward graph - 1. Optimize the computation graph. - 1. Potentially, split the graph for distributed training. + 1. Add optimization operators to the computation graph. + 1. Optionally, split the graph for distributed training. -1. The invocation of `train` or [`infer`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/inference.py#L108) methods in the application Python program does the following: +1. The invocation of `train` or [`infer`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/inference.py#L108) methods in the Python program does the following: 1. Create a new Scope instance in the [scope hierarchy](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/scope.md) for each run of a block, 1. realize local variables defined in the BlockDesc message in the new scope, @@ -107,8 +107,8 @@ Compile Time -> IR -> Runtime ![class_diagram](http://api.paddlepaddle.org/graphviz?dot=https://gist.githubusercontent.com/reyoung/53df507f6749762675dff3e7ce53372f/raw/dd598e8f1976f5759f58af5e5ef94738a6b2e661/op.dot) * `Operator` is the fundamental building block of the user interface. - * Operator stores input/output variable names, and attributes. - * The `InferShape` interface is used to infer the shape of the output variable shapes based on the shapes of the input variables. + * Operator stores input/output variable names and attributes. + * The `InferShape` interface is used to infer the shape of the output variables based on the shapes of the input variables. * Use `Run` to compute the `output` variables from the `input` variables. --- @@ -139,7 +139,7 @@ Compile Time -> IR -> Runtime * Limit the number of `tensor.device(dev) = ` in your code. * `thrust::transform` and `std::transform`. * `thrust` has the same API as C++ standard library. Using `transform`, one can quickly implement customized element-wise kernels. - * `thrust` also has more complex APIs, like `scan`, `reduce`, `reduce_by_key`. + * `thrust`, in addition, supports more complex APIs, like `scan`, `reduce`, `reduce_by_key`. * Hand-writing `GPUKernel` and `CPU` code * Do not write in header (`.h`) files. CPU Kernel should be in cpp source (`.cc`) and GPU kernels should be in cuda (`.cu`) files. (GCC cannot compile GPU code.) --- @@ -185,10 +185,10 @@ Make sure the registration process is executed and linked. 1. Write an Op class and its gradient Op class, if required. 2. Write an Op maker class. In the constructor of this class, describe the inputs, outputs and attributes of the operator. 3. Invoke the macro `REGISTER_OP`. This macro will - 1. Call maker class to complete the `proto` and the `checker` + 1. Call maker class to complete `proto` and `checker` 2. Using the completed `proto` and `checker`, it will add a new key-value pair to the `OpInfoMap` -4. Invoke the `USE` macro in which the Op is used, to make sure that it is linked. +4. Invoke the `USE` macro in which the Op is used to make sure that it is linked. --- # Backward Module (1/2) @@ -199,13 +199,14 @@ Make sure the registration process is executed and linked. --- # Backward Module (2/2) ### Build Backward Network -- **Input**: graph of forward operators -- **Output**: graph of backward operators +- **Input**: a graph of forward operators +- **Output**: a graph of backward operators - **Corner cases in construction** - Shared Variables => insert an `Add` operator to combine gradients - No Gradient => insert a `fill_zero_grad` operator - Recursive NetOp => call `Backward` recursively - RNN Op => recursively call `Backward` on stepnet + - RNN Op => recursively call `Backward` on stepnet --- @@ -215,10 +216,10 @@ Make sure the registration process is executed and linked. * Only dims and data pointers are stored in `Tensor`. * All operations on `Tensor` are written in `Operator` or global functions. * Variable length Tensor design [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md) -* `Variable` instances are the inputs and the outputs of an operator. Not just `Tensor`. +* `Variable` instances are the inputs and the outputs of an operator, not just `Tensor`. * `step_scopes` in RNN is a variable and not a tensor. -* `Scope` is where variables are stores. - * map +* `Scope` is where variables are stored. + * map * `Scope` has a hierarchical structure. The local scope can get variables from its parent scope. --- @@ -246,7 +247,7 @@ Make sure the registration process is executed and linked. --- # Control the migration quality - Compare the performance of migrated models with old ones. -- Follow the google C++ style +- Follow the google C++ style guide. - Build the automatic workflow of generating Python/C++ documentations. - The documentation of layers and ops should be written inside the code. - Take the documentation quality into account when submitting pull requests. diff --git a/doc/design/selected_rows.md b/doc/design/selected_rows.md new file mode 100644 index 0000000000000000000000000000000000000000..9e6f3b20cbcdc55e481fbe7bf5fa555d8b3c3d45 --- /dev/null +++ b/doc/design/selected_rows.md @@ -0,0 +1,74 @@ +# Design Doc: Selected Rows + +`SelectedRows` is a kind of sparse tensor data type, which is designed to support `embedding` operators. The gradient of embedding table is a sparse tensor. Only a few rows are non-zero values in that tensor. It is straightforward to represent the sparse tensor by the following sparse tensor data structure: + +```cpp +class SelectedRows { + private: + vector rows_; + Tensor value_; + int height_; +}; +``` + +The field `height_` shows the first dimension of `SelectedRows`. The `rows` are the indices of which rows of `SelectedRows` are non-zeros. The `value_` field is an N-dim tensor and shape is `[rows.size() /* NUM_ROWS */, ...]`, which supplies values for each row. The dimension of `SelectedRows` satisfies `[height_] + value_.shape[1:]`. + +Suppose that a SelectedRows-typed variable `x` has many rows, but only two of them have values -- row 73 is `[1, 2]` and row 84 is `[3, 4]`, the `SelectedRows` representation would be: + +``` +x = SelectedRow { + rows = [73, 84], + value = [[1, 2], [3,4]] +} +``` + + +## SelectedRows in Protobuf + +`SelectedRows` is a kind of `Variable`. `VarDesc` in protobuf should describe the `SelectedRows` information. Only the tensor dimension of a `SelectedRows` will be described in compile-time since the `rows_` and `value_` are related to training data. +So we use `TensorDesc` to unify `data_type` and `dims`. A LodTensorDesc contains a `TensorDesc` and `lod_level`. The description of `SelectedRows` is a Tensor description. + +```proto +message TensorDesc { + required DataType data_type = 1; + repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] +} + +message LodTensorDesc { + required TensorDesc tensor = 1; + optional int lod_level = 2; +} + +message VarDesc { + required string name = 1; + enum VarType { + LOD_TENSOR = 0; + SELECTED_ROWS = 1; + } + required VarType type = 2; + optional LodTensorDesc lod_desc = 3; + optional TensorDesc selected_rows_desc = 4; + optional bool persistable = 5 [ default = false ]; +} +``` + +## InferShape for Selected Rows + +Just like `LoD` information, `InferShape` method will inference output tensor type as well. The operator should decide whether its output is a `SelectedRows` or `Dense` tensor. + +For example, the gradient operator of `TableLookup` will always generate `SelectedRows`. Its `InferShape` method should be like following + +```cpp +void TableLookupGrad::InferShape(context) { + ... + context.SetDataType("Embedding.Grad", kSelectedRows); +} +``` + + +## Sparse Operators + +There are several operators should be written to support `SelectedRows`. They are: + +1. Operators which generates `SelectedRows` gradient. e.g. Gradient of `TableLookupOp`. +2. Optimize operators which support `SelectedRows` gradient. e.g. `SGD` or `AdaGrad` for `SelectedRows`. However, there should be only one `SGD` operator. `OpWithKernel::Run` should select a suitable kernel for both `dense` tensor or `SelectedRows`. diff --git a/doc/design/test.dot b/doc/design/test.dot new file mode 100644 index 0000000000000000000000000000000000000000..62c69b8fc8010a26a54a6ee8ef1488aad94d747a --- /dev/null +++ b/doc/design/test.dot @@ -0,0 +1,35 @@ + +digraph Test { + z -> generator -> G_img; + G_img -> discriminator -> D_f -> d_loss_f; + label0 -> d_loss_f -> d_loss; + + img -> discriminator -> D_t -> d_loss_t; + label1 -> d_loss_t -> d_loss; + + d_loss -> d_loss_t[color=red, style=dashed]; + d_loss -> d_loss_f[color=red, style=dashed]; + d_loss_t -> D_t[color=red, style=dashed]; + d_loss_f -> D_f[color=red, style=dashed]; + D_t -> discriminator[color=red, style=dashed]; + D_f -> discriminator[color=red, style=dashed]; + + D_f -> g_loss; + label2 -> g_loss; + + g_loss -> D_f[color=green, style=dashed]; + D_f -> discriminator[color=green, style=dashed]; + discriminator -> G_img[color=green, style=dashed]; + G_img -> generator[color=green, style=dashed]; + + discriminator [color=red, shape=box]; + generator [color=green, shape=box]; + z [shape=diamond]; + img [shape=diamond]; + label0 [shape=diamond]; + label1 [shape=diamond]; + label2 [shape=diamond]; + + d_loss [color=red]; + g_loss [color=green]; +} diff --git a/doc/design/test.dot.png b/doc/design/test.dot.png new file mode 100644 index 0000000000000000000000000000000000000000..4e121a40b9f7b2232d7cdda315bad15926446f55 Binary files /dev/null and b/doc/design/test.dot.png differ diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 3e0e0f59038daa33cae1952ffbe5fc0bb1870485..6b34c3bbcfbdb0c36381df7de4dd227e317829e5 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -19,7 +19,7 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope) proto_library(framework_proto SRCS framework.proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto) -cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute) +cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc) @@ -42,5 +42,12 @@ add_custom_command(TARGET framework_py_proto POST_BUILD cc_library(backward SRCS backward.cc DEPS net_op) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context) +cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward ${GLOB_OP_LIB}) +#if(WITH_GPU) +# nv_test(executor_test SRCS executor_test.cc DEPS executor) +#else() +# cc_test(executor_test SRCS executor_test.cc DEPS executor) +#endif() + cc_library(tensor_array SRCS tensor_array.cc DEPS lod_tensor) cc_test(tensor_array_test SRCS tensor_array_test.cc DEPS tensor_array place) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 0a4688db9c930201c95d4a47658f43cad87fdbca..063b108500d95c94d5859cf6e1a5a88dcdb2ed31 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -172,30 +172,14 @@ static std::unique_ptr BackwardRecursive( std::to_string(i)); net->ops_[op_offset]->Rename(name, dup_outputs.back()); } - // collect all the offset to append `add` op for each alias - // - // one variable is shared between multiple operators. - // insert add operator one by one, then add it to output - for (size_t output_idx = 0; output_idx < dup_outputs.size() - 1; - ++output_idx) { - auto insert_add_x = dup_outputs[output_idx]; - auto insert_add_y = dup_outputs[output_idx + 1]; - auto insert_add_out = name + "@SHARED@" + std::to_string(output_idx); - // first add op inserted - if (output_idx == dup_outputs.size() - 2) { - insert_add_out = name; - } - if (output_idx != 0) { - insert_add_y = name + "@SHARED@" + std::to_string(output_idx - 1); - } - insert_position.push_back( - {dup_op.back(), - OpRegistry::CreateOp("sum", {{"X", {insert_add_x, insert_add_y}}}, - {{"Out", {insert_add_out}}}, {})}); - } + // collect all the offset for each alias, + // insert a sum operator to add all aliases to output + insert_position.push_back( + {dup_op.back(), OpRegistry::CreateOp("sum", {{"X", dup_outputs}}, + {{"Out", {name}}}, {})}); } - // make sure the inserted `add` ops follow the BFS order. + // make sure the inserted `sum` ops follow the BFS order. insert_position.sort( [](const Pos& l, const Pos& r) { return l.first > r.first; }); diff --git a/paddle/framework/data_type.h b/paddle/framework/data_type.h index 55e3931f870d62dcaddc6c067f66999c59e2a262..649899d42572c9a22adca5337dcd56b0bcf42e7c 100644 --- a/paddle/framework/data_type.h +++ b/paddle/framework/data_type.h @@ -28,7 +28,6 @@ inline DataType ToDataType(std::type_index type) { return DataType::INT32; } else { PADDLE_THROW("Not supported"); - return static_cast(-1); } } diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..c388b2198e4fbf75d6584d710e00d3deca93eb51 --- /dev/null +++ b/paddle/framework/executor.cc @@ -0,0 +1,163 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/executor.h" + +#include +#include +#include +#include +#include + +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/scope.h" + +namespace paddle { +namespace framework { + +const std::string kFeedOpType = "feed"; +const std::string kFetchOpType = "fetch"; + +Executor::Executor(const std::vector& places) { + PADDLE_ENFORCE_GT(places.size(), 0); + device_contexts_.resize(places.size()); + for (size_t i = 0; i < places.size(); i++) { + if (platform::is_cpu_place(places[i])) { + device_contexts_[i] = new platform::CPUDeviceContext( + boost::get(places[i])); + } else if (platform::is_gpu_place(places[i])) { +#ifdef PADDLE_WITH_CUDA + device_contexts_[i] = new platform::CUDADeviceContext( + boost::get(places[i])); +#else + PADDLE_THROW( + "'GPUPlace' is not supported, Please re-compile with WITH_GPU " + "option"); +#endif + } + } +} + +Executor::~Executor() { + for (auto& device_context : device_contexts_) { + delete device_context; + } +} + +void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) { + // TODO(tonyyang-svail): + // - only runs on the first device (i.e. no interdevice communication) + // - will change to use multiple blocks for RNN op and Cond Op + PADDLE_ENFORCE_GT(pdesc.blocks_size(), block_id); + auto& block = pdesc.blocks(block_id); + auto& device = device_contexts_[0]; + + // Instantiate all the vars in the global scope + for (auto& var : block.vars()) { + scope->NewVar(var.name()); + } + + Scope& local_scope = scope->NewScope(); + + std::vector should_run = Prune(pdesc, block_id); + PADDLE_ENFORCE_EQ(should_run.size(), static_cast(block.ops_size())); + for (size_t i = 0; i < should_run.size(); ++i) { + if (should_run[i]) { + for (auto& var : block.ops(i).outputs()) { + for (auto& argu : var.arguments()) { + if (local_scope.FindVar(argu) == nullptr) { + local_scope.NewVar(argu); + } + } + } + auto op = paddle::framework::OpRegistry::CreateOp(block.ops(i)); + op->Run(local_scope, *device); + } + } + + // TODO(tonyyang-svail): + // - Destroy local_scope +} + +std::vector Prune(const ProgramDesc& pdesc, int block_id) { + // TODO(tonyyang-svail): + // - will change to use multiple blocks for RNN op and Cond Op + + auto& block = pdesc.blocks(block_id); + auto& ops = block.ops(); + + bool expect_feed = true; + for (auto& op_desc : ops) { + PADDLE_ENFORCE(op_desc.type() != kFeedOpType || expect_feed, + "All FeedOps are at the beginning of the ProgramDesc"); + expect_feed = (op_desc.type() == kFeedOpType); + } + + bool expect_fetch = true; + for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { + auto& op_desc = *op_iter; + PADDLE_ENFORCE(op_desc.type() != kFetchOpType || expect_fetch, + "All FetchOps must at the end of the ProgramDesc"); + expect_fetch = (op_desc.type() == kFetchOpType); + } + + std::set dependent_vars; + std::vector should_run; + for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { + auto& op_desc = *op_iter; + + bool found_dependent_vars = false; + for (auto& var : op_desc.outputs()) { + for (auto& argu : var.arguments()) { + if (dependent_vars.count(argu) != 0) { + found_dependent_vars = true; + } + } + } + + if (op_desc.type() == kFetchOpType || found_dependent_vars) { + // erase its output to the dependency graph + for (auto& var : op_desc.outputs()) { + for (auto& argu : var.arguments()) { + dependent_vars.erase(argu); + } + } + + // insert its input to the dependency graph + for (auto& var : op_desc.inputs()) { + for (auto& argu : var.arguments()) { + dependent_vars.insert(argu); + } + } + + should_run.push_back(true); + } else { + should_run.push_back(false); + } + } + + // TODO(tonyyang-svail): + // - check this after integration of Init + // PADDLE_ENFORCE(dependent_vars.empty()); + + // since we are traversing the ProgramDesc in reverse order + // we reverse the should_run vector + std::reverse(should_run.begin(), should_run.end()); + + return should_run; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/executor.h b/paddle/framework/executor.h new file mode 100644 index 0000000000000000000000000000000000000000..4e3bc2c0a59dfee5b9993037671f14a109dc8a74 --- /dev/null +++ b/paddle/framework/executor.h @@ -0,0 +1,55 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/framework.pb.h" +#include "paddle/framework/op_info.h" +#include "paddle/framework/scope.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace framework { + +class Executor { + public: + explicit Executor(const std::vector& places); + ~Executor(); + + /* @Brief + * Runtime evaluation of the given ProgramDesc under certain Scope + * + * @param + * ProgramDesc + * Scope + */ + void Run(const ProgramDesc&, Scope*, int); + + private: + std::vector device_contexts_; +}; + +/* @Brief + * Pruning the graph + * + * @param + * ProgramDesc + * + * @return + * vector Same size as ops. Indicates whether an op should be run. + */ +std::vector Prune(const ProgramDesc& pdesc, int block_id); + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7f6d8fe6a4aec9fdc39b4ffc0837a03e355ec937 --- /dev/null +++ b/paddle/framework/executor_test.cc @@ -0,0 +1,308 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/executor.h" + +#include +#include + +#include "gtest/gtest.h" +#include "paddle/framework/attribute.h" +#include "paddle/framework/backward.h" +#include "paddle/framework/block_desc.h" +#include "paddle/framework/op_desc.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +using namespace paddle::platform; +using namespace paddle::framework; + +void AddOp(const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, AttributeMap attrs, + paddle::framework::BlockDescBind* block) { + // insert output + for (auto kv : outputs) { + for (auto v : kv.second) { + auto var = block->NewVar(v); + var->SetDataType(paddle::framework::DataType::FP32); + } + } + + // insert op + auto op = block->AppendOp(); + op->SetType(type); + for (auto& kv : inputs) { + op->SetInput(kv.first, kv.second); + } + for (auto& kv : outputs) { + op->SetOutput(kv.first, kv.second); + } + op->SetAttrMap(attrs); +} + +// Tensors in feed value variable will only be in CPUPlace +// So we can memcpy the data from vector to feed_value +template +void SetFeedVariable(const std::vector>& inputs, + const std::vector>& dims) { + Variable* g_feed_value = GetGlobalScope().FindVar("feed_value"); + auto& feed_inputs = + *(g_feed_value->GetMutable>()); + size_t size = inputs.size(); + feed_inputs.resize(size); + for (size_t i = 0; i < size; i++) { + T* dst = feed_inputs[i].mutable_data(make_ddim(dims[i]), CPUPlace()); + memcpy(dst, inputs[i].data(), inputs[i].size() * sizeof(T)); + } +} + +// Tensors in fetch value variable will only be in CPUPlace +// So we can memcpy the data from fetch_value to vector +template +std::vector> GetFetchVariable() { + Variable* g_fetch_value = GetGlobalScope().FindVar("fetch_value"); + auto& fetch_outputs = + *(g_fetch_value->GetMutable>()); + + size_t size = fetch_outputs.size(); + std::vector> result; + result.reserve(size); + for (size_t i = 0; i < size; i++) { + std::vector tmp; + tmp.resize(fetch_outputs[i].numel()); + memcpy(tmp.data(), fetch_outputs[i].data(), + fetch_outputs[i].numel() * sizeof(T)); + result.push_back(tmp); + } + + return result; +} + +class ExecutorTesterRandom : public ::testing::Test { + public: + virtual void SetUp() override { + int input_dim = 3, batch_size = 2, embed_dim = 5; + + auto temp_init_root_block = init_pdesc_.add_blocks(); + temp_init_root_block->set_idx(0); + temp_init_root_block->set_parent_idx(-1); + paddle::framework::ProgramDescBind& init_program = + paddle::framework::ProgramDescBind::Instance(&init_pdesc_); + paddle::framework::BlockDescBind* init_root_block = init_program.Block(0); + + AddOp("gaussian_random", {}, {{"Out", {"w1"}}}, + {{"dims", std::vector{input_dim, embed_dim}}}, init_root_block); + AddOp("gaussian_random", {}, {{"Out", {"w2"}}}, + {{"dims", std::vector{embed_dim, input_dim}}}, init_root_block); + AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, init_root_block); + AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, init_root_block); + + // flush + init_program.Proto(); + + // run block + auto temp_root_block = pdesc_.add_blocks(); + temp_root_block->set_idx(0); + temp_root_block->set_parent_idx(-1); + paddle::framework::ProgramDescBind& program = + paddle::framework::ProgramDescBind::Instance(&pdesc_); + paddle::framework::BlockDescBind* root_block = program.Block(0); + + // feed data + inputs_.push_back({1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); + dims_.push_back({batch_size, input_dim}); + AddOp("feed", {}, {{"Out", {"a"}}}, + {{"dims", std::vector{batch_size, input_dim}}, {"col", 0}}, + root_block); + + // forward + AddOp("mul", {{"X", {"a"}}, {"Y", {"w1"}}}, {{"Out", {"b"}}}, {}, + root_block); + AddOp("mul", {{"X", {"b"}}, {"Y", {"w2"}}}, {{"Out", {"a_out"}}}, {}, + root_block); + AddOp("squared_l2_distance", {{"X", {"a"}}, {"Y", {"a_out"}}}, + {{"Out", {"l2_distance"}}, {"sub_result", {"l2_distance_sub"}}}, {}, + root_block); + + // backward + AddOp("fill_constant", {}, {{"Out", {"l2_distance@GRAD"}}}, + {{"shape", std::vector{batch_size, 1}}, {"value", float(1.0)}}, + root_block); + AppendBackward(program, {}); + + // update + AddOp("fill_constant", {}, {{"Out", {"learning_rate"}}}, + {{"shape", std::vector{1}}, {"value", float(0.001)}}, + root_block); + AddOp("sgd", {{"Param", {"w1"}}, + {"LearningRate", {"learning_rate"}}, + {"Grad", {"w1@GRAD"}}}, + {{"ParamOut", {"w1"}}}, {}, root_block); + AddOp("sgd", {{"Param", {"w2"}}, + {"LearningRate", {"learning_rate"}}, + {"Grad", {"w2@GRAD"}}}, + {{"ParamOut", {"w2"}}}, {}, root_block); + + AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, root_block); + AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, root_block); + AddOp("fetch", {{"Input", {"l2_distance"}}}, {}, {{"col", 0}}, root_block); + + // flush + program.Proto(); + } + + protected: + ProgramDesc init_pdesc_; + ProgramDesc pdesc_; + std::vector> inputs_; + std::vector> dims_; +}; + +class ExecutorTesterFeedAndFetch : public ::testing::Test { + public: + virtual void SetUp() override { + auto temp_root_block = pdesc_.add_blocks(); + temp_root_block->set_idx(0); + temp_root_block->set_parent_idx(-1); + + // wrap to BlockDescBind + paddle::framework::ProgramDescBind& program = + paddle::framework::ProgramDescBind::Instance(&pdesc_); + paddle::framework::BlockDescBind* root_block = program.Block(0); + + std::vector dim{6}; + + AddOp("feed", {}, {{"Out", {"a"}}}, {{"dims", dim}, {"col", 0}}, + root_block); + AddOp("feed", {}, {{"Out", {"b"}}}, {{"dims", dim}, {"col", 1}}, + root_block); + AddOp("fetch", {{"Input", {"a"}}}, {}, {{"col", 0}}, root_block); + AddOp("fetch", {{"Input", {"b"}}}, {}, {{"col", 1}}, root_block); + + // flush + program.Proto(); + + std::vector vec1 = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + std::vector vec2 = {4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + inputs_.push_back(vec1); + inputs_.push_back(vec2); + dims_.push_back({static_cast(vec1.size())}); + dims_.push_back({static_cast(vec2.size())}); + } + + protected: + ProgramDesc pdesc_; + std::vector> inputs_; + std::vector> dims_; +}; + +#ifndef PADDLE_WITH_CUDA +TEST_F(ExecutorTesterRandom, CPU) { + std::vector places; + CPUPlace cpu_place; + places.push_back(cpu_place); + + // We have a global Scope and BuddyAllocator, and we must ensure + // global BuddyAllocator is initialized before global Scope. Thus, + // global Scope will deconstruct before BuddyAllocator. Otherwise, + // "pointer being freed was not allocated" error will appear. + paddle::memory::Used(cpu_place); + + std::unique_ptr executor(new Executor(places)); + + executor->Run(init_pdesc_, &GetGlobalScope(), 0); + SetFeedVariable(inputs_, dims_); + executor->Run(pdesc_, &GetGlobalScope(), 0); + std::vector> result = GetFetchVariable(); +} + +TEST_F(ExecutorTesterFeedAndFetch, CPU) { + std::vector places; + CPUPlace cpu_place; + places.push_back(cpu_place); + + // We have a global Scope and BuddyAllocator, and we must ensure + // global BuddyAllocator is initialized before global Scope. Thus, + // global Scope will deconstruct before BuddyAllocator. Otherwise, + // "pointer being freed was not allocated" error will appear. + paddle::memory::Used(cpu_place); + + std::unique_ptr executor(new Executor(places)); + + for (int batch_id = 0; batch_id < 3; batch_id++) { + SetFeedVariable(inputs_, dims_); + executor->Run(pdesc_, &GetGlobalScope(), 0); + std::vector> result = GetFetchVariable(); + PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); + for (size_t i = 0; i < result.size(); ++i) { + PADDLE_ENFORCE_EQ(result[i].size(), inputs_[i].size()); + for (size_t j = 0; j < result[i].size(); ++j) { + PADDLE_ENFORCE_EQ(result[i][j], inputs_[i][j]); + } + } + } +} +#else +TEST_F(ExecutorTesterRandom, GPU) { + std::vector places; + GPUPlace gpu_place(0); + places.push_back(gpu_place); + + // We have a global Scope and BuddyAllocator, and we must ensure + // global BuddyAllocator is initialized before global Scope. Thus, + // global Scope will deconstruct before BuddyAllocator. Otherwise, + // "pointer being freed was not allocated" error will appear. + // If paddle is compiled with GPU, both CPU and GPU BuddyAllocator + // need to be used at first. + paddle::memory::Used(CPUPlace()); + paddle::memory::Used(gpu_place); + + std::unique_ptr executor(new Executor(places)); + + executor->Run(init_pdesc_, &GetGlobalScope(), 0); + for (int batch_id = 0; batch_id < 3; batch_id++) { + SetFeedVariable(inputs_, dims_); + executor->Run(pdesc_, &GetGlobalScope(), 0); + } +} + +TEST_F(ExecutorTesterFeedAndFetch, GPU) { + std::vector places; + GPUPlace gpu_place(0); + places.push_back(gpu_place); + // We have a global Scope and BuddyAllocator, and we must ensure + // global BuddyAllocator is initialized before global Scope. Thus, + // global Scope will deconstruct before BuddyAllocator. Otherwise, + // "pointer being freed was not allocated" error will appear. + // If paddle is compiled with GPU, both CPU and GPU BuddyAllocator + // need to be used at first. + paddle::memory::Used(CPUPlace()); + paddle::memory::Used(gpu_place); + + std::unique_ptr executor(new Executor(places)); + + for (int batch_id = 0; batch_id < 3; batch_id++) { + SetFeedVariable(inputs_, dims_); + executor->Run(pdesc_, &GetGlobalScope(), 0); + std::vector> result = GetFetchVariable(); + PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); + for (size_t i = 0; i < result.size(); ++i) { + PADDLE_ENFORCE_EQ(result[i].size(), inputs_[i].size()); + for (size_t j = 0; j < result[i].size(); ++j) { + PADDLE_ENFORCE_EQ(result[i][j], inputs_[i][j]); + } + } + } +} +#endif diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index 02aa74a8420a5c685c88d7cb0b487284814b3690..e7538b4af3429e566a439d5a0db8496efcd94969 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/op_desc.h" +#include +#include #include "paddle/framework/block_desc.h" +#include "paddle/framework/operator.h" namespace paddle { namespace framework { @@ -25,6 +28,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs, inputs_ = inputs; outputs_ = outputs; attrs_ = attrs; + need_update_ = true; } OpDesc *OpDescBind::Proto() { @@ -184,5 +188,38 @@ void OpDescBind::Sync() { need_update_ = false; } } + +using InferShapeFuncMap = + std::unordered_map>; + +static InferShapeFuncMap &InferShapeFuncs() { + static InferShapeFuncMap *g_map = nullptr; + if (g_map == nullptr) { + g_map = new InferShapeFuncMap(); + auto &info_map = OpInfoMap::Instance(); + // all registered kernels + for (auto &pair : OperatorWithKernel::AllOpKernels()) { + auto &info = info_map.Get(pair.first); + // use empty type here to avoid runtime checks. + auto op = + static_cast(info.Creator()("", {}, {}, {})); + g_map->insert( + {pair.first, [op](InferShapeContext *ctx) { op->InferShape(ctx); }}); + } + } + return *g_map; +} + +void OpDescBind::InferShape(const BlockDescBind &block) const { + auto &funcs = InferShapeFuncs(); + auto it = funcs.find(this->Type()); + if (it == funcs.end()) { + PADDLE_THROW("Operator %s has not been registered", this->Type()); + } + CompileTimeInferShapeContext ctx(*this, block); + it->second(&ctx); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index d0c314771c04d2a293f2d9ae0b7fc2be0ccb3add..81c4225041157ac600d1db73ef2363ebcd4abfc0 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -100,6 +100,8 @@ class OpDescBind { return &this->attrs_; } + void InferShape(const BlockDescBind &block) const; + private: template static std::vector MapKeys(const MapType &map) { diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 1e9ace99876e00b9f733da3e66f37dc0dc2f8cad..15f80b57206c90f689acfdcac60a0d9011025fc0 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -142,9 +142,9 @@ class OperatorBase { // Macro for define a clone method. // If you are writing an kernel operator, `Clone` will be defined when you // register it. i.e. `Clone` method is not needed to define by yourself. -#define DEFINE_OP_CLONE_METHOD(cls) \ - std::unique_ptr Clone() const final { \ - return std::unique_ptr(new cls(*this)); \ +#define DEFINE_OP_CLONE_METHOD(cls) \ + std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final { \ + return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \ } // Macro for define a default constructor for Operator. diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc index 080b4ac621c1b8c0d4b4e7b26f394cf2be263894..5821bac928ed898971d61a3e2a86f59155d76991 100644 --- a/paddle/framework/scope.cc +++ b/paddle/framework/scope.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/scope.h" + +#include // for unique_ptr +#include // for call_once #include "paddle/string/printf.h" namespace paddle { @@ -62,5 +65,17 @@ void Scope::DropKids() { kids_.clear(); } +std::once_flag feed_variable_flag; + +framework::Scope& GetGlobalScope() { + static std::unique_ptr g_scope{nullptr}; + std::call_once(feed_variable_flag, [&]() { + g_scope.reset(new framework::Scope()); + g_scope->NewVar("feed_value"); + g_scope->NewVar("fetch_value"); + }); + return *(g_scope.get()); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index 7047f0d55e9844aec19892631fe4b5b387bf89ca..a8cfb107c25ccd62039db7349cc1c1dbff772f39 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -73,5 +73,7 @@ class Scope { DISABLE_COPY_AND_ASSIGN(Scope); }; +framework::Scope& GetGlobalScope(); + } // namespace framework } // namespace paddle diff --git a/paddle/framework/tensor_array.h b/paddle/framework/tensor_array.h index 94a14c2df492b175cf6a643800937878e95c5f37..293da04997304be41810446cb3e866d545805f83 100644 --- a/paddle/framework/tensor_array.h +++ b/paddle/framework/tensor_array.h @@ -87,12 +87,12 @@ class TensorArray { LoDTensor Stack() const; /* - * Unpacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors. + * Unstacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors. */ void Unstack(const LoDTensor &source) const; /* - * Unpacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors, + * Unstacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors, * with memory of tensors shared. */ void UnstackShared(const LoDTensor &source) const; diff --git a/paddle/framework/var_desc.cc b/paddle/framework/var_desc.cc index 13b9c5f3cdf98e6d22f4217fa1cf9a48910a78d8..a88e813b5e7c7e6420cb0ba8a25bba4f4d658e80 100644 --- a/paddle/framework/var_desc.cc +++ b/paddle/framework/var_desc.cc @@ -32,5 +32,13 @@ std::vector VarDescBind::Shape() const { DataType VarDescBind::GetDataType() const { return desc_.lod_tensor().data_type(); } + +void VarDescBind::SetLoDLevel(int32_t lod_level) { + desc_.mutable_lod_tensor()->set_lod_level(lod_level); +} + +int32_t VarDescBind::GetLodLevel() const { + return desc_.lod_tensor().lod_level(); +} } // namespace framework } // namespace paddle diff --git a/paddle/framework/var_desc.h b/paddle/framework/var_desc.h index 4763bf09d004539ab24e4aad3bf429667f1fcc73..464fece85fe5c674690c2034054e551f14db2138 100644 --- a/paddle/framework/var_desc.h +++ b/paddle/framework/var_desc.h @@ -66,6 +66,10 @@ class VarDescBind { DataType GetDataType() const; + void SetLoDLevel(int32_t lod_level); + + int32_t GetLodLevel() const; + private: VarDesc desc_; }; diff --git a/paddle/math/tests/test_GpuProfiler.cpp b/paddle/math/tests/test_GpuProfiler.cpp index 9402bd3ec48fbed381ef1f676e8b179cabd4cb9f..d9f146f0d1f63480ddee784071b43ff85da0b15c 100644 --- a/paddle/math/tests/test_GpuProfiler.cpp +++ b/paddle/math/tests/test_GpuProfiler.cpp @@ -162,4 +162,4 @@ int main(int argc, char** argv) { return RUN_ALL_TESTS(); } -#endif /* PADDLE_ONLY_CPU */ +#endif diff --git a/paddle/memory/detail/buddy_allocator.cc b/paddle/memory/detail/buddy_allocator.cc index fdc5ed19dc2973e744676c3b795c8ab86da58590..e212f7737a4093125857126cabb5b1a7b3e055b1 100644 --- a/paddle/memory/detail/buddy_allocator.cc +++ b/paddle/memory/detail/buddy_allocator.cc @@ -182,7 +182,7 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() { max_chunk_size_ = platform::GpuMaxChunkSize(); } } -#endif // PADDLE_ONLY_CPU +#endif // Allocate a new maximum sized block size_t index = 0; diff --git a/paddle/memory/detail/system_allocator.cc b/paddle/memory/detail/system_allocator.cc index 6c9a46dd09c15347fca1a30971e7e732d887bc8e..33166d9ce23a4a345fc00a65adf63281b13643c3 100644 --- a/paddle/memory/detail/system_allocator.cc +++ b/paddle/memory/detail/system_allocator.cc @@ -134,7 +134,7 @@ void GPUAllocator::Free(void* p, size_t size, size_t index) { bool GPUAllocator::UseGpu() const { return true; } -#endif // PADDLE_ONLY_CPU +#endif } // namespace detail } // namespace memory diff --git a/paddle/memory/detail/system_allocator.h b/paddle/memory/detail/system_allocator.h index ee9b012f91a9647839cf465c4074082f2d3509a6..552cab4f96ff21a6f3c66209eb62150e92996826 100644 --- a/paddle/memory/detail/system_allocator.h +++ b/paddle/memory/detail/system_allocator.h @@ -51,7 +51,7 @@ class GPUAllocator : public SystemAllocator { size_t gpu_alloc_size_ = 0; size_t fallback_alloc_size_ = 0; }; -#endif // PADDLE_ONLY_CPU +#endif } // namespace detail } // namespace memory diff --git a/paddle/memory/detail/system_allocator_test.cc b/paddle/memory/detail/system_allocator_test.cc index cd563844e7fa23241bb0bb56d1365ef34826c4a8..6a8558937bf0c924e5f48605ff066e2789fd59b6 100644 --- a/paddle/memory/detail/system_allocator_test.cc +++ b/paddle/memory/detail/system_allocator_test.cc @@ -62,4 +62,4 @@ TEST(GPUAllocator, Alloc) { TestAllocator(a, 2048); TestAllocator(a, 0); } -#endif // PADDLE_ONLY_CPU +#endif diff --git a/paddle/memory/memcpy.cc b/paddle/memory/memcpy.cc index 790420a8ab41b1a61ee35dc086c8b95fa1a02019..1df88a6da9fb0c50d0d7ecd083c0533d8a886a67 100644 --- a/paddle/memory/memcpy.cc +++ b/paddle/memory/memcpy.cc @@ -89,7 +89,7 @@ void Copy(platform::GPUPlace dst_place, platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice); } -#endif // PADDLE_ONLY_CPU +#endif } // namespace memory } // namespace paddle diff --git a/paddle/memory/memcpy.h b/paddle/memory/memcpy.h index 0bccee58c3a22379c75523467e0c717b98b08bcf..9b36182c2b619317da31310141823442d8fd3f94 100644 --- a/paddle/memory/memcpy.h +++ b/paddle/memory/memcpy.h @@ -53,7 +53,7 @@ template void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, cudaStream_t stream); -#endif // PADDLE_ONLY_CPU +#endif } // namespace memory } // namespace paddle diff --git a/paddle/memory/memory.cc b/paddle/memory/memory.cc index 30ce8a82e16ed26a41b009ce5d52dd1a2a1b7c21..5087c02385f7f37d78d134b739f3f22522977fb8 100644 --- a/paddle/memory/memory.cc +++ b/paddle/memory/memory.cc @@ -111,7 +111,7 @@ size_t Used(platform::GPUPlace place) { return GetGPUBuddyAllocator(place.device)->Used(); } -#endif // PADDLE_ONLY_CPU +#endif } // namespace memory } // namespace paddle diff --git a/paddle/memory/memory_test.cc b/paddle/memory/memory_test.cc index 0d402038a06f4ad93fd15946fc44aaeac58ada40..2444931e26774ae80b916fbb7bd46ff93025d9ed 100644 --- a/paddle/memory/memory_test.cc +++ b/paddle/memory/memory_test.cc @@ -135,4 +135,4 @@ TEST(BuddyAllocator, GPUMultAlloc) { } } -#endif // PADDLE_ONLY_CPU +#endif diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 0fa1fca2bcd3117e1e9a6a54c343b2d0d8c3822b..ad941bde2be3bbbc6d910fff262ea4cb3878f8be 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -55,12 +55,20 @@ function(op_library TARGET) set(pybind_flag 1) endif() + # pool_op contains several operators if ("${TARGET}" STREQUAL "pool_op") set(pybind_flag 1) # It's enough to just adding one operator to pybind file(APPEND ${pybind_file} "USE_OP(pool2d);\n") endif() + # pool_with_index_op contains several operators + if ("${TARGET}" STREQUAL "pool_with_index_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") + endif() + # activation_op contains several operators if ("${TARGET}" STREQUAL "activation_op") set(pybind_flag 1) @@ -104,7 +112,9 @@ set(DEPS_OPS cond_op cross_entropy_op softmax_with_cross_entropy_op - sum_op) + sum_op + pool_op + pool_with_index_op) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc @@ -113,6 +123,8 @@ op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) op_library(cross_entropy_op DEPS cross_entropy) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(sum_op DEPS net_op) +op_library(pool_op DEPS pooling) +op_library(pool_with_index_op DEPS pooling) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) @@ -125,3 +137,4 @@ cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory) +cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array) diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index 5e5df49b0788ce8422dfb0d82791ec8c0a7ee32d..ced14a8923140ec6b08e3e6725a5780b61033daf 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -49,6 +49,18 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { } }; +class LogSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LogSigmoidOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of LogSigmoid operator"); + AddOutput("Y", "Output of LogSigmoid operator"); + AddComment( + "Logsigmoid activation operator, logsigmoid = log (1 / (1 + exp(-x)))"); + } +}; + class ExpOpMaker : public framework::OpProtoAndCheckerMaker { public: ExpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) @@ -85,6 +97,23 @@ class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker { } }; +template +class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SoftShrinkOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Softshrink operator"); + AddOutput("Y", "Output of Softshrink operator"); + AddComment( + "Softshrink activation operator, " + "softshrink = x - lambda, if x > lambda;" + " x + lambda, if x < lambda; 0 otherwise"); + AddAttr("lambda", "non-negative offset") + .SetDefault(static_cast(0.5f)); + } +}; + class TanhOpMaker : public framework::OpProtoAndCheckerMaker { public: TanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) @@ -108,6 +137,24 @@ class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker { } }; +template +class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker { + public: + HardShrinkOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of HardShrink operator"); + AddOutput("Y", "Output of HardShrink operator"); + AddComment( + "HardShrink activation operator, " + "hard_shrink(x) = x if x > lambda" + "hard_shrink(x) = x if x < -lambda" + "hard_shrink(x) = 0 otherwise"); + AddAttr("threshold", "The value of threshold for HardShrink") + .SetDefault(static_cast(0.5)); + } +}; + class SqrtOpMaker : public framework::OpProtoAndCheckerMaker { public: SqrtOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) @@ -159,6 +206,17 @@ class SquareOpMaker : public framework::OpProtoAndCheckerMaker { } }; +class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SoftplusOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Softplus operator"); + AddOutput("Y", "Output of Softplus operator"); + AddComment("Softplus activation operator, softplus(x) = log(1 + exp(x))"); + } +}; + class SoftsignOpMaker : public framework::OpProtoAndCheckerMaker { public: SoftsignOpMaker(framework::OpProto *proto, @@ -201,6 +259,27 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker { } }; +template +class ELUOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ELUOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(Tensor) The input of ELU operator, it shouldn't be empty. Input " + "is flattened and treated as a 1D array."); + AddOutput("Y", + "(Tensor) The output of ELU operator. It has the same shape as " + "the input."); + AddAttr( + "alpha", "(float, default 1.0) Alpha value in the elu formulation.") + .SetDefault(static_cast(1.)); + AddComment(R"DOC( + ELU activation operator. It applies this element-wise computation on + the input: f(x) = max(0, x) + min(0, alpha * (exp(x) - 1)). + Check .. _Link: https://arxiv.org/abs/1511.07289 for more details.)DOC"); + } +}; + template class Relu6OpMaker : public framework::OpProtoAndCheckerMaker { public: @@ -250,6 +329,9 @@ namespace ops = paddle::operators; REGISTER_OP(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, sigmoid_grad, ops::ActivationOpGrad); +REGISTER_OP(logsigmoid, ops::ActivationOp, ops::LogSigmoidOpMaker, + logsigmoid_grad, ops::ActivationOpGrad); + REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad, ops::ActivationOpGrad); @@ -262,6 +344,9 @@ REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad, REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker, tanh_shrink_grad, ops::ActivationOpGrad); +REGISTER_OP(softshrink, ops::ActivationOp, ops::SoftShrinkOpMaker, + softshrink_grad, ops::ActivationOpGrad); + REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad, ops::ActivationOpGrad); @@ -277,6 +362,9 @@ REGISTER_OP(log, ops::ActivationOp, ops::LogOpMaker, log_grad, REGISTER_OP(square, ops::ActivationOp, ops::SquareOpMaker, square_grad, ops::ActivationOpGrad); +REGISTER_OP(softplus, ops::ActivationOp, ops::SoftplusOpMaker, softplus_grad, + ops::ActivationOpGrad); + REGISTER_OP(softsign, ops::ActivationOp, ops::SoftsignOpMaker, softsign_grad, ops::ActivationOpGrad); @@ -289,6 +377,9 @@ REGISTER_OP(leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker, REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker, soft_relu_grad, ops::ActivationOpGrad); +REGISTER_OP(elu, ops::ActivationOp, ops::ELUOpMaker, elu_grad, + ops::ActivationOpGrad); + REGISTER_OP(relu6, ops::ActivationOp, ops::Relu6OpMaker, relu6_grad, ops::ActivationOpGrad); @@ -298,6 +389,9 @@ REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker, pow_grad, REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker, stanh_grad, ops::ActivationOpGrad); +REGISTER_OP(hard_shrink, ops::ActivationOp, ops::HardShrinkOpMaker, + hard_shrink_grad, ops::ActivationOpGrad); + #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ REGISTER_OP_CPU_KERNEL( \ act_type, \ diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index f127468125c265e5be7aec1f55f83fa5ba9be65a..f88c9c48eb9fcb779de5a99a45a832e582d76ab0 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -95,6 +95,41 @@ struct SigmoidGradFunctor : public BaseActivationFunctor { } }; +// Originally: logsigmoid(x) = -log (1 + exp(-x)) +// For numerical stability, we can use the log-sum-exp trick: +// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ +// We can rewrite the above equation as: +// y = -log( exp(0) + exp(-x)) [since exp(0) = 1] +// = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0))) +// = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x - +// max(-x, 0))) +// = -log( exp(max(-x, 0)) * (exp(-max(-x, 0)) + exp(-x - max(-x, 0)))) +// = -log( exp(max(-x, 0)) - log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))) +// +// Hence, logsigmoid(x) = - (max(-x, 0) + log(exp(-max(-x, 0)) +// + exp(-x - max(-x, 0)))) +template +struct LogSigmoidFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Y y) const { + auto temp = (-x).cwiseMax(static_cast(0)); // temp = max(-x, 0) + y.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log()); + } +}; + +// Originally: f' = exp(-x) / (1 + exp(-x)) +// For numerical stability: f' = exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + +// exp(-x - max(-x, 0))) +template +struct LogSigmoidGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) const { + auto temp = (-x).cwiseMax(static_cast(0)); // temp = max(-x, 0) + dx.device(d) = + dy * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp())); + } +}; + // exp(x) = e^x template struct ExpFunctor : public BaseActivationFunctor { @@ -164,6 +199,70 @@ struct TanhShrinkGradFunctor : public BaseActivationFunctor { } }; +// tanhshrink(x) = x - tanh(x) +// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) +template +struct HardShrinkFunctor : public BaseActivationFunctor { + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + template + void operator()(Device d, X x, Y y) const { + auto temp1 = (x < (threshold * -1)).template cast().eval(); + auto temp2 = (x > threshold).template cast().eval(); + y.device(d) = x * (temp1 + temp2); + } +}; + +template +struct HardShrinkGradFunctor : public BaseActivationFunctor { + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + template + void operator()(Device d, X x, Y y, dY dy, dX dx) const { + auto temp1 = (x < (threshold * -1)).template cast().eval(); + auto temp2 = (x > threshold).template cast().eval(); + dx.device(d) = dy * (temp1 + temp2).template cast(); + } +}; + +// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < lambda; 0 +// otherwise +template +struct SoftShrinkFunctor : public BaseActivationFunctor { + float lambda; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"lambda", &lambda}}; + } + + template + void operator()(Device d, X x, Y y) const { + auto temp1 = (x > lambda).template cast().eval(); + auto temp2 = (x < -lambda).template cast().eval(); + y.device(d) = temp1 * (x - lambda) + temp2 * (x + lambda); + } +}; + +template +struct SoftShrinkGradFunctor : public BaseActivationFunctor { + float lambda; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"lambda", &lambda}}; + } + template + void operator()(Device d, X x, Y y, dY dy, dX dx) const { + auto temp1 = (x > lambda).template cast().eval(); + auto temp2 = (x < -lambda).template cast().eval(); + dx.device(d) = dy * (temp1 + temp2).template cast(); + } +}; + // sqrt(x) = x^(1/2) template struct SqrtFunctor : public BaseActivationFunctor { @@ -285,8 +384,6 @@ template struct Relu6Functor : public BaseActivationFunctor { float threshold; - // NOTE: Explicit hides the `BaseActivationFunctor::GetAttrs` - // not polymorphism for speed. typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } @@ -310,6 +407,33 @@ struct Relu6GradFunctor : public BaseActivationFunctor { } }; +// softplus(x) = log(1 + exp(x)) +// When x is a very large positive number, exp(x) may explode to inf, +// Using trick below for numerical stability +// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ +// Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0))) +template +struct SoftplusFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Y y) { + auto temp = x.cwiseMax(static_cast(0)); // temp = max(x, 0) + y.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log()); + } +}; + +// d(softplus(x))/dx = exp(x) / (1 + exp(x)) +// For numerical stability: +// d(softplus(x))/dx = exp(x - max(x, 0)) / (exp(-max(x, 0)) + +// exp(x - max(x, 0))) +template +struct SoftplusGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + auto temp = x.cwiseMax(static_cast(0)); // temp = max(x, 0) + dx.device(d) = dy * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp())); + } +}; + // softsign(x) = x / (1 + |x|) template struct SoftsignFunctor : public BaseActivationFunctor { @@ -384,6 +508,35 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor { } }; +template +struct ELUFunctor : public BaseActivationFunctor { + float alpha; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + template + void operator()(Device d, X x, Y y) const { + y.device(d) = + x.cwiseMax(static_cast(0)) + + (alpha * (x.exp() - static_cast(1))).cwiseMin(static_cast(0)); + } +}; + +template +struct ELUGradFunctor : public BaseActivationFunctor { + float alpha; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + template + void operator()(Device d, X x, Y y, dY dy, dX dx) const { + dx.device(d) = + dy * (x > static_cast(0)).template cast() + + dy * (y + alpha) * (x < static_cast(0)).template cast(); + } +}; + template struct PowFunctor : public BaseActivationFunctor { float factor; @@ -440,21 +593,26 @@ struct STanhGradFunctor : public BaseActivationFunctor { } // namespace operators } // namespace paddle -#define FOR_EACH_KERNEL_FUNCTOR(__macro) \ - __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ - __macro(exp, ExpFunctor, ExpGradFunctor); \ - __macro(relu, ReluFunctor, ReluGradFunctor); \ - __macro(tanh, TanhFunctor, TanhGradFunctor); \ - __macro(sqrt, SqrtFunctor, SqrtGradFunctor); \ - __macro(abs, AbsFunctor, AbsGradFunctor); \ - __macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ - __macro(log, LogFunctor, LogGradFunctor); \ - __macro(square, SquareFunctor, SquareGradFunctor); \ - __macro(brelu, BReluFunctor, BReluGradFunctor); \ - __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \ - __macro(pow, PowFunctor, PowGradFunctor); \ - __macro(stanh, STanhFunctor, STanhGradFunctor); \ - __macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \ - __macro(relu6, Relu6Functor, Relu6GradFunctor); \ - __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \ - __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor) +#define FOR_EACH_KERNEL_FUNCTOR(__macro) \ + __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ + __macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ + __macro(exp, ExpFunctor, ExpGradFunctor); \ + __macro(relu, ReluFunctor, ReluGradFunctor); \ + __macro(tanh, TanhFunctor, TanhGradFunctor); \ + __macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ + __macro(sqrt, SqrtFunctor, SqrtGradFunctor); \ + __macro(abs, AbsFunctor, AbsGradFunctor); \ + __macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ + __macro(log, LogFunctor, LogGradFunctor); \ + __macro(square, SquareFunctor, SquareGradFunctor); \ + __macro(brelu, BReluFunctor, BReluGradFunctor); \ + __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \ + __macro(pow, PowFunctor, PowGradFunctor); \ + __macro(stanh, STanhFunctor, STanhGradFunctor); \ + __macro(softplus, SoftplusFunctor, SoftplusGradFunctor); \ + __macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \ + __macro(relu6, Relu6Functor, Relu6GradFunctor); \ + __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \ + __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \ + __macro(elu, ELUFunctor, ELUGradFunctor); \ + __macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor) diff --git a/paddle/operators/conv_shift_op.cc b/paddle/operators/conv_shift_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e1e321ed5fce6ce4e4089cc5c5e488a2cbad6c82 --- /dev/null +++ b/paddle/operators/conv_shift_op.cc @@ -0,0 +1,206 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv_shift_op.h" +#include "paddle/framework/eigen.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +class ConvShiftOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(y_dims.size(), 2, "Input(Y)'s rank should be 2."); + PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], + "The 1st dimension of Input(X) and Input(Y) should " + "be equal."); + PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1, + "The 2nd dimension of Input(Y) should be odd."); + PADDLE_ENFORCE_LE(y_dims[1], x_dims[1], + "The 2nd dimension of Input(Y) should be less than or " + "equal to the 2nd dimension of Input(X)."); + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class ConvShiftGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should be not null."); + + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + auto x_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim(x_grad_name, x_dims); + } + + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(y_grad_name)) { + auto y_dims = ctx->GetInputDim("Y"); + ctx->SetOutputDim(y_grad_name, y_dims); + } + } +}; + +class ConvShiftOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ConvShiftOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(Tensor, default Tensor), a 2-D tensor with shape B x M, " + "where B is the batch size and M is the data dimension."); + AddInput("Y", + "(Tensor, default Tensor), a 2-D tensor with shape B x N, " + "where B is the batch size and N is the data dimension. N must " + "be odd."); + AddOutput("Out", + "(Tensor, default Tensor), a 2-D tensor with shape B x M, " + "i.e., the same shape as X."); + AddComment(R"DOC( +ConvShift Operator. + +A layer for circular convolution of two vectors, +as used in the Neural Turing Machine: https://arxiv.org/abs/1410.5401 + +The equation is: + + \f[ + Out[i] = \sum_{j=-(N-1)/2}^{(N-1)/2} X_{i+j} * Y_{j} + \f] + +where X's index is computed modulo M, and b's index is computed modulo N. + +Both of the input `X` and `Y` can carry LoD (Level of Details) information. +However, the output only shares the LoD information with input `X`. +)DOC"); + } +}; + +template +class ConvShiftKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *X = context.Input("X"); + auto *Y = context.Input("Y"); + auto *Out = context.Output("Out"); + Out->mutable_data(context.GetPlace()); + + auto x = EigenMatrix::From(*X); + auto y = EigenMatrix::From(*Y); + auto out = EigenMatrix::From(*Out); + out.setZero(); + + size_t batch_size = X->dims()[0]; + size_t x_width = X->dims()[1]; + size_t y_width = Y->dims()[1]; + size_t y_half_width = (y_width - 1) / 2; + + for (size_t k = 0; k < batch_size; ++k) { + for (size_t i = 0; i < x_width; ++i) { + for (size_t j = 0; j < y_width; ++j) { + int index = (i + j - y_half_width + x_width) % x_width; + out(k, i) += x(k, index) * y(k, j); + } + } + } + } +}; + +template +class ConvShiftGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *X = context.Input("X"); + auto *Y = context.Input("Y"); + auto *dOut = context.Input(framework::GradVarName("Out")); + auto *dX = context.Output(framework::GradVarName("X")); + auto *dY = context.Output(framework::GradVarName("Y")); + + auto x = EigenMatrix::From(*X); + auto y = EigenMatrix::From(*Y); + auto dout = EigenMatrix::From(*dOut); + + auto x_dims = X->dims(); + auto y_dims = Y->dims(); + size_t batch_size = x_dims[0]; + size_t x_width = x_dims[1]; + size_t y_width = y_dims[1]; + size_t y_half_width = (y_width - 1) / 2; + + // The below trades code duplication for efficiency (keeping the if + // statement outside of the loop). + if (dX) { + dX->mutable_data(context.GetPlace()); + auto dx = EigenMatrix::From(*dX); + dx.setZero(); + for (size_t k = 0; k < batch_size; ++k) { + for (size_t i = 0; i < x_width; ++i) { + for (size_t j = 0; j < y_width; ++j) { + int index = (i + j - y_half_width + x_width) % x_width; + dx(k, index) += dout(k, i) * y(k, j); + } + } + } + } + + if (dY) { + dY->mutable_data(context.GetPlace()); + auto dy = EigenMatrix::From(*dY); + dy.setZero(); + for (size_t k = 0; k < batch_size; ++k) { + for (size_t i = 0; i < x_width; ++i) { + for (size_t j = 0; j < y_width; ++j) { + int index = (i + j - y_half_width + x_width) % x_width; + dy(k, j) += x(k, index) * dout(k, i); + } + } + } + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(conv_shift, ops::ConvShiftOp, ops::ConvShiftOpMaker, + conv_shift_grad, ops::ConvShiftGradOp); +REGISTER_OP_CPU_KERNEL(conv_shift, + ops::ConvShiftKernel); +REGISTER_OP_CPU_KERNEL( + conv_shift_grad, + ops::ConvShiftGradKernel); diff --git a/paddle/operators/conv_shift_op.cu b/paddle/operators/conv_shift_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..145e966fe9caa68f7485bb258fa78fd34bfd4c04 --- /dev/null +++ b/paddle/operators/conv_shift_op.cu @@ -0,0 +1,194 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv_shift_op.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +namespace { + +inline int div_up(int x, int y) { return (x + y - 1) / y; } + +// Some notes on the design: +// +// Each thread is responsible for computing a single output out[k, i]. +// Thread blocks are based on tiles of x with height 1 in the batch dimension. +// +// This design is based on the typical use case where the filter +// y is fairly small. For large y, it would probably be more efficient +// to also tile across y. +template +__global__ void conv_shift_forward(const T *x, const T *y, T *out, int x_width, + int y_width, int y_half_width, + int batch_size) { + extern __shared__ T mem[]; + + int tx = threadIdx.x; + int i = blockIdx.x * blockDim.x + tx; // global x index + int k = blockIdx.y; // batch index + + // Check if we are in a boundary block with fewer x's to process than + // blockDim.x. + int num_x = + (blockIdx.x == gridDim.x - 1) ? (x_width % blockDim.x) : blockDim.x; + + T *sx = mem; + T *sx_pad = &mem[num_x]; + T *sy = &mem[blockDim.x + y_width]; + + // Collaboratively load y[k, :] and length-y padding of x into shared memory. + int pad_start = blockIdx.x * blockDim.x + num_x + x_width - y_half_width; + for (int j = tx; j < y_width; j += blockDim.x) { + sy[j] = y[k * y_width + j]; + sx_pad[j] = x[k * x_width + (pad_start + j) % x_width]; + } + + // Load a cyclically shifted slice of x into shared memory. + if (tx < num_x) { + int load_i = (i - y_half_width + x_width) % x_width; + sx[tx] = x[k * x_width + load_i]; + } else { + return; + } + __syncthreads(); + + // Compute dot product of sx[tx:tx + y_width] and sy. + T sum = 0; + for (int j = 0; j < y_width; ++j) { + sum += sx[tx + j] * sy[j]; + } + + // Save to out[k, i]. + out[k * x_width + i] = sum; +} + +// Compute x gradient - initial naive implementation with atomic add. +template +__global__ void conv_shift_dx(const T *dout, const T *y, T *dx, int x_width, + int y_width, int y_half_width, int batch_size) { + int i = blockIdx.x * blockDim.x + threadIdx.x; // x index + int j = blockIdx.y; // y index + int k = blockIdx.z; // batch index + + if (i < x_width) { + int index = (i + j - y_half_width + x_width) % x_width; + atomicAdd(&dx[k * x_width + index], + dout[k * x_width + i] * y[k * y_width + j]); + } +} + +// Compute y gradient - initial naive implementation with atomic add. +template +__global__ void conv_shift_dy(const T *x, const T *dout, T *dy, int x_width, + int y_width, int y_half_width, int batch_size) { + int i = blockIdx.x * blockDim.x + threadIdx.x; // x index + int j = blockIdx.y; // y index + int k = blockIdx.z; // batch index + + if (i < x_width) { + int index = (i + j - y_half_width + x_width) % x_width; + atomicAdd(&dy[k * y_width + j], + x[k * x_width + index] * dout[k * x_width + i]); + } +} +} // namespace + +template +class ConvShiftKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const Tensor *X = context.Input("X"); + const Tensor *Y = context.Input("Y"); + Tensor *Out = context.Output("Out"); + const T *x_data = X->data(); + const T *y_data = Y->data(); + T *out_data = Out->mutable_data(context.GetPlace()); + + int batch_size = X->dims()[0]; + int x_width = X->dims()[1]; + int y_width = Y->dims()[1]; + int y_half_width = (y_width - 1) / 2; + + const int x_per_block = 256; + int num_x_blocks = div_up(x_width, x_per_block); + int mem_per_block = (x_per_block + 2 * y_width) * sizeof(T); + + dim3 grid_dim(num_x_blocks, batch_size); + + auto stream = reinterpret_cast( + context.device_context()) + .stream(); + + conv_shift_forward<<>>( + x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size); + } +}; + +template +class ConvShiftGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const Tensor *X = context.Input("X"); + const Tensor *Y = context.Input("Y"); + const Tensor *dOut = context.Input(framework::GradVarName("Out")); + const T *x_data = X->data(); + const T *y_data = Y->data(); + const T *dout_data = dOut->data(); + + Tensor *dX = context.Output(framework::GradVarName("X")); + Tensor *dY = context.Output(framework::GradVarName("Y")); + + int batch_size = X->dims()[0]; + int x_width = X->dims()[1]; + int y_width = Y->dims()[1]; + int y_half_width = (y_width - 1) / 2; + + auto stream = reinterpret_cast( + context.device_context()) + .stream(); + + const int x_per_block = 256; + int num_x_blocks = div_up(x_width, x_per_block); + dim3 grid_dim(num_x_blocks, y_width, batch_size); + + if (dX) { + T *dx_data = dX->mutable_data(context.GetPlace()); + cudaMemsetAsync(dx_data, 0, dX->numel() * sizeof(T), stream); + conv_shift_dx<<>>( + dout_data, y_data, dx_data, x_width, y_width, y_half_width, + batch_size); + } + if (dY) { + T *dy_data = dY->mutable_data(context.GetPlace()); + cudaMemsetAsync(dy_data, 0, dY->numel() * sizeof(T), stream); + conv_shift_dy<<>>( + x_data, dout_data, dy_data, x_width, y_width, y_half_width, + batch_size); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(conv_shift, + ops::ConvShiftKernel); +REGISTER_OP_GPU_KERNEL( + conv_shift_grad, + ops::ConvShiftGradKernel); diff --git a/paddle/operators/conv_shift_op.h b/paddle/operators/conv_shift_op.h new file mode 100644 index 0000000000000000000000000000000000000000..5a160b0f1696c70868fc48d219b38cde2018e8a3 --- /dev/null +++ b/paddle/operators/conv_shift_op.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class ConvShiftKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override; +}; + +template +class ConvShiftGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override; +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/dynamic_recurrent_op.cc b/paddle/operators/dynamic_recurrent_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b919aef8fb62e5b2331c2d842556e0642ea6b095 --- /dev/null +++ b/paddle/operators/dynamic_recurrent_op.cc @@ -0,0 +1,276 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve . + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/dynamic_recurrent_op.h" + +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::Scope; +using framework::TensorArray; +using framework::LoDTensor; +using framework::Variable; + +namespace detail { + +inline void CreateVariables(Scope& scope, + const std::vector& var_names) { + for (const auto& name : var_names) { + scope.NewVar(name); + } +} + +} // namespace detail + +class DynamicRecurrentOpProtoAndCheckerMaker + : public framework::OpProtoAndCheckerMaker { + public: + DynamicRecurrentOpProtoAndCheckerMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + const auto& name = DynamicRecurrentOp::kArgName; + // inputs and outputs stored in proto + AddInput(name.inlinks, + "the inputs that need to be segmented for each step.") + .AsDuplicable(); + AddInput(name.boot_memories, "variables to initialize memories.") + .AsDuplicable(); + + AddOutput(name.outlinks, "the outputs that need to concated for all steps.") + .AsDuplicable(); + AddOutput(name.step_scopes, "step scopes"); + + // Attributes stored in AttributeMap + AddAttr>(name.pre_memories, + "names of pre-memories"); + AddAttr>(name.memories, "names of memories"); + + AddComment("This is a RNN operator for varience-length sequences."); + } +}; + +void DynamicRecurrentOp::Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const { + cache_.Init(kArgName, *this, scope, &arg_); + SplitInputs(); + CreateScopes(); + WriteStepInputs(); + InitStates(); + + // call stepnet in all the time steps + for (size_t step = 0; step < cache_.num_steps; step++) { + auto& step_scope = cache_.GetScope(step); + stepnet_->Run(step_scope, dev_ctx); + } + + WriteStepOutputs(); + ConcatOutputs(); +} + +void DynamicRecurrentOp::SplitInputs() const { + // TODO(superjom) make level a config + // TODO(superjom) check all the inputs has the same LoD + int level = 0; + const auto& inlinks = cache_.inlinks; + for (const auto& item : inlinks) { + const auto& var = item.second; + const auto& tensor = var->Get(); + TensorArray& ta = step_inputs_[item.first]; + dy_seq_metas_[item.first] = + ta.Unpack(tensor, level, true /*length_descend*/); + + if (cache_.num_steps) { + PADDLE_ENFORCE_EQ(ta.size(), cache_.num_steps, + "inputs should have the same steps"); + } else { + cache_.num_steps = ta.size(); + } + } +} + +void DynamicRecurrentOp::WriteStepInputs() const { + for (const auto& item : cache_.inlinks) { + auto ta_it = step_inputs_.find(item.first); + PADDLE_ENFORCE(ta_it != step_inputs_.end(), + "step_inputs_ not compatible with memory set"); + TensorArray& ta = ta_it->second; + for (size_t step = 0; step < ta.size(); step++) { + auto tensor = ta.Read(step); + auto& step_scope = cache_.GetScope(step); + Variable* var = step_scope.FindVar(item.first); + if (var == nullptr) { + var = step_scope.NewVar(item.first); + } + var->GetMutable()->ShareDataWith(tensor); + } + } +} + +void DynamicRecurrentOp::WriteStepOutputs() const { + for (size_t step = 0; step < cache_.scopes->size(); step++) { + auto& scope = cache_.GetScope(step); + for (auto& item : step_outputs_) { + auto* var = scope.FindVar(item.first); + if (var == nullptr) { + var = scope.NewVar(item.first); + } + auto* tensor = var->GetMutable(); + item.second.WriteShared(step, *tensor); + } + } +} + +void DynamicRecurrentOp::CreateScopes() const { + PADDLE_ENFORCE_GT(cache_.num_steps, 0); + // resize scopes + size_t num_scopes_need_create = cache_.num_steps - cache_.scopes->size(); + for (size_t i = 0; i < num_scopes_need_create; i++) { + cache_.scopes->emplace_back(&cache_.scope->NewScope()); + } + + // init temporary inputs + PADDLE_ENFORCE_NOT_NULL(stepnet_, "stepnet should be set first"); + std::vector memories; + std::vector pre_memories; + std::transform(arg_.memories.begin(), arg_.memories.end(), + std::back_inserter(memories), + [](const rnn::MemoryAttr& m) { return m.var; }); + std::transform(arg_.memories.begin(), arg_.memories.end(), + std::back_inserter(pre_memories), + [](const rnn::MemoryAttr& m) { return m.pre_var; }); + + for (size_t step = 0; step < cache_.num_steps; step++) { + auto& scope = cache_.GetScope(step); + detail::CreateVariables(scope, arg_.inlinks); + detail::CreateVariables(scope, arg_.outlinks); + detail::CreateVariables(scope, memories); + detail::CreateVariables(scope, pre_memories); + } +} + +void DynamicRecurrentOp::ConcatOutputs() const { + // TODO(superjom) transform this to a config + int level = 0; + // TODO(superjom) pass in some lod + // just a placeholder + framework::LoD lod; + for (auto& item : step_outputs_) { + auto tensor = item.second.Pack(level, dy_seq_metas_[item.first], lod); + auto& output = cache_.outlinks[item.first]->Get(); + const_cast(&output)->ShareDataWith(tensor); + } +} + +void DynamicRecurrentOp::InitStates() const { + // init the first state + // TODO(superjom) parepare the scenerio that boot state not exists + for (auto memory : arg_.memories) { + auto* boot_state_var = cache_.scope->FindVar(memory.boot_var); + PADDLE_ENFORCE_NOT_NULL(boot_state_var); + auto& boot_state = boot_state_var->Get(); + const auto& dims = boot_state.dims(); + + for (size_t step = 0; step < cache_.num_steps; step++) { + auto& cur_scope = cache_.GetScope(step); + // link pre-state to boot_state + // init state and pre-state + auto* pre_state = cur_scope.FindVar(memory.pre_var); + PADDLE_ENFORCE_NOT_NULL(pre_state); + pre_state->GetMutable(); + + auto* state = cur_scope.FindVar(memory.var); + PADDLE_ENFORCE_NOT_NULL(state); + state->GetMutable()->Resize(dims); + state->GetMutable()->mutable_data( + platform::CPUPlace()); + + if (step == 0) { + auto* pre_state_tensor = pre_state->GetMutable(); + pre_state_tensor->Resize(boot_state.dims()); + pre_state_tensor->ShareDataWith(boot_state); + } else { + auto& pre_scope = cache_.GetScope(step - 1); + auto* state_pre = pre_scope.FindVar(memory.var); + PADDLE_ENFORCE_NOT_NULL(state_pre); + pre_state->GetMutable()->ShareDataWith( + *state_pre->GetMutable()); + } + } + } +} + +void DynamicRecurrentOp::ArgCache::Init( + const rnn::ArgumentName& name, const paddle::framework::OperatorBase& op, + const paddle::framework::Scope& scope, rnn::Argument* arg) { + this->scope = &scope; + InitArgument(name, op, arg); + CacheScopes(scope, *arg); + CacheInlinks(scope, arg->inlinks); + CacheOutlinks(scope, arg->outlinks); +} + +void DynamicRecurrentOp::ArgCache::InitArgument(const rnn::ArgumentName& name, + const OperatorBase& op, + rnn::Argument* arg) { + rnn::InitArgument(name, arg, op, false /*is_grad*/); +} + +void DynamicRecurrentOp::ArgCache::CacheScopes(const Scope& scope, + const rnn::Argument& arg) { + auto scopes_var = scope.FindVar(arg.step_scopes); + PADDLE_ENFORCE(scopes_var != nullptr, + "the step_scopes output argument [%s] should be created first " + "by framework.", + arg.step_scopes); + this->scopes = scopes_var->GetMutable>(); +} + +void DynamicRecurrentOp::ArgCache::CacheInlinks( + const Scope& scope, const std::vector& names) { + for (auto name : names) { + auto* var = GetVariable(scope, name); + inlinks[name] = var; + } +} + +void DynamicRecurrentOp::ArgCache::CacheOutlinks( + const Scope& scope, const std::vector& names) { + for (auto name : names) { + auto* var = GetVariable(scope, name); + outlinks[name] = var; + } +} + +Variable* DynamicRecurrentOp::ArgCache::GetVariable(const Scope& scope, + const std::string& name) { + auto* var = scope.FindVar(name); + PADDLE_ENFORCE_NOT_NULL(var, "variable [%s] not exist in scope", name); + return var; +} + +const rnn::ArgumentName DynamicRecurrentOp::kArgName{ + "step_net", "step_scopes", "inlinks", "outlinks", + "memories", "pre_memories", "boot_memories"}; + +void DynamicRecurrentGradientOp::Run( + const Scope& scope, const platform::DeviceContext& dev_ctx) const {} + +} // namespace operators +} // namespace paddle + +REGISTER_OP_WITHOUT_GRADIENT( + dynamic_recurrent, paddle::operators::DynamicRecurrentOp, + paddle::operators::DynamicRecurrentOpProtoAndCheckerMaker); diff --git a/paddle/operators/dynamic_recurrent_op.h b/paddle/operators/dynamic_recurrent_op.h new file mode 100644 index 0000000000000000000000000000000000000000..6a2970f27fd5bcb25e924dbc567e254159b55a3e --- /dev/null +++ b/paddle/operators/dynamic_recurrent_op.h @@ -0,0 +1,158 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#ifdef PADDLE_WITH_TESTING +#include "gtest/gtest.h" +#endif + +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/tensor_array.h" +#include "paddle/framework/variable.h" +#include "paddle/operators/rnn/recurrent_op_utils.h" + +namespace paddle { +namespace operators { + +class DynamicRecurrentOp : public framework::OperatorBase { + public: + static const rnn::ArgumentName kArgName; + using value_type = float; + + DynamicRecurrentOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + DynamicRecurrentOp(const DynamicRecurrentOp& o) + : framework::OperatorBase( + static_cast(o)) { + // TODO(yuyang18): Implement copy ctor well. + PADDLE_THROW("Not implemented"); + } + + void Run(const framework::Scope& scope, + const platform::DeviceContext& dev_ctx) const override; + + /* + * Split the inputs(LoDTensors) to segments for each time step. + */ + void SplitInputs() const; + + /* + * Create step-scopes to store temporary outputs in each time steps. + */ + void CreateScopes() const; + + /* + * Link TensorArray steps to the corresponding variables located in + * step-scopes. + */ + void WriteStepInputs() const; + + /* + * Write output of each step to the corresponding TensorArray. + */ + void WriteStepOutputs() const; + + /* + * Initialize the states, each state will have a corresponding pre-state, + * which share the memory with the state in the previous time state. The + * pre-state in the first time step will be initialized with an zero tensor or + * a tensor in parent scope if is provided. + */ + void InitStates() const; + + /* + * Concatenate outputs in each time step and generate a LoDTensor. + */ + void ConcatOutputs() const; + + /* + * set a stepnet that is created according to a RecurrentOp's stepnet. + */ + void SetStepNet(std::unique_ptr net) { + PADDLE_ENFORCE_NOT_NULL(net); + stepnet_ = std::move(net); + } + const OperatorBase& GetStepNet() const { return *stepnet_; } + + protected: + struct ArgCache { + framework::Scope const* scope; + std::vector* scopes; + std::map inlinks; + std::map outlinks; + + size_t num_steps{0}; + + void Init(const rnn::ArgumentName& name, const OperatorBase& op, + const framework::Scope& scope, rnn::Argument* arg); + + framework::Scope& GetScope(size_t index) { + PADDLE_ENFORCE_LT(index, num_steps); + return *scopes->at(index); + } + + private: + void InitArgument(const rnn::ArgumentName& name, const OperatorBase& op, + rnn::Argument* arg); + void CacheScopes(const framework::Scope& scope, const rnn::Argument& arg); + void CacheInlinks(const framework::Scope& scope, + const std::vector& names); + void CacheOutlinks(const framework::Scope& scope, + const std::vector& names); + framework::Variable* GetVariable(const framework::Scope& scope, + const std::string& name); + }; + + private: + std::unique_ptr stepnet_; + mutable framework::TensorArray states_; + mutable std::map step_inputs_; + mutable std::map step_outputs_; + mutable std::map> + dy_seq_metas_; + mutable rnn::Argument arg_; + mutable ArgCache cache_; + +#ifdef PADDLE_WITH_TESTING + friend class DynamicRecurrentOpTestHelper; + FRIEND_TEST(DynamicRecurrentOpTestHelper, SplitInputs); + FRIEND_TEST(DynamicRecurrentOpTestHelper, CreateCache); + FRIEND_TEST(DynamicRecurrentOpTestHelper, CreateScopes); + FRIEND_TEST(DynamicRecurrentOpTestHelper, WriteStepInputs); + FRIEND_TEST(DynamicRecurrentOpTestHelper, WriteStepOutputs); + FRIEND_TEST(DynamicRecurrentOpTestHelper, InitStates); + FRIEND_TEST(DynamicRecurrentOpTestHelper, ConcatOutputs); +#endif +}; + +class DynamicRecurrentGradientOp : public framework::OperatorBase { + public: + DynamicRecurrentGradientOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void Run(const framework::Scope& scope, + const platform::DeviceContext& dev_ctx) const override; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/dynamic_recurrent_op_test.cc b/paddle/operators/dynamic_recurrent_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..675a7890f3fa6bb7ab9dbbdb04894b2557214a8a --- /dev/null +++ b/paddle/operators/dynamic_recurrent_op_test.cc @@ -0,0 +1,222 @@ +#include "paddle/operators/dynamic_recurrent_op.h" + +#include + +#include "paddle/framework/ddim.h" +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/op_desc.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/net_op.h" + +namespace paddle { +namespace operators { + +using framework::Scope; +using framework::TensorArray; +using framework::LoDTensor; +using framework::Variable; + +class TestOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + DEFINE_OP_CLONE_METHOD(TestOp); + void Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const override {} +}; + +void OpDescNewVar(const std::string& param_name, + std::initializer_list arguments, + paddle::framework::OpDesc::Var* var) { + var->set_parameter(param_name); + for (auto& arg_name : arguments) { + var->add_arguments(arg_name); + } +} + +// create a LoD tensor in scope with specific dims +LoDTensor* CreateVar(Scope& scope, std::string name, framework::DDim dims, + const platform::Place& place) { + auto* var = scope.NewVar(name); + auto* tensor = var->GetMutable(); + tensor->Resize(dims); + tensor->mutable_data(place); + return tensor; +} + +class DynamicRecurrentOpTestHelper : public ::testing::Test { + protected: + const rnn::ArgumentName argname = DynamicRecurrentOp::kArgName; + + virtual void SetUp() override { + CreateGlobalVariables(); + + auto op_desc = CreateOpDesc(); + op = paddle::framework::OpRegistry::CreateOp(op_desc); + dop = dynamic_cast(op.get()); + InitCacheManually(); + InitStepNet(); + } + + framework::OpDesc CreateOpDesc() { + // create op + paddle::framework::OpDesc op_desc; + op_desc.set_type("dynamic_recurrent"); + + OpDescNewVar(argname.inlinks, {"in0"}, op_desc.add_inputs()); + OpDescNewVar(argname.boot_memories, {"boot_mem"}, op_desc.add_inputs()); + OpDescNewVar(argname.step_scopes, {"step_scopes"}, op_desc.add_outputs()); + OpDescNewVar(argname.outlinks, {"out0"}, op_desc.add_outputs()); + + // set pre-memories + auto pre_memories = op_desc.mutable_attrs()->Add(); + pre_memories->set_name(argname.pre_memories); + pre_memories->set_type(paddle::framework::AttrType::STRINGS); + auto pre_memories_item = pre_memories->add_strings(); + *pre_memories_item = "mem@pre"; + + // set memories + auto memories = op_desc.mutable_attrs()->Add(); + memories->set_name(argname.memories); + memories->set_type(paddle::framework::AttrType::STRINGS); + auto memories_item = memories->add_strings(); + *memories_item = "mem"; + return op_desc; + } + + void CreateGlobalVariables() { + platform::CPUPlace place; + scope.NewVar("step_scopes"); + CreateVar(scope, "boot_mem", framework::make_ddim({10, 20}), place); + // auto* out0 = + CreateVar(scope, "out0", framework::make_ddim({10, 20}), place); + auto* in0 = CreateVar(scope, "in0", framework::make_ddim({10, 8}), place); + // 10 instanes with 4 sentences, length is 4, 3, 2, 1 respectively. + framework::LoD in0_lod(1); + for (int x : std::vector{0, 4, 7, 9, 10}) { + in0_lod[0].push_back(x); + } + in0->set_lod(in0_lod); + in0->Resize(framework::make_ddim({10, 8})); + // set the content, each sentence content is seqid.batchid + // the seqid starts from 0 + int start = 0; + for (size_t seqid = 0; seqid < in0_lod.size() - 1; seqid++) { + for (size_t batchid = 0; + batchid < in0_lod[0][seqid + 1] - in0_lod[0][seqid]; batchid++) { + float v = seqid + batchid * 0.1; + + for (size_t dim = 0; dim < 8; dim++) { + in0->data()[start * 8 + dim] = v; + } + start++; + } + } + } + + void InitCacheManually() { + dop->cache_.Init(DynamicRecurrentOp::kArgName, *dop, scope, &dop->arg_); + } + + void InitStepNet() { + std::unique_ptr stepnet{new NetOp}; + dynamic_cast(stepnet.get()) + ->AppendOp(std::unique_ptr(new TestOp( + "test", {{"inlinks", {"in0"}}, {"boot_memories", {"boot_mem"}}}, + {{"outlinks", {"out0"}}, {"step_scopes", {"step_scopes"}}}, {}))); + dop->SetStepNet(std::move(stepnet)); + } + + protected: + DynamicRecurrentOp* dop; + std::unique_ptr op; + paddle::platform::CPUDeviceContext device_context; + paddle::framework::Scope scope; +}; + +TEST_F(DynamicRecurrentOpTestHelper, CreateCache) { + const rnn::Argument& arg = dop->arg_; + ASSERT_EQ(arg.inlinks.size(), 1UL); + ASSERT_EQ(arg.outlinks.size(), 1UL); +} + +TEST_F(DynamicRecurrentOpTestHelper, SplitInputs) { + dop->SplitInputs(); + auto& in0_ta = dop->step_inputs_["in0"]; + ASSERT_EQ(in0_ta.size(), 4UL); + + const auto& batch0 = in0_ta.Read(0); + const auto& batch1 = in0_ta.Read(1); + const auto& batch2 = in0_ta.Read(2); + const auto& batch3 = in0_ta.Read(3); + EXPECT_EQ(batch0.dims()[0], 4); + EXPECT_EQ(batch1.dims()[0], 3); + EXPECT_EQ(batch2.dims()[0], 2); + EXPECT_EQ(batch3.dims()[0], 1); +} + +TEST_F(DynamicRecurrentOpTestHelper, CreateScopes) { + dop->SplitInputs(); + dop->CreateScopes(); + ASSERT_EQ(dop->cache_.num_steps, 4UL); + ASSERT_EQ(dop->cache_.scopes->size(), 4UL); +} + +TEST_F(DynamicRecurrentOpTestHelper, WriteStepInputs) { + dop->SplitInputs(); + dop->CreateScopes(); + dop->WriteStepInputs(); + + for (size_t step = 0; step < dop->cache_.num_steps; step++) { + auto& scope = dop->cache_.GetScope(step); + for (auto name : std::vector({"in0"})) { + ASSERT_TRUE(scope.FindVar(name) != nullptr); + } + } +} + +TEST_F(DynamicRecurrentOpTestHelper, WriteStepOutputs) { + dop->SplitInputs(); + dop->CreateScopes(); + dop->WriteStepInputs(); + dop->WriteStepOutputs(); + + for (size_t step = 0; step < dop->cache_.num_steps; step++) { + auto& scope = dop->cache_.GetScope(step); + for (auto name : std::vector({"out0"})) { + ASSERT_TRUE(scope.FindVar(name)); + } + } +} + +TEST_F(DynamicRecurrentOpTestHelper, ConcatOutputs) { + // Let's leave this test to python unittest. +} + +TEST_F(DynamicRecurrentOpTestHelper, InitStates) { + dop->SplitInputs(); + dop->CreateScopes(); + dop->WriteStepInputs(); + dop->WriteStepOutputs(); + dop->InitStates(); + + for (size_t step = 0; step < dop->cache_.num_steps; step++) { + auto& scope = dop->cache_.GetScope(step); + auto state = scope.FindVar("mem"); + ASSERT_TRUE(state != nullptr); + + auto* pre_state = scope.FindVar("mem@pre"); + ASSERT_TRUE(pre_state != nullptr); + + auto* boot_state = scope.FindVar("boot_mem"); + ASSERT_TRUE(boot_state != nullptr); + + if (step == 0) { + // check pre_state is a reference of boot_state + ASSERT_EQ(boot_state->Get().data(), + pre_state->Get().data()); + } + } +} + +} // operators +} // namespace paddle diff --git a/paddle/operators/feed_op.cc b/paddle/operators/feed_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..fa325bb28299afe24a67772473529fb76b9c73e1 --- /dev/null +++ b/paddle/operators/feed_op.cc @@ -0,0 +1,59 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/feed_op.h" + +namespace paddle { +namespace operators { + +class FeedOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null."); + auto& shape = ctx->Attrs().Get>("dims"); + std::vector shape_int64(shape.size(), 0); + std::transform(shape.begin(), shape.end(), shape_int64.begin(), + [](int a) { return static_cast(a); }); + ctx->SetOutputDim("Out", framework::make_ddim(shape_int64)); + // TODO(qijun): need to handle LodTensor later + } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return static_cast(Attr("dataType")); + } +}; + +class FeedOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FeedOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("dataType", "output data type") + .SetDefault(framework::DataType::FP32); + AddAttr("col", "The col in global feed variable").SetDefault(0); + AddAttr>("dims", "The dimension of feed tensor."); + AddOutput("Out", "The output of feed op."); + AddComment(R"DOC(Feed data from global feed variable)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(feed, ops::FeedOp, ops::FeedOpMaker); +REGISTER_OP_CPU_KERNEL(feed, ops::FeedKernel); diff --git a/paddle/operators/feed_op.cu b/paddle/operators/feed_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..7b6a2ac91e7b8d306804ca3d27b1eaf8177302f9 --- /dev/null +++ b/paddle/operators/feed_op.cu @@ -0,0 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/feed_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(feed, ops::FeedKernel); diff --git a/paddle/operators/feed_op.h b/paddle/operators/feed_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9d8158299fea97a464a7bb64321b1adf8b7b2fab --- /dev/null +++ b/paddle/operators/feed_op.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class FeedKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + framework::Tensor* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + framework::Variable* g_feed_variable = + framework::GetGlobalScope().FindVar("feed_value"); + const auto& tensors = + g_feed_variable->Get>(); + int col = ctx.template Attr("col"); + PADDLE_ENFORCE_GT(tensors.size(), static_cast(col)); + // TODO(qijun): + // check tensors[col].dims() with attribute, + // except the first dimenson. + out->CopyFrom(tensors[col], ctx.GetPlace()); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..90737c8c550ca18f03c6a9ad0d9323d0b4d0b96d --- /dev/null +++ b/paddle/operators/fetch_op.cc @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/fetch_op.h" + +namespace paddle { +namespace operators { + +class FetchOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null."); + } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return static_cast(Attr("dataType")); + } +}; + +class FetchOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FetchOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("dataType", "output data type") + .SetDefault(framework::DataType::FP32); + AddAttr("col", "The col in global fetch variable").SetDefault(0); + AddInput("Input", "The output of fetch op."); + AddComment(R"DOC(Fetch data to global fetch variable)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(fetch, ops::FetchOp, ops::FetchOpMaker); +REGISTER_OP_CPU_KERNEL(fetch, ops::FetchKernel); diff --git a/paddle/operators/fetch_op.cu b/paddle/operators/fetch_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..ca39d24c791ded71149777acc53e3b5cc240329f --- /dev/null +++ b/paddle/operators/fetch_op.cu @@ -0,0 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/fetch_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(fetch, ops::FetchKernel); diff --git a/paddle/operators/fetch_op.h b/paddle/operators/fetch_op.h new file mode 100644 index 0000000000000000000000000000000000000000..eb9c3a7b593b84da7c8dc12d71c4f748269c64e6 --- /dev/null +++ b/paddle/operators/fetch_op.h @@ -0,0 +1,44 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class FetchKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const framework::Tensor* input = ctx.Input("Input"); + framework::Variable* g_fetch_variable = + framework::GetGlobalScope().FindVar("fetch_value"); + auto* tensors = + g_fetch_variable->GetMutable>(); + int col = ctx.template Attr("col"); + if (tensors->size() < static_cast(col + 1)) { + tensors->resize(col + 1); + } + PADDLE_ENFORCE_GT(tensors->size(), static_cast(col)); + (*tensors)[col].Resize(input->dims()); + (*tensors)[col].mutable_data(platform::CPUPlace()); + (*tensors)[col].CopyFrom(*input, platform::CPUPlace()); + // TODO(qijun): need to handle LodTensor later + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/margin_rank_loss_op.cc b/paddle/operators/margin_rank_loss_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..5be61dfec3bb58ab9b658cb59ab0dd49bb67d8cb --- /dev/null +++ b/paddle/operators/margin_rank_loss_op.cc @@ -0,0 +1,124 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/margin_rank_loss_op.h" + +namespace paddle { +namespace operators { + +class MarginRankLossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + // input check + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("X1"), "Input(X1) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("X2"), "Input(X2) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null."); + auto label_dims = ctx->GetInputDim("Label"); + auto x1_dims = ctx->GetInputDim("X1"); + auto x2_dims = ctx->GetInputDim("X2"); + PADDLE_ENFORCE( + (label_dims == x1_dims) && (x1_dims == x2_dims) && + (label_dims.size() == 2) && (label_dims[1] == 1), + "All inputs must be 2-D tensor with shape [batch_size x 1]."); + ctx->SetOutputDim("Activated", label_dims); + ctx->SetOutputDim("Out", label_dims); + } +}; + +template +class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MarginRankLossOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X1", + "(2-D tensor with shape [batch_size x 1]) The score for " + "one item X1 to be ranked, from pairwise ranking model."); + AddInput("X2", + "(2-D tensor with shape [batch_size x 1]) The score for " + "another item X2 to be ranked, from pairwise ranking model."); + AddInput("Label", + "(2-D tensor with shape [batch_size x 1]) " + "The label indicating X1 ranked higher than X2 or not, " + "can only be +1 or -1."); + AddAttr("margin", "(scalar, default 0) Margin for MarginRankLossOp.") + .SetDefault(static_cast(0)); + AddOutput("Activated", + "(2-D tensor with shape [batch_size x 1]) Intermediate tensor " + "to indicate whether each element of Output(Out) is activated.") + .AsIntermediate(); + AddOutput("Out", + "(2-D tensor with shape [batch_size x 1]) " + "The output loss of MarginRankLoss operator."); + AddComment(R"DOC( + +MarginRankLoss operator measures the loss given a pair of training sample +{`X1`, `X2`} and the `Label` with attribute `margin`, where `Label = +1` +indicating X1 is ranked higher than `X2`, otherwise `Label = -1`. The loss +turns out + +loss(X1, X2, Label) = max(0, -Label * (X1 - X2) + margin). + +The attribute `margin` involved here helps make the predictions more robust. +Denote the item ranked higher as the positive sample, otherwise the negative +sample. If the score of the two samples satisfies + +positive sample - negative sample < margin, + +the pair of samples will contribute to the final loss, which will backpropogate +and train the ranking model to enlarge the difference of the two score. + +For batch input with size `batch_size`, `X1`, `X2` and `Label` +all have the same shape [batch_size x 1]. + +)DOC"); + } +}; + +class MarginRankLossGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("X1"), "Input(X1) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("X2"), "Input(X2) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("Activated"), + "Intermediate(Activated) shouldn't be null."); + auto dims = ctx->GetInputDim("Label"); + ctx->SetOutputDim(framework::GradVarName("X1"), dims); + ctx->SetOutputDim(framework::GradVarName("X2"), dims); + } +}; + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OP(margin_rank_loss, ops::MarginRankLossOp, + ops::MarginRankLossOpMaker, margin_rank_loss_grad, + ops::MarginRankLossGradOp); +REGISTER_OP_CPU_KERNEL( + margin_rank_loss, + ops::MarginRankLossKernel); +REGISTER_OP_CPU_KERNEL( + margin_rank_loss_grad, + ops::MarginRankLossGradKernel); diff --git a/paddle/operators/margin_rank_loss_op.cu b/paddle/operators/margin_rank_loss_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..3a639f25d478a712c1030d57c57d7e55de1488b5 --- /dev/null +++ b/paddle/operators/margin_rank_loss_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/margin_rank_loss_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + margin_rank_loss, + ops::MarginRankLossKernel); +REGISTER_OP_GPU_KERNEL( + margin_rank_loss_grad, + ops::MarginRankLossGradKernel); diff --git a/paddle/operators/margin_rank_loss_op.h b/paddle/operators/margin_rank_loss_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8d0830147ecc465909e8988e90125929829f6f34 --- /dev/null +++ b/paddle/operators/margin_rank_loss_op.h @@ -0,0 +1,98 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +struct ReLU { + HOSTDEVICE T operator()(const T& val) const { + return val > 0 ? val : static_cast(0); + } +}; + +template +struct Heaviside { + HOSTDEVICE T operator()(const T& val) const { + return static_cast(val > 0 ? 1 : 0); + } +}; + +template +class MarginRankLossKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* out_t = ctx.Output("Out"); + auto* act_t = ctx.Output("Activated"); + + auto* label_t = ctx.Input("Label"); + auto* x1_t = ctx.Input("X1"); + auto* x2_t = ctx.Input("X2"); + + out_t->mutable_data(ctx.GetPlace()); + act_t->mutable_data(ctx.GetPlace()); + + auto margin = static_cast(ctx.Attr("margin")); + auto out = framework::EigenVector::Flatten(*out_t); + auto act = framework::EigenVector::Flatten(*act_t); + + auto label = framework::EigenVector::Flatten(*label_t); + auto x1 = framework::EigenVector::Flatten(*x1_t); + auto x2 = framework::EigenVector::Flatten(*x2_t); + + auto& dev = ctx.GetEigenDevice(); + out.device(dev) = (-label * (x1 - x2) + margin).unaryExpr(ReLU()); + act.device(dev) = out.unaryExpr(Heaviside()); + } +}; + +template +class MarginRankLossGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* d_x1_t = + ctx.Output(framework::GradVarName("X1")); + auto* d_x2_t = + ctx.Output(framework::GradVarName("X2")); + + auto* act_t = ctx.Input("Activated"); + auto* d_out_t = ctx.Input(framework::GradVarName("Out")); + auto* label_t = ctx.Input("Label"); + + auto d_out = framework::EigenVector::Flatten(*d_out_t); + auto act = framework::EigenVector::Flatten(*act_t); + auto label = framework::EigenVector::Flatten(*label_t); + auto& dev = ctx.GetEigenDevice(); + + // compute d_x1 + if (d_x1_t) { + d_x1_t->mutable_data(ctx.GetPlace()); + auto d_x1 = framework::EigenVector::Flatten(*d_x1_t); + d_x1.device(dev) = -d_out * act * label; + } + // compute d_x2 + if (d_x2_t) { + d_x2_t->mutable_data(ctx.GetPlace()); + auto d_x2 = framework::EigenVector::Flatten(*d_x2_t); + d_x2.device(dev) = d_out * act * label; + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index a0ceb029e3abee2fe591325ffa3100168c3aa8e3..1a2f623ce7917b1e60656743e699271eab9c7011 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,13 +1,18 @@ if(WITH_GPU) - nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu pooling.cc pooling.cu DEPS cblas device_context operator) + nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator) nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) + nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) + nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) else() - cc_library(math_function SRCS math_function.cc im2col.cc pooling.cc DEPS cblas device_context operator) + cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_library(softmax SRCS softmax.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) + cc_library(pooling SRCS pooling.cc DEPS device_context) + cc_library(vol2col SRCS vol2col.cc DEPS device_context) endif() cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) +cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor) diff --git a/paddle/operators/math/pooling.cc b/paddle/operators/math/pooling.cc index 3b706529d8f1ed0d673904b81047a5614bd4cf23..50cfb88bb5700dda3785e63e0ccc6457cc928da0 100644 --- a/paddle/operators/math/pooling.cc +++ b/paddle/operators/math/pooling.cc @@ -18,6 +18,11 @@ namespace paddle { namespace operators { namespace math { +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class Pool2dFunctor { public: @@ -73,6 +78,11 @@ class Pool2dFunctor { } }; +/* +* All tensors are in NCHW format. +* Ksize, strides, paddings are two elements. These two elements represent height +* and width, respectively. +*/ template class Pool2dGradFunctor { public: @@ -135,6 +145,11 @@ class Pool2dGradFunctor { } }; +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class MaxPool2dGradFunctor { public: @@ -197,7 +212,7 @@ class MaxPool2dGradFunctor { }; template class MaxPool2dGradFunctor; -// template class MaxPool2dGradFunctor; +template class MaxPool2dGradFunctor; template class Pool2dFunctor, float>; @@ -216,6 +231,11 @@ template class Pool2dGradFunctor< template class Pool2dGradFunctor< platform::CPUPlace, paddle::operators::math::AvgPoolGrad, double>; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class Pool3dFunctor { public: @@ -286,6 +306,11 @@ class Pool3dFunctor { } }; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class Pool3dGradFunctor { public: @@ -364,6 +389,11 @@ class Pool3dGradFunctor { } }; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class MaxPool3dGradFunctor { public: @@ -440,7 +470,7 @@ class MaxPool3dGradFunctor { }; template class MaxPool3dGradFunctor; -// template class MaxPool3dGradFunctor; +template class MaxPool3dGradFunctor; template class Pool3dFunctor, float>; @@ -458,6 +488,253 @@ template class Pool3dGradFunctor< platform::CPUPlace, paddle::operators::math::MaxPoolGrad, double>; template class Pool3dGradFunctor< platform::CPUPlace, paddle::operators::math::AvgPoolGrad, double>; + +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class MaxPool2dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + const int input_stride = input_height * input_width; + const int output_stride = output_height * output_width; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + + T ele = static_cast(-FLT_MAX); + int index = -1; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (ele < input_data[h * input_width + w]) { + ele = input_data[h * input_width + w]; + index = h * input_width + w; + } + } + } + output_data[ph * output_width + pw] = ele; + mask_data[ph * output_width + pw] = index; + } + } + // offset + input_data += input_stride; + output_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class MaxPool2dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_height = input_grad.dims()[2]; + const int input_width = input_grad.dims()[3]; + const int output_channels = output_grad.dims()[1]; + const int output_height = output_grad.dims()[2]; + const int output_width = output_grad.dims()[3]; + const int input_stride = input_height * input_width; + const int output_stride = output_height * output_width; + + const T* mask_data = mask.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < output_channels; ++c) { + for (int ph = 0; ph < output_height; ++ph) { + for (int pw = 0; pw < output_width; ++pw) { + const int output_idx = ph * output_width + pw; + const int input_idx = static_cast(mask_data[output_idx]); + input_grad_data[input_idx] += output_grad_data[output_idx]; + } + } + // offset + input_grad_data += input_stride; + output_grad_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; + +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ +template +class MaxPool3dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + const int input_stride = input_depth * input_height * input_width; + const int output_stride = output_depth * output_height * output_width; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int pd = 0; pd < output_depth; ++pd) { + int dstart = pd * stride_depth - padding_depth; + int dend = std::min(dstart + ksize_depth, input_depth); + dstart = std::max(dstart, 0); + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + + int output_idx = (pd * output_height + ph) * output_width + pw; + T ele = static_cast(-FLT_MAX); + int index = -1; + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_idx = (d * input_height + h) * input_width + w; + if (ele < input_data[input_idx]) { + index = input_idx; + ele = input_data[input_idx]; + } + } + } + } + output_data[output_idx] = ele; + mask_data[output_idx] = index; + } + } + } + // offset + input_data += input_stride; + output_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ +template +class MaxPool3dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_depth = input_grad.dims()[2]; + const int input_height = input_grad.dims()[3]; + const int input_width = input_grad.dims()[4]; + const int output_channels = output_grad.dims()[1]; + const int output_depth = output_grad.dims()[2]; + const int output_height = output_grad.dims()[3]; + const int output_width = output_grad.dims()[4]; + const int input_stride = input_depth * input_height * input_width; + const int output_stride = output_depth * output_height * output_width; + + const T* mask_data = mask.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < output_channels; ++c) { + for (int pd = 0; pd < output_depth; ++pd) { + for (int ph = 0; ph < output_height; ++ph) { + for (int pw = 0; pw < output_width; ++pw) { + const int output_idx = + (pd * output_height + ph) * output_width + pw; + const int input_idx = static_cast(mask_data[output_idx]); + input_grad_data[input_idx] += output_grad_data[output_idx]; + } + } + } + // offset + input_grad_data += input_stride; + output_grad_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/pooling.cu b/paddle/operators/math/pooling.cu index 8aeccd1f8e8855c51ad85016f0cb239b4c9c8fb0..736327f4b7b9e9df9ce8f7f60b0437fc1d2d373a 100644 --- a/paddle/operators/math/pooling.cu +++ b/paddle/operators/math/pooling.cu @@ -144,11 +144,16 @@ __global__ void KernelMaxPool2DGrad( if (maxIndex != -1) { // atomic add - atomicAdd(input_grad + maxIndex, output_grad[index]); + platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]); } } } +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class Pool2dFunctor { public: @@ -190,6 +195,11 @@ class Pool2dFunctor { } }; +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class Pool2dGradFunctor { public: @@ -234,6 +244,11 @@ class Pool2dGradFunctor { } }; +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ template class MaxPool2dGradFunctor { public: @@ -278,9 +293,7 @@ class MaxPool2dGradFunctor { }; template class MaxPool2dGradFunctor; -// template class MaxPool2dGradFunctor; // The -// 64-bit floating-point version of atomicAdd() is only supported by devices of -// compute capability 6.x and higher. +template class MaxPool2dGradFunctor; template class Pool2dFunctor, float>; @@ -453,11 +466,16 @@ __global__ void KernelMaxPool3DGrad( } if (maxIdx != -1) { // atomic add - atomicAdd(input_grad + maxIdx, output_grad[index]); + platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]); } } } +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class Pool3dFunctor { public: @@ -506,6 +524,11 @@ class Pool3dFunctor { } }; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class Pool3dGradFunctor { public: @@ -558,6 +581,11 @@ class Pool3dGradFunctor { } }; +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ template class MaxPool3dGradFunctor { public: @@ -609,9 +637,7 @@ class MaxPool3dGradFunctor { }; template class MaxPool3dGradFunctor; -// template class MaxPool3dGradFunctor; // The -// 64-bit floating-point version of atomicAdd() is only supported by devices of -// compute capability 6.x and higher. +template class MaxPool3dGradFunctor; template class Pool3dFunctor, float>; @@ -630,6 +656,404 @@ template class Pool3dGradFunctor< template class Pool3dGradFunctor< platform::GPUPlace, paddle::operators::math::AvgPoolGrad, double>; +template +__global__ void KernelMaxPool2dWithIdx( + const int nthreads, const T* input_data, T* output_data, T* mask_data, + const int channels, const int input_height, const int input_width, + const int output_height, const int output_width, const int ksize_height, + const int ksize_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int pw = index % output_width; + int ph = (index / output_width) % output_height; + int c = (index / output_width / output_height) % channels; + int batch_idx = index / output_width / output_height / channels; + + int hstart = ph * stride_height - padding_height; + int hend = min(hstart + ksize_height, input_height); + hstart = max(hstart, 0); + + int wstart = pw * stride_width - padding_width; + int wend = min(wstart + ksize_width, input_width); + wstart = max(wstart, 0); + + input_data += (batch_idx * channels + c) * input_height * input_width; + T ele = -FLT_MAX; + int max_index = -1; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_index = h * input_width + w; + if (ele < input_data[input_index]) { + max_index = input_index; + ele = input_data[input_index]; + } + } + } + output_data[index] = ele; + mask_data[index] = max_index; + } +} + +template +__global__ void KernelMaxPool2DWithIdxGrad( + const int nthreads, T* input_grad, const T* output_grad, const T* mask_data, + const int channels, const int input_height, const int input_width, + const int output_height, const int output_width, const int ksize_height, + const int ksize_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int w_offset = index % input_width; + int h_offset = (index / input_width) % input_height; + int c_offset = (index / input_width / input_height) % channels; + int batch_idx = index / input_width / input_height / channels; + + int ph_start = + (h_offset + padding_height < ksize_height) + ? 0 + : (h_offset + padding_height - ksize_height) / stride_height + 1; + int pw_start = + (w_offset + padding_width < ksize_width) + ? 0 + : (w_offset + padding_width - ksize_width) / stride_width + 1; + int ph_end = + min((h_offset + padding_height) / stride_height + 1, output_height); + int pw_end = + min((w_offset + padding_width) / stride_width + 1, output_width); + + T gradient = 0; + int input_current_featuremap_idx = h_offset * input_width + w_offset; + int output_idx = + (batch_idx * channels + c_offset) * output_height * output_width; + + mask_data += output_idx; + output_grad += output_idx; + for (int ph = ph_start; ph < ph_end; ++ph) { + for (int pw = pw_start; pw < pw_end; ++pw) { + if (mask_data[ph * output_width + pw] == input_current_featuremap_idx) + gradient += output_grad[ph * output_width + pw]; + } + } + input_grad[index] = gradient; + } +} + +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class MaxPool2dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_height * output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool2dWithIdx< + T><<(context) + .stream()>>>(nthreads, input_data, output_data, mask_data, + input_channels, input_height, input_width, + output_height, output_width, ksize_height, + ksize_width, stride_height, stride_width, + padding_height, padding_width); + } +}; + +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class MaxPool2dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_channels = input_grad.dims()[1]; + const int input_height = input_grad.dims()[2]; + const int input_width = input_grad.dims()[3]; + const int output_height = output_grad.dims()[2]; + const int output_width = output_grad.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* mask_data = mask.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + int nthreads = batch_size * input_channels * input_height * input_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool2DWithIdxGrad< + T><<(context) + .stream()>>>(nthreads, input_grad_data, output_grad_data, + mask_data, input_channels, input_height, + input_width, output_height, output_width, + ksize_height, ksize_width, stride_height, + stride_width, padding_height, padding_width); + } +}; + +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; + +template +__global__ void KernelMaxPool3DWithIdx( + const int nthreads, const T* input_data, T* output_data, T* mask_data, + const int channels, const int input_depth, const int input_height, + const int input_width, const int output_depth, const int output_height, + const int output_width, const int ksize_depth, const int ksize_height, + const int ksize_width, const int stride_depth, const int stride_height, + const int stride_width, const int padding_depth, const int padding_height, + const int padding_width) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int pw = index % output_width; + int ph = (index / output_width) % output_height; + int pd = (index / output_width / output_height) % output_depth; + int c = (index / output_width / output_height / output_depth) % channels; + int batch_idx = + index / output_width / output_height / output_depth / channels; + + int dstart = pd * stride_depth - padding_depth; + int hstart = ph * stride_height - padding_height; + int wstart = pw * stride_width - padding_width; + int dend = min(dstart + ksize_depth, input_depth); + int hend = min(hstart + ksize_height, input_height); + int wend = min(wstart + ksize_width, input_width); + dstart = max(dstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + + T ele = -FLT_MAX; + int max_index = -1; + input_data += + (batch_idx * channels + c) * input_depth * input_height * input_width; + + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (ele < input_data[(d * input_height + h) * input_width + w]) { + max_index = (d * input_height + h) * input_width + w; + ele = input_data[max_index]; + } + } + } + } + output_data[index] = ele; + mask_data[index] = max_index; + } +} + +template +__global__ void KernelMaxPool3DWithIdxGrad( + const int nthreads, T* input_grad, const T* output_grad, const T* mask, + const int channels, const int input_depth, const int input_height, + const int input_width, const int output_depth, const int output_height, + const int output_width, const int ksize_depth, const int ksize_height, + const int ksize_width, const int stride_depth, const int stride_height, + const int stride_width, const int padding_depth, const int padding_height, + const int padding_width) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int w_offset = index % input_width; + int h_offset = (index / input_width) % input_height; + int d_offset = (index / input_width / input_height) % input_depth; + int c_offset = + (index / input_width / input_height / input_depth) % channels; + int batch_idx = index / input_width / input_height / input_depth / channels; + + int pd_start = + (d_offset + padding_depth < ksize_depth) + ? 0 + : (d_offset + padding_depth - ksize_depth) / stride_depth + 1; + int ph_start = + (h_offset + padding_height < ksize_height) + ? 0 + : (h_offset + padding_height - ksize_height) / stride_height + 1; + int pw_start = + (w_offset + padding_width < ksize_width) + ? 0 + : (w_offset + padding_width - ksize_width) / stride_width + 1; + int pd_end = + min((d_offset + padding_depth) / stride_depth + 1, output_depth); + int ph_end = + min((h_offset + padding_height) / stride_height + 1, output_height); + int pw_end = + min((w_offset + padding_width) / stride_width + 1, output_width); + + T gradient = 0; + int input_current_feature_map_idx = + (d_offset * input_height + h_offset) * input_width + w_offset; + int output_idx = (batch_idx * channels + c_offset) * output_depth * + output_height * output_width; + mask += output_idx; + output_grad += output_idx; + + for (int pd = pd_start; pd < pd_end; ++pd) { + for (int ph = ph_start; ph < ph_end; ++ph) { + for (int pw = pw_start; pw < pw_end; ++pw) { + if (mask[(pd * output_height + ph) * output_width + pw] == + input_current_feature_map_idx) + gradient += + output_grad[(pd * output_height + ph) * output_width + pw]; + } + } + } + input_grad[index] = gradient; + } +} + +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ +template +class MaxPool3dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_depth * output_height * + output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool3DWithIdx< + T><<(context) + .stream()>>>( + nthreads, input_data, output_data, mask_data, input_channels, + input_depth, input_height, input_width, output_depth, output_height, + output_width, ksize_depth, ksize_height, ksize_width, stride_depth, + stride_height, stride_width, padding_depth, padding_height, + padding_width); + } +}; + +/* + * All tensors are in NCDHW format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + */ +template +class MaxPool3dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_channels = input_grad.dims()[1]; + const int input_depth = input_grad.dims()[2]; + const int input_height = input_grad.dims()[3]; + const int input_width = input_grad.dims()[4]; + const int output_depth = output_grad.dims()[2]; + const int output_height = output_grad.dims()[3]; + const int output_width = output_grad.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + + const T* output_grad_data = output_grad.data(); + const T* mask_data = mask.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + int nthreads = + batch_size * input_channels * input_depth * input_height * input_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool3DWithIdxGrad< + T><<(context) + .stream()>>>( + nthreads, input_grad_data, output_grad_data, mask_data, input_channels, + input_depth, input_height, input_width, output_depth, output_height, + output_width, ksize_depth, ksize_height, ksize_width, stride_depth, + stride_height, stride_width, padding_depth, padding_height, + padding_width); + } +}; + +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/pooling.h b/paddle/operators/math/pooling.h index d214c689235ad4233d3e4e1c2aa0fdc993bf20c6..c50c57b5c52cdc5c12425cb119b80502aef5451e 100644 --- a/paddle/operators/math/pooling.h +++ b/paddle/operators/math/pooling.h @@ -21,15 +21,27 @@ limitations under the License. */ namespace paddle { namespace operators { namespace math { -////////////////////// -#define FLT_MAX __FLT_MAX__ // +#define FLT_MAX \ + __FLT_MAX__ // It might need to be placed in another file, but I'm still + // wondering where to put it. + +/* + * \brief Extracting simple operations from pooling. + * Both MaxPool and AvgPool need "initial", "compute" and "finalize" + * operation. + * MaxPool initializes temp variable to the negative maximum to find the + * maximum value in the pooling field. + * AvgPool initializes temp variable to the zero to accumulate all values + * in pool pooling, and finally takes the average. + * MaxPoolGrad and AvgPoolGrad are gradient operations respectively. + */ template class MaxPool { public: DEVICE inline T initial() { return static_cast(-FLT_MAX); } DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; } - DEVICE inline void finalize(T& y, const T& poo_size) {} + DEVICE inline void finalize(T& y, const T& pool_field) {} }; template @@ -37,8 +49,9 @@ class AvgPool { public: DEVICE inline T initial() { return static_cast(0); } DEVICE inline void compute(T& y, const T& x) { y += x; } - DEVICE inline void finalize(T& y, const T& poo_size) { y /= poo_size; } + DEVICE inline void finalize(T& y, const T& pool_field) { y /= pool_field; } }; + template class MaxPoolGrad { public: @@ -57,6 +70,20 @@ class AvgPoolGrad { } }; +/* + * \brief Getting pooling results, and calculating gradient. + * + * In pool2d, all tensors are in NCHW format. Where N is batch size, C is the + * number of channels, H and W is the height and width of feature. + * In pool3d, all tensors are in NCDHW format. Where N is batch size, C is the + * number of channels, D, H and W is the depth, height and width of feature. + * + * In max pooling, it is possible that the pooling region has multiple maximum + * elements. In this case, we should compute the gradient of the first maximum + * element. + * This is different from average pooling. So we rewrite the max_pool_grad: + * MaxPool2dGradFunctor, MaxPool3dGradFunctor. + */ template class Pool2dFunctor { public: @@ -117,6 +144,51 @@ class MaxPool3dGradFunctor { std::vector& strides, std::vector& paddings); }; +/* + * \brief Getting max pooling results and corresponding max index, and + * calculating gradient. + * In up-sampling-pooling, it is necessary to know max element index. + * In pool2d, all tensors are in NCHW format. In pool3d, all tensors are in + * NCDHW format. + */ +template +class MaxPool2dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +template +class MaxPool2dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +template +class MaxPool3dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +template +class MaxPool3dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/vol2col.cc b/paddle/operators/math/vol2col.cc new file mode 100644 index 0000000000000000000000000000000000000000..e9718a047381596a1570b4b00546622968b70227 --- /dev/null +++ b/paddle/operators/math/vol2col.cc @@ -0,0 +1,155 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/vol2col.h" + +namespace paddle { +namespace operators { +namespace math { + +/* + * vol = [input_channels, input_depth, input_height, input_width] + * col = + * [input_channels, filter_depth, filter_height, filter_width, + * output_depth, output_height, output_width] + */ +template +class Vol2ColFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& vol, framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const { + PADDLE_ENFORCE(vol.dims().size() == 4); + PADDLE_ENFORCE(col.dims().size() == 7); + + int input_channels = vol.dims()[0]; + int input_depth = vol.dims()[1]; + int input_height = vol.dims()[2]; + int input_width = vol.dims()[3]; + int filter_depth = col.dims()[1]; + int filter_height = col.dims()[2]; + int filter_width = col.dims()[3]; + int output_depth = col.dims()[4]; + int output_height = col.dims()[5]; + int output_width = col.dims()[6]; + int channels_col = + input_channels * filter_depth * filter_height * filter_width; + + const T* vol_data = vol.data(); + T* col_data = col.data(); + + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int d_offset = (c / filter_width / filter_height) % filter_depth; + int c_in = c / filter_width / filter_height / filter_depth; + for (int d = 0; d < output_depth; ++d) { + int d_pad = d * stride_depth - padding_depth + d_offset; + for (int h = 0; h < output_height; ++h) { + int h_pad = h * stride_height - padding_height + h_offset; + for (int w = 0; w < output_width; ++w) { + int w_pad = w * stride_width - padding_width + w_offset; + + int col_idx = + ((c * output_depth + d) * output_height + h) * output_width + w; + if (h_pad < 0 || h_pad >= input_height || w_pad < 0 || + w_pad >= input_width || d_pad < 0 || d_pad >= input_depth) { + col_data[col_idx] = static_cast(0); + } else { + int vol_idx = + ((c_in * input_depth + d_pad) * input_height + h_pad) * + input_width + + w_pad; + col_data[col_idx] = vol_data[vol_idx]; + } + } + } + } + } + } +}; + +/* + * vol = [input_channels,input_depth, input_height, input_width] + * col = + * [input_channels, filter_depth, filter_height, filter_width, + * output_depth, output_height, output_width] + */ +template +class Col2VolFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& vol, const framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const { + PADDLE_ENFORCE(vol.dims().size() == 4); + PADDLE_ENFORCE(col.dims().size() == 7); + + int input_channels = vol.dims()[0]; + int input_depth = vol.dims()[1]; + int input_height = vol.dims()[2]; + int input_width = vol.dims()[3]; + int filter_depth = col.dims()[1]; + int filter_height = col.dims()[2]; + int filter_width = col.dims()[3]; + int output_depth = col.dims()[4]; + int output_height = col.dims()[5]; + int output_width = col.dims()[6]; + int channels_col = + input_channels * filter_depth * filter_height * filter_width; + + T* vol_data = vol.data(); + const T* col_data = col.data(); + + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int d_offset = (c / filter_width / filter_height) % filter_depth; + int cIm = c / filter_width / filter_height / filter_depth; + for (int d = 0; d < output_depth; ++d) { + int d_pad = d * stride_depth - padding_depth + d_offset; + for (int h = 0; h < output_height; ++h) { + int h_pad = h * stride_height - padding_height + h_offset; + for (int w = 0; w < output_width; ++w) { + int w_pad = w * stride_width - padding_width + w_offset; + + if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 && + w_pad < input_width && d_pad >= 0 && d_pad < input_depth) { + int vol_idx = + ((cIm * input_depth + d_pad) * input_height + h_pad) * + input_width + + w_pad; + int col_idx = + ((c * output_depth + d) * output_height + h) * output_width + + w; + vol_data[vol_idx] += col_data[col_idx]; + } + } + } + } + } + } +}; + +template class Vol2ColFunctor; +template class Vol2ColFunctor; +template class Col2VolFunctor; +template class Col2VolFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/vol2col.cu b/paddle/operators/math/vol2col.cu new file mode 100644 index 0000000000000000000000000000000000000000..27b11fb237575fd25a789a5fcc24ed4e30607009 --- /dev/null +++ b/paddle/operators/math/vol2col.cu @@ -0,0 +1,204 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/vol2col.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__global__ void vol2col(int num_kernels, const T* data_vol, int depth, + int height, int width, int filter_depth, + int filter_height, int filter_width, int stride_depth, + int stride_height, int stride_width, int padding_depth, + int padding_height, int padding_width, int output_detph, + int output_height, int output_width, T* data_col) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; + index += blockDim.x * gridDim.x) { + int w_out = index % output_width; + int h_out = (index / output_width) % output_height; + int d_out = (index / output_width / output_height) % output_detph; + int channel_in = index / output_width / output_height / output_detph; + int channel_out = channel_in * filter_depth * filter_height * filter_width; + int w_in = w_out * stride_width - padding_width; + int h_in = h_out * stride_height - padding_height; + int d_in = d_out * stride_depth - padding_depth; + + data_col += ((channel_out * output_detph + d_out) * output_height + h_out) * + output_width + + w_out; + data_vol += ((channel_in * depth + d_in) * height + h_in) * width + w_in; + for (int k = 0; k < filter_depth; ++k) { + for (int i = 0; i < filter_height; ++i) { + for (int j = 0; j < filter_width; ++j) { + int d = d_in + k; + int h = h_in + i; + int w = w_in + j; + *data_col = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 && + w < width) + ? data_vol[(k * height + i) * width + j] + : 0; + data_col += output_detph * output_height * output_width; + } + } + } + } +} + +/* + * im = [input_channels,intpu_depth, input_height, input_width] + * col = + * [input_channels, filter_depth, filter_height, filter_width, + * output_depth, output_height, output_width] + */ +template +class Vol2ColFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& vol, framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const { + PADDLE_ENFORCE(vol.dims().size() == 4); + PADDLE_ENFORCE(col.dims().size() == 7); + + int input_channels = vol.dims()[0]; + int input_depth = vol.dims()[1]; + int input_height = vol.dims()[2]; + int input_width = vol.dims()[3]; + int filter_depth = col.dims()[1]; + int filter_height = col.dims()[2]; + int filter_width = col.dims()[3]; + int output_depth = col.dims()[4]; + int output_height = col.dims()[5]; + int output_width = col.dims()[6]; + + int num_outputs = + input_channels * output_depth * output_height * output_width; + + const int threads = 1024; + const int blocks = (num_outputs + 1024 - 1) / 1024; + vol2col<<(context) + .stream()>>>( + num_outputs, vol.data(), input_depth, input_height, input_width, + filter_depth, filter_height, filter_width, stride_depth, stride_height, + stride_width, padding_depth, padding_height, padding_width, + output_depth, output_height, output_width, col.data()); + } +}; + +template +__global__ void col2vol(int num_kernels, const T* data_col, int depth, + int height, int width, int filter_depth, + int filter_height, int filter_width, int stride_depth, + int stride_height, int stride_width, int padding_depth, + int padding_height, int padding_width, int output_detph, + int output_height, int output_width, T* data_vol) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; + index += blockDim.x * gridDim.x) { + T src_val = 0; + int w = index % width + padding_width; + int h = (index / width) % height + padding_height; + int d = (index / width / height) % depth + padding_depth; + int c = index / width / height / depth; + // compute the start and end of the output + int w_col_start = + (w < filter_width) ? 0 : (w - filter_width) / stride_width + 1; + int w_col_end = min(w / stride_width + 1, output_width); + int h_col_start = + (h < filter_height) ? 0 : (h - filter_height) / stride_height + 1; + int h_col_end = min(h / stride_height + 1, output_height); + int d_col_start = + (d < filter_depth) ? 0 : (d - filter_depth) / stride_depth + 1; + int d_col_end = min(d / stride_depth + 1, output_detph); + + int offset = (c * filter_depth * filter_height * filter_width + + d * filter_width * filter_height + h * filter_width + w) * + output_detph * output_height * output_width; + + int coeff_d_col = + (1 - stride_depth * filter_width * filter_height * output_detph) * + output_height * output_width; + int coeff_h_col = + (1 - stride_height * filter_width * output_detph * output_height) * + output_width; + int coeff_w_col = + (1 - stride_width * output_detph * output_height * output_width); + + for (int d_col = d_col_start; d_col < d_col_end; ++d_col) { + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { + for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { + src_val += data_col[offset + d_col * coeff_d_col + + h_col * coeff_h_col + w_col * coeff_w_col]; + } + } + } + data_vol[index] = src_val; + } +} + +/* + * im = [input_channels, input_depth, input_height, input_width] + * col = + * [input_channels, filter_depth, filter_height, filter_width, + * output_depth, output_height, output_width] + */ +template +class Col2VolFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& vol, const framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const { + PADDLE_ENFORCE(vol.dims().size() == 4); + PADDLE_ENFORCE(col.dims().size() == 7); + + int input_channels = vol.dims()[0]; + int input_depth = vol.dims()[1]; + int input_height = vol.dims()[2]; + int input_width = vol.dims()[3]; + int filter_depth = col.dims()[1]; + int filter_height = col.dims()[2]; + int filter_width = col.dims()[3]; + int output_depth = col.dims()[4]; + int output_height = col.dims()[5]; + int output_width = col.dims()[6]; + + int num_kernels = input_channels * input_depth * input_height * input_width; + + const int threads = 1024; + const int blocks = (num_kernels + 1024 - 1) / 1024; + + col2vol<<(context) + .stream()>>>( + num_kernels, col.data(), input_depth, input_height, input_width, + filter_depth, filter_height, filter_width, stride_depth, stride_height, + stride_width, padding_depth, padding_height, padding_width, + output_depth, output_height, output_width, vol.data()); + } +}; + +template class Vol2ColFunctor; +template class Vol2ColFunctor; +template class Col2VolFunctor; +template class Col2VolFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/vol2col.h b/paddle/operators/math/vol2col.h new file mode 100644 index 0000000000000000000000000000000000000000..f022365a16fbf61981e94bedbd8b21a32887b235 --- /dev/null +++ b/paddle/operators/math/vol2col.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { +/* + * \brief Converts the feature data of four dimensions(CDHW) into a colData of + * seven dimensions in the Vol2ColFunctor calculation, + * And in the Col2VolFunctor calculation, it is reversed. + * + * \param volData Vol data. + * \param volShape The shape of volData, + * [input_channels, input_depth, input_height, input_width]. + * \param colData Column data. + * \param colShape The shape of colData. + * + * The shape of colData is: + * [input_channels, filter_depth, filter_height, filter_width, output_depth, + * output_height, output_width] + * So, it is easy to reshape into a convolution matrix for convolution + * calculation based on matrix multiplication. + * The shape of convolution matrix is [height, width], where the height is equal + * input_channels * filter_depth * filter_height * filter_width, and the width + * is equal output_depth * output_height * output_width. + * + * Reshape: + * shape of colData shape of convolution matrix + * [input_channels, + * filter_depth, + * filter_height, + * filter_width, ======> [height, width] + * output_depth, + * output_height, + * output_width] + * + * \note The caller needs to ensure that volShape.inputChannels is equal to + * colShape.inputChannels. + */ +template +class Vol2ColFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& vol, framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const; +}; + +template +class Col2VolFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& vol, const framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const; +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/vol2col_test.cc b/paddle/operators/math/vol2col_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..81225e9a9803ce371d23620876ac22da63a8e2d1 --- /dev/null +++ b/paddle/operators/math/vol2col_test.cc @@ -0,0 +1,135 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/vol2col.h" +#include +#include + +template +void testVol2col() { + paddle::framework::Tensor input; + paddle::framework::Tensor input_tmp; + paddle::framework::Tensor output; + paddle::framework::Tensor output_tmp; + + auto* place = new Place(); + paddle::platform::DeviceContext* context; + if (paddle::platform::is_cpu_place(*place)) { + context = + new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace()); + } else { +#ifdef PADDLE_WITH_CUDA + context = + new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace()); +#else + PADDLE_THROW("no GPU support"); +#endif // PADDLE_WITH_CUDA + } + + /** + * input = [[0, 1, 2, + * 3, 4, 5] + * [6, 7, 8, + * 9, 10, 11]] + * + * output = [0, 1 + * 1, 2 + * 3, 4 + * 4, 5 + * 6, 7 + * 7, 8 + * 9, 10 + * 10, 11] + * + * col2vol = [[0, 2, 2, + * 3, 8, 5] + * [6, 14, 8, + * 9, 20, 11]] + * + */ + int input_depth = 2; + int input_height = 2; + int input_width = 3; + int filter_size = 2; + int stride = 1; + int padding = 0; + int output_depth = (input_depth - filter_size + 2 * padding) / stride + 1; + int output_height = (input_height - filter_size + 2 * padding) / stride + 1; + int output_width = (input_width - filter_size + 2 * padding) / stride + 1; + + // Vol2Col test + float* input_ptr = + input_tmp.mutable_data({1, input_depth, input_height, input_width}, + paddle::platform::CPUPlace()); + float arr[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + memcpy(input_ptr, arr, 12 * sizeof(float)); + + if (paddle::platform::is_cpu_place(*place)) { + input = input_tmp; + } else { + input.CopyFrom(input_tmp, *place); + } + output.mutable_data({1, filter_size, filter_size, filter_size, + output_depth, output_height, output_width}, + *place); + + paddle::operators::math::Vol2ColFunctor vol2col; + vol2col(*context, input, output, stride, stride, stride, padding, padding, + padding); + + float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11}; + float* out_cfo_ptr; + if (paddle::platform::is_cpu_place(*place)) { + out_cfo_ptr = output.data(); + } else { + output_tmp.CopyFrom(output, paddle::platform::CPUPlace()); + out_cfo_ptr = output_tmp.data(); + } + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(out_cfo_ptr[i], vol_2_col[i]); + } + + // Col2Vol test + float col_2_vol[] = {0, 2, 2, 3, 8, 5, 6, 14, 8, 9, 20, 11}; + memset(input_ptr, 0, 12 * sizeof(float)); + if (paddle::platform::is_cpu_place(*place)) { + input = input_tmp; + } else { + input.CopyFrom(input_tmp, *place); + } + + paddle::operators::math::Col2VolFunctor col2vol; + col2vol(*context, input, output, stride, stride, stride, padding, padding, + padding); + + float* in_ptr; + if (paddle::platform::is_cpu_place(*place)) { + in_ptr = input.data(); + } else { + input_tmp.CopyFrom(input, paddle::platform::CPUPlace()); + in_ptr = input_tmp.data(); + } + + for (int i = 0; i < 12; ++i) { + EXPECT_EQ(in_ptr[i], col_2_vol[i]); + } +} + +TEST(math, vol2col) { + testVol2col(); +#ifdef PADDLE_WITH_CUDA + testVol2col(); +#endif // PADDLE_WITH_CUDA +} diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc index ba3b5ed2075ceca284b49ecddb90ba5950b820c3..c6d9aae13322ebcc9ebbe394d9b9831bd67fe632 100644 --- a/paddle/operators/pool_op.cc +++ b/paddle/operators/pool_op.cc @@ -22,157 +22,181 @@ int OutputSizePool(int input_size, int filter_size, int padding, int stride) { return output_size; } -class PoolOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "X(Input) of Pooling should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Out(Output) of Pooling should not be null."); - - auto in_x_dims = ctx->GetInputDim("X"); - - std::string pooling_type = ctx->Attrs().Get("poolingType"); - std::vector ksize = ctx->Attrs().Get>("ksize"); - std::vector strides = ctx->Attrs().Get>("strides"); - std::vector paddings = ctx->Attrs().Get>("paddings"); - - PADDLE_ENFORCE(pooling_type == "max" || pooling_type == "avg", - "pooling_type should be 'max' or 'avg'"); - PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, - "Pooling intput should be 4-D or 5-D"); - - if (ctx->Attrs().Get("globalPooling")) { - ksize.resize(static_cast(in_x_dims.size()) - 2); - for (size_t i = 0; i < ksize.size(); ++i) - ksize[i] = static_cast(in_x_dims[i + 2]); - } - - PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U, - "Input size and Pooling size should be consistent."); - PADDLE_ENFORCE(ksize.size() == 2 || ksize.size() == 3, - "Pooling size should be 2 elements. or 3 elements."); - PADDLE_ENFORCE_EQ(ksize.size(), strides.size(), - "strides size and pooling size should be the same."); - PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(), - "paddings size and pooling size should be the same."); - - std::vector output_shape({in_x_dims[0], in_x_dims[1]}); - for (size_t i = 0; i < ksize.size(); ++i) { - output_shape.push_back( - OutputSizePool(in_x_dims[i + 2], ksize[i], paddings[i], strides[i])); - } - ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); +void PoolOp::InferShape(framework::InferShapeContext *ctx) const { + PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) of Pooling should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Out(Output) of Pooling should not be null."); + + auto in_x_dims = ctx->GetInputDim("X"); + + std::string pooling_type = ctx->Attrs().Get("poolingType"); + std::vector ksize = ctx->Attrs().Get>("ksize"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + + PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, + "Pooling intput should be 4-D or 5-D tensor."); + + if (ctx->Attrs().Get("globalPooling")) { + ksize.resize(static_cast(in_x_dims.size()) - 2); + for (size_t i = 0; i < ksize.size(); ++i) + ksize[i] = static_cast(in_x_dims[i + 2]); } -}; - -class PoolOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "X(Input) of Pooling should not be null."); - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), - "Input@Grad of Pooling should not be null."); - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + + PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U, + "Input size and pooling size should be consistent."); + PADDLE_ENFORCE_EQ(ksize.size(), strides.size(), + "Strides size and pooling size should be the same."); + PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(), + "Paddings size and pooling size should be the same."); + + std::vector output_shape({in_x_dims[0], in_x_dims[1]}); + for (size_t i = 0; i < ksize.size(); ++i) { + output_shape.push_back( + OutputSizePool(in_x_dims[i + 2], ksize[i], paddings[i], strides[i])); } -}; - -class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { - public: - Pool2dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput( - "X", - "The input tensor of pooling operator. " - "The format of input tensor is NCHW. Where N is batch size, C is the " - "number of channels, H and W is the height and width of feature."); - AddOutput("Out", - "The output tensor of pooling operator." - "The format of output tensor is also NCHW."); - - AddAttr("poolingType", - "PoolingType of pooling operator." - "Str constant equal to 'max' or 'avg'.") - .InEnum({"max", "avg"}); - AddAttr>( - "ksize", - "Pooling size(depth, height, width) of pooling operator." - "If globalPooling = true, ksize is ignored and need not be " - "specified."); // TODO(Add checker) - AddAttr( - "globalPooling", - "Whether to use the globalPooling." - "Bool constant equal to false or true." - "Default false." - "If globalPooling = true, ksize is ignored and need not be specified.") - .SetDefault(false); - AddAttr>("strides", - "Strides(height, width) of pooling operator." - "Default {1,1}") - .SetDefault({1, 1}); // TODO(Add checker) - AddAttr>("paddings", - "Paddings(height, width) of pooling operator." - "Default {0,0}.") - .SetDefault({0, 0}); // TODO(Add checker) - AddComment(R"DOC( + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); +} + +void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Input(X@GRAD) should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); +} + +Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor) The input tensor of pooling operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of feature."); + AddOutput("Out", + "(Tensor) The output tensor of pooling operator." + "The format of output tensor is also NCHW." + "Where N is batch size, C is " + "the number of channels, H and W is the height and " + "width of feature."); + + AddAttr("poolingType", + "PoolingType of pooling operator." + "Str constant equal to 'max' or 'avg'.") + .InEnum({"max", "avg"}); + + AddAttr>( + "ksize", + "The pooling window size(height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be " + "specified."); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + AddAttr( + "globalPooling", + "Whether to use the globalPooling." + "Bool constant equal to false or true." + "Default false." + "If globalPooling = true, ksize is ignored and need not be specified.") + .SetDefault(false); + AddAttr>("strides", + "The strides(height, width) of pooling window." + "Default {1,1}.") + .SetDefault({1, 1}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + AddAttr>("paddings", + "The zero padding(height, width) size on both sides" + "Default {0,0}.") + .SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + + AddComment(R"DOC( The pooling2d operation calculates the output based on the input, poolingType and ksize, strides, paddings parameters. +Input(X) and output(Out) are in NCHW format. Where N is batch size, C is the +number of channels, H and W is the height and width of feature. +Parameters(ksize, strides, paddings) are two elements. +These two elements represent height and width, respectively. +The input(X) size and output(Out) size may be different. + +Example: + Input: + X shape: (N, C, H_in, W_in) + Output: + Out shape: (N, C, H_out, W_out) + Mask shape: (N, C, H_out, W_out) + where + H_out = (H_in - ksize[0] + 2 * paddings[0]) / strides[0] + 1; + W_out = (W_in - ksize[1] + 2 * paddings[1]) / strides[1] + 1; )DOC"); - } -}; - -class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { - public: - Pool3dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", - "The input tensor of pooling operator. " - "The format of input tensor is NCDHW. Where N is batch size, C is " - "the " - "number of channels, D, H and W is the depth, height and width of " - "feature."); - AddOutput("Out", - "The output tensor of pooling operator." - "The format of output tensor is also NCDHW."); - - AddAttr("poolingType", - "PoolingType of pooling operator." - "str constant equal to 'max' or 'avg'.") - .InEnum({"max", "avg"}); - AddAttr>( - "ksize", - "Pooling size(depth, height, width) of pooling operator." - "If globalPooling = true, ksize is ignored and need not be " - "specified."); // TODO(Add checker) - AddAttr( - "globalPooling", - "Whether to use the globalPooling." - "Bool constant equal to false or true." - "Default false." - "If globalPooling = true, ksize is ignored and need not be specified.") - .SetDefault(false); - AddAttr>( - "strides", - "Strides(depth, height, width) of pooling operator." - "Default {1,1,1}.") - .SetDefault({1, 1, 1}); // TODO(Add checker) - AddAttr>( - "paddings", - "Paddings(depth, height, width) of pooling operator." - "Default {0,0,0}.") - .SetDefault({0, 0, 0}); // TODO(Add checker) - AddComment(R"DOC( +} + +Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor) The input tensor of pooling operator. " + "The format of input tensor is NCDHW. Where N is batch size, C is " + "the number of channels, D, H and W is the depth, height and width of " + "feature."); + AddOutput("Out", + "(Tensor) The output tensor of pooling operator." + "The format of output tensor is also NCDHW." + "Where N is batch size, C is " + "the number of channels, D, H and W is the depth, height and " + "width of feature."); + + AddAttr("poolingType", + "PoolingType of pooling operator." + "Str constant equal to 'max' or 'avg'.") + .InEnum({"max", "avg"}); + + AddAttr>( + "ksize", + "The pooling window size(depth, height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be " + "specified."); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + AddAttr( + "globalPooling", + "Whether to use the globalPooling." + "Bool constant equal to false or true." + "Default false." + "If globalPooling = true, ksize is ignored and need not be specified.") + .SetDefault(false); + AddAttr>("strides", + "Strides(depth, height, width) of pooling operator." + "Default {1,1,1}.") + .SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + AddAttr>( + "paddings", + "Paddings(depth, height, width) of pooling operator." + "Default {0,0,0}.") + .SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + + AddComment(R"DOC( The pooling3d operation calculates the output based on the input, poolingType and ksize, strides, paddings parameters. +Input(X) and output(Out) are in NCDHW format. Where N is batch +size, C is the number of channels, D, H and W is the depth, height and +width of feature. Parameters(ksize, strides, paddings) are three elements. +These three elements represent depth, height and width, respectively. +The input(X) size and output(Out) size may be different. + +Example: + Input: + X shape: (N, C, D_in, H_in, W_in) + Output: + Out shape: (N, C, D_out, H_out, W_out) + Mask shape: (N, C, D_out, H_out, W_out) + where + D_out = (D_in - ksize[0] + 2 * paddings[0]) / strides[0] + 1; + H_out = (H_in - ksize[1] + 2 * paddings[1]) / strides[1] + 1; + W_out = (W_in - ksize[2] + 2 * paddings[2]) / strides[2] + 1; )DOC"); - } -}; +} } // namespace operators } // namespace paddle diff --git a/paddle/operators/pool_op.h b/paddle/operators/pool_op.h index c2bc358def42959f2cc8f61cb00436fae1b7514b..e5016d573dde0a9c8a90cddf14f68706b69fade5 100644 --- a/paddle/operators/pool_op.h +++ b/paddle/operators/pool_op.h @@ -24,6 +24,34 @@ namespace operators { using Tensor = framework::Tensor; +class PoolOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override; +}; + +class PoolOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override; +}; + +class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Pool2dOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker); +}; + +class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Pool3dOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker); +}; + template class PoolKernel : public framework::OpKernel { public: diff --git a/paddle/operators/pool_with_index_op.cc b/paddle/operators/pool_with_index_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..005ee886934b193064cc739638398b3535db9274 --- /dev/null +++ b/paddle/operators/pool_with_index_op.cc @@ -0,0 +1,253 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/pool_with_index_op.h" + +namespace paddle { +namespace operators { + +inline int OutputSizeMaxPool(int input_size, int filter_size, int padding, + int stride) { + int output_size = (input_size - filter_size + 2 * padding) / stride + 1; + return output_size; +} + +class MaxPoolWithIndexOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "X(Input) of Pooling should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Out(Output) of Pooling should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Mask"), + "Mask(Output) of Pooling should not be null."); + + auto in_x_dims = ctx->GetInputDim("X"); + + std::vector ksize = ctx->Attrs().Get>("ksize"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + + PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, + "Pooling intput should be 4-D or 5-D tensor."); + + if (ctx->Attrs().Get("globalPooling")) { + ksize.resize(static_cast(in_x_dims.size()) - 2); + for (size_t i = 0; i < ksize.size(); ++i) + ksize[i] = static_cast(in_x_dims[i + 2]); + } + + PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U, + "Input size and pooling size should be consistent."); + PADDLE_ENFORCE_EQ(ksize.size(), strides.size(), + "Strides size and pooling size should be the same."); + PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(), + "Paddings size and pooling size should be the same."); + + std::vector output_shape({in_x_dims[0], in_x_dims[1]}); + for (size_t i = 0; i < ksize.size(); ++i) { + output_shape.push_back(OutputSizeMaxPool(in_x_dims[i + 2], ksize[i], + paddings[i], strides[i])); + } + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + ctx->SetOutputDim("Mask", framework::make_ddim(output_shape)); + } +}; + +class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Mask"), "Input(Mask) must not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Input(X@GRAD) should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; + +class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MaxPool2dWithIndexOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor) The input tensor of pooling operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of image."); + AddOutput("Out", + "(Tensor) The output tensor of pooling operator." + "The format of output tensor is also NCHW." + "Where N is batch size, C is " + "the number of channels, H and W is the height and " + "width of image."); + AddOutput("Mask", + "(Tensor) The Mask tensor of pooling operator." + "The format of output tensor is also NCHW." + "Where N is batch size, C is the number of channels, H and W " + "is the height and width of image." + "The value in it is the index in current feature map"); + + AddAttr>( + "ksize", + "The pooling window size(height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be " + "specified."); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + AddAttr( + "globalPooling", + "Whether to use the globalPooling." + "Bool constant equal to false or true." + "Default false." + "If globalPooling = true, ksize is ignored and need not be specified.") + .SetDefault(false); + AddAttr>("strides", + "The strides(height, width) of pooling window." + "Default {1,1}.") + .SetDefault({1, 1}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + AddAttr>( + "paddings", + "The zero padding(height, width) size on both sides" + "Default {0,0}.") + .SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + + AddComment(R"DOC( +The maxPooling2d with index operation calculates the output and the mask +based on the input and ksize, strides, paddings parameters. Input(X) and +output(Out, Mask) are in NCHW format. Where N is batch size, C is the +number of channels, H and W is the height and width of feature. +Parameters(ksize, strides, paddings) are two elements. +These two elements represent height and width, respectively. +The input(X) size and output(Out, Mask) size may be different. + +Example: + Input: + X shape: (N, C, H_in, W_in) + Output: + Out shape: (N, C, H_out, W_out) + Mask shape: (N, C, H_out, W_out) + where + H_out = (H_in - ksize[0] + 2 * paddings[0]) / strides[0] + 1; + W_out = (W_in - ksize[1] + 2 * paddings[1]) / strides[1] + 1; +)DOC"); + } +}; + +class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MaxPool3dWithIndexOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor) The input tensor of pooling operator. " + "The format of input tensor is NCDHW. Where N is batch size, C is " + "the number of channels, D, H and W is the depth, height and width of " + "image."); + AddOutput("Out", + "(Tensor) The output tensor of pooling operator." + "The format of output tensor is also NCDHW." + "Where N is batch size, C is " + "the number of channels, D, H and W is the depth, height and " + "width of image."); + AddOutput("Mask", + "(Tensor) The Mask tensor of pooling operator." + "The format of output tensor is also NCDHW." + "Where N is batch size, C is the number of channels, D, H and W " + "is the depth, height and width of image." + "The value in it is the index in current feature map"); + + AddAttr>( + "ksize", + "The pooling window size(depth, height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be " + "specified."); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + AddAttr( + "globalPooling", + "Whether to use the globalPooling." + "Bool constant equal to false or true." + "Default false." + "If globalPooling = true, ksize is ignored and need not be specified.") + .SetDefault(false); + AddAttr>( + "strides", + "Strides(depth, height, width) of pooling operator." + "Default {1,1,1}.") + .SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + AddAttr>( + "paddings", + "Paddings(depth, height, width) of pooling operator." + "Default {0,0,0}.") + .SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently, + // TypedAttrChecker don't support vector type.) + + AddComment(R"DOC( +The maxpooling3d with index operation calculates the output and the mask +based on the input and ksize, strides, paddings parameters. +Input(X) and output(Out, Mask) are in NCDHW format. Where N is batch +size, C is the number of channels, D, H and W is the depth, height and +width of feature. Parameters(ksize, strides, paddings) are three elements. +These three elements represent depth, height and width, respectively. +The input(X) size and output(Out, Mask) size may be different. + +Example: + Input: + X shape: (N, C, D_in, H_in, W_in) + Output: + Out shape: (N, C, D_out, H_out, W_out) + Mask shape: (N, C, D_out, H_out, W_out) + where + D_out = (D_in - ksize[0] + 2 * paddings[0]) / strides[0] + 1; + H_out = (H_in - ksize[1] + 2 * paddings[1]) / strides[1] + 1; + W_out = (W_in - ksize[2] + 2 * paddings[2]) / strides[2] + 1; +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP(max_pool2d_with_index, ops::MaxPoolWithIndexOp, + ops::MaxPool2dWithIndexOpMaker, max_pool2d_with_index_grad, + ops::MaxPoolWithIndexOpGrad); + +REGISTER_OP_CPU_KERNEL( + max_pool2d_with_index, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_CPU_KERNEL( + max_pool2d_with_index_grad, + ops::MaxPoolWithIndexGradKernel) + +REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp, + ops::MaxPool3dWithIndexOpMaker, max_pool3d_with_index_grad, + ops::MaxPoolWithIndexOpGrad); + +REGISTER_OP_CPU_KERNEL( + max_pool3d_with_index, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_CPU_KERNEL( + max_pool3d_with_index_grad, + ops::MaxPoolWithIndexGradKernel) diff --git a/paddle/operators/pool_with_index_op.cu b/paddle/operators/pool_with_index_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..287657d4b1c57f354ef050885f71261092bdc062 --- /dev/null +++ b/paddle/operators/pool_with_index_op.cu @@ -0,0 +1,31 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/pool_with_index_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + max_pool2d_with_index, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_GPU_KERNEL( + max_pool2d_with_index_grad, + ops::MaxPoolWithIndexGradKernel) + +REGISTER_OP_GPU_KERNEL( + max_pool3d_with_index, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_GPU_KERNEL( + max_pool3d_with_index_grad, + ops::MaxPoolWithIndexGradKernel) diff --git a/paddle/operators/pool_with_index_op.h b/paddle/operators/pool_with_index_op.h new file mode 100644 index 0000000000000000000000000000000000000000..01b961ca8295f723bea7335e43ec5ab100dfc65c --- /dev/null +++ b/paddle/operators/pool_with_index_op.h @@ -0,0 +1,103 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/pooling.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class MaxPoolWithIndexKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* in_x = context.Input("X"); + Tensor* out = context.Output("Out"); + Tensor* mask = context.Output("Mask"); + + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + if (context.Attr("globalPooling")) { + for (size_t i = 0; i < ksize.size(); ++i) { + ksize[i] = static_cast(in_x->dims()[i + 2]); + } + } + + switch (ksize.size()) { + case 2: { + paddle::operators::math::MaxPool2dWithIndexFunctor + pool2d_forward; + pool2d_forward(context.device_context(), *in_x, *out, *mask, ksize, + strides, paddings); + } break; + case 3: { + paddle::operators::math::MaxPool3dWithIndexFunctor + pool3d_forward; + pool3d_forward(context.device_context(), *in_x, *out, *mask, ksize, + strides, paddings); + } break; + } + } +}; + +template +class MaxPoolWithIndexGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* mask = context.Input("Mask"); + const Tensor* out_grad = + context.Input(framework::GradVarName("Out")); + Tensor* in_x_grad = context.Output(framework::GradVarName("X")); + + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + if (context.Attr("globalPooling")) { + for (size_t i = 0; i < ksize.size(); ++i) { + ksize[i] = static_cast(in_x_grad->dims()[i + 2]); + } + } + + if (in_x_grad) { + in_x_grad->mutable_data(context.GetPlace()); + auto temp = framework::EigenVector::Flatten(*in_x_grad); + temp.device(context.GetEigenDevice()) = + temp.constant(static_cast(0)); + + switch (ksize.size()) { + case 2: { + paddle::operators::math::MaxPool2dWithIndexGradFunctor + pool2d_backward; + pool2d_backward(context.device_context(), *in_x_grad, *out_grad, + *mask, ksize, strides, paddings); + } break; + case 3: { + paddle::operators::math::MaxPool3dWithIndexGradFunctor + pool3d_backward; + pool3d_backward(context.device_context(), *in_x_grad, *out_grad, + *mask, ksize, strides, paddings); + } break; + } + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/sequence_concat_op.cc b/paddle/operators/sequence_concat_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..287fb1942e4a2b17f6d51c9a6b7f6fb71fbaa601 --- /dev/null +++ b/paddle/operators/sequence_concat_op.cc @@ -0,0 +1,129 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_concat_op.h" + +namespace paddle { +namespace operators { + +class SequenceConcatOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInputs("X"), + "Inputs(X) of SequenceConcatOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceConcatOp should not be null."); + const size_t level = static_cast(ctx->Attrs().Get("level")); + const size_t axis = static_cast(ctx->Attrs().Get("axis")); + PADDLE_ENFORCE(level == 0UL || level == 1UL, + "The sequence_concat operator only accepts sequence " + "or a nested sequence as its input."); + auto ins_dims = ctx->GetInputsDim("X"); + framework::DDim out_dims = ins_dims[0]; + const size_t n = ins_dims.size(); + for (size_t i = 1; i < n; ++i) { + out_dims[axis] += ins_dims[i][axis]; + } + ctx->SetOutputDim("Out", out_dims); + } +}; + +class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SequenceConcatOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(A vector of LoDTensor), the input is a vector of LoDTensor, " + "each of which is a variable-length sequence or nested sequence.") + .AsDuplicable(); + AddOutput("Out", + "(A LoDTensor), the variable-length output of " + "sequence_concat Op."); + AddAttr("axis", + "(int, default 0)" + "The axis which the inputs will be joined with. " + "If axis is 0, the inputs will be joined with LoD index.") + .SetDefault(0); + AddAttr("level", + "(int, default 0)" + "The level at which the inputs will be joined. " + "If the level is 0, the inputs will be joined at the nested " + "sequence level. " + "If the level is 1, the inputs will be joined at the " + "sequence level. " + "The level should be less than the level number of inputs.") + .SetDefault(0); + AddComment(R"DOC( + The sequence_concat operator concatenates multiple LoDTensors. + It only supports sequence (LoD Tensor with level number is 1) + or a nested sequence (LoD tensor with level number is 2) as its input. + - Case1: + If the axis is other than 0(here, axis is 1 and level is 1), + each input should have the same LoD information and the LoD + information of the output keeps the same as the input. + + LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) + LoD(x1) = {{0,2,4}, {0,1,2,3,4}}; Dims(x1) = (4,4,4) + LoD(Out) = {{0,2,4}, {0,1,2,3,4}}; Dims(Out) = (4,7,4) + + - Case2: + If the axis is 0(here, leve is 0), the inputs are concatenated along + time steps, the LoD information of the output need to re-compute. + + LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) + LoD(x1) = {{0,3,5}, {0,1,2,3,5}}; Dims(x1) = (5,3,4) + LoD(Out) = {{0,5,9}, {0,1,2,3,4,5,6,7,9}}; Dims(Out) = (9,3,4) + + - Case3: + If the axis is 0(here, level is 1). + + LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) + LoD(x1) = {{0,3,5}, {0,1,3,4,5}}; Dims(x1) = (5,3,4) + LoD(Out) = {{0,5,9}, {0,2,5,7,9}}; Dims(Out) = (9,3,4) + + NOTE: The levels of all the inputs should be the same. + )DOC"); + } +}; + +class SequenceConcatGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "The gradient of Out should not be null."); + PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")), + "The gradient of X should not be null."); + ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(sequence_concat, ops::SequenceConcatOp, ops::SequenceConcatOpMaker, + sequence_concat_grad, ops::SequenceConcatGradOp); +REGISTER_OP_CPU_KERNEL( + sequence_concat, + ops::SequenceConcatOpKernel); +REGISTER_OP_CPU_KERNEL( + sequence_concat_grad, + ops::SequenceConcatGradOpKernel); diff --git a/paddle/operators/sequence_concat_op.cu b/paddle/operators/sequence_concat_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..8dc4764785871262d21a5631cc9e8b805ba84244 --- /dev/null +++ b/paddle/operators/sequence_concat_op.cu @@ -0,0 +1,25 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU + +#include "paddle/operators/sequence_concat_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + sequence_concat, + ops::SequenceConcatOpKernel); +REGISTER_OP_GPU_KERNEL( + sequence_concat_grad, + ops::SequenceConcatGradOpKernel); diff --git a/paddle/operators/sequence_concat_op.h b/paddle/operators/sequence_concat_op.h new file mode 100644 index 0000000000000000000000000000000000000000..a197a05bbb881806b24f9dcce5282a4d972e3adc --- /dev/null +++ b/paddle/operators/sequence_concat_op.h @@ -0,0 +1,155 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/strided_memcpy.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using LoD = framework::LoD; + +template +LoD concatLoD(const std::vector ins, const size_t axis, + const size_t level) { + auto out_lod = ins[0]->lod(); + const size_t n = ins.size(); + if (axis == 0UL) { + for (size_t i = 1; i < n; ++i) { + for (size_t j = 0; j < ins[i]->lod()[0].size(); ++j) { + out_lod[0][j] += ins[i]->lod()[0][j]; + } + + if (ins[0]->NumLevels() == 2) { + for (size_t j = 1; j < ins[i]->lod()[1].size(); ++j) { + if (level == 0UL) { + out_lod[1].push_back(out_lod[1].back() + ins[i]->lod()[1][j] - + ins[i]->lod()[1][j - 1]); + } else if (level == 1UL) { + out_lod[1][j] += ins[1]->lod()[1][j]; + } + } + } + } + } + return out_lod; +} + +template +class SequenceConcatOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("X"); + auto* out = ctx.Output("Out"); + const size_t axis = static_cast(ctx.Attr("axis")); + const size_t level = static_cast(ctx.Attr("level")); + const size_t n = ins.size(); + + for (size_t i = 1; i < n; ++i) { + PADDLE_ENFORCE_EQ(ins[0]->NumLevels(), ins[i]->NumLevels(), + "The levels of all the input LoDTensors " + "should be the same."); + PADDLE_ENFORCE_EQ(ins[0]->dims().size(), ins[i]->dims().size(), + "The dimension size of all the input LoDTensors " + "should be the same."); + + const size_t dims_size = ins[i]->dims().size(); + for (size_t j = 0; j < dims_size; ++j) { + if (j == axis) continue; + PADDLE_ENFORCE_EQ(ins[0]->dims()[j], ins[i]->dims()[j], + "Except for the dimension of the specified " + "axis along which all the inputs are concatenated, " + "dimensions of all the other axises of the input " + "LoDTensors should be the same."); + } + } + PADDLE_ENFORCE_GT(ins[0]->NumLevels(), level, + "The levels of all the input LoDTensors " + "should be greater than the specify level"); + + out->mutable_data(ctx.GetPlace()); + auto out_lod = concatLoD(ins, axis, level); + out->set_lod(out_lod); + + auto out_lod_level = out_lod[level]; + for (size_t i = 0; i < out_lod_level.size() - 1; ++i) { + Tensor out_t = out->Slice(static_cast(out_lod_level[i]), + static_cast(out_lod_level[i + 1])); + auto out_stride = framework::stride(out_t.dims()); + size_t offset = 0; + + for (size_t j = 0; j < n; ++j) { + auto in_lod_level = ins[j]->lod()[level]; + auto in_stride = framework::stride(ins[j]->dims()); + Tensor in_t = ins[j]->Slice(static_cast(in_lod_level[i]), + static_cast(in_lod_level[i + 1])); + size_t axis_dim = in_t.dims()[axis]; + StridedMemcpy(ctx.device_context(), in_t.data(), in_stride, + in_t.dims(), out_stride, out_t.data() + offset); + offset += axis_dim * in_stride[axis]; + } + } + } +}; + +template +class SequenceConcatGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("X"); + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + auto x_grads = + ctx.MultiOutput(framework::GradVarName("X")); + size_t axis = static_cast(ctx.Attr("axis")); + size_t level = static_cast(ctx.Attr("level")); + const size_t n = x_grads.size(); + + // Set Grad(X) LoD as X + for (size_t i = 0; i < n; i++) { + x_grads[i]->set_lod(ins[i]->lod()); + x_grads[i]->mutable_data(ctx.GetPlace()); + } + + auto out_lod = concatLoD(ins, axis, level); + auto out_lod_level = out_lod[level]; + + for (size_t i = 0; i < out_lod_level.size() - 1; ++i) { + Tensor out_grad_t = + out_grad->Slice(static_cast(out_lod_level[i]), + static_cast(out_lod_level[i + 1])); + auto out_grad_stride = framework::stride(out_grad_t.dims()); + size_t offset = 0; + + for (size_t j = 0; j < n; ++j) { + auto x_grad_lod_level = x_grads[j]->lod()[level]; + auto x_grad_stride = framework::stride(x_grads[j]->dims()); + Tensor x_grad_t = + x_grads[j]->Slice(static_cast(x_grad_lod_level[i]), + static_cast(x_grad_lod_level[i + 1])); + size_t axis_dim = x_grad_t.dims()[axis]; + StridedMemcpy(ctx.device_context(), out_grad_t.data() + offset, + out_grad_stride, out_grad_t.dims(), x_grad_stride, + x_grad_t.data()); + offset += axis_dim * out_grad_stride[axis]; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index a9b6b799036a4f2ba93ef52398131db4fcb599f5..36450e926891342f37424447703781a33c1190ae 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -136,7 +136,7 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; } -#endif // PADDLE_ONLY_CPU +#endif } // namespace platform } // namespace paddle diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 15d8446cd8dceb2fdc03536e1f7bbcde73403a23..cd906c3fa9375cd6edaed0377a596771e25043d4 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -41,7 +41,7 @@ limitations under the License. */ #include #include -#endif // PADDLE_ONLY_CPU +#endif namespace paddle { namespace platform { diff --git a/paddle/platform/gpu_info.cc b/paddle/platform/gpu_info.cc index 70ad611d5dd61937e6bf7d980e34b5c9023977b2..0cab5ffc5609bbd6fd08c74329d8370fb95f8102 100644 --- a/paddle/platform/gpu_info.cc +++ b/paddle/platform/gpu_info.cc @@ -43,6 +43,8 @@ int GetCurrentDeviceId() { } void SetDeviceId(int id) { + // TODO(qijun): find a better way to cache the cuda device count + PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); PADDLE_ENFORCE(cudaSetDevice(id), "cudaSetDevice failed in paddle::platform::SetDeviceId"); } diff --git a/paddle/platform/gpu_info.h b/paddle/platform/gpu_info.h index fb33db07bd54d37dec2e5d687ecefb01cc330e44..37665b97d764fbcfe0964127d230b1d28d90b687 100644 --- a/paddle/platform/gpu_info.h +++ b/paddle/platform/gpu_info.h @@ -63,4 +63,4 @@ void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device, } // namespace platform } // namespace paddle -#endif // PADDLE_ONLY_CPU +#endif diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 47bd7bc3bb3f0737ba3c9efe5b49defed87f36a1..116c99bd2c1ca59b093392f9e6cc481c089309bc 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -166,7 +166,9 @@ void BindVarDsec(py::module &m) { .def("set_shape", &VarDescBind::SetShape) .def("set_data_type", &VarDescBind::SetDataType) .def("shape", &VarDescBind::Shape, py::return_value_policy::reference) - .def("data_type", &VarDescBind::GetDataType); + .def("data_type", &VarDescBind::GetDataType) + .def("lod_level", &VarDescBind::GetLodLevel) + .def("set_lod_level", &VarDescBind::SetLoDLevel); } void BindOpDesc(py::module &m) { @@ -196,7 +198,8 @@ void BindOpDesc(py::module &m) { .def("set_attr", &OpDescBind::SetAttr) .def("attr", &OpDescBind::GetAttr) .def("set_block_attr", &OpDescBind::SetBlockAttr) - .def("get_block_attr", &OpDescBind::GetBlockAttr); + .def("get_block_attr", &OpDescBind::GetBlockAttr) + .def("infer_shape", &OpDescBind::InferShape); } } // namespace pybind diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 356c4986e2e182e904215f7ebb8cac5146364f8b..0f6e3101e26c5ac249664ce8badc10adc939305f 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -231,21 +231,6 @@ All parameter, weight, gradient are variables in Paddle. desc.InitializationErrorString()); return OpRegistry::CreateOp(desc); }) - .def_static("infer_shape", - [](OpDescBind &op_desc, BlockDescBind &block) { - auto op = OpRegistry::CreateOp(*op_desc.Proto()); - auto *op_with_kernel = - dynamic_cast(op.get()); - if (op_with_kernel != nullptr) { - auto ctx = CompileTimeInferShapeContext(op_desc, block); - op_with_kernel->InferShape(&ctx); - } else { - PADDLE_THROW( - "OP(%s) is not type of OperatorWithKernel, " - "should not call this function", - op_desc.Type()); - } - }) .def("backward", [](const OperatorBase &forwardOp, const std::unordered_set &no_grad_vars) { diff --git a/proto/CMakeLists.txt b/proto/CMakeLists.txt index 6212c2e60a8ed94ecc1d6e58535a2b3d365e3eb8..5d898d860cfc6dc26eaf5a81d8aed6d757ed5831 100644 --- a/proto/CMakeLists.txt +++ b/proto/CMakeLists.txt @@ -1,4 +1,10 @@ -file(GLOB proto_filenames . *.proto) +if (MOBILE_INFERENCE) + file(GLOB proto_filenames . ModelConfig.proto ParameterConfig.proto + TrainerConfig.proto DataConfig.proto) +else() + file(GLOB proto_filenames . *.proto) +endif() + include_directories(${CMAKE_CURRENT_BINARY_DIR}) proto_library(paddle_proto SRCS ${proto_filenames}) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index d37f29d2c4bf9177398ea82c99bc40affdd952c2..5043fb811de09638ef8261eb6a9c7d21685bf29c 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -318,7 +318,7 @@ class LayerOutput(object): :param activation: Layer Activation. :type activation: BaseActivation. :param parents: Layer's parents. - :type parents: list|tuple|collections.Sequence + :type parents: list | tuple | collections.Sequence """ def __init__(self, @@ -435,7 +435,7 @@ def full_matrix_projection(input, size=0, param_attr=None): size=100, param_attr=ParamAttr(name='_proj')) - :param input: input layer + :param input: The input of this layer. :type input: LayerOutput :param size: The parameter size. Means the width of parameter. :type size: int @@ -471,7 +471,7 @@ def trans_full_matrix_projection(input, size=0, param_attr=None): initial_mean=0.0, initial_std=0.01)) - :param input: input layer + :param input: The input of this layer. :type input: LayerOutput :param size: The parameter size. Means the width of parameter. :type size: int @@ -516,7 +516,7 @@ def table_projection(input, size=0, param_attr=None): param_attr=ParamAttr(name='_proj')) - :param input: Input layer, which must contains id fields. + :param input: The input of this layer, which must contains id fields. :type input: LayerOutput :param size: The parameter size. Means the width of parameter. :type size: int @@ -561,7 +561,7 @@ def identity_projection(input, offset=None, size=None): Note that both of two projections should not have any parameter. - :param input: Input Layer. + :param input: The input of this layer. :type input: LayerOutput :param offset: Offset, None if use default. :type offset: int @@ -596,7 +596,7 @@ def slice_projection(input, slices): Note that slice_projection should not have any parameter. - :param input: Input Layer. + :param input: The input of this layer. :type input: LayerOutput :param slices: An array of slice parameters. Each slice contains the start and end offsets based @@ -634,7 +634,7 @@ def scaling_projection(input, param_attr=None): proj = scaling_projection(input=layer) - :param input: Input Layer. + :param input: The input of this layer. :type input: LayerOutput :param param_attr: Parameter config, None if use default. :type param_attr: ParameterAttribute @@ -663,7 +663,7 @@ def dotmul_projection(input, param_attr=None): proj = dotmul_projection(input=layer) - :param input: Input layer. + :param input: The input of this layer. :type input: LayerOutput :param param_attr: Parameter config, None if use default. :type param_attr: ParameterAttribute @@ -734,7 +734,7 @@ def context_projection(input, after context projection and not set padding_attr, sequence will be [ 0AB ABC BCD CDE DEF EFG FG0 ]. - :param input: Input Sequence. + :param input: The input of this layer, which should be a sequence. :type input: LayerOutput :param context_len: context length. :type context_len: int @@ -744,7 +744,7 @@ def context_projection(input, :param padding_attr: Padding Parameter Attribute. If false, it means padding always be zero. Otherwise Padding is learnable, and parameter attribute is set by this parameter. - :type padding_attr: bool|ParameterAttribute + :type padding_attr: bool | ParameterAttribute :return: Projection :rtype: Projection """ @@ -782,13 +782,13 @@ class MixedLayerType(LayerOutput): :type name: basestring :param size: layer size. :type size: int - :param act: activation type. + :param act: Activation type. :type act: BaseActivation :param bias_attr: The Bias Attribute. If the parameter is set to False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :param layer_attr: Extra Layer Attribute. :type layer_attr: ExtraLayerAttribute or None """ @@ -880,15 +880,15 @@ def mixed_layer(size=0, :type name: basestring :param size: layer size. :type size: int - :param input: inputs layer. It is an optional parameter. If set, + :param input: The input of this layer. It is an optional parameter. If set, then this function will just return layer's name. - :param act: Activation Type. + :param act: Activation Type. LinearActivation is the default. :type act: BaseActivation :param bias_attr: The Bias Attribute. If the parameter is set to False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :param layer_attr: The extra layer config. Default is None. :type layer_attr: ExtraLayerAttribute :return: MixedLayerType object can add inputs or layer name. @@ -929,9 +929,9 @@ def data_layer(name, size, depth=None, height=None, width=None, :param size: Size of this data layer. :type size: int :param height: Height of this data layer, used for image - :type height: int|None + :type height: int | None :param width: Width of this data layer, used for image - :type width: int|None + :type width: int | None :param layer_attr: Extra Layer Attribute. :type layer_attr: ExtraLayerAttribute. :return: LayerOutput object. @@ -966,15 +966,15 @@ def embedding_layer(input, size, name=None, param_attr=None, layer_attr=None): :param name: The name of this layer. It is optional. :type name: basestring - :param input: The input layer for this embedding. NOTE: must be Index Data. + :param input: The input of this layer, which must be Index Data. :type input: LayerOutput :param size: The embedding dimension. :type size: int :param param_attr: The embedding parameter attribute. See ParameterAttribute for details. - :type param_attr: ParameterAttribute|None + :type param_attr: ParameterAttribute | None :param layer_attr: Extra layer Config. Default is None. - :type layer_attr: ExtraLayerAttribute|None + :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput """ @@ -1021,11 +1021,11 @@ def fc_layer(input, :param name: The name of this layer. It is optional. :type name: basestring - :param input: The input layer. Could be a list/tuple of input layer. - :type input: LayerOutput|list|tuple + :param input: The input of this layer. + :type input: LayerOutput | list | tuple :param size: The layer dimension. :type size: int - :param act: Activation Type. Default is tanh. + :param act: Activation Type. TanhActivation is the default. :type act: BaseActivation :param param_attr: The Parameter Attribute|list. :type param_attr: ParameterAttribute @@ -1033,9 +1033,9 @@ def fc_layer(input, False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute|None + :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput """ @@ -1072,8 +1072,8 @@ def printer_layer(input, format=None, name=None): :param name: The name of this layer. It is optional. :type name: basestring - :param input: The input layer. Could be a list/tuple of input layer. - :type input: LayerOutput|list|tuple + :param input: The input of this layer. + :type input: LayerOutput | list | tuple :return: LayerOutput """ if isinstance(input, LayerOutput): @@ -1110,7 +1110,7 @@ def priorbox_layer(input, :param name: The name of this layer. It is optional. :type name: basestring - :param input: The input layer. + :param input: The input of this layer. :type input: LayerOutput :param image: The network input image. :type image: LayerOutput @@ -1306,7 +1306,7 @@ def cross_channel_norm_layer(input, name=None, param_attr=None): :param name: The name of this layer. It is optional. :type name: basestring - :param input: The input layer. + :param input: The input of this layer. :type input: LayerOutput :param param_attr: The Parameter Attribute|list. :type param_attr: ParameterAttribute @@ -1371,20 +1371,20 @@ def pooling_layer(input, :type agg_level: AggregateLevel :param name: The name of this layer. It is optional. :type name: basestring - :param input: input layer name. + :param input: The input of this layer. :type input: LayerOutput :param pooling_type: Type of pooling, MaxPooling(default), AvgPooling, SumPooling, SquareRootNPooling. - :type pooling_type: BasePoolingType|None + :type pooling_type: BasePoolingType | None :param stride: The step size between successive pooling regions. :type stride: Int :param bias_attr: The Bias Attribute. If the parameter is set to False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :param layer_attr: The Extra Attributes for layer, such as dropout. - :type layer_attr: ExtraLayerAttribute|None + :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput """ @@ -1469,11 +1469,11 @@ def lstmemory(input, :type name: basestring :param size: DEPRECATED. size of the lstm cell :type size: int - :param input: input layer name. + :param input: The input of this layer. :type input: LayerOutput :param reverse: is sequence process reversed or not. :type reverse: bool - :param act: activation type, TanhActivation by default. :math:`h_t` + :param act: Activation type. TanhActivation is the default. :math:`h_t` :type act: BaseActivation :param gate_act: gate activation type, SigmoidActivation by default. :type gate_act: BaseActivation @@ -1483,11 +1483,11 @@ def lstmemory(input, False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :param param_attr: Parameter Attribute. - :type param_attr: ParameterAttribute|None|False + :type param_attr: ParameterAttribute | None | False :param layer_attr: Extra Layer attribute - :type layer_attr: ExtraLayerAttribute|None + :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput """ @@ -1591,14 +1591,14 @@ def grumemory(input, gru = grumemory(input) :param name: The gru layer name. - :type name: None|basestring - :param input: input layer. + :type name: None | basestring + :param input: The input of this layer. :type input: LayerOutput. :param size: DEPRECATED. size of the gru cell :type size: int :param reverse: Whether sequence process is reversed or not. :type reverse: bool - :param act: activation type, TanhActivation by default. This activation + :param act: Activation type, TanhActivation is the default. This activation affects the :math:`{\\tilde{h_t}}`. :type act: BaseActivation :param gate_act: gate activation type, SigmoidActivation by default. @@ -1609,11 +1609,11 @@ def grumemory(input, False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :param param_attr: Parameter Attribute. - :type param_attr: ParameterAttribute|None|False + :type param_attr: ParameterAttribute | None | False :param layer_attr: Extra Layer attribute - :type layer_attr: ExtraLayerAttribute|None + :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput """ @@ -1670,7 +1670,7 @@ def last_seq(input, :param agg_level: Aggregated level :param name: The name of this layer. It is optional. :type name: basestring - :param input: Input layer name. + :param input: The input of this layer. :type input: LayerOutput :param stride: The step size between successive pooling regions. :type stride: Int @@ -1726,7 +1726,7 @@ def first_seq(input, :param agg_level: aggregation level :param name: The name of this layer. It is optional. :type name: basestring - :param input: Input layer name. + :param input: The input of this layer. :type input: LayerOutput :param stride: The step size between successive pooling regions. :type stride: Int @@ -1799,7 +1799,7 @@ def expand_layer(input, expand_as=layer2, expand_level=ExpandLevel.FROM_NO_SEQUENCE) - :param input: Input layer + :param input: The input of this layer. :type input: LayerOutput :param expand_as: Expand as this layer's sequence info. :type expand_as: LayerOutput @@ -1809,7 +1809,7 @@ def expand_layer(input, False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :param expand_level: whether input layer is timestep(default) or sequence. :type expand_level: ExpandLevel :param layer_attr: extra layer attributes. @@ -1858,7 +1858,7 @@ def repeat_layer(input, expand = repeat_layer(input=layer, num_repeats=4) - :param input: Input layer + :param input: The input of this layer. :type input: LayerOutput :param num_repeats: Repeat the input so many times :type num_repeats: int @@ -1869,7 +1869,7 @@ def repeat_layer(input, False for treating input as column vector and repeating in the row direction. :type as_row_vector: bool - :param act: Activation type. + :param act: Activation type. IdentityActivation is the default. :type act: BaseActivation :type name: basestring :param layer_attr: extra layer attributes. @@ -1917,13 +1917,13 @@ def seq_reshape_layer(input, reshape = seq_reshape_layer(input=layer, reshape_size=4) - :param input: Input layer. + :param input: The input of this layer. :type input: LayerOutput :param reshape_size: the size of reshaped sequence. :type reshape_size: int :param name: The name of this layer. It is optional. :type name: basestring - :param act: Activation type. + :param act: Activation type. IdentityActivation is the default. :type act: BaseActivation :param layer_attr: extra layer attributes. :type layer_attr: ExtraLayerAttribute. @@ -1931,7 +1931,7 @@ def seq_reshape_layer(input, False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :return: LayerOutput object. :rtype: LayerOutput """ @@ -1970,8 +1970,8 @@ def interpolation_layer(input, weight, name=None, layer_attr=None): interpolation = interpolation_layer(input=[layer1, layer2], weight=layer3) - :param input: Input layer. - :type input: list|tuple + :param input: The input of this layer. + :type input: list | tuple :param weight: Weight layer. :type weight: LayerOutput :param name: The name of this layer. It is optional. @@ -2023,11 +2023,11 @@ def bilinear_interp_layer(input, :param input: A input layer. :type input: LayerOutput. :param out_size_x: bilinear interpolation output width. - :type out_size_x: int|None + :type out_size_x: int | None :param out_size_y: bilinear interpolation output height. - :type out_size_y: int|None + :type out_size_y: int | None :param name: The layer's name, which cna not be specified. - :type name: None|basestring + :type name: None | basestring :param layer_attr: Extra Layer attribute. :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. @@ -2075,7 +2075,7 @@ def power_layer(input, weight, name=None, layer_attr=None): power = power_layer(input=layer1, weight=layer2) - :param input: Input layer. + :param input: The input of this layer. :type input: LayerOutput :param weight: Weight layer. :type weight: LayerOutput @@ -2119,7 +2119,7 @@ def scaling_layer(input, weight, name=None, layer_attr=None): scale = scaling_layer(input=layer1, weight=layer2) - :param input: Input layer. + :param input: The input of this layer. :type input: LayerOutput :param weight: Weight layer. :type weight: LayerOutput @@ -2159,7 +2159,7 @@ def trans_layer(input, name=None, layer_attr=None): trans = trans_layer(input=layer) - :param input: Input layer. + :param input: The input of this layer. :type input: LayerOutput :param name: The name of this layer. It is optional. :type name: basestring @@ -2197,7 +2197,7 @@ def rotate_layer(input, height, width, name=None, layer_attr=None): height=100, width=100) - :param input: Input layer. + :param input: The input of this layer. :type input: LayerOutput :param height: The height of the sample matrix :type height: int @@ -2306,22 +2306,21 @@ def hsigmoid(input, cost = hsigmoid(input=[layer1, layer2], label=data_layer) - :param input: Input layers. It could be a LayerOutput or list/tuple of - LayerOutput. - :type input: LayerOutput|list|tuple + :param input: The input of this layer. + :type input: LayerOutput | list | tuple :param label: Label layer. :type label: LayerOutput :param num_classes: number of classes. - :type num_classes: int|None + :type num_classes: int | None :param name: The name of this layer. It is optional. :type name: basestring :param bias_attr: The Bias Attribute. If the parameter is set to False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :param param_attr: Parameter Attribute. None means default parameter. - :type param_attr: ParameterAttribute|None + :type param_attr: ParameterAttribute | None :param layer_attr: Extra Layer Attribute. :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. @@ -2429,40 +2428,40 @@ def img_conv_layer(input, :param name: The name of this layer. It is optional. :type name: basestring - :param input: Layer Input. + :param input: The input of this layer. :type input: LayerOutput :param filter_size: The x dimension of a filter kernel. Or input a tuple for two image dimension. - :type filter_size: int|tuple|list + :type filter_size: int | tuple | list :param filter_size_y: The y dimension of a filter kernel. Since PaddlePaddle currently supports rectangular filters, the filter's shape will be (filter_size, filter_size_y). - :type filter_size_y: int|None + :type filter_size_y: int | None :param num_filters: Each filter group's number of filter - :param act: Activation type. Default is tanh + :param act: Activation type. ReluActivation is the default. :type act: BaseActivation :param groups: Group size of filters. :type groups: int :param stride: The x dimension of the stride. Or input a tuple for two image dimension. - :type stride: int|tuple|list + :type stride: int | tuple | list :param stride_y: The y dimension of the stride. :type stride_y: int :param padding: The x dimension of the padding. Or input a tuple for two image dimension - :type padding: int|tuple|list + :type padding: int | tuple | list :param padding_y: The y dimension of the padding. :type padding_y: int :param dilation: The x dimension of the dilation. Or input a tuple for two image dimension - :type dilation: int|tuple|list + :type dilation: int | tuple | list :param dilation_y: The y dimension of the dilation. :type dilation_y: int :param bias_attr: The Bias Attribute. If the parameter is set to False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :param num_channels: number of input channels. If None will be set automatically from previous output. :type num_channels: int @@ -2616,15 +2615,15 @@ def img_pool_layer(input, :param padding: pooling padding width. :type padding: int :param padding_y: pooling padding height. It's equal to padding by default. - :type padding_y: int|None + :type padding_y: int | None :param name: name of pooling layer :type name: basestring. - :param input: layer's input + :param input: The input of this layer. :type input: LayerOutput :param pool_size: pooling window width :type pool_size: int :param pool_size_y: pooling window height. It's eaqual to pool_size by default. - :type pool_size_y: int|None + :type pool_size_y: int | None :param num_channels: number of input channel. :type num_channels: int :param pool_type: pooling type. MaxPooling or AvgPooling. Default is @@ -2633,7 +2632,7 @@ def img_pool_layer(input, :param stride: stride width of pooling. :type stride: int :param stride_y: stride height of pooling. It is equal to stride by default. - :type stride_y: int|None + :type stride_y: int | None :param layer_attr: Extra Layer attribute. :type layer_attr: ExtraLayerAttribute :param ceil_mode: Wether to use ceil mode to calculate output height and with. @@ -2743,20 +2742,20 @@ def img_pool3d_layer(input, pool_type=MaxPooling()) :param padding: pooling padding width. - :type padding: int|tuple|list + :type padding: int | tuple | list :param name: name of pooling layer :type name: basestring. - :param input: layer's input + :param input: The input of this layer. :type input: LayerOutput :param pool_size: pooling window width - :type pool_size: int|tuple|list + :type pool_size: int | tuple | list :param num_channels: number of input channel. :type num_channels: int :param pool_type: pooling type. MaxPooling or AvgPooling. Default is MaxPooling. :type pool_type: BasePoolingType :param stride: stride width of pooling. - :type stride: int|tuple|list + :type stride: int | tuple | list :param layer_attr: Extra Layer attribute. :type layer_attr: ExtraLayerAttribute :param ceil_mode: Wether to use ceil mode to calculate output height and with. @@ -2855,7 +2854,7 @@ def spp_layer(input, :param name: The name of this layer. It is optional. :type name: basestring - :param input: layer's input. + :param input: The input of this layer. :type input: LayerOutput :param num_channels: number of input channel. :type num_channels: int @@ -2948,8 +2947,8 @@ def img_cmrnorm_layer(input, norm = img_cmrnorm_layer(input=net, size=5) :param name: The name of this layer. It is optional. - :type name: None|basestring - :param input: layer's input. + :type name: None | basestring + :param input: The input of this layer. :type input: LayerOutput :param size: Normalize in number of :math:`size` feature maps. :type size: int @@ -3024,7 +3023,7 @@ def batch_norm_layer(input, batch_norm for CPU. Otherwise, select batch norm type based on the specified type. If you use cudnn_batch_norm, we suggested you use latest version, such as v5.1. - :type batch_norm_type: None|string, None or "batch_norm" or "cudnn_batch_norm" + :type batch_norm_type: None | string, None or "batch_norm" or "cudnn_batch_norm" :param act: Activation Type. Better be relu. Because batch normalization will normalize input near zero. :type act: BaseActivation @@ -3034,7 +3033,7 @@ def batch_norm_layer(input, :type num_channels: int :param bias_attr: :math:`\\beta`, better be zero when initialize. So the initial_std=0, initial_mean=1 is best practice. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :param param_attr: :math:`\\gamma`, better be one when initialize. So the initial_std=0, initial_mean=1 is best practice. :type param_attr: ParameterAttribute @@ -3046,7 +3045,7 @@ def batch_norm_layer(input, testing. If False, it will use the mean and variance of current batch of test data for testing. - :type use_global_stats: bool|None. + :type use_global_stats: bool | None. :param moving_average_fraction: Factor used in the moving average computation, referred to as facotr, :math:`runningMean = newMean*(1-factor) @@ -3107,7 +3106,7 @@ def sum_to_one_norm_layer(input, name=None, layer_attr=None): sum_to_one_norm = sum_to_one_norm_layer(input=layer) - :param input: Input layer. + :param input: The input of this layer. :type input: LayerOutput :param name: The name of this layer. It is optional. :type name: basestring @@ -3143,7 +3142,7 @@ def row_l2_norm_layer(input, name=None, layer_attr=None): row_l2_norm_layer = row_l2_norm_layer(input=layer) - :param input: Input layer. + :param input: The input of this layer. :type input: LayerOutput :param name: The name of this layer. It is optional. :type name: basestring @@ -3201,14 +3200,14 @@ def addto_layer(input, act=None, name=None, bias_attr=None, layer_attr=None): :type name: basestring :param input: Input layers. It could be a LayerOutput or list/tuple of LayerOutput. - :type input: LayerOutput|list|tuple - :param act: Activation Type, default is tanh. + :type input: LayerOutput | list | tuple + :param act: Activation Type. LinearActivation is the default. :type act: BaseActivation :param bias_attr: The Bias Attribute. If the parameter is set to False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :param layer_attr: Extra Layer attribute. :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. @@ -3260,8 +3259,8 @@ def concat_layer(input, act=None, name=None, layer_attr=None, bias_attr=None): :param name: The name of this layer. It is optional. :type name: basestring :param input: input layers or projections - :type input: list|tuple|collections.Sequence - :param act: Activation type. + :type input: list | tuple | collections.Sequence + :param act: Activation type. IdentityActivation is the default. :type act: BaseActivation :param layer_attr: Extra Layer Attribute. :type layer_attr: ExtraLayerAttribute @@ -3356,7 +3355,7 @@ def seq_concat_layer(a, b, act=None, name=None, layer_attr=None, :type a: LayerOutput :param b: input sequence layer :type b: LayerOutput - :param act: Activation type. + :param act: Activation type. IdentityActivation is the default. :type act: BaseActivation :param layer_attr: Extra Layer Attribute. :type layer_attr: ExtraLayerAttribute @@ -3364,7 +3363,7 @@ def seq_concat_layer(a, b, act=None, name=None, layer_attr=None, False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :return: LayerOutput object. :rtype: LayerOutput """ @@ -3440,9 +3439,9 @@ def memory(name, :param is_seq: DEPRECATED. is sequence for boot_layer :type is_seq: bool :param boot_layer: boot layer of memory. - :type boot_layer: LayerOutput|None + :type boot_layer: LayerOutput | None :param boot_bias: boot layer's bias - :type boot_bias: ParameterAttribute|None + :type boot_bias: ParameterAttribute | None :param boot_bias_active_type: boot layer's active type. :type boot_bias_active_type: BaseActivation :param boot_with_const_id: boot layer's id. @@ -3537,19 +3536,17 @@ def lstm_step_layer(input, :type input: LayerOutput :param state: State Layer. :math:`c_{t-1}` :type state: LayerOutput - :param act: Activation type. Default is tanh + :param act: Activation type. TanhActivation is the default. :type act: BaseActivation - :param gate_act: Gate Activation Type. Default is sigmoid, and should - be sigmoid only. + :param gate_act: Gate Activation Type. SigmoidActivation is the default. :type gate_act: BaseActivation - :param state_act: State Activation Type. Default is sigmoid, and should - be sigmoid only. + :param state_act: State Activation Type. TanhActivation is the default. :type state_act: BaseActivation :param bias_attr: The Bias Attribute. If the parameter is set to False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :param layer_attr: layer's extra attribute. :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. @@ -3600,13 +3597,15 @@ def gru_step_layer(input, :param output_mem: :param size: :param act: + :type act: BaseActivation :param name: The name of this layer. It is optional. - :param gate_act: + :param gate_act: Activation type of this layer's two gates. Default is Sigmoid. + :type gate_act: BaseActivation :param bias_attr: The Bias Attribute. If the parameter is set to False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :param param_attr: the parameter_attribute for transforming the output_mem from previous step. :param layer_attr: @@ -3662,12 +3661,14 @@ def gru_step_naive_layer(input, :param size: :param name: The name of this layer. It is optional. :param act: - :param gate_act: + :type act: BaseActivation + :param gate_act: Activation type of this layer's two gates. Default is Sigmoid. + :type gate_act: BaseActivation :param bias_attr: The Bias Attribute. If the parameter is set to False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :param param_attr: :param layer_attr: :return: @@ -3786,15 +3787,15 @@ def recurrent_layer(input, out_{i} = act(in_{i} + out_{i+1} * W) \\ \\ \\text{for} \\ start <= i < end - :param input: Input Layer + :param input: The input of this layer. :type input: LayerOutput - :param act: activation. + :param act: Activation type. TanhActivation is the default. :type act: BaseActivation :param bias_attr: The Bias Attribute. If the parameter is set to False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :param param_attr: parameter attribute. :type param_attr: ParameterAttribute :param name: The name of this layer. It is optional. @@ -3901,7 +3902,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): StaticInput will be imported to each time step, and doesn't change through time. It's a mechanism to access layer outside step function. - :type input: LayerOutput|StaticInput|SubsequenceInput|list|tuple + :type input: LayerOutput | StaticInput | SubsequenceInput | list | tuple :param reverse: If reverse is set true, the recurrent unit will process the input sequence in a reverse order. @@ -3916,7 +3917,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): of words in each sentence) with all layer group's outputs. targetInlink should be one of the layer group's input. - :type targetInlink: LayerOutput|SubsequenceInput + :type targetInlink: LayerOutput | SubsequenceInput :return: LayerOutput object. :rtype: LayerOutput @@ -4034,7 +4035,7 @@ def maxid_layer(input, name=None, layer_attr=None): maxid = maxid_layer(input=layer) - :param input: Input layer name. + :param input: The input of this layer. :type input: LayerOutput :param name: The name of this layer. It is optional. :type name: basestring @@ -4112,7 +4113,7 @@ def eos_layer(input, eos_id, name=None, layer_attr=None): :param name: The name of this layer. It is optional. :type name: basestring - :param input: Input layer name. + :param input: The input of this layer. :type input: LayerOutput :param eos_id: end id of sequence :type eos_id: int @@ -4504,7 +4505,7 @@ def conv_projection(input, num_filters=64, num_channels=64) - :param input: input layer + :param input: The input of this layer. :type input: LayerOutput :param filter_size: The x dimension of a filter kernel. :type filter_size: int @@ -4529,7 +4530,7 @@ def conv_projection(input, :param param_attr: Convolution param attribute. None means default attribute :type param_attr: ParameterAttribute :param trans: whether it is convTrans or conv - :type trans: boolean + :type trans: bool :return: A DotMulProjection Object. :rtype: DotMulProjection """ @@ -4637,14 +4638,14 @@ def pad_layer(input, pad_h=[0,0], pad_w=[2,2]) - :param input: layer's input. + :param input: The input of this layer. :type input: LayerOutput :param pad_c: padding size in channel dimension. - :type pad_c: list|None + :type pad_c: list | None :param pad_h: padding size in height dimension. - :type pad_h: list|None + :type pad_h: list | None :param pad_w: padding size in width dimension. - :type pad_w: list|None + :type pad_w: list | None :param layer_attr: Extra Layer Attribute. :type layer_attr: ExtraLayerAttribute :param name: The name of this layer. It is optional. @@ -4779,7 +4780,7 @@ def tensor_layer(a, :type b: LayerOutput :param size: the layer dimension. :type size: int. - :param act: Activation Type. Default is tanh. + :param act: Activation type. LinearActivation is the default. :type act: BaseActivation :param param_attr: The Parameter Attribute. :type param_attr: ParameterAttribute @@ -4787,9 +4788,9 @@ def tensor_layer(a, False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute|None + :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput """ @@ -4836,15 +4837,15 @@ def selective_fc_layer(input, :param name: The name of this layer. It is optional. :type name: basestring - :param input: The input layer. - :type input: LayerOutput|list|tuple + :param input: The input of this layer. + :type input: LayerOutput | list | tuple :param select: The select layer. The output of select layer should be a sparse binary matrix, and treat as the mask of selective fc. If is None, acts exactly like fc_layer. :type select: LayerOutput :param size: The layer dimension. :type size: int - :param act: Activation Type. Default is tanh. + :param act: Activation type. TanhActivation is the default. :type act: BaseActivation :param param_attr: The Parameter Attribute. :type param_attr: ParameterAttribute @@ -4852,9 +4853,9 @@ def selective_fc_layer(input, False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute|None + :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput """ @@ -4906,12 +4907,12 @@ def sampling_id_layer(input, name=None, layer_attr=None): samping_id = sampling_id_layer(input=input) - :param input: The input layer. + :param input: The input of this layer. :type input: LayerOutput :param name: The name of this layer. It is optional. :type name: basestring :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute|None + :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput """ @@ -4944,7 +4945,7 @@ def slope_intercept_layer(input, scale = slope_intercept_layer(input=input, slope=-1.0, intercept=1.0) - :param input: The input layer. + :param input: The input of this layer. :type input: LayerOutput :param name: The name of this layer. It is optional. :type name: basestring @@ -4953,7 +4954,7 @@ def slope_intercept_layer(input, :param intercept: the offset. :type intercept: float. :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute|None + :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput """ @@ -5013,7 +5014,7 @@ def linear_comb_layer(weights, vectors, size=None, name=None, layer_attr=None): :param name: The name of this layer. It is optional. :type name: basestring :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute|None + :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput """ @@ -5077,10 +5078,10 @@ def block_expand_layer(input, block_x=1, block_x=3) - :param input: The input layer. + :param input: The input of this layer. :type input: LayerOutput :param num_channels: The channel number of input layer. - :type num_channels: int|None + :type num_channels: int | None :param block_x: The width of sub block. :type block_x: int :param block_y: The width of sub block. @@ -5094,9 +5095,9 @@ def block_expand_layer(input, :param padding_y: The padding size in vertical direction. :type padding_y: int :param name: The name of this layer. It is optional. - :type name: None|basestring. + :type name: None | basestring. :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute|None + :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput """ @@ -5155,15 +5156,15 @@ def maxout_layer(input, groups, num_channels=None, name=None, layer_attr=None): num_channels=128, groups=4) - :param input: The input layer. + :param input: The input of this layer. :type input: LayerOutput :param num_channels: The channel number of input layer. If None will be set automatically from previous output. - :type num_channels: int|None + :type num_channels: int | None :param groups: The group number of input layer. :type groups: int :param name: The name of this layer. It is optional. - :type name: None|basestring. + :type name: None | basestring. :param layer_attr: Extra Layer attribute. :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. @@ -5220,18 +5221,18 @@ def ctc_layer(input, size=9055, norm_by_times=True) - :param input: The input layer. + :param input: The input of this layer. :type input: LayerOutput :param label: The data layer of label with variable length. :type label: LayerOutput :param size: category numbers + 1. :type size: int :param name: The name of this layer. It is optional. - :type name: basestring|None + :type name: basestring | None :param norm_by_times: Whether to normalization by times. False by default. :type norm_by_times: bool :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute|None + :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput """ @@ -5297,20 +5298,20 @@ def warp_ctc_layer(input, blank=1000, norm_by_times=False) - :param input: The input layer. + :param input: The input of this layer. :type input: LayerOutput :param label: The data layer of label with variable length. :type label: LayerOutput :param size: category numbers + 1. :type size: int :param name: The name of this layer. It is optional. - :type name: basestring|None + :type name: basestring | None :param blank: the 'blank' label used in ctc :type blank: int :param norm_by_times: Whether to normalization by times. False by default. :type norm_by_times: bool :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute|None + :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput """ @@ -5368,11 +5369,11 @@ def crf_layer(input, :param param_attr: Parameter attribute. None means default attribute :type param_attr: ParameterAttribute :param name: The name of this layer. It is optional. - :type name: None|basestring + :type name: None | basestring :param coeff: The coefficient affects the gradient in the backward. :type coeff: float :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute|None + :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput """ @@ -5438,9 +5439,9 @@ def crf_decoding_layer(input, :param param_attr: Parameter attribute. None means default attribute :type param_attr: ParameterAttribute :param name: The name of this layer. It is optional. - :type name: None|basestring + :type name: None | basestring :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute|None + :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput """ @@ -5499,14 +5500,14 @@ def nce_layer(input, :param name: The name of this layer. It is optional. :type name: basestring :param input: The input layers. It could be a LayerOutput of list/tuple of LayerOutput. - :type input: LayerOutput|list|tuple|collections.Sequence + :type input: LayerOutput | list | tuple | collections.Sequence :param label: label layer :type label: LayerOutput :param weight: weight layer, can be None(default) :type weight: LayerOutput :param num_classes: number of classes. :type num_classes: int - :param act: Activation, default is Sigmoid. + :param act: Activation type. SigmoidActivation is the default. :type act: BaseActivation :param param_attr: The Parameter Attribute|list. :type param_attr: ParameterAttribute @@ -5515,12 +5516,12 @@ def nce_layer(input, :param neg_distribution: The distribution for generating the random negative labels. A uniform distribution will be used if not provided. If not None, its length must be equal to num_classes. - :type neg_distribution: list|tuple|collections.Sequence|None + :type neg_distribution: list | tuple | collections.Sequence | None :param bias_attr: The Bias Attribute. If the parameter is set to False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :param layer_attr: Extra Layer Attribute. :type layer_attr: ExtraLayerAttribute :return: layer name. @@ -5636,7 +5637,7 @@ def rank_cost(left, It is an optional argument. :type weight: LayerOutput :param name: The name of this layer. It is optional. - :type name: None|basestring + :type name: None | basestring :param coeff: The coefficient affects the gradient in the backward. :type coeff: float :param layer_attr: Extra Layer Attribute. @@ -5701,7 +5702,7 @@ def lambda_cost(input, entire list of get gradient. :type max_sort_size: int :param name: The name of this layer. It is optional. - :type name: None|basestring + :type name: None | basestring :param layer_attr: Extra Layer Attribute. :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. @@ -5745,7 +5746,7 @@ def cross_entropy(input, :param label: The input label. :type input: LayerOutput. :param name: The name of this layer. It is optional. - :type name: None|basestring. + :type name: None | basestring. :param coeff: The cost is multiplied with coeff. The coefficient affects the gradient in the backward. :type coeff: float. @@ -5793,7 +5794,7 @@ def cross_entropy_with_selfnorm(input, :param label: The input label. :type input: LayerOutput. :param name: The name of this layer. It is optional. - :type name: None|basestring. + :type name: None | basestring. :param coeff: The coefficient affects the gradient in the backward. :type coeff: float. :param softmax_selfnorm_alpha: The scale factor affects the cost. @@ -5830,10 +5831,10 @@ def sum_cost(input, name=None, layer_attr=None): cost = sum_cost(input=input_layer) - :param input: The first input layer. + :param input: The input of this layer. :type input: LayerOutput. :param name: The name of this layer. It is optional. - :type name: None|basestring. + :type name: None | basestring. :param layer_attr: Extra Layer Attribute. :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. @@ -5878,7 +5879,7 @@ def huber_regression_cost(input, :param label: The input label. :type input: LayerOutput. :param name: The name of this layer. It is optional. - :type name: None|basestring. + :type name: None | basestring. :param delta: The difference between the observed and predicted values. :type delta: float. :param coeff: The coefficient affects the gradient in the backward. @@ -5928,7 +5929,7 @@ def huber_classification_cost(input, :param label: The input label. :type input: LayerOutput. :param name: The name of this layer. It is optional. - :type name: None|basestring. + :type name: None | basestring. :param coeff: The coefficient affects the gradient in the backward. :type coeff: float. :param layer_attr: Extra Layer Attribute. @@ -5971,7 +5972,7 @@ def multi_binary_label_cross_entropy(input, :param label: The input label. :type input: LayerOutput :param name: The name of this layer. It is optional. - :type name: None|basestring + :type name: None | basestring :param coeff: The coefficient affects the gradient in the backward. :type coeff: float :param layer_attr: Extra Layer Attribute. @@ -6139,7 +6140,7 @@ def smooth_l1_cost(input, label, name=None, coeff=1.0, layer_attr=None): :param label: The input label. :type input: LayerOutput :param name: The name of this layer. It is optional. - :type name: None|basestring + :type name: None | basestring :param coeff: The coefficient affects the gradient in the backward. :type coeff: float :param layer_attr: Extra Layer Attribute. @@ -6226,7 +6227,7 @@ def dropout_layer(input, dropout_rate, name=None): :param name: The name of this layer. It is optional. :type name: basestring - :param input: The input layer. + :param input: The input of this layer. :type input: LayerOutput :param dropout_rate: The probability of dropout. :type dropout_rate: float @@ -6285,18 +6286,18 @@ def row_conv_layer(input, row_conv = row_conv_layer(input=input_layer, context_len=3) - :param input: The input layer. + :param input: The input of this layer. :type input: LayerOutput :param context_len: The context length equals the lookahead step number plus one. :type context_len: int - :param act: Activation Type. Default is linear activation. + :param act: Activation Type. LinearActivation is the default. :type act: BaseActivation :param param_attr: The Parameter Attribute. If None, the parameter will be initialized smartly. It's better to set it by yourself. :type param_attr: ParameterAttribute :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute|None + :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput @@ -6342,7 +6343,7 @@ def prelu_layer(input, :param name: The name of this layer. It is optional. :type name: basestring - :param input: The input layer. + :param input: The input of this layer. :type input: LayerOutput :param partial_sum: this parameter makes a group of inputs share a same weight. @@ -6352,9 +6353,9 @@ def prelu_layer(input, :type partial_sum: int :param param_attr: The parameter attribute. See ParameterAttribute for details. - :type param_attr: ParameterAttribute|None + :type param_attr: ParameterAttribute | None :param layer_attr: Extra layer configurations. Default is None. - :type layer_attr: ExtraLayerAttribute|None + :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput """ @@ -6407,37 +6408,37 @@ def gated_unit_layer(input, .. code-block:: python gated_unit = gated_unit_layer(size=128, input=input_layer)) - :param input: input for this layer. + :param input: The input of this layer. :type input: LayerOutput :param size: output size of the gated unit. :type size: int - :param act: activation type of the projected input. + :param act: Activation type of the projected input. LinearActivation is the default. :type act: BaseActivation :param name: The name of this layer. It is optional. :type name: basestring :param gate_attr: Attributes to tune the gate output, for example, error clipping threshold, dropout and so on. See ExtraLayerAttribute for more details. - :type gate_attr: ExtraLayerAttribute|None + :type gate_attr: ExtraLayerAttribute | None :param gate_param_attr: Attributes to tune the learnable projected matrix parameter of the gate. - :type gate_param_attr: ParameterAttribute|None + :type gate_param_attr: ParameterAttribute | None :param gate_bias_attr: Attributes to tune the learnable bias of the gate. - :type gate_bias_attr: ParameterAttribute|None + :type gate_bias_attr: ParameterAttribute | None :param inproj_attr: Attributes to the tune the projected input, for example, error clipping threshold, dropout and so on. See ExtraLayerAttribute for more details. - :type inproj_attr: ExtraLayerAttribute|None + :type inproj_attr: ExtraLayerAttribute | None :param inproj_param_attr: Attributes to tune the learnable parameter of the projection of input. - :type inproj_param_attr: ParameterAttribute|None + :type inproj_param_attr: ParameterAttribute | None :param inproj_bias_attr: Attributes to tune the learnable bias of projection of the input. - :type inproj_bias_attr: ParameterAttribute|None + :type inproj_bias_attr: ParameterAttribute | None :param layer_attr: Attributes to tune the final output of the gated unit, for example, error clipping threshold, dropout and so on. See ExtraLayerAttribute for more details. - :type layer_attr: ExtraLayerAttribute|None + :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput """ @@ -6487,7 +6488,7 @@ def switch_order_layer(input, switch = switch_order(input=layer, name='switch', reshape_axis=reshape_axis) reshape = {'height':[ 0, 1, 2], 'width':[3]} - :param input: The input layer. + :param input: The input of this layer. :type input: LayerOutput :param name: The name of this layer. It is optional. :type name: basestring @@ -6521,7 +6522,7 @@ def switch_order_layer(input, @layer_support() def crop_layer(input, offset, axis=2, shape=None, name=None, layer_attr=None): """ - The crop layer crops images by offset and shape. User can set crop shape by + This layer crops images by offset and shape. User can set crop shape by args 'shape' explicitly or by reference input layer. The example usage is: @@ -6529,10 +6530,10 @@ def crop_layer(input, offset, axis=2, shape=None, name=None, layer_attr=None): .. code-block:: python crop = crop_layer(input=[image_input, reference_input], axis=2, offset=[2, 3]) - :param input: The input layer.If two inputs were setted, - the second input will be regarded as reference input - :type input: LayerOutput or Sequence - :param offset: The crop offset + :param input: The input of this layer. If two inputs are given, the second input + will be regarded as reference input. + :type input: LayerOutput | Sequence + :param offset: The crop offset. :type offset: Sequence :param axis: start axis to be cropped. To image input layer: - 0: batch size @@ -6581,12 +6582,12 @@ def sub_nested_seq_layer(input, selected_indices, name=None): .. code-block:: python - sub_nest_seq = sub_nested_seq_layer(input=[data, selected_indices]) + sub_nest_seq = sub_nested_seq_layer(input=data, selected_indices=selected_ids) - :param input: A nested sequence. + :param input: The input of this layer. It is a nested sequence. :type input: LayerOutput - :param selected_indices: a set of sequence indices in the nested sequence. + :param selected_indices: A set of sequence indices in the nested sequence. :type input: LayerOutput :param name: The name of this layer. It is optional. :type name: basestring @@ -6628,7 +6629,7 @@ def clip_layer(input, min, max, name=None): :param name: The name of this layer. It is optional. :type name: basestring - :param input: The input layer. + :param input: The input of this layer. :type input: LayerOutput. :param min: The lower threshold for clipping. :type min: double @@ -6673,12 +6674,12 @@ def seq_slice_layer(input, starts, ends, name=None): :param name: The name of this layer. It is optional. :type name: basestring - :param input: input for this layer, it should be a sequence. + :param input: The input of this layer, which should be a sequence. :type input: LayerOutput :param starts: start indices to slice the input sequence. - :type starts: LayerOutput|None + :type starts: LayerOutput | None :param ends: end indices to slice the input sequence. - :type ends: LayerOutput|None + :type ends: LayerOutput | None :return: LayerOutput object. :rtype: LayerOutput @@ -6727,9 +6728,9 @@ def kmax_seq_score_layer(input, name=None, beam_size=1): :param name: The name of this layer. It is optional. :type name: basestring - :param input: The input layer. It stores scores over a sequence or a nested + :param input: The input of this layer. It stores scores over a sequence or a nested sequence and its size must be 1. - :type input: LayerOutput. + :type input: LayerOutput :param beam_size: sequence indices with top beam_size scores are returned. :type beam_size: double :return: LayerOutput object. @@ -6785,24 +6786,24 @@ def img_conv3d_layer(input, :param name: The name of this layer. It is optional. :type name: basestring - :param input: Layer Input. + :param input: The input of this layer. :type input: LayerOutput :param filter_size: The x dimension of a filter kernel. Or input a list. - :type filter_size: int|tuple|list + :type filter_size: int | tuple | list :param num_filters: Each filter group's number of filter - :param act: Activation type. Default is tanh + :param act: Activation type. ReluActivation is the default. :type act: BaseActivation :param groups: Group size of filters. :type groups: int :param stride: The x dimension of the stride. Or input a tuple for two image dimension. - :type stride: int|tuple|list + :type stride: int | tuple | list :param padding: The x dimension of the padding. Or input a tuple for two image dimension - :type padding: int|tuple|list + :type padding: int | tuple | list :param bias_attr: Convolution bias attribute. None means default bias. False means no bias. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :param num_channels: number of input channels. If None will be set automatically from previous output. :type num_channels: int @@ -6916,15 +6917,15 @@ def scale_shift_layer(input, name=None, param_attr=None, bias_attr=None): :param name: The name of this layer. It is optional. :type name: basestring - :param input: The input layer. - :type input: LayerOutput. + :param input: The input of this layer. + :type input: LayerOutput :param param_attr: The parameter attribute of scaling. :type param_attr: ParameterAttribute :param bias_attr: The Bias Attribute. If the parameter is set to False or something not type of ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. - :type bias_attr: ParameterAttribute|None|Bool|Any + :type bias_attr: ParameterAttribute | None | bool | Any :return: LayerOutput object. :rtype: LayerOutput """ @@ -6944,11 +6945,11 @@ def resize_layer(input, size, name=None): into the output matrix with a shape of [Height x Width / size, size], where size is the parameter of this layer indicating the output dimension. - :param input: The input to this layer. + :param input: The input of this layer. :type input: LayerOutput. :param name: The name of this layer. It is optional. :type name: basestring - :param size: The resized output dimesion of this layer. + :param size: The resized output dimension of this layer. :type size: int :return: A LayerOutput object. :rtype: LayerOutput diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..0f0a2847e58a1ca172bf1ba382abb2ebc1ecb8ed --- /dev/null +++ b/python/paddle/v2/framework/graph.py @@ -0,0 +1,240 @@ +import paddle.v2.framework.core as core +import collections +import numpy as np +import copy + +__all__ = ['Block', 'Variable', 'Program', 'Operator'] + + +class Variable(object): + def __init__(self, + block, + name=None, + shape=None, + dtype=None, + lod_level=None, + **kwargs): + self.block = block + + if name is None: + name = Variable._unique_var_name_() + try: + self.desc = self.block.desc.var(name) + is_new_var = False + except core.EnforceNotMet: + self.desc = self.block.desc.new_var(name) + is_new_var = True + + if shape is not None: + if is_new_var: + self.desc.set_shape(shape) + else: + old_shape = self.shape + shape = tuple(shape) + if shape != old_shape: + raise ValueError( + "Variable {0} has been created before. the previous " + "shape is {1}; the new shape is {2}. They are not " + "matched.".format(self.name, old_shape, shape)) + if dtype is not None: + if not isinstance(dtype, core.DataType): + dtype = Variable._convert_np_dtype_to_dtype_(dtype) + if is_new_var: + self.desc.set_data_type(dtype) + else: + old_dtype = self.data_type() + if dtype != old_shape: + raise ValueError("Variable {0} has been created before. " + "The previous data type is {1}; the new " + "data type is {2}. They are not " + "matched.".format(self.name, old_dtype, + dtype)) + + if lod_level is not None: + if is_new_var: + self.desc.set_lod_level(lod_level) + else: + if lod_level != self.lod_level: + raise ValueError("Variable {0} has been created before. " + "The previous lod_level is {1}; the new " + "lod_level is {2}. They are not " + "matched".format(self.name, self.lod_level, + lod_level)) + self.block.vars[name] = self + self.op = None + + @property + def name(self): + return self.desc.name() + + @property + def shape(self): + # convert to tuple, make it as same as numpy API. + return tuple(self.desc.shape()) + + @property + def data_type(self): + return self.desc.data_type() + + @property + def lod_level(self): + return self.desc.lod_level() + + @staticmethod + def _unique_var_name_(): + uid = core.unique_integer() # unique during whole process. + return "_generated_var_%d" % uid + + @staticmethod + def _convert_np_dtype_to_dtype_(np_dtype): + dtype = np.dtype(np_dtype) + if dtype == np.float32: + return core.DataType.FP32 + elif dtype == np.float64: + return core.DataType.FP64 + elif dtype == np.float16: + return core.DataType.FP16 + elif dtype == np.int32: + return core.DataType.INT32 + elif dtype == np.int16: + return core.DataType.INT16 + elif dtype == np.int64: + return core.DataType.INT64 + elif dtype == np.bool: + return core.DataType.BOOL + else: + raise ValueError("Not supported numpy dtype " + str(dtype)) + + +class Operator(object): + def __init__(self, + block, + desc, + type=None, + inputs=None, + outputs=None, + attrs=None): + self.block = block + self.desc = desc + if type is not None: + # TODO. + pass + if inputs is not None: + # TODO + pass + if outputs is not None: + # TODO + pass + if attrs is not None: + # TODO + pass + + # TODO: Getters + + +class Block(object): + def __init__(self, program, idx): + self.desc = program.desc.block(idx) + self.vars = dict() # var_name --> var + self.ops = collections.deque() # operator list + self.program = program + + @property + def parent_idx(self): + return self.desc.parent + + @property + def idx(self): + return self.desc.id + + def create_var(self, *args, **kwargs): + return Variable(self, *args, **kwargs) + + def create_parameter(self, *args, **kwargs): + global_block = self.program.global_block() + return Parameter(global_block, *args, **kwargs) + + def append_op(self, *args, **kwargs): + op_desc = self.desc.append_op() + op = Operator(self, op_desc, *args, **kwargs) + self.ops.append(op) + return op + + def prepend_op(self, *args, **kwargs): + op_desc = self.desc.prepend_op() + op = Operator(self, op_desc, *args, **kwargs) + self.ops.appendleft(op) + return op + + +class Program(object): + @classmethod + def instance(cls): + # From https://stackoverflow.com/questions/8212053 + # Making Program as a Singleton class. + if not hasattr(cls, '_instance'): + cls._instance = cls() + return cls._instance + + def __init__(self): + assert not hasattr(self.__class__, + '_instance'), 'Do not call constructor directly!' + self.desc = core.ProgramDesc.instance() + self.blocks = [Block(self, 0)] + self.current_block_idx = 0 + + def global_block(self): + return self.blocks[0] + + def current_block(self): + return self.blocks[self.current_block_idx] + + def create_block(self): + new_block_idx = len(self.blocks) + self.desc.append_block(self.current_block().desc) + self.current_block_idx = new_block_idx + self.blocks.append(Block(self, self.current_block_idx)) + return self.current_block() + + def rollback(self): + self.current_block_idx = self.current_block().parent_idx + + +class Parameter(Variable): + def __init__(self, block, shape, dtype, **kwargs): + if shape is None or dtype is None: + raise ValueError("Parameter must set shape and dtype") + if len(shape) == 0: + raise ValueError("Parameter shape cannot be empty") + + for each in shape: + if each < 0: + raise ValueError("Parameter shape should not be related with " + "batch-size") + + Variable.__init__(self, block, shape=shape, dtype=dtype, **kwargs) + self.trainable = kwargs.get('trainable', True) + self.init_attr = kwargs.get('initialize_attr', { + 'type': 'uniform_random', + 'min': -1.0, + 'max': 1.0 + }) + + self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0}) + self._append_initialize_ops_() + + def _append_initialize_ops_(self): + attr = copy.deepcopy(self.init_attr) + op_type = attr.pop('type', None) + block = self.block + assert isinstance(block, Block) + shape = self.shape + attr['dims'] = shape + attr['data_type'] = int(self.data_type) + op = block.prepend_op( + type=op_type, inputs=None, outputs={'Out': [self]}, attrs=attr) + self.op = op + + +# program is a global instance. +g_program = Program.instance() diff --git a/python/paddle/v2/framework/tests/test_activation_op.py b/python/paddle/v2/framework/tests/test_activation_op.py index 8b76decaecdcb23d8292490b2988d2df043b5581..a28c4431e1ae9230750247c0ed16c9aff37364fa 100644 --- a/python/paddle/v2/framework/tests/test_activation_op.py +++ b/python/paddle/v2/framework/tests/test_activation_op.py @@ -33,6 +33,21 @@ class TestSigmoid(OpTest): self.check_grad(['X'], 'Y', max_relative_error=0.008) +class TestLogSigmoid(OpTest): + def setUp(self): + self.op_type = "logsigmoid" + self.inputs = { + 'X': np.random.uniform(-1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.log(1 / (1 + np.exp(-self.inputs['X'])))} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.008) + + class TestTanh(OpTest): def setUp(self): self.op_type = "tanh" @@ -63,6 +78,46 @@ class TestTanhShrink(OpTest): self.check_grad(['X'], 'Y', max_relative_error=0.008) +class TestHardShrink(OpTest): + def setUp(self): + self.op_type = "hard_shrink" + x = np.random.uniform(-1, 1, [4, 4]).astype("float32") + threshold = 0.5 + + self.inputs = {'X': x} + self.attrs = {'lambda': threshold} + + t = np.copy(x) + t[(t >= -threshold) & (t <= threshold)] = 0 + self.outputs = {'Y': t} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.005) + + +class TestSoftShrink(OpTest): + def setUp(self): + self.op_type = "softshrink" + lambda_val = 0.1 + self.attrs = {'lambda': lambda_val} + self.inputs = { + 'X': np.random.uniform(0.25, 10, [4, 4]).astype("float32") + } + y = np.copy(self.inputs['X']) + y = (y < -lambda_val) * (y + lambda_val) + (y > lambda_val) * ( + y - lambda_val) + self.outputs = {'Y': y} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + class TestSqrt(OpTest): def setUp(self): self.op_type = "sqrt" @@ -181,6 +236,26 @@ class TestSoftRelu(OpTest): self.check_grad(['X'], 'Y', max_relative_error=0.02) +class TestELU(OpTest): + def setUp(self): + self.op_type = "elu" + x = np.random.uniform(-3, 3, [4, 4]).astype("float32") + alpha = 1. + # Note: unlike other Relu extensions, point 0 on standard ELU function (i.e. alpha = 1) + # is differentiable, so we can skip modifications like x[np.abs(x) < 0.005] = 0.02 here + self.inputs = {'X': x} + self.attrs = {'alpha': alpha} + self.outputs = { + 'Y': np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1)) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.02) + + class TestReciprocal(OpTest): def setUp(self): self.op_type = "reciprocal" @@ -256,6 +331,21 @@ class TestSTanh(OpTest): self.check_grad(['X'], 'Y', max_relative_error=0.007) +class TestSoftplus(OpTest): + def setUp(self): + self.op_type = "softplus" + self.inputs = { + 'X': np.random.uniform(-1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.log(1 + np.exp(self.inputs['X']))} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + class TestSoftsign(OpTest): def setUp(self): self.op_type = "softsign" diff --git a/python/paddle/v2/framework/tests/test_conv_shift_op.py b/python/paddle/v2/framework/tests/test_conv_shift_op.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ab21a06a1c6e8e2d1e936a0b4b8a07a59f57b9 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_conv_shift_op.py @@ -0,0 +1,47 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def conv_shift_forward(x, y): + out = np.zeros_like(x) + M = x.shape[1] + N = y.shape[1] + y_half_width = (N - 1) / 2 + for i in xrange(M): + for j in xrange(N): + out[:, i] += x[:, (i + j + M - y_half_width) % M] * y[:, j] + return out + + +class TestConvShiftOp(OpTest): + def setUp(self): + self.op_type = "conv_shift" + + batch_size = 4 + x_dim = 17 + y_dim = 3 # must be odd and <= x_dim + x = np.random.random((batch_size, x_dim)).astype("float32") + y = np.random.random((batch_size, y_dim)).astype("float32") + self.inputs = {'X': x, 'Y': y} + + out = conv_shift_forward(x, y) + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.05) + + def test_check_grad_ignore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=0.05, no_grad_set=set("X")) + + def test_check_grad_ignore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Y')) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_infer_shape.py b/python/paddle/v2/framework/tests/test_infer_shape.py index b38ec9c03740a2e69f1247c094ce56ab43fa8e32..99562890fdd4d8b10f420869f1ba9f694db5969a 100644 --- a/python/paddle/v2/framework/tests/test_infer_shape.py +++ b/python/paddle/v2/framework/tests/test_infer_shape.py @@ -1,6 +1,6 @@ import unittest + import paddle.v2.framework.core as core -from paddle.v2.framework.op import Operator class TestInferShape(unittest.TestCase): @@ -26,7 +26,7 @@ class TestInferShape(unittest.TestCase): sum_op_desc.set_input("X", ["x1", "x2"]) sum_op_desc.set_output("Out", ["out"]) - core.Operator.infer_shape(sum_op_desc, block) + sum_op_desc.infer_shape(block) self.assertEqual(out.shape(), shape) def test_mul_op(self): @@ -55,7 +55,7 @@ class TestInferShape(unittest.TestCase): mul_op_desc.set_attr("x_num_col_dims", 1) mul_op_desc.set_attr("y_num_col_dims", 1) - core.Operator.infer_shape(mul_op_desc, block) + mul_op_desc.infer_shape(block) self.assertEqual(out.shape(), [x_shape[0], y_shape[1]]) diff --git a/python/paddle/v2/framework/tests/test_margin_rank_loss_op.py b/python/paddle/v2/framework/tests/test_margin_rank_loss_op.py new file mode 100644 index 0000000000000000000000000000000000000000..63378cbc4ec95d7d3c49a92f750b55a8dbc22414 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_margin_rank_loss_op.py @@ -0,0 +1,39 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestMarginRankLossOp(OpTest): + def setUp(self): + self.op_type = "margin_rank_loss" + batch_size = 5 + margin = 0.5 + # labels_{i} = {-1, 1} + label = 2 * np.random.randint( + 0, 2, size=(batch_size, 1)).astype("float32") - 1 + x1 = np.random.random((batch_size, 1)).astype("float32") + x2 = np.random.random((batch_size, 1)).astype("float32") + # loss = max(0, -label * (x1 - x2) + margin) + loss = -label * (x1 - x2) + margin + loss = np.where(loss > 0, loss, 0) + act = np.where(loss > 0, 1., 0.) + + self.attrs = {'margin': margin} + self.inputs = {'Label': label, 'X1': x1, 'X2': x2} + self.outputs = {'Activated': act, 'Out': loss} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X1", "X2"], "Out") + + def test_check_grad_ignore_x1(self): + self.check_grad(["X2"], "Out", no_grad_set=set('X1')) + + def test_check_grad_ignore_x2(self): + self.check_grad(["X1"], "Out", no_grad_set=set('X2')) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_parameter.py b/python/paddle/v2/framework/tests/test_parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..3b5d38f257e6f51be30d9f1fa42285461b2a0eb7 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_parameter.py @@ -0,0 +1,27 @@ +import unittest +from paddle.v2.framework.graph import g_program +import paddle.v2.framework.core as core + + +class TestParameter(unittest.TestCase): + def test_param(self): + b = g_program.create_block() + param = b.create_parameter( + name='fc.w', + shape=[784, 100], + dtype='float32', + initialize_attr={ + 'type': 'uniform_random', + 'seed': 13, + 'min': -5.0, + 'max': 5.0 + }) + self.assertIsNotNone(param) + self.assertEqual('fc.w', param.name) + self.assertEqual((784, 100), param.shape) + self.assertEqual(core.DataType.FP32, param.data_type) + self.assertEqual(0, param.block.idx) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_pool_max_op.py b/python/paddle/v2/framework/tests/test_pool_max_op.py new file mode 100644 index 0000000000000000000000000000000000000000..f0f8aa6089c74d31702a6a5d37362099205d96b2 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_pool_max_op.py @@ -0,0 +1,212 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def max_pool3D_forward_naive(x, + ksize, + strides, + paddings=[0, 0, 0], + global_pool=0): + + N, C, D, H, W = x.shape + if global_pool == 1: + ksize = [D, H, W] + D_out = (D - ksize[0] + 2 * paddings[0]) / strides[0] + 1 + H_out = (H - ksize[1] + 2 * paddings[1]) / strides[1] + 1 + W_out = (W - ksize[2] + 2 * paddings[2]) / strides[2] + 1 + out = np.zeros((N, C, D_out, H_out, W_out)) + mask = np.zeros((N, C, D_out, H_out, W_out)) + for k in xrange(D_out): + d_start = np.max((k * strides[0] - paddings[0], 0)) + d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D)) + for i in xrange(H_out): + h_start = np.max((i * strides[0] - paddings[0], 0)) + h_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + for j in xrange(W_out): + w_start = np.max((j * strides[1] - paddings[1], 0)) + w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end] + + out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4)) + + for n in xrange(N): + for c in xrange(C): + arr = x_masked[n, c, :, :, :] + index = np.where(arr == np.max(arr)) + sub_deep = index[0][0] + sub_row = index[1][0] + sub_col = index[2][0] + index = ((d_start + sub_deep) * H + + (h_start + sub_row)) * W + w_start + sub_col + mask[n, c, k, i, j] = index + + return out, mask + + +def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): + + N, C, H, W = x.shape + if global_pool == 1: + ksize = [H, W] + H_out = (H - ksize[0] + 2 * paddings[0]) / strides[0] + 1 + W_out = (W - ksize[1] + 2 * paddings[1]) / strides[1] + 1 + out = np.zeros((N, C, H_out, W_out)) + mask = np.zeros((N, C, H_out, W_out)) + for i in xrange(H_out): + for j in xrange(W_out): + r_start = np.max((i * strides[0] - paddings[0], 0)) + r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + c_start = np.max((j * strides[1] - paddings[1], 0)) + c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + x_masked = x[:, :, r_start:r_end, c_start:c_end] + + out[:, :, i, j] = np.max(x_masked, axis=(2, 3)) + + for n in xrange(N): + for c in xrange(C): + arr = x_masked[n, c, :, :] + index = np.where(arr == np.max(arr)) + sub_row = index[0][0] + sub_col = index[1][0] + index = (r_start + sub_row) * W + c_start + sub_col + mask[n, c, i, j] = index + + return out, mask + + +class TestMaxPoolWithIndex_Op(OpTest): + def setUp(self): + self.initTestCase() + input = np.random.random(self.shape).astype("float32") + output, mask = self.pool_forward_naive(input, self.ksize, self.strides, + self.paddings, self.global_pool) + + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + 'ksize': self.ksize, + 'globalPooling': self.global_pool, + } + + self.inputs = {'X': input} + self.outputs = {'Out': output, "Mask": mask} + + def test_check_output(self): + self.check_output() + + # def test_check_grad(self): + # self.check_grad(set(['X']), ['Out'], max_relative_error=0.07) + + def initTestCase(self): + self.global_pool = True + self.index = "max_pool3d_with_index" + self.op_type = "%s" % self.index + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +class TestCase1(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "max_pool3d_with_index" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +class TestCase2(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "max_pool3d_with_index" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 7, 7, 7] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +class TestCase3(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "max_pool3d_with_index" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 7, 7, 7] + self.ksize = [3, 3, 3] + self.strides = [2, 2, 2] + self.paddings = [0, 0, 0] + + +class TestCase4(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "max_pool3d_with_index" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +class TestCase5(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "max_pool3d_with_index" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [2, 2, 2] + self.paddings = [0, 0, 0] + + +class TestCase6(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "max_pool2d_with_index" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + + +class TestCase7(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "max_pool2d_with_index" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [2, 2] + self.paddings = [0, 0] + + +class TestCase8(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "max_pool2d_with_index" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 5, 5] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + + +class TestCase9(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "max_pool2d_with_index" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 5, 5] + self.ksize = [3, 3] + self.strides = [2, 2] + self.paddings = [0, 0] + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_program.py b/python/paddle/v2/framework/tests/test_program.py new file mode 100644 index 0000000000000000000000000000000000000000..b82d1760d65a24401aaa336bc41f75ed60af8ae9 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_program.py @@ -0,0 +1,36 @@ +import unittest +from paddle.v2.framework.graph import g_program + + +class TestProgram(unittest.TestCase): + def test_program(self): + b = g_program.current_block() + self.assertEqual(-1, b.parent_idx) + self.assertEqual(0, b.idx) + + b = g_program.create_block() + self.assertEqual(1, b.idx) + self.assertEqual(0, b.parent_idx) + + b = g_program.create_block() + self.assertEqual(2, b.idx) + self.assertEqual(1, b.parent_idx) + + g_program.rollback() + + b = g_program.current_block() + self.assertEqual(1, b.idx) + self.assertEqual(0, b.parent_idx) + + b = g_program.create_block() + self.assertEqual(3, b.idx) + self.assertEqual(1, b.parent_idx) + + g_program.rollback() + b = g_program.current_block() + self.assertEqual(1, b.idx) + self.assertEqual(0, b.parent_idx) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_seq_concat_op.py b/python/paddle/v2/framework/tests/test_seq_concat_op.py new file mode 100644 index 0000000000000000000000000000000000000000..6309b09bc98f6d529f80bfa269a0eaadd799fcbc --- /dev/null +++ b/python/paddle/v2/framework/tests/test_seq_concat_op.py @@ -0,0 +1,77 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestConcatOp(OpTest): + def set_data(self): + # two level, batch size is 3 + x0 = np.random.random((4, 6, 3)).astype('float32') + lod0 = [[0, 2, 4], [0, 1, 2, 3, 4]] + x1 = np.random.random((4, 8, 3)).astype('float32') + lod1 = [[0, 2, 4], [0, 1, 2, 3, 4]] + axis = 1 + level = 1 + self.inputs = {'X': [('x0', (x0, lod0)), ('x1', (x1, lod1))]} + self.attrs = {'axis': axis, 'level': level} + outs = [] + for i in range(4): + sub_x0 = x0[lod0[level][i]:lod0[level][i + 1], :] + sub_x1 = x1[lod1[level][i]:lod1[level][i + 1], :] + outs.append(np.concatenate((sub_x0, sub_x1), axis=axis)) + + self.outputs = {'Out': np.concatenate(outs, axis=0)} + + def setUp(self): + self.op_type = "sequence_concat" + self.set_data() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['x0'], 'Out') + + +class TestConcatOpDiffLod(TestConcatOp): + def set_data(self): + # two level, batch size is 3 + x0 = np.random.random((4, 6, 3)).astype('float32') + lod0 = [[0, 2, 4], [0, 1, 2, 3, 4]] + x1 = np.random.random((5, 6, 3)).astype('float32') + lod1 = [[0, 3, 5], [0, 1, 2, 3, 5]] + axis = 0 + level = 1 + self.inputs = {'X': [('x0', (x0, lod0)), ('x1', (x1, lod1))]} + self.attrs = {'axis': axis, 'level': level} + outs = [] + for i in range(4): + sub_x0 = x0[lod0[level][i]:lod0[level][i + 1], :] + sub_x1 = x1[lod1[level][i]:lod1[level][i + 1], :] + outs.append(np.concatenate((sub_x0, sub_x1), axis=axis)) + + self.outputs = {'Out': np.concatenate(outs, axis=0)} + + +class TestConcatOpLevelZero(TestConcatOp): + def set_data(self): + # two level, batch size is 3 + x0 = np.random.random((4, 3, 4)).astype('float32') + lod0 = [[0, 2, 4], [0, 1, 2, 3, 4]] + x1 = np.random.random((5, 3, 4)).astype('float32') + lod1 = [[0, 3, 5], [0, 1, 3, 4, 5]] + axis = 0 + level = 0 + self.inputs = {'X': [('x0', (x0, lod0)), ('x1', (x1, lod1))]} + self.attrs = {'axis': axis, 'level': level} + outs = [] + for i in range(2): + sub_x0 = x0[lod0[level][i]:lod0[level][i + 1], :] + sub_x1 = x1[lod1[level][i]:lod1[level][i + 1], :] + outs.append(np.concatenate((sub_x0, sub_x1), axis=axis)) + + self.outputs = {'Out': np.concatenate(outs, axis=0)} + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_variable.py b/python/paddle/v2/framework/tests/test_variable.py new file mode 100644 index 0000000000000000000000000000000000000000..8ea1083ff6535d2d517f2ac587a956bfed906f03 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_variable.py @@ -0,0 +1,40 @@ +import unittest +from paddle.v2.framework.graph import Variable, g_program +import paddle.v2.framework.core as core +import numpy as np + + +class TestVariable(unittest.TestCase): + def test_np_dtype_convert(self): + DT = core.DataType + convert = Variable._convert_np_dtype_to_dtype_ + self.assertEqual(DT.FP32, convert(np.float32)) + self.assertEqual(DT.FP16, convert("float16")) + self.assertEqual(DT.FP64, convert("float64")) + self.assertEqual(DT.INT32, convert("int32")) + self.assertEqual(DT.INT16, convert("int16")) + self.assertEqual(DT.INT64, convert("int64")) + self.assertEqual(DT.BOOL, convert("bool")) + self.assertRaises(ValueError, lambda: convert("int8")) + + def test_var(self): + b = g_program.current_block() + w = b.create_var( + dtype="float64", shape=[784, 100], lod_level=0, name="fc.w") + self.assertEqual(core.DataType.FP64, w.data_type) + self.assertEqual((784, 100), w.shape) + self.assertEqual("fc.w", w.name) + self.assertEqual(0, w.lod_level) + + w = b.create_var(name='fc.w') + self.assertEqual(core.DataType.FP64, w.data_type) + self.assertEqual((784, 100), w.shape) + self.assertEqual("fc.w", w.name) + self.assertEqual(0, w.lod_level) + + self.assertRaises(ValueError, + lambda: b.create_var(name="fc.w", shape=(24, 100))) + + +if __name__ == '__main__': + unittest.main()