提交 e85fa5e9 编写于 作者: C chengduo 提交者: GitHub

Merge branch 'develop' into fix_pool_op_doc_and_class_declarations

......@@ -24,6 +24,10 @@ if(WITH_DOUBLE)
add_definitions(-DPADDLE_TYPE_DOUBLE)
endif(WITH_DOUBLE)
if(WITH_TESTING)
add_definitions(-DPADDLE_WITH_TESTING)
endif(WITH_TESTING)
if(NOT WITH_TIMER)
add_definitions(-DPADDLE_DISABLE_TIMER)
endif(NOT WITH_TIMER)
......
......@@ -5,12 +5,12 @@
Both deep learning systems and programming languages help users describe computation procedures. These systems use various representations of computation:
- Caffe, Torch, and Paddle: sequences of layers.
- TensorFlow, Caffe2, Mxnet: graphs of operators.
- TensorFlow, Caffe2, Mxnet: graph of operators.
- PaddlePaddle: nested blocks, like C++ and Java programs.
## Block in Programming Languages and Deep Learning
In programming languages, a block is a pair of curly braces that includes local variables definitions and a sequence of instructions, or operators.
In programming languages, a block is a pair of curly braces that includes local variables definitions and a sequence of instructions or operators.
Blocks work with control flow structures like `if`, `else`, and `for`, which have equivalents in deep learning:
......@@ -24,14 +24,14 @@ A key difference is that a C++ program describes a one pass computation, whereas
## Stack Frames and the Scope Hierarchy
The existence of the backward makes the execution of a block of traditional programs and PaddlePaddle different to each other:
The existence of the backward pass makes the execution of a block of PaddlePaddle different from traditional programs:
| programming languages | PaddlePaddle |
|-----------------------|-------------------------------|
| stack | scope hierarchy |
| stack frame | scope |
| push at entering block| push at entering block |
| pop at leaving block | destroy at minibatch completes|
| programming languages | PaddlePaddle |
|-----------------------|---------------------------------|
| stack | scope hierarchy |
| stack frame | scope |
| push at entering block| push at entering block |
| pop at leaving block | destroy when minibatch completes|
1. In traditional programs:
......@@ -42,9 +42,9 @@ The existence of the backward makes the execution of a block of traditional prog
1. In PaddlePaddle
- When the execution enters a block, PaddlePaddle adds a new scope, where it realizes variables.
- PaddlePaddle doesn't pop a scope after the execution of the block because variables therein are to be used by the backward pass. So it has a stack forest known as a *scope hierarchy*.
- PaddlePaddle doesn't pop a scope after the execution of the block because variables therein are used by the backward pass. So it has a stack forest known as a *scope hierarchy*.
- The height of the highest tree is the maximum depth of nested blocks.
- After the process of a minibatch, PaddlePaddle destroys the scope hierarchy.
- After the processing of a minibatch, PaddlePaddle destroys the scope hierarchy.
## Use Blocks in C++ and PaddlePaddle Programs
......@@ -94,14 +94,14 @@ with ie.false_block():
o1, o2 = ie(cond)
```
In both examples, the left branch computes `x+y` and `softmax(x+y)`, the right branch computes `x+1` and `fc(x)`.
In both examples, the left branch computes `x+y` and `softmax(x+y)`, the right branch computes `fc(x)` and `x+1` .
A difference is that variables in the C++ program contain scalar values, whereas those in the PaddlePaddle programs are mini-batches of instances. The `ie.input(true, 0)` invocation returns instances in the 0-th input, `x`, that corresponds to true values in `cond` as the local variable `x`, where `ie.input(false, 0)` returns instances corresponding to false values.
The difference is that variables in the C++ program contain scalar values, whereas those in the PaddlePaddle programs are mini-batches of instances.
### Blocks with `for` and `RNNOp`
The following RNN model from the [RNN design doc](./rnn.md)
The following RNN model in PaddlePaddle from the [RNN design doc](./rnn.md) :
```python
x = sequence([10, 20, 30]) # shape=[None, 1]
......@@ -112,9 +112,9 @@ U = var(0.375, param=true) # shape=[1]
rnn = pd.rnn()
with rnn.step():
h = rnn.memory(init = m)
hh = rnn.previous_memory(h)
h_prev = rnn.previous_memory(h)
a = layer.fc(W, x)
b = layer.fc(U, hh)
b = layer.fc(U, h_prev)
s = pd.add(a, b)
act = pd.sigmoid(s)
rnn.update_memory(h, act)
......@@ -147,9 +147,9 @@ for (int i = 1; i <= sizeof(x)/sizeof(x[0]); ++i) {
## Compilation and Execution
Like TensorFlow programs, a PaddlePaddle program is written in Python. The first part describes a neural network as a protobuf message, and the rest part executes the message for training or inference.
Like TensorFlow, a PaddlePaddle program is written in Python. The first part describes a neural network as a protobuf message, and the rest executes the message for training or inference.
The generation of this protobuf message is like what a compiler generates a binary executable file. The execution of the message that the OS executes the binary file.
The generation of this protobuf message is similar to how a compiler generates a binary executable file. The execution of the message is similar to how the OS executes the binary file.
## The "Binary Executable File Format"
......@@ -186,8 +186,8 @@ Also, the RNN operator in above example is serialized into a protobuf message of
```
OpDesc {
inputs = {0} // the index of x
outputs = {5, 3} // indices of act and hidden_out
inputs = {0} // the index of x in vars of BlockDesc above
outputs = {5, 3} // indices of act and hidden_out in vars of BlockDesc above
attrs {
"memories" : {1} // the index of h
"step_net" : <above step net>
......@@ -203,14 +203,14 @@ This `OpDesc` value is in the `ops` field of the `BlockDesc` value representing
During the generation of the Protobuf message, the Block should store VarDesc (the Protobuf message which describes Variable) and OpDesc (the Protobuf message which describes Operator).
VarDesc in a block should have its name scope to avoid local variables affect parent block's name scope.
Child block's name scopes should inherit the parent's so that OpDesc in child block can reference a VarDesc that stored in parent block. For example
Child block's name scopes should inherit the parent's so that OpDesc in child block can reference a VarDesc that stored in parent block. For example:
```python
a = pd.Varaible(shape=[20, 20])
a = pd.Variable(shape=[20, 20])
b = pd.fc(a, params=["fc.w", "fc.b"])
rnn = pd.create_rnn()
with rnn.stepnet()
with rnn.stepnet():
x = a.as_step_input()
# reuse fc's parameter
fc_without_b = pd.get_variable("fc.w")
......@@ -218,17 +218,17 @@ with rnn.stepnet()
out = rnn()
```
the method `pd.get_variable` can help retrieve a Variable by a name, a Variable may store in a parent block, but might be retrieved in a child block, so block should have a variable scope that supports inheritance.
The method `pd.get_variable` can help retrieve a Variable by the name. The Variable may be stored in a parent block, but might be retrieved in a child block, so block should have a variable scope that supports inheritance.
In compiler design, the symbol table is a data structure created and maintained by compilers to store information about the occurrence of various entities such as variable names, function names, classes, etc.
To store the definition of variables and operators, we define a C++ class `SymbolTable`, like the one used in compilers.
`SymbolTable` can do the following stuff:
`SymbolTable` can do the following:
- store the definitions (some names and attributes) of variables and operators,
- to verify if a variable was declared,
- to make it possible to implement type checking (offer Protobuf message pointers to `InferShape` handlers).
- verify if a variable was declared,
- make it possible to implement type checking (offer Protobuf message pointers to `InferShape` handlers).
```c++
......@@ -240,19 +240,18 @@ class SymbolTable {
OpDesc* NewOp(const string& name="");
// TODO determine whether name is generated by python or C++
// currently assume that a unique name will be generated by C++ if the
// argument name left default.
// TODO determine whether name is generated by python or C++.
// Currently assume that a unique name will be generated by C++ if the
// argument name is left default.
VarDesc* NewVar(const string& name="");
// find a VarDesc by name, if recursive true, find parent's SymbolTable
// find a VarDesc by name, if recursive is true, find parent's SymbolTable
// recursively.
// this interface is introduced to support InferShape, find protobuf messages
// of variables and operators, pass pointers into InferShape.
// operator
//
// NOTE maybe some C++ classes such as VarDescBuilder and OpDescBuilder should
// be proposed and embedded into pybind to enable python operate on C++ pointers.
// be proposed and embedded into pybind to enable python operation on C++ pointers.
VarDesc* FindVar(const string& name, bool recursive=true);
OpDesc* FindOp(const string& name);
......@@ -270,7 +269,7 @@ class SymbolTable {
After all the description of variables and operators is added into SymbolTable,
the block has enough information to run.
The `Block` class takes a `BlockDesc` as input, and provide `Run` and `InferShape` functions.
The `Block` class takes a `BlockDesc` as input, and provides `Run` and `InferShape` functions.
```c++
......@@ -302,7 +301,7 @@ public:
void CreateVariables(const framework::Scope& scope);
void CreateOperators();
// some other necessary interfaces of NetOp are list below
// some other necessary interfaces of NetOp are listed below
// ...
private:
......@@ -316,15 +315,14 @@ private:
Block inherits from OperatorBase, which has a Run method.
Block's Run method will run its operators sequentially.
There is another important interface called `Eval`, which take some arguments called targets, and generate a minimal graph which takes targets as the end points and creates a new Block,
after `Run`, `Eval` will get the latest value and return the targets.
There is another important interface called `Eval`, which takes some arguments called targets and generates a minimal graph which treats targets as the end points and creates a new Block. After `Run`, `Eval` will get the latest value and return the targets.
The definition of Eval is as follows:
```c++
// clean a block description by targets using the corresponding dependency graph.
// return a new BlockDesc with minimal number of operators.
// NOTE not return a Block but the block's description so that this can be distributed
// NOTE: The return type is not a Block but the block's description so that this can be distributed
// to a cluster.
BlockDesc Prune(const BlockDesc& desc, vector<string> targets);
......
# Design for GAN
GAN (General Adversarial Net [https://arxiv.org/abs/1406.2661]) is an important model for unsupervised learning and widely used in many areas.
It applies several important concepts in machine learning system design, including building and running subgraphs, dependency tracing, different optimizers in one executor and so forth.
In our GAN design, we wrap it as a user-friendly easily customized python API to design different models. We take the conditional DC-GAN (Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks [https://arxiv.org/abs/1511.06434]) as an example due to its good performance on image generation.
<p align="center">
<img src="./test.dot.png" width = "35%" align="center"/><br/>
Figure 1. The overall running logic of GAN. The black solid arrows indicate the forward pass; the green dashed arrows indicate the backward pass of generator training; the red dashed arrows indicate the backward pass of the discriminator training. The BP pass of the green (red) arrow should only update the parameters in the green (red) boxes. The diamonds indicate the data providers. d\_loss and g\_loss marked in red and green are the two targets we would like to run.
</p>
The operators, layers and functions required/optional to build a GAN demo is summarized in https://github.com/PaddlePaddle/Paddle/issues/4563.
<p align="center">
<img src="./dcgan.png" width = "90%" align="center"/><br/>
Figure 2. Photo borrowed from the original DC-GAN paper.
</p>
## The Conditional-GAN might be a class.
This design we adopt the popular open source design in https://github.com/carpedm20/DCGAN-tensorflow and https://github.com/rajathkmp/DCGAN. It contains following data structure:
- DCGAN(object): which contains everything required to build a GAN model. It provides following member functions methods as API:
- __init__(...): Initialize hyper-parameters (like conv dimension and so forth), and declare model parameters of discriminator and generator as well.
- generator(z, y=None): Generate a fake image from input noise z. If the label y is provided, the conditional GAN model will be chosen.
Returns a generated image.
- discriminator(image):
Given an image, decide if it is from a real source or a fake one.
Returns a 0/1 binary label.
- build_model(self):
build the whole GAN model, define training loss for both generator and discrimator.
## Discussion on Engine Functions required to build GAN
- Trace the tensor and variable dependency in the engine executor. (Very critical, otherwise GAN can'be be trained correctly)
- Different optimizers responsible for optimizing different loss.
To be more detailed, we introduce our design of DCGAN as following:
### Class member Function: Initializer
- Set up hyper-parameters, including condtional dimension, noise dimension, batch size and so forth.
- Declare and define all the model variables. All the discriminator parameters are included in the list self.theta_D and all the generator parameters are included in the list self.theta_G.
```python
class DCGAN(object):
def __init__(self, y_dim=None):
# hyper parameters
self.y_dim = y_dim # conditional gan or not
self.batch_size = 100
self.z_dim = z_dim # input noise dimension
# define parameters of discriminators
self.D_W0 = pd.Variable(shape=[3,3, 1, 128], data=pd.gaussian_normal_randomizer())
self.D_b0 = pd.Variable(np.zeros(128)) # variable also support initialization using a numpy data
self.D_W1 = pd.Variable(shape=[784, 128], data=pd.gaussian_normal_randomizer())
self.D_b1 = pd.Variable(np.zeros(128)) # variable also support initialization using a numpy data
self.D_W2 = pd.Varialble(np.random.rand(128, 1))
self.D_b2 = pd.Variable(np.zeros(128))
self.theta_D = [self.D_W0, self.D_b0, self.D_W1, self.D_b1, self.D_W2, self.D_b2]
# define parameters of generators
self.G_W0 = pd.Variable(shape=[784, 128], data=pd.gaussian_normal_randomizer())
self.G_b0 = pd.Variable(np.zeros(128)) # variable also support initialization using a numpy data
self.G_W1 = pd.Variable(shape=[784, 128], data=pd.gaussian_normal_randomizer())
self.G_b1 = pd.Variable(np.zeros(128)) # variable also support initialization using a numpy data
self.G_W2 = pd.Varialble(np.random.rand(128, 1))
self.G_b2 = pd.Variable(np.zeros(128))
self.theta_G = [self.G_W0, self.G_b0, self.G_W1, self.G_b1, self.G_W2, self.G_b2]
```
### Class member Function: Generator
- Given a noisy input z, returns a fake image.
- Concatenation, batch-norm, FC operations required;
- Deconv layer required, which is missing now...
```python
class DCGAN(object):
def generator(self, z, y = None):
# input z: the random noise
# input y: input data label (optional)
# output G_im: generated fake images
if not self.y_dim:
z = pd.layer.concat(1, [z, y])
G_h0 = pd.layer.fc(z, self.G_w0, self.G_b0)
G_h0_bn = pd.layer.batch_norm(G_h0)
G_h0_relu = pd.layer.relu(G_h0_bn)
G_h1 = pd.layer.deconv(G_h0_relu, self.G_w1, self.G_b1)
G_h1_bn = pd.layer.batch_norm(G_h1)
G_h1_relu = pd.layer.relu(G_h1_bn)
G_h2 = pd.layer.deconv(G_h1_relu, self.G_W2, self.G_b2))
G_im = pd.layer.tanh(G_im)
return G_im
```
### Class member function: Discriminator
- Given a noisy input z, returns a fake image.
- Concatenation, Convolution, batch-norm, FC, Leaky-ReLU operations required;
```python
class DCGAN(object):
def discriminator(self, image):
# input image: either generated images or real ones
# output D_h2: binary logit of the label
D_h0 = pd.layer.conv2d(image, w=self.D_w0, b=self.D_b0)
D_h0_bn = pd.layer.batchnorm(h0)
D_h0_relu = pd.layer.lrelu(h0_bn)
D_h1 = pd.layer.conv2d(D_h0_relu, w=self.D_w1, b=self.D_b1)
D_h1_bn = pd.layer.batchnorm(D_h1)
D_h1_relu = pd.layer.lrelu(D_h1_bn)
D_h2 = pd.layer.fc(D_h1_relu, w=self.D_w2, b=self.D_b2)
return D_h2
```
### Class member function: Build the model
- Define data readers as placeholders to hold the data;
- Build generator and discriminators;
- Define two training losses for discriminator and generator, respectively.
If we have execution dependency engine to back-trace all tensors, the module building our GAN model will be like this:
```python
class DCGAN(object):
def build_model(self):
if self.y_dim:
self.y = pd.data(pd.float32, [self.batch_size, self.y_dim])
self.images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size])
self.faked_images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size])
self.z = pd.data(tf.float32, [None, self.z_size])
# step 1: generate images by generator, classify real/fake images with discriminator
if self.y_dim: # if conditional GAN, includes label
self.G = self.generator(self.z, self.y)
self.D_t = self.discriminator(self.images)
# generated fake images
self.sampled = self.sampler(self.z, self.y)
self.D_f = self.discriminator(self.G)
else: # original version of GAN
self.G = self.generator(self.z)
self.D_t = self.discriminator(self.images)
# generate fake images
self.sampled = self.sampler(self.z)
self.D_f = self.discriminator(self.images)
# step 2: define the two losses
self.d_loss_real = pd.reduce_mean(pd.cross_entropy(self.D_t, np.ones(self.batch_size))
self.d_loss_fake = pd.reduce_mean(pd.cross_entropy(self.D_f, np.zeros(self.batch_size))
self.d_loss = self.d_loss_real + self.d_loss_fake
self.g_loss = pd.reduce_mean(pd.cross_entropy(self.D_f, np.ones(self.batch_szie))
```
If we do not have dependency engine but blocks, the module building our GAN model will be like this:
```python
class DCGAN(object):
def build_model(self, default_block):
# input data in the default block
if self.y_dim:
self.y = pd.data(pd.float32, [self.batch_size, self.y_dim])
self.images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size])
# self.faked_images = pd.data(pd.float32, [self.batch_size, self.im_size, self.im_size])
self.z = pd.data(tf.float32, [None, self.z_size])
# step 1: generate images by generator, classify real/fake images with discriminator
with pd.default_block().g_block():
if self.y_dim: # if conditional GAN, includes label
self.G = self.generator(self.z, self.y)
self.D_g = self.discriminator(self.G, self.y)
else: # original version of GAN
self.G = self.generator(self.z)
self.D_g = self.discriminator(self.G, self.y)
self.g_loss = pd.reduce_mean(pd.cross_entropy(self.D_g, np.ones(self.batch_szie))
with pd.default_block().d_block():
if self.y_dim: # if conditional GAN, includes label
self.D_t = self.discriminator(self.images, self.y)
self.D_f = self.discriminator(self.G, self.y)
else: # original version of GAN
self.D_t = self.discriminator(self.images)
self.D_f = self.discriminator(self.G)
# step 2: define the two losses
self.d_loss_real = pd.reduce_mean(pd.cross_entropy(self.D_t, np.ones(self.batch_size))
self.d_loss_fake = pd.reduce_mean(pd.cross_entropy(self.D_f, np.zeros(self.batch_size))
self.d_loss = self.d_loss_real + self.d_loss_fake
```
Some small confusion and problems with this design:
- D\_g and D\_f are actually the same thing, but has to be written twice; i.e., if we want to run two sub-graphs conceptually, the same codes have to be written twice if they are shared by the graph.
- Requires ability to create a block anytime, rather than in if-else or rnn only;
## Main function for the demo:
Generally, the user of GAN just need to the following things:
- Define an object as DCGAN class;
- Build the DCGAN model;
- Specify two optimizers for two different losses with respect to different parameters.
```python
# pd for short, should be more concise.
from paddle.v2 as pd
import numpy as np
import logging
if __name__ == "__main__":
# dcgan class in the default graph/block
# if we use dependency engine as tensorflow
# the codes, will be slightly different like:
# dcgan = DCGAN()
# dcgan.build_model()
with pd.block() as def_block:
dcgan = DCGAN()
dcgan.build_model(def_block)
# load mnist data
data_X, data_y = self.load_mnist()
# Two subgraphs required!!!
with pd.block().d_block():
d_optim = pd.train.Adam(lr = .001, beta= .1)
d_step = d_optim.minimize(dcgan.d_loss, dcgan.theta_D)
with pd.block.g_block():
g_optim = pd.train.Adam(lr = .001, beta= .1)
g_step = pd.minimize(dcgan.g_loss, dcgan.theta_G)
# executor
sess = pd.executor()
# training
for epoch in xrange(10000):
for batch_id in range(N / batch_size):
idx = ...
# sample a batch
batch_im, batch_label = data_X[idx:idx+batch_size], data_y[idx:idx+batch_size]
# sample z
batch_z = np.random.uniform(-1., 1., [batch_size, z_dim])
if batch_id % 2 == 0:
sess.run(d_step,
feed_dict = {dcgan.images: batch_im,
dcgan.y: batch_label,
dcgan.z: batch_z})
else:
sess.run(g_step,
feed_dict = {dcgan.z: batch_z})
```
# More thinking about dependency engine v.s. block design:
- What if we just want to run an intermediate result? Do we need to run the whole block/graph?
- Should we call eval() to get the fake images in the first stage? And then train the discriminator in the second stage?
......@@ -22,7 +22,7 @@ Whenever we create a block, we need to set its parent block to the current block
```python
class Program(objects):
def __init__(self):
self.proto = core.NewProgram() # a C++ ProgramDesc pointer.
self.desc = core.NewProgram() # a C++ ProgramDesc pointer.
self.blocks = vector<Block>()
self.blocks.append(Block(self, -1)) # the global block
self.current_block = 0 # initialized to the global block
......@@ -57,7 +57,7 @@ A [Block](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.m
```python
class Block(objects):
def __init__(self, program, parent_idx):
self.proto = core.NewBlock(program.proto)
self.desc = core.NewBlock(program.desc)
self.program = program
self.vars = map<string, Variable>()
self.ops = vector<Operator>()
......@@ -98,11 +98,11 @@ class Operator(object):
outputs,# dict<stirng, Variable>
attrs # dict<string, Any>
):
self.proto = core.NewOpDesc(block.proto, type, inputs, outputs, attrs)
core.infer_shape(self.proto, inputs, outputs)
self.desc = core.NewOpDesc(block.desc, type, inputs, outputs, attrs)
core.infer_shape(self.desc, inputs, outputs)
def type(self):
return self.proto.type()
return self.desc.type()
```
`Operator` creates the `OpDesc` message in C++ space, so that it can call the `InferShape` function, which is in C++.
......@@ -124,7 +124,7 @@ class Variable(object):
name = unique_name_generator()
self.name = name
self.block = block
self.proto = core.NewVarDesc(block.proto, name, shape, lod_level)
self.desc = core.NewVarDesc(block.desc, name, shape, lod_level)
self.writer = None
```
......
......@@ -17,22 +17,22 @@ The goals of refactoring include:
1. A graph is composed of *variables* and *operators*.
1. The description of graphs must be capable of being serialized/deserialized, so that:
1. The description of graphs must be serializable/deserializable, so that:
1. It can to be sent to the cloud for distributed execution, and
1. It can be sent to the cloud for distributed execution, and
1. It can be sent to clients for mobile or enterprise deployment.
1. The Python program does the following steps
1. The Python program does two things
1. *compilation*: run a Python program to generate a protobuf message representation of the graph and send it to
1. *Compilation* runs a Python program to generate a protobuf message representation of the graph and send it to
1. the C++ library `libpaddle.so` for local execution,
1. the master process of a distributed training job for training, or
1. the server process of a Kubernetes serving job for distributed serving.
1. *execution*: execute the graph by constructing instances of class [`Variable`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/variable.h#L24) and [`OperatorBase`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h#L70), according to the protobuf message.
1. *Execution* executes the graph by constructing instances of class [`Variable`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/variable.h#L24) and [`OperatorBase`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h#L70), according to the protobuf message.
## Description and Realization of Computation Graph
At compile time, the Python program generates a protobuf message representation of the graph, or the description of the graph.
At compile time, the Python program generates a protobuf message representation of the graph, or a description of the graph.
At runtime, the C++ program realizes the graph and runs it.
......@@ -42,11 +42,11 @@ At runtime, the C++ program realizes the graph and runs it.
|Operation|[OpDesc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L35)|[Operator](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h#L64)|
|Block|BlockDesc|Block|
The word *graph* is interchangeable with *block* in this document. A graph represents computation steps and local variables similar to a C++/Java program block, or a pair of parentheses(`{` and `}`).
The word *graph* is interchangeable with *block* in this document. A graph consists of computation steps and local variables similar to a C++/Java program block, or a pair of parentheses(`{` and `}`).
## Compilation and Execution
1. Run an application Python program to describe the graph. In particular, the Python application program does the following:
1. Run a Python program to describe the graph. In particular, the Python application program does the following:
1. Create `VarDesc` to represent local/intermediate variables,
1. Create operators and set attributes,
......@@ -54,10 +54,10 @@ The word *graph* is interchangeable with *block* in this document. A graph repr
1. Infer the type and the shape of variables,
1. Plan memory-reuse for variables,
1. Generate the backward graph
1. Optimize the computation graph.
1. Potentially, split the graph for distributed training.
1. Add optimization operators to the computation graph.
1. Optionally, split the graph for distributed training.
1. The invocation of `train` or [`infer`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/inference.py#L108) methods in the application Python program does the following:
1. The invocation of `train` or [`infer`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/inference.py#L108) methods in the Python program does the following:
1. Create a new Scope instance in the [scope hierarchy](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/scope.md) for each run of a block,
1. realize local variables defined in the BlockDesc message in the new scope,
......@@ -107,8 +107,8 @@ Compile Time -> IR -> Runtime
![class_diagram](http://api.paddlepaddle.org/graphviz?dot=https://gist.githubusercontent.com/reyoung/53df507f6749762675dff3e7ce53372f/raw/dd598e8f1976f5759f58af5e5ef94738a6b2e661/op.dot)
* `Operator` is the fundamental building block of the user interface.
* Operator stores input/output variable names, and attributes.
* The `InferShape` interface is used to infer the shape of the output variable shapes based on the shapes of the input variables.
* Operator stores input/output variable names and attributes.
* The `InferShape` interface is used to infer the shape of the output variables based on the shapes of the input variables.
* Use `Run` to compute the `output` variables from the `input` variables.
---
......@@ -139,7 +139,7 @@ Compile Time -> IR -> Runtime
* Limit the number of `tensor.device(dev) = ` in your code.
* `thrust::transform` and `std::transform`.
* `thrust` has the same API as C++ standard library. Using `transform`, one can quickly implement customized element-wise kernels.
* `thrust` also has more complex APIs, like `scan`, `reduce`, `reduce_by_key`.
* `thrust`, in addition, supports more complex APIs, like `scan`, `reduce`, `reduce_by_key`.
* Hand-writing `GPUKernel` and `CPU` code
* Do not write in header (`.h`) files. CPU Kernel should be in cpp source (`.cc`) and GPU kernels should be in cuda (`.cu`) files. (GCC cannot compile GPU code.)
---
......@@ -185,10 +185,10 @@ Make sure the registration process is executed and linked.
1. Write an Op class and its gradient Op class, if required.
2. Write an Op maker class. In the constructor of this class, describe the inputs, outputs and attributes of the operator.
3. Invoke the macro `REGISTER_OP`. This macro will
1. Call maker class to complete the `proto` and the `checker`
1. Call maker class to complete `proto` and `checker`
2. Using the completed `proto` and `checker`, it will add a new key-value pair to the `OpInfoMap`
4. Invoke the `USE` macro in which the Op is used, to make sure that it is linked.
4. Invoke the `USE` macro in which the Op is used to make sure that it is linked.
---
# Backward Module (1/2)
......@@ -199,13 +199,14 @@ Make sure the registration process is executed and linked.
---
# Backward Module (2/2)
### Build Backward Network
- **Input**: graph of forward operators
- **Output**: graph of backward operators
- **Input**: a graph of forward operators
- **Output**: a graph of backward operators
- **Corner cases in construction**
- Shared Variables => insert an `Add` operator to combine gradients
- No Gradient => insert a `fill_zero_grad` operator
- Recursive NetOp => call `Backward` recursively
- RNN Op => recursively call `Backward` on stepnet
- RNN Op => recursively call `Backward` on stepnet
---
......@@ -215,10 +216,10 @@ Make sure the registration process is executed and linked.
* Only dims and data pointers are stored in `Tensor`.
* All operations on `Tensor` are written in `Operator` or global functions.
* Variable length Tensor design [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md)
* `Variable` instances are the inputs and the outputs of an operator. Not just `Tensor`.
* `Variable` instances are the inputs and the outputs of an operator, not just `Tensor`.
* `step_scopes` in RNN is a variable and not a tensor.
* `Scope` is where variables are stores.
* map<string `variable_name`, Variable>
* `Scope` is where variables are stored.
* map<string `var name`, Variable>
* `Scope` has a hierarchical structure. The local scope can get variables from its parent scope.
---
......@@ -246,7 +247,7 @@ Make sure the registration process is executed and linked.
---
# Control the migration quality
- Compare the performance of migrated models with old ones.
- Follow the google C++ style
- Follow the google C++ style guide.
- Build the automatic workflow of generating Python/C++ documentations.
- The documentation of layers and ops should be written inside the code.
- Take the documentation quality into account when submitting pull requests.
......
digraph Test {
z -> generator -> G_img;
G_img -> discriminator -> D_f -> d_loss_f;
label0 -> d_loss_f -> d_loss;
img -> discriminator -> D_t -> d_loss_t;
label1 -> d_loss_t -> d_loss;
d_loss -> d_loss_t[color=red, style=dashed];
d_loss -> d_loss_f[color=red, style=dashed];
d_loss_t -> D_t[color=red, style=dashed];
d_loss_f -> D_f[color=red, style=dashed];
D_t -> discriminator[color=red, style=dashed];
D_f -> discriminator[color=red, style=dashed];
D_f -> g_loss;
label2 -> g_loss;
g_loss -> D_f[color=green, style=dashed];
D_f -> discriminator[color=green, style=dashed];
discriminator -> G_img[color=green, style=dashed];
G_img -> generator[color=green, style=dashed];
discriminator [color=red, shape=box];
generator [color=green, shape=box];
z [shape=diamond];
img [shape=diamond];
label0 [shape=diamond];
label1 [shape=diamond];
label2 [shape=diamond];
d_loss [color=red];
g_loss [color=green];
}
......@@ -19,7 +19,7 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
proto_library(framework_proto SRCS framework.proto)
cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc)
......
......@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_desc.h"
#include <functional>
#include <unordered_map>
#include "paddle/framework/block_desc.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
......@@ -25,6 +28,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
inputs_ = inputs;
outputs_ = outputs;
attrs_ = attrs;
need_update_ = true;
}
OpDesc *OpDescBind::Proto() {
......@@ -184,5 +188,38 @@ void OpDescBind::Sync() {
need_update_ = false;
}
}
using InferShapeFuncMap =
std::unordered_map<std::string /*op_type*/,
std::function<void(InferShapeContext *)>>;
static InferShapeFuncMap &InferShapeFuncs() {
static InferShapeFuncMap *g_map = nullptr;
if (g_map == nullptr) {
g_map = new InferShapeFuncMap();
auto &info_map = OpInfoMap::Instance();
// all registered kernels
for (auto &pair : OperatorWithKernel::AllOpKernels()) {
auto &info = info_map.Get(pair.first);
// use empty type here to avoid runtime checks.
auto op =
static_cast<OperatorWithKernel *>(info.Creator()("", {}, {}, {}));
g_map->insert(
{pair.first, [op](InferShapeContext *ctx) { op->InferShape(ctx); }});
}
}
return *g_map;
}
void OpDescBind::InferShape(const BlockDescBind &block) const {
auto &funcs = InferShapeFuncs();
auto it = funcs.find(this->Type());
if (it == funcs.end()) {
PADDLE_THROW("Operator %s has not been registered", this->Type());
}
CompileTimeInferShapeContext ctx(*this, block);
it->second(&ctx);
}
} // namespace framework
} // namespace paddle
......@@ -100,6 +100,8 @@ class OpDescBind {
return &this->attrs_;
}
void InferShape(const BlockDescBind &block) const;
private:
template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
......
......@@ -142,9 +142,9 @@ class OperatorBase {
// Macro for define a clone method.
// If you are writing an kernel operator, `Clone` will be defined when you
// register it. i.e. `Clone` method is not needed to define by yourself.
#define DEFINE_OP_CLONE_METHOD(cls) \
std::unique_ptr<OperatorBase> Clone() const final { \
return std::unique_ptr<OperatorBase>(new cls(*this)); \
#define DEFINE_OP_CLONE_METHOD(cls) \
std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final { \
return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \
}
// Macro for define a default constructor for Operator.
......
......@@ -87,12 +87,12 @@ class TensorArray {
LoDTensor Stack() const;
/*
* Unpacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors.
* Unstacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors.
*/
void Unstack(const LoDTensor &source) const;
/*
* Unpacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors,
* Unstacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors,
* with memory of tensors shared.
*/
void UnstackShared(const LoDTensor &source) const;
......
......@@ -32,5 +32,13 @@ std::vector<int64_t> VarDescBind::Shape() const {
DataType VarDescBind::GetDataType() const {
return desc_.lod_tensor().data_type();
}
void VarDescBind::SetLoDLevel(int32_t lod_level) {
desc_.mutable_lod_tensor()->set_lod_level(lod_level);
}
int32_t VarDescBind::GetLodLevel() const {
return desc_.lod_tensor().lod_level();
}
} // namespace framework
} // namespace paddle
......@@ -66,6 +66,10 @@ class VarDescBind {
DataType GetDataType() const;
void SetLoDLevel(int32_t lod_level);
int32_t GetLodLevel() const;
private:
VarDesc desc_;
};
......
......@@ -162,4 +162,4 @@ int main(int argc, char** argv) {
return RUN_ALL_TESTS();
}
#endif /* PADDLE_ONLY_CPU */
#endif
......@@ -182,7 +182,7 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() {
max_chunk_size_ = platform::GpuMaxChunkSize();
}
}
#endif // PADDLE_ONLY_CPU
#endif
// Allocate a new maximum sized block
size_t index = 0;
......
......@@ -134,7 +134,7 @@ void GPUAllocator::Free(void* p, size_t size, size_t index) {
bool GPUAllocator::UseGpu() const { return true; }
#endif // PADDLE_ONLY_CPU
#endif
} // namespace detail
} // namespace memory
......
......@@ -51,7 +51,7 @@ class GPUAllocator : public SystemAllocator {
size_t gpu_alloc_size_ = 0;
size_t fallback_alloc_size_ = 0;
};
#endif // PADDLE_ONLY_CPU
#endif
} // namespace detail
} // namespace memory
......
......@@ -62,4 +62,4 @@ TEST(GPUAllocator, Alloc) {
TestAllocator(a, 2048);
TestAllocator(a, 0);
}
#endif // PADDLE_ONLY_CPU
#endif
......@@ -89,7 +89,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice);
}
#endif // PADDLE_ONLY_CPU
#endif
} // namespace memory
} // namespace paddle
......@@ -53,7 +53,7 @@ template <typename DstPlace, typename SrcPlace>
void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num,
cudaStream_t stream);
#endif // PADDLE_ONLY_CPU
#endif
} // namespace memory
} // namespace paddle
......@@ -111,7 +111,7 @@ size_t Used<platform::GPUPlace>(platform::GPUPlace place) {
return GetGPUBuddyAllocator(place.device)->Used();
}
#endif // PADDLE_ONLY_CPU
#endif
} // namespace memory
} // namespace paddle
......@@ -135,4 +135,4 @@ TEST(BuddyAllocator, GPUMultAlloc) {
}
}
#endif // PADDLE_ONLY_CPU
#endif
......@@ -137,3 +137,4 @@ cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array)
......@@ -49,6 +49,18 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class LogSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LogSigmoidOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of LogSigmoid operator");
AddOutput("Y", "Output of LogSigmoid operator");
AddComment(
"Logsigmoid activation operator, logsigmoid = log (1 / (1 + exp(-x)))");
}
};
class ExpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ExpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
......@@ -85,6 +97,23 @@ class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
template <typename AttrType>
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftShrinkOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softshrink operator");
AddOutput("Y", "Output of Softshrink operator");
AddComment(
"Softshrink activation operator, "
"softshrink = x - lambda, if x > lambda;"
" x + lambda, if x < lambda; 0 otherwise");
AddAttr<AttrType>("lambda", "non-negative offset")
.SetDefault(static_cast<AttrType>(0.5f));
}
};
class TanhOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
......@@ -271,6 +300,9 @@ namespace ops = paddle::operators;
REGISTER_OP(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, sigmoid_grad,
ops::ActivationOpGrad);
REGISTER_OP(logsigmoid, ops::ActivationOp, ops::LogSigmoidOpMaker,
logsigmoid_grad, ops::ActivationOpGrad);
REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad,
ops::ActivationOpGrad);
......@@ -283,6 +315,9 @@ REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad,
REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker,
tanh_shrink_grad, ops::ActivationOpGrad);
REGISTER_OP(softshrink, ops::ActivationOp, ops::SoftShrinkOpMaker<float>,
softshrink_grad, ops::ActivationOpGrad);
REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad,
ops::ActivationOpGrad);
......
......@@ -95,6 +95,41 @@ struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
}
};
// Originally: logsigmoid(x) = -log (1 + exp(-x))
// For numerical stability, we can use the log-sum-exp trick:
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// We can rewrite the above equation as:
// y = -log( exp(0) + exp(-x)) [since exp(0) = 1]
// = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0)))
// = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x -
// max(-x, 0)))
// = -log( exp(max(-x, 0)) * (exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
// = -log( exp(max(-x, 0)) - log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))
//
// Hence, logsigmoid(x) = - (max(-x, 0) + log(exp(-max(-x, 0))
// + exp(-x - max(-x, 0))))
template <typename T>
struct LogSigmoidFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
auto temp = (-x).cwiseMax(static_cast<T>(0)); // temp = max(-x, 0)
y.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log());
}
};
// Originally: f' = exp(-x) / (1 + exp(-x))
// For numerical stability: f' = exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) +
// exp(-x - max(-x, 0)))
template <typename T>
struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp = (-x).cwiseMax(static_cast<T>(0)); // temp = max(-x, 0)
dx.device(d) =
dy * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp()));
}
};
// exp(x) = e^x
template <typename T>
struct ExpFunctor : public BaseActivationFunctor<T> {
......@@ -164,6 +199,37 @@ struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
}
};
// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < lambda; 0
// otherwise
template <typename T>
struct SoftShrinkFunctor : public BaseActivationFunctor<T> {
float lambda;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"lambda", &lambda}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
auto temp1 = (x > lambda).template cast<T>().eval();
auto temp2 = (x < -lambda).template cast<T>().eval();
y.device(d) = temp1 * (x - lambda) + temp2 * (x + lambda);
}
};
template <typename T>
struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
float lambda;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"lambda", &lambda}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp1 = (x > lambda).template cast<T>().eval();
auto temp2 = (x < -lambda).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>();
}
};
// sqrt(x) = x^(1/2)
template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> {
......@@ -471,9 +537,11 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, ReluFunctor, ReluGradFunctor); \
__macro(tanh, TanhFunctor, TanhGradFunctor); \
__macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
__macro(abs, AbsFunctor, AbsGradFunctor); \
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
......@@ -484,7 +552,7 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
__macro(pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor); \
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
__macro(relu6, Relu6Functor, Relu6GradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
__macro(elu, ELUFunctor, ELUGradFunctor)
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/conv_shift_op.h"
#include "paddle/framework/eigen.h"
namespace paddle {
namespace operators {
using framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
class ConvShiftOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(y_dims.size(), 2, "Input(Y)'s rank should be 2.");
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
"The 1st dimension of Input(X) and Input(Y) should "
"be equal.");
PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1,
"The 2nd dimension of Input(Y) should be odd.");
PADDLE_ENFORCE_LE(y_dims[1], x_dims[1],
"The 2nd dimension of Input(Y) should be less than or "
"equal to the 2nd dimension of Input(X).");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class ConvShiftGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should be not null.");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(x_grad_name, x_dims);
}
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(y_grad_name)) {
auto y_dims = ctx->GetInputDim("Y");
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
};
class ConvShiftOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ConvShiftOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape B x M, "
"where B is the batch size and M is the data dimension.");
AddInput("Y",
"(Tensor, default Tensor<float>), a 2-D tensor with shape B x N, "
"where B is the batch size and N is the data dimension. N must "
"be odd.");
AddOutput("Out",
"(Tensor, default Tensor<float>), a 2-D tensor with shape B x M, "
"i.e., the same shape as X.");
AddComment(R"DOC(
ConvShift Operator.
A layer for circular convolution of two vectors,
as used in the Neural Turing Machine: https://arxiv.org/abs/1410.5401
The equation is:
\f[
Out[i] = \sum_{j=-(N-1)/2}^{(N-1)/2} X_{i+j} * Y_{j}
\f]
where X's index is computed modulo M, and b's index is computed modulo N.
Both of the input `X` and `Y` can carry LoD (Level of Details) information.
However, the output only shares the LoD information with input `X`.
)DOC");
}
};
template <typename T>
class ConvShiftKernel<platform::CPUPlace, T> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *X = context.Input<Tensor>("X");
auto *Y = context.Input<Tensor>("Y");
auto *Out = context.Output<Tensor>("Out");
Out->mutable_data<T>(context.GetPlace());
auto x = EigenMatrix<T>::From(*X);
auto y = EigenMatrix<T>::From(*Y);
auto out = EigenMatrix<T>::From(*Out);
out.setZero();
size_t batch_size = X->dims()[0];
size_t x_width = X->dims()[1];
size_t y_width = Y->dims()[1];
size_t y_half_width = (y_width - 1) / 2;
for (size_t k = 0; k < batch_size; ++k) {
for (size_t i = 0; i < x_width; ++i) {
for (size_t j = 0; j < y_width; ++j) {
int index = (i + j - y_half_width + x_width) % x_width;
out(k, i) += x(k, index) * y(k, j);
}
}
}
}
};
template <typename T>
class ConvShiftGradKernel<platform::CPUPlace, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *X = context.Input<Tensor>("X");
auto *Y = context.Input<Tensor>("Y");
auto *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto *dX = context.Output<Tensor>(framework::GradVarName("X"));
auto *dY = context.Output<Tensor>(framework::GradVarName("Y"));
auto x = EigenMatrix<T>::From(*X);
auto y = EigenMatrix<T>::From(*Y);
auto dout = EigenMatrix<T>::From(*dOut);
auto x_dims = X->dims();
auto y_dims = Y->dims();
size_t batch_size = x_dims[0];
size_t x_width = x_dims[1];
size_t y_width = y_dims[1];
size_t y_half_width = (y_width - 1) / 2;
// The below trades code duplication for efficiency (keeping the if
// statement outside of the loop).
if (dX) {
dX->mutable_data<T>(context.GetPlace());
auto dx = EigenMatrix<T>::From(*dX);
dx.setZero();
for (size_t k = 0; k < batch_size; ++k) {
for (size_t i = 0; i < x_width; ++i) {
for (size_t j = 0; j < y_width; ++j) {
int index = (i + j - y_half_width + x_width) % x_width;
dx(k, index) += dout(k, i) * y(k, j);
}
}
}
}
if (dY) {
dY->mutable_data<T>(context.GetPlace());
auto dy = EigenMatrix<T>::From(*dY);
dy.setZero();
for (size_t k = 0; k < batch_size; ++k) {
for (size_t i = 0; i < x_width; ++i) {
for (size_t j = 0; j < y_width; ++j) {
int index = (i + j - y_half_width + x_width) % x_width;
dy(k, j) += x(k, index) * dout(k, i);
}
}
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(conv_shift, ops::ConvShiftOp, ops::ConvShiftOpMaker,
conv_shift_grad, ops::ConvShiftGradOp);
REGISTER_OP_CPU_KERNEL(conv_shift,
ops::ConvShiftKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv_shift_grad,
ops::ConvShiftGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/conv_shift_op.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
using framework::Tensor;
namespace {
inline int div_up(int x, int y) { return (x + y - 1) / y; }
// Some notes on the design:
//
// Each thread is responsible for computing a single output out[k, i].
// Thread blocks are based on tiles of x with height 1 in the batch dimension.
//
// This design is based on the typical use case where the filter
// y is fairly small. For large y, it would probably be more efficient
// to also tile across y.
template <typename T>
__global__ void conv_shift_forward(const T *x, const T *y, T *out, int x_width,
int y_width, int y_half_width,
int batch_size) {
extern __shared__ T mem[];
int tx = threadIdx.x;
int i = blockIdx.x * blockDim.x + tx; // global x index
int k = blockIdx.y; // batch index
// Check if we are in a boundary block with fewer x's to process than
// blockDim.x.
int num_x =
(blockIdx.x == gridDim.x - 1) ? (x_width % blockDim.x) : blockDim.x;
T *sx = mem;
T *sx_pad = &mem[num_x];
T *sy = &mem[blockDim.x + y_width];
// Collaboratively load y[k, :] and length-y padding of x into shared memory.
int pad_start = blockIdx.x * blockDim.x + num_x + x_width - y_half_width;
for (int j = tx; j < y_width; j += blockDim.x) {
sy[j] = y[k * y_width + j];
sx_pad[j] = x[k * x_width + (pad_start + j) % x_width];
}
// Load a cyclically shifted slice of x into shared memory.
if (tx < num_x) {
int load_i = (i - y_half_width + x_width) % x_width;
sx[tx] = x[k * x_width + load_i];
} else {
return;
}
__syncthreads();
// Compute dot product of sx[tx:tx + y_width] and sy.
T sum = 0;
for (int j = 0; j < y_width; ++j) {
sum += sx[tx + j] * sy[j];
}
// Save to out[k, i].
out[k * x_width + i] = sum;
}
// Compute x gradient - initial naive implementation with atomic add.
template <typename T>
__global__ void conv_shift_dx(const T *dout, const T *y, T *dx, int x_width,
int y_width, int y_half_width, int batch_size) {
int i = blockIdx.x * blockDim.x + threadIdx.x; // x index
int j = blockIdx.y; // y index
int k = blockIdx.z; // batch index
if (i < x_width) {
int index = (i + j - y_half_width + x_width) % x_width;
atomicAdd(&dx[k * x_width + index],
dout[k * x_width + i] * y[k * y_width + j]);
}
}
// Compute y gradient - initial naive implementation with atomic add.
template <typename T>
__global__ void conv_shift_dy(const T *x, const T *dout, T *dy, int x_width,
int y_width, int y_half_width, int batch_size) {
int i = blockIdx.x * blockDim.x + threadIdx.x; // x index
int j = blockIdx.y; // y index
int k = blockIdx.z; // batch index
if (i < x_width) {
int index = (i + j - y_half_width + x_width) % x_width;
atomicAdd(&dy[k * y_width + j],
x[k * x_width + index] * dout[k * x_width + i]);
}
}
} // namespace
template <typename T>
class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Y = context.Input<Tensor>("Y");
Tensor *Out = context.Output<Tensor>("Out");
const T *x_data = X->data<T>();
const T *y_data = Y->data<T>();
T *out_data = Out->mutable_data<T>(context.GetPlace());
int batch_size = X->dims()[0];
int x_width = X->dims()[1];
int y_width = Y->dims()[1];
int y_half_width = (y_width - 1) / 2;
const int x_per_block = 256;
int num_x_blocks = div_up(x_width, x_per_block);
int mem_per_block = (x_per_block + 2 * y_width) * sizeof(T);
dim3 grid_dim(num_x_blocks, batch_size);
auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
context.device_context())
.stream();
conv_shift_forward<T><<<grid_dim, x_per_block, mem_per_block, stream>>>(
x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size);
}
};
template <typename T>
class ConvShiftGradKernel<platform::GPUPlace, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Y = context.Input<Tensor>("Y");
const Tensor *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
const T *x_data = X->data<T>();
const T *y_data = Y->data<T>();
const T *dout_data = dOut->data<T>();
Tensor *dX = context.Output<Tensor>(framework::GradVarName("X"));
Tensor *dY = context.Output<Tensor>(framework::GradVarName("Y"));
int batch_size = X->dims()[0];
int x_width = X->dims()[1];
int y_width = Y->dims()[1];
int y_half_width = (y_width - 1) / 2;
auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
context.device_context())
.stream();
const int x_per_block = 256;
int num_x_blocks = div_up(x_width, x_per_block);
dim3 grid_dim(num_x_blocks, y_width, batch_size);
if (dX) {
T *dx_data = dX->mutable_data<T>(context.GetPlace());
cudaMemsetAsync(dx_data, 0, dX->numel() * sizeof(T), stream);
conv_shift_dx<T><<<grid_dim, x_per_block, 0, stream>>>(
dout_data, y_data, dx_data, x_width, y_width, y_half_width,
batch_size);
}
if (dY) {
T *dy_data = dY->mutable_data<T>(context.GetPlace());
cudaMemsetAsync(dy_data, 0, dY->numel() * sizeof(T), stream);
conv_shift_dy<T><<<grid_dim, x_per_block, 0, stream>>>(
x_data, dout_data, dy_data, x_width, y_width, y_half_width,
batch_size);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(conv_shift,
ops::ConvShiftKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
conv_shift_grad,
ops::ConvShiftGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class ConvShiftKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override;
};
template <typename Place, typename T>
class ConvShiftGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override;
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve .
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/dynamic_recurrent_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using framework::Scope;
using framework::TensorArray;
using framework::LoDTensor;
using framework::Variable;
namespace detail {
inline void CreateVariables(Scope& scope,
const std::vector<std::string>& var_names) {
for (const auto& name : var_names) {
scope.NewVar(name);
}
}
} // namespace detail
class DynamicRecurrentOpProtoAndCheckerMaker
: public framework::OpProtoAndCheckerMaker {
public:
DynamicRecurrentOpProtoAndCheckerMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
const auto& name = DynamicRecurrentOp::kArgName;
// inputs and outputs stored in proto
AddInput(name.inlinks,
"the inputs that need to be segmented for each step.")
.AsDuplicable();
AddInput(name.boot_memories, "variables to initialize memories.")
.AsDuplicable();
AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
.AsDuplicable();
AddOutput(name.step_scopes, "step scopes");
// Attributes stored in AttributeMap
AddAttr<std::vector<std::string>>(name.pre_memories,
"names of pre-memories");
AddAttr<std::vector<std::string>>(name.memories, "names of memories");
AddComment("This is a RNN operator for varience-length sequences.");
}
};
void DynamicRecurrentOp::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
cache_.Init(kArgName, *this, scope, &arg_);
SplitInputs();
CreateScopes();
WriteStepInputs();
InitStates();
// call stepnet in all the time steps
for (size_t step = 0; step < cache_.num_steps; step++) {
auto& step_scope = cache_.GetScope(step);
stepnet_->Run(step_scope, dev_ctx);
}
WriteStepOutputs();
ConcatOutputs();
}
void DynamicRecurrentOp::SplitInputs() const {
// TODO(superjom) make level a config
// TODO(superjom) check all the inputs has the same LoD
int level = 0;
const auto& inlinks = cache_.inlinks;
for (const auto& item : inlinks) {
const auto& var = item.second;
const auto& tensor = var->Get<LoDTensor>();
TensorArray& ta = step_inputs_[item.first];
dy_seq_metas_[item.first] =
ta.Unpack(tensor, level, true /*length_descend*/);
if (cache_.num_steps) {
PADDLE_ENFORCE_EQ(ta.size(), cache_.num_steps,
"inputs should have the same steps");
} else {
cache_.num_steps = ta.size();
}
}
}
void DynamicRecurrentOp::WriteStepInputs() const {
for (const auto& item : cache_.inlinks) {
auto ta_it = step_inputs_.find(item.first);
PADDLE_ENFORCE(ta_it != step_inputs_.end(),
"step_inputs_ not compatible with memory set");
TensorArray& ta = ta_it->second;
for (size_t step = 0; step < ta.size(); step++) {
auto tensor = ta.Read(step);
auto& step_scope = cache_.GetScope(step);
Variable* var = step_scope.FindVar(item.first);
if (var == nullptr) {
var = step_scope.NewVar(item.first);
}
var->GetMutable<LoDTensor>()->ShareDataWith<value_type>(tensor);
}
}
}
void DynamicRecurrentOp::WriteStepOutputs() const {
for (size_t step = 0; step < cache_.scopes->size(); step++) {
auto& scope = cache_.GetScope(step);
for (auto& item : step_outputs_) {
auto* var = scope.FindVar(item.first);
if (var == nullptr) {
var = scope.NewVar(item.first);
}
auto* tensor = var->GetMutable<LoDTensor>();
item.second.WriteShared(step, *tensor);
}
}
}
void DynamicRecurrentOp::CreateScopes() const {
PADDLE_ENFORCE_GT(cache_.num_steps, 0);
// resize scopes
size_t num_scopes_need_create = cache_.num_steps - cache_.scopes->size();
for (size_t i = 0; i < num_scopes_need_create; i++) {
cache_.scopes->emplace_back(&cache_.scope->NewScope());
}
// init temporary inputs
PADDLE_ENFORCE_NOT_NULL(stepnet_, "stepnet should be set first");
std::vector<std::string> memories;
std::vector<std::string> pre_memories;
std::transform(arg_.memories.begin(), arg_.memories.end(),
std::back_inserter(memories),
[](const rnn::MemoryAttr& m) { return m.var; });
std::transform(arg_.memories.begin(), arg_.memories.end(),
std::back_inserter(pre_memories),
[](const rnn::MemoryAttr& m) { return m.pre_var; });
for (size_t step = 0; step < cache_.num_steps; step++) {
auto& scope = cache_.GetScope(step);
detail::CreateVariables(scope, arg_.inlinks);
detail::CreateVariables(scope, arg_.outlinks);
detail::CreateVariables(scope, memories);
detail::CreateVariables(scope, pre_memories);
}
}
void DynamicRecurrentOp::ConcatOutputs() const {
// TODO(superjom) transform this to a config
int level = 0;
// TODO(superjom) pass in some lod
// just a placeholder
framework::LoD lod;
for (auto& item : step_outputs_) {
auto tensor = item.second.Pack(level, dy_seq_metas_[item.first], lod);
auto& output = cache_.outlinks[item.first]->Get<LoDTensor>();
const_cast<LoDTensor*>(&output)->ShareDataWith<value_type>(tensor);
}
}
void DynamicRecurrentOp::InitStates() const {
// init the first state
// TODO(superjom) parepare the scenerio that boot state not exists
for (auto memory : arg_.memories) {
auto* boot_state_var = cache_.scope->FindVar(memory.boot_var);
PADDLE_ENFORCE_NOT_NULL(boot_state_var);
auto& boot_state = boot_state_var->Get<LoDTensor>();
const auto& dims = boot_state.dims();
for (size_t step = 0; step < cache_.num_steps; step++) {
auto& cur_scope = cache_.GetScope(step);
// link pre-state to boot_state
// init state and pre-state
auto* pre_state = cur_scope.FindVar(memory.pre_var);
PADDLE_ENFORCE_NOT_NULL(pre_state);
pre_state->GetMutable<LoDTensor>();
auto* state = cur_scope.FindVar(memory.var);
PADDLE_ENFORCE_NOT_NULL(state);
state->GetMutable<LoDTensor>()->Resize(dims);
state->GetMutable<LoDTensor>()->mutable_data<value_type>(
platform::CPUPlace());
if (step == 0) {
auto* pre_state_tensor = pre_state->GetMutable<LoDTensor>();
pre_state_tensor->Resize(boot_state.dims());
pre_state_tensor->ShareDataWith<value_type>(boot_state);
} else {
auto& pre_scope = cache_.GetScope(step - 1);
auto* state_pre = pre_scope.FindVar(memory.var);
PADDLE_ENFORCE_NOT_NULL(state_pre);
pre_state->GetMutable<LoDTensor>()->ShareDataWith<value_type>(
*state_pre->GetMutable<LoDTensor>());
}
}
}
}
void DynamicRecurrentOp::ArgCache::Init(
const rnn::ArgumentName& name, const paddle::framework::OperatorBase& op,
const paddle::framework::Scope& scope, rnn::Argument* arg) {
this->scope = &scope;
InitArgument(name, op, arg);
CacheScopes(scope, *arg);
CacheInlinks(scope, arg->inlinks);
CacheOutlinks(scope, arg->outlinks);
}
void DynamicRecurrentOp::ArgCache::InitArgument(const rnn::ArgumentName& name,
const OperatorBase& op,
rnn::Argument* arg) {
rnn::InitArgument(name, arg, op, false /*is_grad*/);
}
void DynamicRecurrentOp::ArgCache::CacheScopes(const Scope& scope,
const rnn::Argument& arg) {
auto scopes_var = scope.FindVar(arg.step_scopes);
PADDLE_ENFORCE(scopes_var != nullptr,
"the step_scopes output argument [%s] should be created first "
"by framework.",
arg.step_scopes);
this->scopes = scopes_var->GetMutable<std::vector<Scope*>>();
}
void DynamicRecurrentOp::ArgCache::CacheInlinks(
const Scope& scope, const std::vector<std::string>& names) {
for (auto name : names) {
auto* var = GetVariable(scope, name);
inlinks[name] = var;
}
}
void DynamicRecurrentOp::ArgCache::CacheOutlinks(
const Scope& scope, const std::vector<std::string>& names) {
for (auto name : names) {
auto* var = GetVariable(scope, name);
outlinks[name] = var;
}
}
Variable* DynamicRecurrentOp::ArgCache::GetVariable(const Scope& scope,
const std::string& name) {
auto* var = scope.FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var, "variable [%s] not exist in scope", name);
return var;
}
const rnn::ArgumentName DynamicRecurrentOp::kArgName{
"step_net", "step_scopes", "inlinks", "outlinks",
"memories", "pre_memories", "boot_memories"};
void DynamicRecurrentGradientOp::Run(
const Scope& scope, const platform::DeviceContext& dev_ctx) const {}
} // namespace operators
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(
dynamic_recurrent, paddle::operators::DynamicRecurrentOp,
paddle::operators::DynamicRecurrentOpProtoAndCheckerMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_TESTING
#include "gtest/gtest.h"
#endif
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor_array.h"
#include "paddle/framework/variable.h"
#include "paddle/operators/rnn/recurrent_op_utils.h"
namespace paddle {
namespace operators {
class DynamicRecurrentOp : public framework::OperatorBase {
public:
static const rnn::ArgumentName kArgName;
using value_type = float;
DynamicRecurrentOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
DynamicRecurrentOp(const DynamicRecurrentOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
// TODO(yuyang18): Implement copy ctor well.
PADDLE_THROW("Not implemented");
}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override;
/*
* Split the inputs(LoDTensors) to segments for each time step.
*/
void SplitInputs() const;
/*
* Create step-scopes to store temporary outputs in each time steps.
*/
void CreateScopes() const;
/*
* Link TensorArray steps to the corresponding variables located in
* step-scopes.
*/
void WriteStepInputs() const;
/*
* Write output of each step to the corresponding TensorArray.
*/
void WriteStepOutputs() const;
/*
* Initialize the states, each state will have a corresponding pre-state,
* which share the memory with the state in the previous time state. The
* pre-state in the first time step will be initialized with an zero tensor or
* a tensor in parent scope if is provided.
*/
void InitStates() const;
/*
* Concatenate outputs in each time step and generate a LoDTensor.
*/
void ConcatOutputs() const;
/*
* set a stepnet that is created according to a RecurrentOp's stepnet.
*/
void SetStepNet(std::unique_ptr<OperatorBase> net) {
PADDLE_ENFORCE_NOT_NULL(net);
stepnet_ = std::move(net);
}
const OperatorBase& GetStepNet() const { return *stepnet_; }
protected:
struct ArgCache {
framework::Scope const* scope;
std::vector<framework::Scope*>* scopes;
std::map<std::string, framework::Variable*> inlinks;
std::map<std::string, framework::Variable*> outlinks;
size_t num_steps{0};
void Init(const rnn::ArgumentName& name, const OperatorBase& op,
const framework::Scope& scope, rnn::Argument* arg);
framework::Scope& GetScope(size_t index) {
PADDLE_ENFORCE_LT(index, num_steps);
return *scopes->at(index);
}
private:
void InitArgument(const rnn::ArgumentName& name, const OperatorBase& op,
rnn::Argument* arg);
void CacheScopes(const framework::Scope& scope, const rnn::Argument& arg);
void CacheInlinks(const framework::Scope& scope,
const std::vector<std::string>& names);
void CacheOutlinks(const framework::Scope& scope,
const std::vector<std::string>& names);
framework::Variable* GetVariable(const framework::Scope& scope,
const std::string& name);
};
private:
std::unique_ptr<OperatorBase> stepnet_;
mutable framework::TensorArray states_;
mutable std::map<std::string, framework::TensorArray> step_inputs_;
mutable std::map<std::string, framework::TensorArray> step_outputs_;
mutable std::map<std::string, std::vector<framework::DySeqMeta>>
dy_seq_metas_;
mutable rnn::Argument arg_;
mutable ArgCache cache_;
#ifdef PADDLE_WITH_TESTING
friend class DynamicRecurrentOpTestHelper;
FRIEND_TEST(DynamicRecurrentOpTestHelper, SplitInputs);
FRIEND_TEST(DynamicRecurrentOpTestHelper, CreateCache);
FRIEND_TEST(DynamicRecurrentOpTestHelper, CreateScopes);
FRIEND_TEST(DynamicRecurrentOpTestHelper, WriteStepInputs);
FRIEND_TEST(DynamicRecurrentOpTestHelper, WriteStepOutputs);
FRIEND_TEST(DynamicRecurrentOpTestHelper, InitStates);
FRIEND_TEST(DynamicRecurrentOpTestHelper, ConcatOutputs);
#endif
};
class DynamicRecurrentGradientOp : public framework::OperatorBase {
public:
DynamicRecurrentGradientOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override;
};
} // namespace operators
} // namespace paddle
#include "paddle/operators/dynamic_recurrent_op.h"
#include <gtest/gtest.h>
#include "paddle/framework/ddim.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
using framework::Scope;
using framework::TensorArray;
using framework::LoDTensor;
using framework::Variable;
class TestOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
DEFINE_OP_CLONE_METHOD(TestOp);
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
};
void OpDescNewVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
var->add_arguments(arg_name);
}
}
// create a LoD tensor in scope with specific dims
LoDTensor* CreateVar(Scope& scope, std::string name, framework::DDim dims,
const platform::Place& place) {
auto* var = scope.NewVar(name);
auto* tensor = var->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(place);
return tensor;
}
class DynamicRecurrentOpTestHelper : public ::testing::Test {
protected:
const rnn::ArgumentName argname = DynamicRecurrentOp::kArgName;
virtual void SetUp() override {
CreateGlobalVariables();
auto op_desc = CreateOpDesc();
op = paddle::framework::OpRegistry::CreateOp(op_desc);
dop = dynamic_cast<DynamicRecurrentOp*>(op.get());
InitCacheManually();
InitStepNet();
}
framework::OpDesc CreateOpDesc() {
// create op
paddle::framework::OpDesc op_desc;
op_desc.set_type("dynamic_recurrent");
OpDescNewVar(argname.inlinks, {"in0"}, op_desc.add_inputs());
OpDescNewVar(argname.boot_memories, {"boot_mem"}, op_desc.add_inputs());
OpDescNewVar(argname.step_scopes, {"step_scopes"}, op_desc.add_outputs());
OpDescNewVar(argname.outlinks, {"out0"}, op_desc.add_outputs());
// set pre-memories
auto pre_memories = op_desc.mutable_attrs()->Add();
pre_memories->set_name(argname.pre_memories);
pre_memories->set_type(paddle::framework::AttrType::STRINGS);
auto pre_memories_item = pre_memories->add_strings();
*pre_memories_item = "mem@pre";
// set memories
auto memories = op_desc.mutable_attrs()->Add();
memories->set_name(argname.memories);
memories->set_type(paddle::framework::AttrType::STRINGS);
auto memories_item = memories->add_strings();
*memories_item = "mem";
return op_desc;
}
void CreateGlobalVariables() {
platform::CPUPlace place;
scope.NewVar("step_scopes");
CreateVar(scope, "boot_mem", framework::make_ddim({10, 20}), place);
// auto* out0 =
CreateVar(scope, "out0", framework::make_ddim({10, 20}), place);
auto* in0 = CreateVar(scope, "in0", framework::make_ddim({10, 8}), place);
// 10 instanes with 4 sentences, length is 4, 3, 2, 1 respectively.
framework::LoD in0_lod(1);
for (int x : std::vector<int>{0, 4, 7, 9, 10}) {
in0_lod[0].push_back(x);
}
in0->set_lod(in0_lod);
in0->Resize(framework::make_ddim({10, 8}));
// set the content, each sentence content is seqid.batchid
// the seqid starts from 0
int start = 0;
for (size_t seqid = 0; seqid < in0_lod.size() - 1; seqid++) {
for (size_t batchid = 0;
batchid < in0_lod[0][seqid + 1] - in0_lod[0][seqid]; batchid++) {
float v = seqid + batchid * 0.1;
for (size_t dim = 0; dim < 8; dim++) {
in0->data<float>()[start * 8 + dim] = v;
}
start++;
}
}
}
void InitCacheManually() {
dop->cache_.Init(DynamicRecurrentOp::kArgName, *dop, scope, &dop->arg_);
}
void InitStepNet() {
std::unique_ptr<framework::OperatorBase> stepnet{new NetOp};
dynamic_cast<NetOp*>(stepnet.get())
->AppendOp(std::unique_ptr<TestOp>(new TestOp(
"test", {{"inlinks", {"in0"}}, {"boot_memories", {"boot_mem"}}},
{{"outlinks", {"out0"}}, {"step_scopes", {"step_scopes"}}}, {})));
dop->SetStepNet(std::move(stepnet));
}
protected:
DynamicRecurrentOp* dop;
std::unique_ptr<framework::OperatorBase> op;
paddle::platform::CPUDeviceContext device_context;
paddle::framework::Scope scope;
};
TEST_F(DynamicRecurrentOpTestHelper, CreateCache) {
const rnn::Argument& arg = dop->arg_;
ASSERT_EQ(arg.inlinks.size(), 1UL);
ASSERT_EQ(arg.outlinks.size(), 1UL);
}
TEST_F(DynamicRecurrentOpTestHelper, SplitInputs) {
dop->SplitInputs();
auto& in0_ta = dop->step_inputs_["in0"];
ASSERT_EQ(in0_ta.size(), 4UL);
const auto& batch0 = in0_ta.Read(0);
const auto& batch1 = in0_ta.Read(1);
const auto& batch2 = in0_ta.Read(2);
const auto& batch3 = in0_ta.Read(3);
EXPECT_EQ(batch0.dims()[0], 4);
EXPECT_EQ(batch1.dims()[0], 3);
EXPECT_EQ(batch2.dims()[0], 2);
EXPECT_EQ(batch3.dims()[0], 1);
}
TEST_F(DynamicRecurrentOpTestHelper, CreateScopes) {
dop->SplitInputs();
dop->CreateScopes();
ASSERT_EQ(dop->cache_.num_steps, 4UL);
ASSERT_EQ(dop->cache_.scopes->size(), 4UL);
}
TEST_F(DynamicRecurrentOpTestHelper, WriteStepInputs) {
dop->SplitInputs();
dop->CreateScopes();
dop->WriteStepInputs();
for (size_t step = 0; step < dop->cache_.num_steps; step++) {
auto& scope = dop->cache_.GetScope(step);
for (auto name : std::vector<std::string>({"in0"})) {
ASSERT_TRUE(scope.FindVar(name) != nullptr);
}
}
}
TEST_F(DynamicRecurrentOpTestHelper, WriteStepOutputs) {
dop->SplitInputs();
dop->CreateScopes();
dop->WriteStepInputs();
dop->WriteStepOutputs();
for (size_t step = 0; step < dop->cache_.num_steps; step++) {
auto& scope = dop->cache_.GetScope(step);
for (auto name : std::vector<std::string>({"out0"})) {
ASSERT_TRUE(scope.FindVar(name));
}
}
}
TEST_F(DynamicRecurrentOpTestHelper, ConcatOutputs) {
// Let's leave this test to python unittest.
}
TEST_F(DynamicRecurrentOpTestHelper, InitStates) {
dop->SplitInputs();
dop->CreateScopes();
dop->WriteStepInputs();
dop->WriteStepOutputs();
dop->InitStates();
for (size_t step = 0; step < dop->cache_.num_steps; step++) {
auto& scope = dop->cache_.GetScope(step);
auto state = scope.FindVar("mem");
ASSERT_TRUE(state != nullptr);
auto* pre_state = scope.FindVar("mem@pre");
ASSERT_TRUE(pre_state != nullptr);
auto* boot_state = scope.FindVar("boot_mem");
ASSERT_TRUE(boot_state != nullptr);
if (step == 0) {
// check pre_state is a reference of boot_state
ASSERT_EQ(boot_state->Get<LoDTensor>().data<float>(),
pre_state->Get<LoDTensor>().data<float>());
}
}
}
} // operators
} // namespace paddle
......@@ -74,8 +74,8 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Mask"), "Input(Mask) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mask"), "Input(Mask) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
......
......@@ -136,7 +136,7 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
#endif // PADDLE_ONLY_CPU
#endif
} // namespace platform
} // namespace paddle
......@@ -41,7 +41,7 @@ limitations under the License. */
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#endif // PADDLE_ONLY_CPU
#endif
namespace paddle {
namespace platform {
......
......@@ -63,4 +63,4 @@ void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device,
} // namespace platform
} // namespace paddle
#endif // PADDLE_ONLY_CPU
#endif
......@@ -166,7 +166,9 @@ void BindVarDsec(py::module &m) {
.def("set_shape", &VarDescBind::SetShape)
.def("set_data_type", &VarDescBind::SetDataType)
.def("shape", &VarDescBind::Shape, py::return_value_policy::reference)
.def("data_type", &VarDescBind::GetDataType);
.def("data_type", &VarDescBind::GetDataType)
.def("lod_level", &VarDescBind::GetLodLevel)
.def("set_lod_level", &VarDescBind::SetLoDLevel);
}
void BindOpDesc(py::module &m) {
......@@ -196,7 +198,8 @@ void BindOpDesc(py::module &m) {
.def("set_attr", &OpDescBind::SetAttr)
.def("attr", &OpDescBind::GetAttr)
.def("set_block_attr", &OpDescBind::SetBlockAttr)
.def("get_block_attr", &OpDescBind::GetBlockAttr);
.def("get_block_attr", &OpDescBind::GetBlockAttr)
.def("infer_shape", &OpDescBind::InferShape);
}
} // namespace pybind
......
......@@ -231,21 +231,6 @@ All parameter, weight, gradient are variables in Paddle.
desc.InitializationErrorString());
return OpRegistry::CreateOp(desc);
})
.def_static("infer_shape",
[](OpDescBind &op_desc, BlockDescBind &block) {
auto op = OpRegistry::CreateOp(*op_desc.Proto());
auto *op_with_kernel =
dynamic_cast<OperatorWithKernel *>(op.get());
if (op_with_kernel != nullptr) {
auto ctx = CompileTimeInferShapeContext(op_desc, block);
op_with_kernel->InferShape(&ctx);
} else {
PADDLE_THROW(
"OP(%s) is not type of OperatorWithKernel, "
"should not call this function",
op_desc.Type());
}
})
.def("backward",
[](const OperatorBase &forwardOp,
const std::unordered_set<std::string> &no_grad_vars) {
......
import paddle.v2.framework.core as core
import collections
import numpy as np
import copy
__all__ = ['Block', 'Variable', 'Program', 'Operator']
class Variable(object):
def __init__(self,
block,
name=None,
shape=None,
dtype=None,
lod_level=None,
**kwargs):
self.block = block
if name is None:
name = Variable._unique_var_name_()
try:
self.desc = self.block.desc.var(name)
is_new_var = False
except core.EnforceNotMet:
self.desc = self.block.desc.new_var(name)
is_new_var = True
if shape is not None:
if is_new_var:
self.desc.set_shape(shape)
else:
old_shape = self.shape
shape = tuple(shape)
if shape != old_shape:
raise ValueError(
"Variable {0} has been created before. the previous "
"shape is {1}; the new shape is {2}. They are not "
"matched.".format(self.name, old_shape, shape))
if dtype is not None:
if not isinstance(dtype, core.DataType):
dtype = Variable._convert_np_dtype_to_dtype_(dtype)
if is_new_var:
self.desc.set_data_type(dtype)
else:
old_dtype = self.data_type()
if dtype != old_shape:
raise ValueError("Variable {0} has been created before. "
"The previous data type is {1}; the new "
"data type is {2}. They are not "
"matched.".format(self.name, old_dtype,
dtype))
if lod_level is not None:
if is_new_var:
self.desc.set_lod_level(lod_level)
else:
if lod_level != self.lod_level:
raise ValueError("Variable {0} has been created before. "
"The previous lod_level is {1}; the new "
"lod_level is {2}. They are not "
"matched".format(self.name, self.lod_level,
lod_level))
self.block.vars[name] = self
self.op = None
@property
def name(self):
return self.desc.name()
@property
def shape(self):
# convert to tuple, make it as same as numpy API.
return tuple(self.desc.shape())
@property
def data_type(self):
return self.desc.data_type()
@property
def lod_level(self):
return self.desc.lod_level()
@staticmethod
def _unique_var_name_():
uid = core.unique_integer() # unique during whole process.
return "_generated_var_%d" % uid
@staticmethod
def _convert_np_dtype_to_dtype_(np_dtype):
dtype = np.dtype(np_dtype)
if dtype == np.float32:
return core.DataType.FP32
elif dtype == np.float64:
return core.DataType.FP64
elif dtype == np.float16:
return core.DataType.FP16
elif dtype == np.int32:
return core.DataType.INT32
elif dtype == np.int16:
return core.DataType.INT16
elif dtype == np.int64:
return core.DataType.INT64
elif dtype == np.bool:
return core.DataType.BOOL
else:
raise ValueError("Not supported numpy dtype " + str(dtype))
class Operator(object):
def __init__(self,
block,
desc,
type=None,
inputs=None,
outputs=None,
attrs=None):
self.block = block
self.desc = desc
if type is not None:
# TODO.
pass
if inputs is not None:
# TODO
pass
if outputs is not None:
# TODO
pass
if attrs is not None:
# TODO
pass
# TODO: Getters
class Block(object):
def __init__(self, program, idx):
self.desc = program.desc.block(idx)
self.vars = dict() # var_name --> var
self.ops = collections.deque() # operator list
self.program = program
@property
def parent_idx(self):
return self.desc.parent
@property
def idx(self):
return self.desc.id
def create_var(self, *args, **kwargs):
return Variable(self, *args, **kwargs)
def create_parameter(self, *args, **kwargs):
global_block = self.program.global_block()
return Parameter(global_block, *args, **kwargs)
def append_op(self, *args, **kwargs):
op_desc = self.desc.append_op()
op = Operator(self, op_desc, *args, **kwargs)
self.ops.append(op)
return op
def prepend_op(self, *args, **kwargs):
op_desc = self.desc.prepend_op()
op = Operator(self, op_desc, *args, **kwargs)
self.ops.appendleft(op)
return op
class Program(object):
@classmethod
def instance(cls):
# From https://stackoverflow.com/questions/8212053
# Making Program as a Singleton class.
if not hasattr(cls, '_instance'):
cls._instance = cls()
return cls._instance
def __init__(self):
assert not hasattr(self.__class__,
'_instance'), 'Do not call constructor directly!'
self.desc = core.ProgramDesc.instance()
self.blocks = [Block(self, 0)]
self.current_block_idx = 0
def global_block(self):
return self.blocks[0]
def current_block(self):
return self.blocks[self.current_block_idx]
def create_block(self):
new_block_idx = len(self.blocks)
self.desc.append_block(self.current_block().desc)
self.current_block_idx = new_block_idx
self.blocks.append(Block(self, self.current_block_idx))
return self.current_block()
def rollback(self):
self.current_block_idx = self.current_block().parent_idx
class Parameter(Variable):
def __init__(self, block, shape, dtype, **kwargs):
if shape is None or dtype is None:
raise ValueError("Parameter must set shape and dtype")
if len(shape) == 0:
raise ValueError("Parameter shape cannot be empty")
for each in shape:
if each < 0:
raise ValueError("Parameter shape should not be related with "
"batch-size")
Variable.__init__(self, block, shape=shape, dtype=dtype, **kwargs)
self.trainable = kwargs.get('trainable', True)
self.init_attr = kwargs.get('initialize_attr', {
'type': 'uniform_random',
'min': -1.0,
'max': 1.0
})
self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0})
self._append_initialize_ops_()
def _append_initialize_ops_(self):
attr = copy.deepcopy(self.init_attr)
op_type = attr.pop('type', None)
block = self.block
assert isinstance(block, Block)
shape = self.shape
attr['dims'] = shape
attr['data_type'] = int(self.data_type)
op = block.prepend_op(
type=op_type, inputs=None, outputs={'Out': [self]}, attrs=attr)
self.op = op
# program is a global instance.
g_program = Program.instance()
......@@ -33,6 +33,21 @@ class TestSigmoid(OpTest):
self.check_grad(['X'], 'Y', max_relative_error=0.008)
class TestLogSigmoid(OpTest):
def setUp(self):
self.op_type = "logsigmoid"
self.inputs = {
'X': np.random.uniform(-1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': np.log(1 / (1 + np.exp(-self.inputs['X'])))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.008)
class TestTanh(OpTest):
def setUp(self):
self.op_type = "tanh"
......@@ -63,6 +78,26 @@ class TestTanhShrink(OpTest):
self.check_grad(['X'], 'Y', max_relative_error=0.008)
class TestSoftShrink(OpTest):
def setUp(self):
self.op_type = "softshrink"
lambda_val = 0.1
self.attrs = {'lambda': lambda_val}
self.inputs = {
'X': np.random.uniform(0.25, 10, [4, 4]).astype("float32")
}
y = np.copy(self.inputs['X'])
y = (y < -lambda_val) * (y + lambda_val) + (y > lambda_val) * (
y - lambda_val)
self.outputs = {'Y': y}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
class TestSqrt(OpTest):
def setUp(self):
self.op_type = "sqrt"
......
import unittest
import numpy as np
from op_test import OpTest
def conv_shift_forward(x, y):
out = np.zeros_like(x)
M = x.shape[1]
N = y.shape[1]
y_half_width = (N - 1) / 2
for i in xrange(M):
for j in xrange(N):
out[:, i] += x[:, (i + j + M - y_half_width) % M] * y[:, j]
return out
class TestConvShiftOp(OpTest):
def setUp(self):
self.op_type = "conv_shift"
batch_size = 4
x_dim = 17
y_dim = 3 # must be odd and <= x_dim
x = np.random.random((batch_size, x_dim)).astype("float32")
y = np.random.random((batch_size, y_dim)).astype("float32")
self.inputs = {'X': x, 'Y': y}
out = conv_shift_forward(x, y)
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.05)
def test_check_grad_ignore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=0.05, no_grad_set=set("X"))
def test_check_grad_ignore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Y'))
if __name__ == '__main__':
unittest.main()
import unittest
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
class TestInferShape(unittest.TestCase):
......@@ -26,7 +26,7 @@ class TestInferShape(unittest.TestCase):
sum_op_desc.set_input("X", ["x1", "x2"])
sum_op_desc.set_output("Out", ["out"])
core.Operator.infer_shape(sum_op_desc, block)
sum_op_desc.infer_shape(block)
self.assertEqual(out.shape(), shape)
def test_mul_op(self):
......@@ -55,7 +55,7 @@ class TestInferShape(unittest.TestCase):
mul_op_desc.set_attr("x_num_col_dims", 1)
mul_op_desc.set_attr("y_num_col_dims", 1)
core.Operator.infer_shape(mul_op_desc, block)
mul_op_desc.infer_shape(block)
self.assertEqual(out.shape(), [x_shape[0], y_shape[1]])
......
import unittest
from paddle.v2.framework.graph import g_program
import paddle.v2.framework.core as core
class TestParameter(unittest.TestCase):
def test_param(self):
b = g_program.create_block()
param = b.create_parameter(
name='fc.w',
shape=[784, 100],
dtype='float32',
initialize_attr={
'type': 'uniform_random',
'seed': 13,
'min': -5.0,
'max': 5.0
})
self.assertIsNotNone(param)
self.assertEqual('fc.w', param.name)
self.assertEqual((784, 100), param.shape)
self.assertEqual(core.DataType.FP32, param.data_type)
self.assertEqual(0, param.block.idx)
if __name__ == '__main__':
unittest.main()
import unittest
from paddle.v2.framework.graph import g_program
class TestProgram(unittest.TestCase):
def test_program(self):
b = g_program.current_block()
self.assertEqual(-1, b.parent_idx)
self.assertEqual(0, b.idx)
b = g_program.create_block()
self.assertEqual(1, b.idx)
self.assertEqual(0, b.parent_idx)
b = g_program.create_block()
self.assertEqual(2, b.idx)
self.assertEqual(1, b.parent_idx)
g_program.rollback()
b = g_program.current_block()
self.assertEqual(1, b.idx)
self.assertEqual(0, b.parent_idx)
b = g_program.create_block()
self.assertEqual(3, b.idx)
self.assertEqual(1, b.parent_idx)
g_program.rollback()
b = g_program.current_block()
self.assertEqual(1, b.idx)
self.assertEqual(0, b.parent_idx)
if __name__ == '__main__':
unittest.main()
import unittest
from paddle.v2.framework.graph import Variable, g_program
import paddle.v2.framework.core as core
import numpy as np
class TestVariable(unittest.TestCase):
def test_np_dtype_convert(self):
DT = core.DataType
convert = Variable._convert_np_dtype_to_dtype_
self.assertEqual(DT.FP32, convert(np.float32))
self.assertEqual(DT.FP16, convert("float16"))
self.assertEqual(DT.FP64, convert("float64"))
self.assertEqual(DT.INT32, convert("int32"))
self.assertEqual(DT.INT16, convert("int16"))
self.assertEqual(DT.INT64, convert("int64"))
self.assertEqual(DT.BOOL, convert("bool"))
self.assertRaises(ValueError, lambda: convert("int8"))
def test_var(self):
b = g_program.current_block()
w = b.create_var(
dtype="float64", shape=[784, 100], lod_level=0, name="fc.w")
self.assertEqual(core.DataType.FP64, w.data_type)
self.assertEqual((784, 100), w.shape)
self.assertEqual("fc.w", w.name)
self.assertEqual(0, w.lod_level)
w = b.create_var(name='fc.w')
self.assertEqual(core.DataType.FP64, w.data_type)
self.assertEqual((784, 100), w.shape)
self.assertEqual("fc.w", w.name)
self.assertEqual(0, w.lod_level)
self.assertRaises(ValueError,
lambda: b.create_var(name="fc.w", shape=(24, 100)))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册