diff --git a/CMakeLists.txt b/CMakeLists.txt index 0cc4e4768280d9e0f368b15857b3952fbe312aaf..264420ad830ed39b38f1918951d8d66c84fd5ee9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -127,9 +127,9 @@ include(external/warpctc) # download, build, install warpctc include(external/any) # download libn::any include(external/eigen) # download eigen3 include(external/pybind11) # download pybind11 +include(external/nccl) include(cudnn) # set cudnn libraries, must before configure -include(nccl) # set nccl libraries include(configure) # add paddle env configuration include(generic) # simplify cmake module include(package) # set paddle packages diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 00dc33514133c9f5a789a4db765b33fb525421ec..24ddb24399dabeec9b8e5faf36be3eb21f420111 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -68,13 +68,6 @@ else() if(NOT CUDNN_FOUND) message(FATAL_ERROR "Paddle needs cudnn to compile") endif() - if (NOT NCCL_INCLUDE_DIR) - message(FATAL_ERROR "Paddle needs nccl header to compile") - endif() - if (NOT WITH_DSO AND NOT NCCL_LIBRARY) - message(FATAL_ERROR "Paddle needs nccl libraries when WITH_DSO=OFF") - endif() - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} "-Xcompiler ${SIMD_FLAG}") diff --git a/cmake/external/nccl.cmake b/cmake/external/nccl.cmake new file mode 100644 index 0000000000000000000000000000000000000000..dfbbed58c9ed7cc57809b3d33a29ce26a35d75a2 --- /dev/null +++ b/cmake/external/nccl.cmake @@ -0,0 +1,50 @@ +INCLUDE(ExternalProject) + +SET(NCCL_SOURCE_DIR ${THIRD_PARTY_PATH}/nccl) + +INCLUDE_DIRECTORIES(${NCCL_SOURCE_DIR}/src/extern_nccl/src) + + +if(WITH_DSO) + # If we use DSO, we do not build nccl, just download the dependencies + set(NCCL_BUILD_COMMAND "") + set(NCCL_INSTALL_COMMAND "") + set(NCCL_INSTALL_DIR "") +else() + # otherwise, we build nccl and link it. + set(NCCL_BUILD_COMMAND "make -j 8") + set(NCCL_INSTALL_COMMAND "make install") + SET(NCCL_INSTALL_DIR ${THIRD_PARTY_PATH}/install/nccl) +endif() + +ExternalProject_Add( + extern_nccl + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "https://github.com/NVIDIA/nccl.git" + GIT_TAG "v1.3.4-1" + PREFIX "${NCCL_SOURCE_DIR}" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "${NCCL_BUILD_COMMAND}" + INSTALL_COMMAND "${NCCL_INSTALL_COMMAND}" + INSTALL_DIR "${NCCL_INSTALL_DIR}" + TEST_COMMAND "" +) + +if (WITH_DSO) + if (${CMAKE_VERSION} VERSION_LESS "3.3.0") + set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/lib_any_dummy.c) + file(WRITE ${dummyfile} "const char * dummy_any = \"${dummyfile}\";") + add_library(nccl STATIC ${dummyfile}) + else() + add_library(nccl INTERFACE) + endif() +else() + ADD_LIBRARY(nccl STATIC IMPORTED GLOBAL) + SET_PROPERTY(TARGET nccl PROPERTY IMPORTED_LOCATION + ${NCCL_INSTALL_DIR}/lib/libnccl.a) +endif() + +add_dependencies(nccl extern_nccl) + +LIST(APPEND external_project_dependencies nccl) diff --git a/cmake/nccl.cmake b/cmake/nccl.cmake deleted file mode 100644 index 872b4d56fb2a0a30809b937ecd7bae1656a518d6..0000000000000000000000000000000000000000 --- a/cmake/nccl.cmake +++ /dev/null @@ -1,30 +0,0 @@ -if (NOT WITH_GPU) - return () -endif() - -set(NCCL_ROOT "/usr" CACHE PATH "CUDNN ROOT") -find_path(NCCL_INCLUDE_DIR nccl.h PATHS - ${NCCL_ROOT} ${NCCL_ROOT}/include - $ENV{NCCL_ROOT} $ENV{NCCL_ROOT}/include ${CUDA_TOOLKIT_INCLUDE} - NO_DEFAULT_PATH) - -get_filename_component(__libpath_hist ${CUDA_CUDART_LIBRARY} PATH) - -set(TARGET_ARCH "x86_64") -if(NOT ${CMAKE_SYSTEM_PROCESSOR}) - set(TARGET_ARCH ${CMAKE_SYSTEM_PROCESSOR}) -endif() - -list(APPEND NCCL_CHECK_LIBRARY_DIRS - ${NCCL_ROOT} - ${NCCL_ROOT}/lib64 - ${NCCL_ROOT}/lib - ${NCCL_ROOT}/lib/${TARGET_ARCH}-linux-gnu - $ENV{NCCL_ROOT} - $ENV{NCCL_ROOT}/lib64 - $ENV{NCCL_ROOT}/lib - /usr/lib) -find_library(NCCL_LIBRARY NAMES libnccl.so libnccl.dylib # libcudnn_static.a - PATHS ${NCCL_CHECK_LIBRARY_DIRS} ${NCCL_INCLUDE_DIR} ${__libpath_hist} - NO_DEFAULT_PATH - DOC "Path to nccl library.") diff --git a/doc/design/graph_survey.md b/doc/design/graph_survey.md new file mode 100644 index 0000000000000000000000000000000000000000..6c6db08f463ae0a2b94fc4546f123a1d7c151870 --- /dev/null +++ b/doc/design/graph_survey.md @@ -0,0 +1,232 @@ +## Survey on Graph + +Neural network framework often provides symbolic API for users to write network topology conveniently. This doc manily focus on symbolic API in most popular neural network frameworks, and try to find out how to parse symbolic configuration to a portable file, such as protobuf or json. + +### Mxnet + +The core concept of symbolic API is `Symbol`. Mxnet implements `Symbol` class in C++, and export to Python using C-API. Please refer to the comments in Mxnet: + + +`Symbol` is help class used to represent the operator node in Graph. +`Symbol` acts as an interface for building graphs from different components like Variable, Functor and Group. `Symbol` is also exported to python front-end (while Graph is not) to enable quick test and deployment. Conceptually, symbol is the final operation of a graph and thus including all the information required (the graph) to evaluate its output value. + + +A simple network topology wrote by Symbol is as follows: + +```python +def get_symbol(num_classes=10, **kwargs): + data = mx.symbol.Variable('data') + data = mx.symbol.Flatten(data=data) + fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) + act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") + fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64) + act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") + fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes) + mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax') + return mlp +``` + + + +Varible here is actually a Symbol. Every basic Symbol will correspond to one Node, and every Node has its own NodeAttr. There is a op field in NodeAttr class, when a Symbol represents Variable(often input data), the op field is null. + +Symbol contains a data member, std::vector outputs, and NodeEntry cantains a poniter to Node. We can follow the Node pointer to get all the Graph. + +And Symbol can be saved to a Json file. + +Here is a detailed example: + +``` +>>> import mxnet as mx +>>> data = mx.symbol.Variable('data') +>>> print data.debug_str() +Variable:data + +>>> data = mx.symbol.Flatten(data=data) +>>> print data.debug_str() +Symbol Outputs: + output[0]=flatten0(0) +Variable:data +-------------------- +Op:Flatten, Name=flatten0 +Inputs: + arg[0]=data(0) version=0 + +>>> fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) +>>> print fc1.debug_str() +Symbol Outputs: + output[0]=fc1(0) +Variable:data +-------------------- +Op:Flatten, Name=flatten0 +Inputs: + arg[0]=data(0) version=0 +Variable:fc1_weight +Variable:fc1_bias +-------------------- +Op:FullyConnected, Name=fc1 +Inputs: + arg[0]=flatten0(0) + arg[1]=fc1_weight(0) version=0 + arg[2]=fc1_bias(0) version=0 +Attrs: + num_hidden=128 + +``` + + +### TensorFlow + + +The core concept of symbolic API is `Tensor`. Tensorflow defines `Tensor` in Python. Please refer to the comments in TensorFlow: + +A `Tensor` is a symbolic handle to one of the outputs of an `Operation`. It does not hold the values of that operation's output, but instead provides a means of computing those values in a TensorFlow [Session](https://www.tensorflow.org/api_docs/python/tf/Session). + +A simple example is as follows: + +```python + # Build a dataflow graph. + c = tf.constant([[1.0, 2.0], [3.0, 4.0]]) + d = tf.constant([[1.0, 1.0], [0.0, 1.0]]) + e = tf.matmul(c, d) + + # Construct a `Session` to execute the graph. + sess = tf.Session() + + # Execute the graph and store the value that `e` represents in `result`. + result = sess.run(e) +``` + + +The main method of `Tensor` is as follows: + + +```python +@property +def op(self): + """The `Operation` that produces this tensor as an output.""" + return self._op + +@property +def dtype(self): + """The `DType` of elements in this tensor.""" + return self._dtype + +@property +def graph(self): + """The `Graph` that contains this tensor.""" + return self._op.graph + +@property +def name(self): + """The string name of this tensor.""" + if not self._op.name: + raise ValueError("Operation was not named: %s" % self._op) + return "%s:%d" % (self._op.name, self._value_index) + +@property +def device(self): + """The name of the device on which this tensor will be produced, or None.""" + return self._op.device +``` + + +Tensor can be taken as target to run by session. Tensor contains all the information of Graph, and tracks data dependency. + + +Here is a detailed example: + + +``` +>>> import tensorflow as tf +>>> c = tf.constant([[1.0, 2.0], [3.0, 4.0]]) +>>> print c.graph + +>>> d = tf.constant([[1.0, 1.0], [0.0, 1.0]]) +>>> print d.graph + +>>> e = tf.matmul(c, d) +>>> print e.graph + +``` + +### Dynet + + +The core concept of symbolic API is `Expression`, and Dynet defines `Expression` class in C++. + + +A simple example is as follows: + +```cpp +ComputationGraph cg; +Expression W = parameter(cg, pW); + +Expression in = input(cg, xs[i]); +Expression label = input(cg, ys[i]); +Expression pred = W * in; +Expression loss = square(pred - label); +``` + +The input data and parameter are also represented by Expression. Every basci Expression corresponds to a Node. And input data is also a Node. + +Expression has a data member ComputationGraph, and ComputationGraph will be modified in users' configuring process. Expression can be a running target, beacuse Expression contains all dependency. + + +Here is a detailed example: + +write topology in C++ + +``` +ComputationGraph cg; +Expression W = parameter(cg, pW); +cg.print_graphviz(); + +Expression pred = W * xs[i]; +cg.print_graphviz(); + +Expression loss = square(pred - ys[i]); +cg.print_graphviz(); +``` + +compile and print + +``` +# first print +digraph G { + rankdir=LR; + nodesep=.05; + N0 [label="v0 = parameters({1}) @ 0x7ffe4de00110"]; +} +# second print +digraph G { + rankdir=LR; + nodesep=.05; + N0 [label="v0 = parameters({1}) @ 0x7ffe4de00110"]; + N1 [label="v1 = v0 * -0.98"]; + N0 -> N1; +} +# third print +digraph G { + rankdir=LR; + nodesep=.05; + N0 [label="v0 = parameters({1}) @ 0x7ffe4de00110"]; + N1 [label="v1 = v0 * -0.98"]; + N0 -> N1; + N2 [label="v2 = -1.88387 - v1"]; + N1 -> N2; + N3 [label="v3 = -v2"]; + N2 -> N3; + N4 [label="v4 = square(v3)"]; + N3 -> N4; +} +``` + +### Conclusion + + +Actually, Symbol/Tensor/Expression in Mxnet/TensorFlow/Dynet are the same level concepts. We use a unified name Expression here, this level concept has following features: + +- Users wirte topoloy with symbolic API, and all return value is Expression, including input data and parameter. +- Expression corresponds with a global Graph, and Expression can also be composed. +- Expression tracks all dependency and can be taken as a run target diff --git a/paddle/operators/conv2d_op.h b/paddle/operators/conv2d_op.h index f629728f68d65ce81b4910cae7f89ab06d6d94b8..0621389a79eee6b5e75b1eab309b49f8aa4a97ca 100644 --- a/paddle/operators/conv2d_op.h +++ b/paddle/operators/conv2d_op.h @@ -114,7 +114,7 @@ class GemmConv2DKernel : public framework::OpKernel { // im2col Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); im2col(context.device_context(), in_slice, col, strides[0], strides[1], - paddings[0], paddings[1]); + paddings[0], paddings[0], paddings[1], paddings[1]); // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); @@ -213,7 +213,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel { Tensor in_grad_slice = in_grad_batch.Slice(g * in_step, (g + 1) * in_step); col2im(context.device_context(), in_grad_slice, col, strides[0], - strides[1], paddings[0], paddings[1]); + strides[1], paddings[0], paddings[0], paddings[1], + paddings[1]); } } } @@ -235,7 +236,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel { out_grad_batch.Slice(g * out_step, (g + 1) * out_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); im2col(context.device_context(), in_slice, col, strides[0], - strides[1], paddings[0], paddings[1]); + strides[1], paddings[0], paddings[0], paddings[1], + paddings[1]); // gemm Tensor filter_grad_slice = diff --git a/paddle/operators/conv2dtranspose_op.cc b/paddle/operators/conv2dtranspose_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c1b231906e2f172b6f9cee55f850d1a5ec6c3221 --- /dev/null +++ b/paddle/operators/conv2dtranspose_op.cc @@ -0,0 +1,107 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv2dtranspose_op.h" + +namespace paddle { +namespace operators { + +void Conv2DTransposeOp::InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of Conv2DTransposeOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Filter"), + "Input(Filter) of Conv2DTransposeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output(Output) of Conv2DTransposeOp should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + + for (size_t i = 0; i < paddings.size(); ++i) { + PADDLE_ENFORCE_EQ(paddings[i], 0, + "No Padding allowed in conv transpose op."); + } + + PADDLE_ENFORCE_EQ(in_dims.size(), 4, + "Conv2DTransposeOp input should be 4-D tensor."); + PADDLE_ENFORCE_EQ(filter_dims.size(), 4, + "Conv2DTransposeOp filter should be 4-D tensor."); + PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], + "input and kernel input dimension should be equal."); + + auto output_height = (in_dims[2] - 1) * strides[0] + filter_dims[2]; + auto output_width = (in_dims[3] - 1) * strides[1] + filter_dims[3]; + ctx->SetOutputDim("Output", + {in_dims[0], filter_dims[1], output_height, output_width}); +} + +Conv2DTransposeOpMaker::Conv2DTransposeOpMaker( + framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Input", + "(Tensor) The input tensor of convolution transpose operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of input channels, H and W is the height and width of image."); + AddInput("Filter", + "(Tensor) The filter tensor of convolution transpose operator." + "The format of the filter tensor is CMHW, where C is the number of " + "output image channels, M is the number of input image channels, " + "H and W is height and width of filter. " + "We enforce groups number == 1 and padding == 0 in " + "convolution transpose Scenario."); + AddOutput("Output", + "(Tensor) The output tensor of convolution transpose operator." + "The format of output tensor is also NCHW."); + AddAttr>("strides", + "strides of convolution transpose operator.") + .SetDefault({1, 1}); + AddAttr>("paddings", + "paddings of convolution transpose operator.") + .SetDefault({0, 0}); + AddComment(R"DOC( +The convolution transpose operation calculates the output based on the input, filter +and strides, paddings, groups parameters. The size of each dimension of the +parameters is checked in the infer-shape. +)DOC"); +} + +void Conv2DTransposeOpGrad::InferShape( + framework::InferShapeContext* ctx) const { + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + if (ctx->HasOutput(framework::GradVarName("Input"))) { + ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); + } + if (ctx->HasOutput(framework::GradVarName("Filter"))) { + ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); + } +} + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(conv2dtranspose, ops::Conv2DTransposeOp, + ops::Conv2DTransposeOpMaker, conv2dtranspose_grad, + ops::Conv2DTransposeOpGrad); + +REGISTER_OP_CPU_KERNEL( + conv2dtranspose, + ops::GemmConv2DTransposeKernel); +REGISTER_OP_CPU_KERNEL( + conv2dtranspose_grad, + ops::GemmConv2DTransposeGradKernel); diff --git a/paddle/operators/conv2dtranspose_op.cu b/paddle/operators/conv2dtranspose_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..761bc1959e69be94f43571728e6b92a322558b99 --- /dev/null +++ b/paddle/operators/conv2dtranspose_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv2dtranspose_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + conv2dtranspose, + ops::GemmConv2DTransposeKernel); +REGISTER_OP_GPU_KERNEL( + conv2dtranspose_grad, + ops::GemmConv2DTransposeGradKernel); diff --git a/paddle/operators/conv2dtranspose_op.h b/paddle/operators/conv2dtranspose_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8c70b3dcec1e26ab3d8a42d88040764c643b5ae6 --- /dev/null +++ b/paddle/operators/conv2dtranspose_op.h @@ -0,0 +1,254 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/im2col.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DDim = framework::DDim; + +// Define Op classes in .h file so that other conv transpose +// operator implementations can reuse the code. +class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Conv2DTransposeOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker); +}; + +class Conv2DTransposeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override; +}; + +class Conv2DTransposeOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override; +}; + +template +class GemmConv2DTransposeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + // The filter will be reshaped, so it should not be constant pointer + Tensor filter = *context.Input("Filter"); + + Tensor* output = context.Output("Output"); + + std::vector strides = context.Attr>("strides"); + + // TODO(Zhuoyuan): Paddings can be added in future. + // groups will alway be disabled in conv2dtranspose. + + const int batch_size = input->dims()[0]; + const int m = input->dims()[1]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + + const int k_h = filter.dims()[2]; + const int k_w = filter.dims()[3]; + + const int c = output->dims()[1]; // output channels + const int o_h = output->dims()[2]; + const int o_w = output->dims()[3]; + + paddle::operators::math::Col2ImFunctor< + paddle::operators::math::ColFormat::kCFO, Place, T> + col2im; + + // use col_shape in the im2col and col2im calculation + DDim col_shape = {c, k_h, k_w, h, w}; + + // use col_matrix_shape in the gemm calculation + DDim col_matrix_shape = {c * k_h * k_w, h * w}; + + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix; + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + + DDim output_shape = {c, o_h, o_w}; + DDim input_matrix_shape = {m, h * w}; + + DDim filter_matrix_shape = {m, c * k_h * k_w}; + filter.Resize(filter_matrix_shape); + + // convolution transpose: gemm + col2im (similar to conv-backward on input) + + output->mutable_data(context.GetPlace()); + auto t = framework::EigenVector::Flatten(*output); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + // batch with size (M, h * w) + Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); + // filter size: (M, c * k_h * k_w) + + // output size: (c, o_h, o_w) + Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); + + // col_matrix = filter * input_batch + // of shape (c * k_h * k_w, h * w) + math::matmul(context.device_context(), filter, true, + input_batch, false, T(1.0), &col_matrix, T(0.0)); + col2im(context.device_context(), output_batch, col, strides[0], + strides[1], 0, 0, 0, 0); + } + } +}; + +template +class GemmConv2DTransposeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + const Tensor* output_grad = + context.Input(framework::GradVarName("Output")); + + // For filter, we do not use const pointer b/c we will do reshape, + // but we should avoid modifying its value. + Tensor filter = *context.Input("Filter"); + + Tensor* input_grad = + context.Output(framework::GradVarName("Input")); + Tensor* filter_grad = + context.Output(framework::GradVarName("Filter")); + + std::vector strides = context.Attr>("strides"); + // Actually, no paddings and groups allowed in conv transpose. + std::vector paddings = context.Attr>("paddings"); + + const int batch_size = input->dims()[0]; + const int m = input->dims()[1]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + + const int k_h = filter.dims()[2]; + const int k_w = filter.dims()[3]; + + const int c = output_grad->dims()[1]; // output channels + const int o_h = output_grad->dims()[2]; + const int o_w = output_grad->dims()[3]; + + // Only im2col functor required for bp to get to the right shape + paddle::operators::math::Im2ColFunctor< + paddle::operators::math::ColFormat::kCFO, Place, T> + im2col; + + // use col_shape in the im2col and col2im calculation + DDim col_shape = {c, k_h, k_w, h, w}; + + // use col_matrix_shape in the gemm calculation + DDim col_matrix_shape_f = {c * h * w, k_h * k_w}; + + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + + DDim output_shape = {c, o_h, o_w}; + DDim input_matrix_shape = {m, h * w}; + + DDim filter_matrix_shape = {m, c * k_h * k_w}; + filter.Resize(filter_matrix_shape); + + // convolution transpose grad on input: + // im2col + gemm (similar to conv-forward) + // input need to compute gradient + if (input_grad) { + Tensor col_matrix; + col_matrix.ShareDataWith(col); + DDim col_matrix_shape = {c * k_h * k_w, h * w}; + col_matrix.Resize(col_matrix_shape); + + input_grad->mutable_data(context.GetPlace()); + auto t = framework::EigenVector::Flatten(*input_grad); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + // batch with size (c, o_h * o_w) + Tensor output_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_shape); + // filter of size (m, c * k_h * k_w) + + // batch with size (m, h, w) + Tensor input_grad_batch = + input_grad->Slice(i, i + 1).Resize(input_matrix_shape); + + // im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w) + im2col(context.device_context(), output_grad_batch, col, strides[0], + strides[1], paddings[0], paddings[0], paddings[1], paddings[1]); + + // gemm: dx = filter * dy + // (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h) + math::matmul(context.device_context(), filter, false, + col_matrix, false, T(1.0), &input_grad_batch, + T(0.0)); + } + } + + // filter gradient required + if (filter_grad) { + Tensor col_matrix_f; + col_matrix_f.ShareDataWith(col); + DDim col_matrix_shape_f = {c * h * w, k_h * k_w}; + col_matrix_f.Resize(col_matrix_shape_f); + + filter_grad->mutable_data(context.GetPlace()); + Tensor filter_grad_ = *filter_grad; + filter_grad_.Resize(filter_matrix_shape); + auto t = framework::EigenVector::Flatten(filter_grad_); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; ++i) { + // batch with size (c, o_h, o_w) + Tensor output_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_shape); + // input batch + Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); + + // im2col: (c * h * w, k_h * k_w) + im2col(context.device_context(), output_grad_batch, col, strides[0], + strides[1], paddings[0], paddings[0], paddings[1], paddings[1]); + + // gemm: d_filter = x * y_grad^T + // (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h) + math::matmul(context.device_context(), in_batch, false, + col_matrix_f, true, T(1.0), &filter_grad_, + T(1.0)); + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc index c1b3d66bac4c703ce78b247aadc2975bb146b5b0..c35d7d49e31f6ca11e2b37a455af430aac50a232 100644 --- a/paddle/operators/fetch_op.cc +++ b/paddle/operators/fetch_op.cc @@ -52,6 +52,7 @@ class FetchOp : public framework::OperatorBase { // FIXME(yuyang18): Should we assume the fetch operator always generate // CPU outputs? dst_item.CopyFrom(src_item, platform::CPUPlace(), dev_ctx); + dst_item.set_lod(src_item.lod()); VLOG(3) << "Fetch variable " << fetch_var_name << " to " << out_name; } diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index c08a3380f042886cd400df0d840e61856274619c..3b1b0bd71dd3768b932864e185af8dc839b4653e 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -29,8 +29,8 @@ class Im2ColFunctor(); @@ -52,16 +68,14 @@ class Im2ColFunctor= input_height || - (im_col_idx - padding_width) < 0 || - (im_col_idx - padding_width) >= input_width) { + int im_row_idx = h * stride_height + h_offset - padding_up; + int im_col_idx = w * stride_width + w_offset - padding_left; + + if (im_row_idx < 0 || im_row_idx >= input_height || im_col_idx < 0 || + im_col_idx >= input_width) { col_data[(c * output_height + h) * output_width + w] = T(0); } else { - im_row_idx += c_im * input_height - padding_height; - im_col_idx -= padding_width; + im_row_idx += c_im * input_height; col_data[(c * output_height + h) * output_width + w] = im_data[im_row_idx * input_width + im_col_idx]; } @@ -82,7 +96,8 @@ class Col2ImFunctor(); @@ -103,14 +134,12 @@ class Col2ImFunctor= 0 && - (im_row_idx - padding_height) < input_height && - (im_col_idx - padding_width) >= 0 && - (im_col_idx - padding_width) < input_width) { - im_row_idx += c_im * input_height - padding_height; - im_col_idx -= padding_width; + int im_row_idx = h * stride_height + h_offset - padding_up; + int im_col_idx = w * stride_width + w_offset - padding_left; + + if ((im_row_idx) >= 0 && (im_row_idx) < input_height && + (im_col_idx) >= 0 && (im_col_idx) < input_width) { + im_row_idx += c_im * input_height; im_data[im_row_idx * input_width + im_col_idx] += col_data[(c * output_height + h) * output_width + w]; } @@ -140,8 +169,8 @@ class Im2ColFunctor(); T* col_data = col.data(); @@ -163,10 +207,10 @@ class Im2ColFunctor(); const T* col_data = col.data(); @@ -223,9 +283,9 @@ class Col2ImFunctor(context) .stream()>>>( im.data(), num_outputs, input_height, input_width, filter_height, - filter_width, stride_height, stride_width, padding_height, - padding_width, output_height, output_width, col.data()); + filter_width, stride_height, stride_width, padding_up, padding_left, + output_height, output_width, col.data()); } }; @@ -152,7 +161,8 @@ class Col2ImFunctor<<(context) .stream()>>>( - num_kernels, col.data(), input_height + 2 * padding_height, - input_width + 2 * padding_width, input_channels, filter_height, - filter_width, stride_height, stride_width, padding_height, - padding_width, output_height, output_width, im.data()); + num_kernels, col.data(), input_height + padding_up + padding_down, + input_width + padding_left + padding_left, input_channels, + filter_height, filter_width, stride_height, stride_width, padding_up, + padding_left, output_height, output_width, im.data()); } }; @@ -238,8 +258,8 @@ class Im2ColFunctor(context) .stream()>>>( im.data(), col.data(), input_channels, input_height, input_width, - filter_height, filter_width, stride_height, stride_width, - padding_height, padding_width, output_height, output_width); + filter_height, filter_width, stride_height, stride_width, padding_up, + padding_left, output_height, output_width); } }; @@ -322,7 +351,8 @@ class Col2ImFunctor(context) .stream()>>>( im.data(), col.data(), input_channels, input_height, input_width, - filter_height, filter_width, stride_height, stride_width, - padding_height, padding_width, output_height, output_width); + filter_height, filter_width, stride_height, stride_width, padding_up, + padding_left, output_height, output_width); } }; diff --git a/paddle/operators/math/im2col.h b/paddle/operators/math/im2col.h index 7b717e1603c94cd77c74cb0d86f1d23e2692f9d8..c736d4fa523c2af3e3dd7a11114d7f84021bc5c1 100644 --- a/paddle/operators/math/im2col.h +++ b/paddle/operators/math/im2col.h @@ -74,8 +74,8 @@ class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& im, framework::Tensor& col, - int stride_height, int stride_width, int padding_height, - int padding_width); + int stride_height, int stride_width, int padding_up, + int padding_down, int padding_left, int padding_right); }; template @@ -83,7 +83,8 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, const framework::Tensor& col, int stride_height, - int stride_width, int padding_height, int padding_width); + int stride_width, int padding_up, int padding_down, + int padding_left, int padding_right); }; } // namespace math diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc index 443c94b83f0bf24837afe703b19e2ab47a0dd786..5763782c4edec87f44dabef2ccffe3097eeb2421 100644 --- a/paddle/operators/math/im2col_test.cc +++ b/paddle/operators/math/im2col_test.cc @@ -35,6 +35,12 @@ void testIm2col() { * * output_ocf = [0, 1, 3, 4 * 1, 2, 4, 5] + * + * col2im_cfo = [0, 2, 2 + * 3, 4, 5] + * + * col2im_ocf = [0, 2, 2 + * 3, 4, 5] */ int input_height = 2; int input_width = 3; @@ -59,7 +65,7 @@ void testIm2col() { new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace()); #else PADDLE_THROW("no GPU support"); -#endif // PADDLE_ONLY_CPU +#endif // PADDLE_WITH_CUDA } if (paddle::platform::is_cpu_place(*place)) { input = input_tmp; @@ -71,6 +77,7 @@ void testIm2col() { output_ocf.mutable_data( {output_height, output_width, 1, filter_size, filter_size}, *place); + // Im2Col paddle::operators::math::Im2ColFunctor< paddle::operators::math::ColFormat::kCFO, Place, float> im2col; @@ -78,8 +85,13 @@ void testIm2col() { paddle::operators::math::ColFormat::kOCF, Place, float> im2col_ocf; - im2col(*context, input, output_cfo, stride, stride, padding, padding); - im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding); + im2col(*context, input, output_cfo, stride, stride, padding, padding, padding, + padding); + im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding, + padding, padding); + + float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5}; + float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5}; float* out_cfo_ptr; if (paddle::platform::is_cpu_place(*place)) { @@ -88,14 +100,9 @@ void testIm2col() { output_tmp.CopyFrom(output_cfo, paddle::platform::CPUPlace(), *context); out_cfo_ptr = output_tmp.data(); } - EXPECT_EQ(out_cfo_ptr[0], 0); - EXPECT_EQ(out_cfo_ptr[1], 1); - EXPECT_EQ(out_cfo_ptr[2], 1); - EXPECT_EQ(out_cfo_ptr[3], 2); - EXPECT_EQ(out_cfo_ptr[4], 3); - EXPECT_EQ(out_cfo_ptr[5], 4); - EXPECT_EQ(out_cfo_ptr[6], 4); - EXPECT_EQ(out_cfo_ptr[7], 5); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(out_cfo_ptr[i], out_cfo_data[i]); + } float* out_ocf_ptr; if (paddle::platform::is_cpu_place(*place)) { @@ -104,14 +111,60 @@ void testIm2col() { output_tmp.CopyFrom(output_ocf, paddle::platform::CPUPlace(), *context); out_ocf_ptr = output_tmp.data(); } - EXPECT_EQ(out_ocf_ptr[0], 0); - EXPECT_EQ(out_ocf_ptr[1], 1); - EXPECT_EQ(out_ocf_ptr[2], 3); - EXPECT_EQ(out_ocf_ptr[3], 4); - EXPECT_EQ(out_ocf_ptr[4], 1); - EXPECT_EQ(out_ocf_ptr[5], 2); - EXPECT_EQ(out_ocf_ptr[6], 4); - EXPECT_EQ(out_ocf_ptr[7], 5); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(out_ocf_ptr[i], out_ocf_data[i]); + } + + // Col2Im: kCFO + paddle::operators::math::Col2ImFunctor< + paddle::operators::math::ColFormat::kCFO, Place, float> + col2im; + paddle::operators::math::Col2ImFunctor< + paddle::operators::math::ColFormat::kOCF, Place, float> + col2im_ocf; + float col2im_data[] = {0, 2, 2, 3, 8, 5}; + + memset(input_ptr, 0, 6 * sizeof(float)); + if (paddle::platform::is_cpu_place(*place)) { + input = input_tmp; + } else { + input.CopyFrom(input_tmp, *place, *context); + } + + col2im(*context, input, output_cfo, stride, stride, padding, padding, padding, + padding); + + float* in_ptr; + if (paddle::platform::is_cpu_place(*place)) { + in_ptr = input.data(); + } else { + input_tmp.CopyFrom(input, paddle::platform::CPUPlace(), *context); + in_ptr = input_tmp.data(); + } + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(in_ptr[i], col2im_data[i]); + } + + // Col2Im: kOCF + memset(input_ptr, 0, 6 * sizeof(float)); + if (paddle::platform::is_cpu_place(*place)) { + input = input_tmp; + } else { + input.CopyFrom(input_tmp, *place, *context); + } + + col2im_ocf(*context, input, output_ocf, stride, stride, padding, padding, + padding, padding); + + if (paddle::platform::is_cpu_place(*place)) { + in_ptr = input.data(); + } else { + input_tmp.CopyFrom(input, paddle::platform::CPUPlace(), *context); + in_ptr = input_tmp.data(); + } + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(in_ptr[i], col2im_data[i]); + } } TEST(math, im2col) { diff --git a/paddle/platform/dynload/CMakeLists.txt b/paddle/platform/dynload/CMakeLists.txt index 4c8be33480e80e586a4bbbdb2664eb4cb2db8f0b..bb3fec1be9e811c26cc6851314e960e96fc366b3 100644 --- a/paddle/platform/dynload/CMakeLists.txt +++ b/paddle/platform/dynload/CMakeLists.txt @@ -1,2 +1,3 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags) -nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc nccl.cc DEPS dynamic_loader) +nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc nccl.cc + DEPS dynamic_loader nccl) diff --git a/python/paddle/v2/framework/framework.py b/python/paddle/v2/framework/framework.py index 813e25816d4792cc1b1adaf5212dbd3ad6a29516..40b9008d67b4e42093b9b9cbdecd1dbff4150b41 100644 --- a/python/paddle/v2/framework/framework.py +++ b/python/paddle/v2/framework/framework.py @@ -113,6 +113,10 @@ class Variable(object): def lod_level(self): return self.desc.lod_level() + @property + def type(self): + return self.desc.type() + @staticmethod def _unique_var_name_(): uid = core.unique_integer() # unique during whole process. diff --git a/python/paddle/v2/framework/layer_helper.py b/python/paddle/v2/framework/layer_helper.py index 5e14f39e3363b984d94fa472397af723a1fc4590..f3da32f0e07a22204b3feaed5d1d8d01556e4655 100644 --- a/python/paddle/v2/framework/layer_helper.py +++ b/python/paddle/v2/framework/layer_helper.py @@ -1,8 +1,11 @@ -from paddle.v2.framework.framework import Variable, OpProtoHolder, g_program, g_init_program -import paddle.v2.framework.core as core import copy import itertools +import paddle.v2.framework.core as core + +from paddle.v2.framework.framework import Variable, g_program, \ + g_init_program + def unique_name(prefix): uid = core.unique_integer() # unique during whole process. @@ -130,6 +133,9 @@ class LayerHelper(object): dtype=dtype, persistable=False) + def create_variable(self, *args, **kwargs): + return self.program.current_block().create_var(*args, **kwargs) + def create_global_variable(self, *args, **kwargs): return self.program.global_block().create_var( *args, persistable=False, **kwargs) diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/framework/layers.py index b7e914d734fd08a43de0c18bfa8fc14061a5cf3e..6894c40c3a6514f448133f029c4de8cc30405515 100644 --- a/python/paddle/v2/framework/layers.py +++ b/python/paddle/v2/framework/layers.py @@ -1,10 +1,11 @@ -from paddle.v2.framework.layer_helper import LayerHelper +from paddle.v2.framework.layer_helper import LayerHelper, unique_name import paddle.v2.framework.core as core -from paddle.v2.framework.framework import OpProtoHolder, Variable +from paddle.v2.framework.framework import OpProtoHolder, Variable, Program import re __all__ = [ - 'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat' + 'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat', + 'StaticRNN' ] @@ -26,7 +27,9 @@ def fc(input, mul_results = [] for input_var, param_attr in helper.iter_inputs_and_params(): input_shape = input_var.shape - param_shape = list(input_shape[num_flatten_dims:]) + [size] + param_shape = [ + reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1) + ] + [size] w = helper.create_parameter( attr=param_attr, shape=param_shape, dtype=dtype) @@ -38,10 +41,8 @@ def fc(input, "Y": w, }, outputs={"Out": tmp}, - attrs={ - 'x_num_col_dims': num_flatten_dims, - 'y_num_col_dims': len(input_shape) - num_flatten_dims - }) + attrs={'x_num_col_dims': num_flatten_dims, + 'y_num_col_dims': 1}) mul_results.append(tmp) # sum @@ -273,3 +274,170 @@ def pool2d(input, }) return pool_out + + +class BlockGuard(object): + """ + BlockGuard used to create sub-block in program by using Python `with` + keyword. + """ + + def __init__(self, program): + if not isinstance(program, Program): + raise TypeError("BlockGuard takes a program") + self.program = program + + def __enter__(self): + self.program.create_block() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.program.rollback() + if exc_type is not None: + return False # re-raise exception + return True + + +class StaticRNNGuard(BlockGuard): + def __init__(self, rnn): + if not isinstance(rnn, StaticRNN): + raise TypeError("StaticRNNGuard takes an StaticRNN") + super(StaticRNNGuard, self).__init__(rnn.helper.program) + self.rnn = rnn + + def __enter__(self): + self.rnn.status = StaticRNN.IN_RNN_BLOCK + return super(StaticRNNGuard, self).__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.rnn.status = StaticRNN.AFTER_RNN_BLOCK + self.rnn.complete_rnn_op() + return super(StaticRNNGuard, self).__exit__(exc_type, exc_val, exc_tb) + + +class StaticRNNMemoryLink(object): + """ + :param init: the initial variable for Memory + :type init: Variable + :param pre_mem: the memory variable in previous time step + :type pre_mem: Variable + :param mem: the memory variable in current time step + :type mem: Variable + """ + + def __init__(self, init, pre_mem, mem=None): + self.init = init + self.pre_mem = pre_mem + self.mem = mem + + +class StaticRNN(object): + BEFORE_RNN_BLOCK = 0 + IN_RNN_BLOCK = 1 + AFTER_RNN_BLOCK = 2 + + def __init__(self, name=None, program=None): + self.helper = LayerHelper("static_rnn", name=name, program=program) + self.memories = {} # memory map, from pre_mem.name --> MemoryLink + self.inputs = [] # input variable list in current block + self.outputs = [] # output variable list in parent block + self.status = StaticRNN.BEFORE_RNN_BLOCK # status flag. + # sequence length, since it is a static RNN, sequence length are fixed. + self.seq_len = None + + def step(self): + return StaticRNNGuard(self) + + def _assert_in_rnn_block_(self, method): + if self.status != StaticRNN.IN_RNN_BLOCK: + raise ValueError("You must invoke {0} in rnn block".format(method)) + + def memory(self, init=None, shape=None, dtype=None, init_value=0): + self._assert_in_rnn_block_('memory') + if init is None: + if shape is None or dtype is None: + raise ValueError( + "if init is None, memory at least need shape and dtype") + parent_block = self.parent_block() + var_name = unique_name("@".join([self.helper.name, "memory_boot"])) + boot_var = parent_block.create_var( + name=var_name, shape=shape, dtype=dtype, persistable=False) + + parent_block.append_op( + type="fill_constant", + inputs={}, + outputs={'Out': [boot_var]}, + attrs={ + 'value': init_value, + 'shape': boot_var.shape, + 'data_type': boot_var.data_type + }) + + return self.memory(init=boot_var) + else: + pre_mem = self.helper.create_variable( + name=unique_name("@".join([self.helper.name, "mem"])), + dtype=init.data_type, + shape=init.shape) + self.memories[pre_mem.name] = StaticRNNMemoryLink( + init=init, pre_mem=pre_mem) + return pre_mem + + def step_input(self, x): + self._assert_in_rnn_block_('step_input') + if not isinstance(x, Variable): + raise TypeError("step input takes a Variable") + if self.seq_len is None: + self.seq_len = x.shape[1] + elif self.seq_len != x.shape[1]: + raise ValueError("Static RNN only take fix seq_len input") + + ipt = self.helper.create_variable( + name=x.name, + dtype=x.data_type, + shape=[-1] + list(x.shape[2:]), + type=x.type) + self.inputs.append(ipt) + return ipt + + def step_output(self, o): + self._assert_in_rnn_block_('step_output') + if not isinstance(o, Variable): + raise TypeError("step output takes a Variable") + + out_var = self.parent_block().create_var( + name=o.name, + shape=[-1, self.seq_len] + list(o.shape[1:]), + dtype=o.data_type) + + self.outputs.append(out_var) + + def output(self, *outputs): + for each in outputs: + self.step_output(each) + + def update_memory(self, mem, var): + if not isinstance(mem, Variable) or not isinstance(var, Variable): + raise TypeError("update memory should take variables") + self.memories[mem.name].mem = var + + def parent_block(self): + prog = self.helper.program + parent_idx = prog.current_block().parent_idx + assert parent_idx >= 0 + parent_block = prog.block(parent_idx) + return parent_block + + def __call__(self, *args, **kwargs): + if self.status != StaticRNN.AFTER_RNN_BLOCK: + raise ValueError("RNN output can only be retrieved after rnn block") + if len(self.outputs) == 0: + raise ValueError("RNN has no output") + elif len(self.outputs) == 1: + return self.outputs[0] + else: + return self.outputs + + def complete_rnn_op(self): + # TODO(yuyang18): Create RNN Op here. + # Implement this method after RNN op complete. + pass diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index 1c6dce96346cecdf3f5fe5bcaf4ff4cd61da94ac..0f8c61a2ab5a9d046f80c5ead89e0e4e8f7422ce 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -179,7 +179,12 @@ def get_backward_op(scope, op, no_grad_set): return backward_op -def get_gradient(scope, op, inputs, outputs, grad_name, place, +def get_gradient(scope, + op, + inputs, + outputs, + grad_names, + place, no_grad_set=None): ctx = core.DeviceContext.create(place) @@ -195,8 +200,10 @@ def get_gradient(scope, op, inputs, outputs, grad_name, place, backward_op.run(scope, ctx) - out = np.array(scope.find_var(grad_name).get_tensor()) - return out + return [ + np.array(scope.find_var(grad_name).get_tensor()) + for grad_name in grad_names + ] def append_input_output(block, op_proto, np_list, is_input): @@ -326,20 +333,31 @@ class OpTest(unittest.TestCase): type(sub_out)) for sub_out_name, expect in sub_out: idx = find_actual(sub_out_name, fetch_list) - actual = outs[idx] + actual_t = np.array(outs[idx]) + expect_t = expect[0] \ + if isinstance(expect, tuple) else expect self.assertTrue( np.allclose( - actual, expect, atol=atol), + actual_t, expect_t, atol=atol), "Output (" + sub_out_name + ") has diff at " + str(place)) + if isinstance(expect, tuple): + self.assertListEqual( + actual_t.lod(), expect[1], "Output (" + sub_out_name + + ") has different lod at " + str(place)) else: idx = find_actual(out_name, fetch_list) - actual = outs[idx] + actual_t = outs[idx] expect = self.outputs[out_name] + expect_t = expect[0] if isinstance(expect, tuple) else expect self.assertTrue( np.allclose( - actual, expect, atol=atol), + actual_t, expect_t, atol=atol), "Output (" + out_name + ") has diff at " + str(place)) + if isinstance(expect, tuple): + self.assertListEqual(actual_t.lod(), expect[1], + "Output (" + out_name + + ") has different lod at " + str(place)) def check_output(self, atol=1e-5): places = [core.CPUPlace()] @@ -399,11 +417,9 @@ class OpTest(unittest.TestCase): ] cpu_place = core.CPUPlace() - cpu_analytic_grads = [ - get_gradient(self.scope, self.op, self.inputs, self.outputs, - grad_name, cpu_place, no_grad_set) - for grad_name in grad_names - ] + cpu_analytic_grads = get_gradient(self.scope, self.op, self.inputs, + self.outputs, grad_names, cpu_place, + no_grad_set) self.__assert_is_close(numeric_grads, cpu_analytic_grads, grad_names, max_relative_error, @@ -411,11 +427,9 @@ class OpTest(unittest.TestCase): if core.is_compile_gpu() and self.op.support_gpu(): gpu_place = core.GPUPlace(0) - gpu_analytic_grads = [ - get_gradient(self.scope, self.op, self.inputs, self.outputs, - grad_name, gpu_place, no_grad_set) - for grad_name in grad_names - ] + gpu_analytic_grads = get_gradient(self.scope, self.op, self.inputs, + self.outputs, grad_names, + gpu_place, no_grad_set) self.__assert_is_close(numeric_grads, gpu_analytic_grads, grad_names, max_relative_error, diff --git a/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py b/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py new file mode 100644 index 0000000000000000000000000000000000000000..71ca262f00378381d2d65e87d198d6b1755e9a2b --- /dev/null +++ b/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py @@ -0,0 +1,102 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def conv2dtranspose_forward_naive(input_, filter_, conv2dtranspose_param): + # [2, 3, 5, 5] + in_n, in_c, in_h, in_w = input_.shape + # [3, 6, 3, 3] + f_c, out_c, f_h, f_w = filter_.shape + assert in_c == f_c + + stride, pad = conv2dtranspose_param['stride'], conv2dtranspose_param['pad'] + out_h = (in_h - 1) * stride[0] + f_h + out_w = (in_w - 1) * stride[1] + f_w + + out = np.zeros((in_n, out_c, out_h, out_w)) + + for n in range(in_n): + for i in range(in_h): + for j in range(in_w): + input_masked = input_[n, :, i, j] # (c) + input_masked = np.reshape(input_masked, (in_c, 1, 1)) + input_masked = np.tile(input_masked, (1, f_h, f_w)) + + for k in range(out_c): + tmp_out = np.sum(input_masked * filter_[:, k, :, :], axis=0) + i1, i2 = i * stride[0], i * stride[0] + f_h + j1, j2 = j * stride[0], j * stride[0] + f_w + out[n, k, i1:i2, j1:j2] += tmp_out + + return out + + +class TestConv2dTransposeOp(OpTest): + def setUp(self): + # init as conv transpose + self.init_op_type() + + # [2, 3, 5, 5] -> kernel [3, 6, 3, 3] -> output [2, 6, 7, 7] + self.init_test_case() + + conv2dtranspose_param = {'stride': self.stride, 'pad': self.pad} + input_ = np.random.random(self.input_size).astype("float32") + filter_ = np.random.random(self.filter_size).astype("float32") + output = conv2dtranspose_forward_naive(input_, filter_, + conv2dtranspose_param) + # print 'deconv output py', output, output.shape + + self.inputs = {'Input': input_, 'Filter': filter_} + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + # 'dilations': self.dilations + } + self.outputs = {'Output': output} + + def test_check_output(self): + print 'check output here' + self.check_output() + + def test_check_grad(self): + self.check_grad( + set(['Input', 'Filter']), 'Output', max_relative_error=0.05) + + def test_check_grad_no_filter(self): + self.check_grad( + ['Input'], + 'Output', + max_relative_error=0.05, + no_grad_set=set(['Filter'])) + + def test_check_grad_no_input(self): + self.check_grad( + ['Filter'], + 'Output', + max_relative_error=0.05, + no_grad_set=set(['Input'])) + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.dilations = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + def init_op_type(self): + self.op_type = "conv2dtranspose" + + +""" +class TestCudnn(TestConv2dOp): + def init_group(self): + self.groups = 1 + + def init_op_type(self): + self.op_type = "conv_cudnn" +""" + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/framework/tests/test_lstm_op.py index bcce8d32c944a39e6d6aad4c99f8aa152222c3c1..93a4e450e916716e27573d192bace73f271733de 100644 --- a/python/paddle/v2/framework/tests/test_lstm_op.py +++ b/python/paddle/v2/framework/tests/test_lstm_op.py @@ -155,7 +155,11 @@ class TestLstmOp(OpTest): 'Weight': w, 'Bias': b } - self.outputs = {'Hidden': h, 'Cell': c, 'BatchGate': g_sort} + self.outputs = { + 'Hidden': (h, self.lod), + 'Cell': (c, self.lod), + 'BatchGate': g_sort + } self.attrs = { 'usePeepholes': True, 'isReverse': self.is_reverse, diff --git a/python/paddle/v2/framework/tests/test_rnn_helpers.py b/python/paddle/v2/framework/tests/test_rnn_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..be0ecfb129aa181229bc42d8d6818ad860991965 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_rnn_helpers.py @@ -0,0 +1,38 @@ +import unittest +from paddle.v2.framework.layers import * +from paddle.v2.framework.framework import g_program + + +class TestRNN(unittest.TestCase): + def test_rnn(self): + img = data( + shape=[ + 80, # sequence length + 22, # image height + 22 + ], # image width + data_type='float32', + name='image') + hidden = fc(input=img, size=100, act='sigmoid', num_flatten_dims=2) + self.assertEqual((-1, 80, 100), hidden.shape) + hidden = fc(input=hidden, size=100, act='sigmoid', num_flatten_dims=2) + self.assertEqual((-1, 80, 100), hidden.shape) + + rnn = StaticRNN() + with rnn.step(): + hidden = rnn.step_input(hidden) + self.assertEqual((-1, 100), hidden.shape) + memory = rnn.memory(shape=(-1, 32), dtype='float32', init_value=0.0) + + rnn_out = fc(input=[hidden, memory], size=32, act='sigmoid') + self.assertEqual((-1, 32), rnn_out.shape) + rnn.update_memory(memory, rnn_out) + rnn.output(rnn_out) + + out = rnn() + self.assertEqual((-1, 80, 32), out.shape) + print g_program + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index d0b5ff12f2748909745294d4ba96888f2baf2b8d..bd97dc1199fedc8ac91c1c6086957e8cce88bdc4 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -326,6 +326,17 @@ class Parameters(object): self.set(name, arr.reshape(self.get_shape(name))) def to_tar(self, f): + """ + Save parameters to a tar file. + + WARNING: You should use `paddle.v2.trainer.SGD.save_parameter_to_tar(f)` + to save parameters most of the time. Otherwise, some settings such + as model average will not take effect. + + :param f: + :type f: file + :return: + """ tar = tarfile.TarFile(fileobj=f, mode='w') for nm in self.names(): buf = cStringIO.StringIO()