提交 da1181bf 编写于 作者: D Dong Zhihong

Merge remote-tracking branch 'origin/develop' into feature/multigpu

......@@ -127,9 +127,9 @@ include(external/warpctc) # download, build, install warpctc
include(external/any) # download libn::any
include(external/eigen) # download eigen3
include(external/pybind11) # download pybind11
include(external/nccl)
include(cudnn) # set cudnn libraries, must before configure
include(nccl) # set nccl libraries
include(configure) # add paddle env configuration
include(generic) # simplify cmake module
include(package) # set paddle packages
......
......@@ -68,13 +68,6 @@ else()
if(NOT CUDNN_FOUND)
message(FATAL_ERROR "Paddle needs cudnn to compile")
endif()
if (NOT NCCL_INCLUDE_DIR)
message(FATAL_ERROR "Paddle needs nccl header to compile")
endif()
if (NOT WITH_DSO AND NOT NCCL_LIBRARY)
message(FATAL_ERROR "Paddle needs nccl libraries when WITH_DSO=OFF")
endif()
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} "-Xcompiler ${SIMD_FLAG}")
......
INCLUDE(ExternalProject)
SET(NCCL_SOURCE_DIR ${THIRD_PARTY_PATH}/nccl)
INCLUDE_DIRECTORIES(${NCCL_SOURCE_DIR}/src/extern_nccl/src)
if(WITH_DSO)
# If we use DSO, we do not build nccl, just download the dependencies
set(NCCL_BUILD_COMMAND "")
set(NCCL_INSTALL_COMMAND "")
set(NCCL_INSTALL_DIR "")
else()
# otherwise, we build nccl and link it.
set(NCCL_BUILD_COMMAND "make -j 8")
set(NCCL_INSTALL_COMMAND "make install")
SET(NCCL_INSTALL_DIR ${THIRD_PARTY_PATH}/install/nccl)
endif()
ExternalProject_Add(
extern_nccl
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/NVIDIA/nccl.git"
GIT_TAG "v1.3.4-1"
PREFIX "${NCCL_SOURCE_DIR}"
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND "${NCCL_BUILD_COMMAND}"
INSTALL_COMMAND "${NCCL_INSTALL_COMMAND}"
INSTALL_DIR "${NCCL_INSTALL_DIR}"
TEST_COMMAND ""
)
if (WITH_DSO)
if (${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/lib_any_dummy.c)
file(WRITE ${dummyfile} "const char * dummy_any = \"${dummyfile}\";")
add_library(nccl STATIC ${dummyfile})
else()
add_library(nccl INTERFACE)
endif()
else()
ADD_LIBRARY(nccl STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET nccl PROPERTY IMPORTED_LOCATION
${NCCL_INSTALL_DIR}/lib/libnccl.a)
endif()
add_dependencies(nccl extern_nccl)
LIST(APPEND external_project_dependencies nccl)
if (NOT WITH_GPU)
return ()
endif()
set(NCCL_ROOT "/usr" CACHE PATH "CUDNN ROOT")
find_path(NCCL_INCLUDE_DIR nccl.h PATHS
${NCCL_ROOT} ${NCCL_ROOT}/include
$ENV{NCCL_ROOT} $ENV{NCCL_ROOT}/include ${CUDA_TOOLKIT_INCLUDE}
NO_DEFAULT_PATH)
get_filename_component(__libpath_hist ${CUDA_CUDART_LIBRARY} PATH)
set(TARGET_ARCH "x86_64")
if(NOT ${CMAKE_SYSTEM_PROCESSOR})
set(TARGET_ARCH ${CMAKE_SYSTEM_PROCESSOR})
endif()
list(APPEND NCCL_CHECK_LIBRARY_DIRS
${NCCL_ROOT}
${NCCL_ROOT}/lib64
${NCCL_ROOT}/lib
${NCCL_ROOT}/lib/${TARGET_ARCH}-linux-gnu
$ENV{NCCL_ROOT}
$ENV{NCCL_ROOT}/lib64
$ENV{NCCL_ROOT}/lib
/usr/lib)
find_library(NCCL_LIBRARY NAMES libnccl.so libnccl.dylib # libcudnn_static.a
PATHS ${NCCL_CHECK_LIBRARY_DIRS} ${NCCL_INCLUDE_DIR} ${__libpath_hist}
NO_DEFAULT_PATH
DOC "Path to nccl library.")
## Survey on Graph
Neural network framework often provides symbolic API for users to write network topology conveniently. This doc manily focus on symbolic API in most popular neural network frameworks, and try to find out how to parse symbolic configuration to a portable file, such as protobuf or json.
### Mxnet
The core concept of symbolic API is `Symbol`. Mxnet implements `Symbol` class in C++, and export to Python using C-API. Please refer to the comments in Mxnet:
`Symbol` is help class used to represent the operator node in Graph.
`Symbol` acts as an interface for building graphs from different components like Variable, Functor and Group. `Symbol` is also exported to python front-end (while Graph is not) to enable quick test and deployment. Conceptually, symbol is the final operation of a graph and thus including all the information required (the graph) to evaluate its output value.
A simple network topology wrote by Symbol is as follows:
```python
def get_symbol(num_classes=10, **kwargs):
data = mx.symbol.Variable('data')
data = mx.symbol.Flatten(data=data)
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
return mlp
```
Varible here is actually a Symbol. Every basic Symbol will correspond to one Node, and every Node has its own NodeAttr. There is a op field in NodeAttr class, when a Symbol represents Variable(often input data), the op field is null.
Symbol contains a data member, std::vector<NodeEntry> outputs, and NodeEntry cantains a poniter to Node. We can follow the Node pointer to get all the Graph.
And Symbol can be saved to a Json file.
Here is a detailed example:
```
>>> import mxnet as mx
>>> data = mx.symbol.Variable('data')
>>> print data.debug_str()
Variable:data
>>> data = mx.symbol.Flatten(data=data)
>>> print data.debug_str()
Symbol Outputs:
output[0]=flatten0(0)
Variable:data
--------------------
Op:Flatten, Name=flatten0
Inputs:
arg[0]=data(0) version=0
>>> fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
>>> print fc1.debug_str()
Symbol Outputs:
output[0]=fc1(0)
Variable:data
--------------------
Op:Flatten, Name=flatten0
Inputs:
arg[0]=data(0) version=0
Variable:fc1_weight
Variable:fc1_bias
--------------------
Op:FullyConnected, Name=fc1
Inputs:
arg[0]=flatten0(0)
arg[1]=fc1_weight(0) version=0
arg[2]=fc1_bias(0) version=0
Attrs:
num_hidden=128
```
### TensorFlow
The core concept of symbolic API is `Tensor`. Tensorflow defines `Tensor` in Python. Please refer to the comments in TensorFlow:
A `Tensor` is a symbolic handle to one of the outputs of an `Operation`. It does not hold the values of that operation's output, but instead provides a means of computing those values in a TensorFlow [Session](https://www.tensorflow.org/api_docs/python/tf/Session).
A simple example is as follows:
```python
# Build a dataflow graph.
c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
e = tf.matmul(c, d)
# Construct a `Session` to execute the graph.
sess = tf.Session()
# Execute the graph and store the value that `e` represents in `result`.
result = sess.run(e)
```
The main method of `Tensor` is as follows:
```python
@property
def op(self):
"""The `Operation` that produces this tensor as an output."""
return self._op
@property
def dtype(self):
"""The `DType` of elements in this tensor."""
return self._dtype
@property
def graph(self):
"""The `Graph` that contains this tensor."""
return self._op.graph
@property
def name(self):
"""The string name of this tensor."""
if not self._op.name:
raise ValueError("Operation was not named: %s" % self._op)
return "%s:%d" % (self._op.name, self._value_index)
@property
def device(self):
"""The name of the device on which this tensor will be produced, or None."""
return self._op.device
```
Tensor can be taken as target to run by session. Tensor contains all the information of Graph, and tracks data dependency.
Here is a detailed example:
```
>>> import tensorflow as tf
>>> c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
>>> print c.graph
<tensorflow.python.framework.ops.Graph object at 0x10f256d50>
>>> d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
>>> print d.graph
<tensorflow.python.framework.ops.Graph object at 0x10f256d50>
>>> e = tf.matmul(c, d)
>>> print e.graph
<tensorflow.python.framework.ops.Graph object at 0x10f256d50>
```
### Dynet
The core concept of symbolic API is `Expression`, and Dynet defines `Expression` class in C++.
A simple example is as follows:
```cpp
ComputationGraph cg;
Expression W = parameter(cg, pW);
Expression in = input(cg, xs[i]);
Expression label = input(cg, ys[i]);
Expression pred = W * in;
Expression loss = square(pred - label);
```
The input data and parameter are also represented by Expression. Every basci Expression corresponds to a Node. And input data is also a Node.
Expression has a data member ComputationGraph, and ComputationGraph will be modified in users' configuring process. Expression can be a running target, beacuse Expression contains all dependency.
Here is a detailed example:
write topology in C++
```
ComputationGraph cg;
Expression W = parameter(cg, pW);
cg.print_graphviz();
Expression pred = W * xs[i];
cg.print_graphviz();
Expression loss = square(pred - ys[i]);
cg.print_graphviz();
```
compile and print
```
# first print
digraph G {
rankdir=LR;
nodesep=.05;
N0 [label="v0 = parameters({1}) @ 0x7ffe4de00110"];
}
# second print
digraph G {
rankdir=LR;
nodesep=.05;
N0 [label="v0 = parameters({1}) @ 0x7ffe4de00110"];
N1 [label="v1 = v0 * -0.98"];
N0 -> N1;
}
# third print
digraph G {
rankdir=LR;
nodesep=.05;
N0 [label="v0 = parameters({1}) @ 0x7ffe4de00110"];
N1 [label="v1 = v0 * -0.98"];
N0 -> N1;
N2 [label="v2 = -1.88387 - v1"];
N1 -> N2;
N3 [label="v3 = -v2"];
N2 -> N3;
N4 [label="v4 = square(v3)"];
N3 -> N4;
}
```
### Conclusion
Actually, Symbol/Tensor/Expression in Mxnet/TensorFlow/Dynet are the same level concepts. We use a unified name Expression here, this level concept has following features:
- Users wirte topoloy with symbolic API, and all return value is Expression, including input data and parameter.
- Expression corresponds with a global Graph, and Expression can also be composed.
- Expression tracks all dependency and can be taken as a run target
......@@ -114,7 +114,7 @@ class GemmConv2DKernel : public framework::OpKernel<T> {
// im2col
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
im2col(context.device_context(), in_slice, col, strides[0], strides[1],
paddings[0], paddings[1]);
paddings[0], paddings[0], paddings[1], paddings[1]);
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
......@@ -213,7 +213,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel<T> {
Tensor in_grad_slice =
in_grad_batch.Slice(g * in_step, (g + 1) * in_step);
col2im(context.device_context(), in_grad_slice, col, strides[0],
strides[1], paddings[0], paddings[1]);
strides[1], paddings[0], paddings[0], paddings[1],
paddings[1]);
}
}
}
......@@ -235,7 +236,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel<T> {
out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
im2col(context.device_context(), in_slice, col, strides[0],
strides[1], paddings[0], paddings[1]);
strides[1], paddings[0], paddings[0], paddings[1],
paddings[1]);
// gemm
Tensor filter_grad_slice =
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/conv2dtranspose_op.h"
namespace paddle {
namespace operators {
void Conv2DTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of Conv2DTransposeOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"),
"Input(Filter) of Conv2DTransposeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Output"),
"Output(Output) of Conv2DTransposeOp should not be null.");
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
for (size_t i = 0; i < paddings.size(); ++i) {
PADDLE_ENFORCE_EQ(paddings[i], 0,
"No Padding allowed in conv transpose op.");
}
PADDLE_ENFORCE_EQ(in_dims.size(), 4,
"Conv2DTransposeOp input should be 4-D tensor.");
PADDLE_ENFORCE_EQ(filter_dims.size(), 4,
"Conv2DTransposeOp filter should be 4-D tensor.");
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0],
"input and kernel input dimension should be equal.");
auto output_height = (in_dims[2] - 1) * strides[0] + filter_dims[2];
auto output_width = (in_dims[3] - 1) * strides[1] + filter_dims[3];
ctx->SetOutputDim("Output",
{in_dims[0], filter_dims[1], output_height, output_width});
}
Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(
framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"Input",
"(Tensor) The input tensor of convolution transpose operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of input channels, H and W is the height and width of image.");
AddInput("Filter",
"(Tensor) The filter tensor of convolution transpose operator."
"The format of the filter tensor is CMHW, where C is the number of "
"output image channels, M is the number of input image channels, "
"H and W is height and width of filter. "
"We enforce groups number == 1 and padding == 0 in "
"convolution transpose Scenario.");
AddOutput("Output",
"(Tensor) The output tensor of convolution transpose operator."
"The format of output tensor is also NCHW.");
AddAttr<std::vector<int>>("strides",
"strides of convolution transpose operator.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings",
"paddings of convolution transpose operator.")
.SetDefault({0, 0});
AddComment(R"DOC(
The convolution transpose operation calculates the output based on the input, filter
and strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
)DOC");
}
void Conv2DTransposeOpGrad::InferShape(
framework::InferShapeContext* ctx) const {
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
if (ctx->HasOutput(framework::GradVarName("Input"))) {
ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
}
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
}
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(conv2dtranspose, ops::Conv2DTransposeOp,
ops::Conv2DTransposeOpMaker, conv2dtranspose_grad,
ops::Conv2DTransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
conv2dtranspose,
ops::GemmConv2DTransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv2dtranspose_grad,
ops::GemmConv2DTransposeGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/conv2dtranspose_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
conv2dtranspose,
ops::GemmConv2DTransposeKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
conv2dtranspose_grad,
ops::GemmConv2DTransposeGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
// Define Op classes in .h file so that other conv transpose
// operator implementations can reuse the code.
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv2DTransposeOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker);
};
class Conv2DTransposeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override;
};
class Conv2DTransposeOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override;
};
template <typename Place, typename T>
class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
// The filter will be reshaped, so it should not be constant pointer
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// TODO(Zhuoyuan): Paddings can be added in future.
// groups will alway be disabled in conv2dtranspose.
const int batch_size = input->dims()[0];
const int m = input->dims()[1];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int k_h = filter.dims()[2];
const int k_w = filter.dims()[3];
const int c = output->dims()[1]; // output channels
const int o_h = output->dims()[2];
const int o_w = output->dims()[3];
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T>
col2im;
// use col_shape in the im2col and col2im calculation
DDim col_shape = {c, k_h, k_w, h, w};
// use col_matrix_shape in the gemm calculation
DDim col_matrix_shape = {c * k_h * k_w, h * w};
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
Tensor col_matrix;
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
DDim output_shape = {c, o_h, o_w};
DDim input_matrix_shape = {m, h * w};
DDim filter_matrix_shape = {m, c * k_h * k_w};
filter.Resize(filter_matrix_shape);
// convolution transpose: gemm + col2im (similar to conv-backward on input)
output->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*output);
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < batch_size; i++) {
// batch with size (M, h * w)
Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
// filter size: (M, c * k_h * k_w)
// output size: (c, o_h, o_w)
Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape);
// col_matrix = filter * input_batch
// of shape (c * k_h * k_w, h * w)
math::matmul<Place, T>(context.device_context(), filter, true,
input_batch, false, T(1.0), &col_matrix, T(0.0));
col2im(context.device_context(), output_batch, col, strides[0],
strides[1], 0, 0, 0, 0);
}
}
};
template <typename Place, typename T>
class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
// For filter, we do not use const pointer b/c we will do reshape,
// but we should avoid modifying its value.
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* input_grad =
context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter"));
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// Actually, no paddings and groups allowed in conv transpose.
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
const int batch_size = input->dims()[0];
const int m = input->dims()[1];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int k_h = filter.dims()[2];
const int k_w = filter.dims()[3];
const int c = output_grad->dims()[1]; // output channels
const int o_h = output_grad->dims()[2];
const int o_w = output_grad->dims()[3];
// Only im2col functor required for bp to get to the right shape
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T>
im2col;
// use col_shape in the im2col and col2im calculation
DDim col_shape = {c, k_h, k_w, h, w};
// use col_matrix_shape in the gemm calculation
DDim col_matrix_shape_f = {c * h * w, k_h * k_w};
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
DDim output_shape = {c, o_h, o_w};
DDim input_matrix_shape = {m, h * w};
DDim filter_matrix_shape = {m, c * k_h * k_w};
filter.Resize(filter_matrix_shape);
// convolution transpose grad on input:
// im2col + gemm (similar to conv-forward)
// input need to compute gradient
if (input_grad) {
Tensor col_matrix;
col_matrix.ShareDataWith(col);
DDim col_matrix_shape = {c * k_h * k_w, h * w};
col_matrix.Resize(col_matrix_shape);
input_grad->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*input_grad);
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < batch_size; i++) {
// batch with size (c, o_h * o_w)
Tensor output_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_shape);
// filter of size (m, c * k_h * k_w)
// batch with size (m, h, w)
Tensor input_grad_batch =
input_grad->Slice(i, i + 1).Resize(input_matrix_shape);
// im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w)
im2col(context.device_context(), output_grad_batch, col, strides[0],
strides[1], paddings[0], paddings[0], paddings[1], paddings[1]);
// gemm: dx = filter * dy
// (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h)
math::matmul<Place, T>(context.device_context(), filter, false,
col_matrix, false, T(1.0), &input_grad_batch,
T(0.0));
}
}
// filter gradient required
if (filter_grad) {
Tensor col_matrix_f;
col_matrix_f.ShareDataWith(col);
DDim col_matrix_shape_f = {c * h * w, k_h * k_w};
col_matrix_f.Resize(col_matrix_shape_f);
filter_grad->mutable_data<T>(context.GetPlace());
Tensor filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape);
auto t = framework::EigenVector<T>::Flatten(filter_grad_);
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < batch_size; ++i) {
// batch with size (c, o_h, o_w)
Tensor output_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_shape);
// input batch
Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
// im2col: (c * h * w, k_h * k_w)
im2col(context.device_context(), output_grad_batch, col, strides[0],
strides[1], paddings[0], paddings[0], paddings[1], paddings[1]);
// gemm: d_filter = x * y_grad^T
// (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h)
math::matmul<Place, T>(context.device_context(), in_batch, false,
col_matrix_f, true, T(1.0), &filter_grad_,
T(1.0));
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -52,6 +52,7 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs?
dst_item.CopyFrom(src_item, platform::CPUPlace(), dev_ctx);
dst_item.set_lod(src_item.lod());
VLOG(3) << "Fetch variable " << fetch_var_name << " to " << out_name;
}
......
......@@ -29,8 +29,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width) {
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
......@@ -41,6 +41,22 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int filter_width = col.dims()[2];
int output_height = col.dims()[3];
int output_width = col.dims()[4];
PADDLE_ENFORCE_EQ(
(input_height + padding_up + padding_down - filter_height) /
stride_height +
1,
output_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
int channels_col = input_channels * filter_height * filter_width;
const T* im_data = im.data<T>();
......@@ -52,16 +68,14 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int c_im = c / filter_width / filter_height;
for (int h = 0; h < output_height; ++h) {
for (int w = 0; w < output_width; ++w) {
int im_row_idx = h * stride_height + h_offset;
int im_col_idx = w * stride_width + w_offset;
if ((im_row_idx - padding_height) < 0 ||
(im_row_idx - padding_height) >= input_height ||
(im_col_idx - padding_width) < 0 ||
(im_col_idx - padding_width) >= input_width) {
int im_row_idx = h * stride_height + h_offset - padding_up;
int im_col_idx = w * stride_width + w_offset - padding_left;
if (im_row_idx < 0 || im_row_idx >= input_height || im_col_idx < 0 ||
im_col_idx >= input_width) {
col_data[(c * output_height + h) * output_width + w] = T(0);
} else {
im_row_idx += c_im * input_height - padding_height;
im_col_idx -= padding_width;
im_row_idx += c_im * input_height;
col_data[(c * output_height + h) * output_width + w] =
im_data[im_row_idx * input_width + im_col_idx];
}
......@@ -82,7 +96,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) {
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......@@ -92,6 +107,22 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int filter_width = col.dims()[2];
int output_height = col.dims()[3];
int output_width = col.dims()[4];
PADDLE_ENFORCE_EQ(
(input_height + padding_up + padding_down - filter_height) /
stride_height +
1,
output_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
int channels_col = input_channels * filter_height * filter_width;
T* im_data = im.data<T>();
......@@ -103,14 +134,12 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int c_im = c / filter_width / filter_height;
for (int h = 0; h < output_height; ++h) {
for (int w = 0; w < output_width; ++w) {
int im_row_idx = h * stride_height + h_offset;
int im_col_idx = w * stride_width + w_offset;
if ((im_row_idx - padding_height) >= 0 &&
(im_row_idx - padding_height) < input_height &&
(im_col_idx - padding_width) >= 0 &&
(im_col_idx - padding_width) < input_width) {
im_row_idx += c_im * input_height - padding_height;
im_col_idx -= padding_width;
int im_row_idx = h * stride_height + h_offset - padding_up;
int im_col_idx = w * stride_width + w_offset - padding_left;
if ((im_row_idx) >= 0 && (im_row_idx) < input_height &&
(im_col_idx) >= 0 && (im_col_idx) < input_width) {
im_row_idx += c_im * input_height;
im_data[im_row_idx * input_width + im_col_idx] +=
col_data[(c * output_height + h) * output_width + w];
}
......@@ -140,8 +169,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width) {
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......@@ -152,6 +181,21 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
int output_height = col.dims()[0];
int output_width = col.dims()[1];
PADDLE_ENFORCE_EQ(
(input_height + padding_up + padding_down - filter_height) /
stride_height +
1,
output_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
const T* im_data = im.data<T>();
T* col_data = col.data<T>();
......@@ -163,10 +207,10 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) {
int im_row_offset =
col_row_idx * stride_height + filter_row_idx - padding_height;
col_row_idx * stride_height + filter_row_idx - padding_up;
int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_width;
int col_offset = (((col_row_idx * output_width + col_col_idx) *
col_col_idx * stride_width + filter_col_idx - padding_left;
int col_offset = ((((col_row_idx)*output_width + col_col_idx) *
input_channels +
channel) *
filter_height +
......@@ -201,7 +245,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) {
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......@@ -212,6 +257,21 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
int output_height = col.dims()[0];
int output_width = col.dims()[1];
PADDLE_ENFORCE_EQ(
(input_height + padding_up + padding_down - filter_height) /
stride_height +
1,
output_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
T* im_data = im.data<T>();
const T* col_data = col.data<T>();
......@@ -223,9 +283,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) {
int im_row_offset =
col_row_idx * stride_height + filter_row_idx - padding_height;
col_row_idx * stride_height + filter_row_idx - padding_up;
int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_width;
col_col_idx * stride_width + filter_col_idx - padding_left;
int col_offset = (((col_row_idx * output_width + col_col_idx) *
input_channels +
channel) *
......
......@@ -66,8 +66,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width) {
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
......@@ -79,6 +79,15 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int output_height = col.dims()[3];
int output_width = col.dims()[4];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
int num_outputs = input_channels * output_height * output_width;
int blocks = (num_outputs + 1024 - 1) / 1024;
int block_x = 512;
......@@ -89,8 +98,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
im.data<T>(), num_outputs, input_height, input_width, filter_height,
filter_width, stride_height, stride_width, padding_height,
padding_width, output_height, output_width, col.data<T>());
filter_width, stride_height, stride_width, padding_up, padding_left,
output_height, output_width, col.data<T>());
}
};
......@@ -152,7 +161,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) {
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
......@@ -164,8 +174,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int output_height = col.dims()[3];
int output_width = col.dims()[4];
size_t num_kernels = input_channels * (input_height + 2 * padding_height) *
(input_width + 2 * padding_width);
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
size_t num_kernels = input_channels *
(input_height + padding_up + padding_down) *
(input_width + padding_left + padding_right);
size_t blocks = (num_kernels + 1024 - 1) / 1024;
size_t block_x = 512;
......@@ -178,10 +198,10 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
col2im<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
num_kernels, col.data<T>(), input_height + 2 * padding_height,
input_width + 2 * padding_width, input_channels, filter_height,
filter_width, stride_height, stride_width, padding_height,
padding_width, output_height, output_width, im.data<T>());
num_kernels, col.data<T>(), input_height + padding_up + padding_down,
input_width + padding_left + padding_left, input_channels,
filter_height, filter_width, stride_height, stride_width, padding_up,
padding_left, output_height, output_width, im.data<T>());
}
};
......@@ -238,8 +258,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width) {
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......@@ -250,6 +270,15 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
int output_height = col.dims()[0];
int output_width = col.dims()[1];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
int block_dim_x = 0;
int block_dim_y = 0;
if (filter_height <= 4 && filter_width <= 4) {
......@@ -274,8 +303,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width,
padding_height, padding_width, output_height, output_width);
filter_height, filter_width, stride_height, stride_width, padding_up,
padding_left, output_height, output_width);
}
};
......@@ -322,7 +351,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) {
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......@@ -333,6 +363,15 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
int output_height = col.dims()[0];
int output_width = col.dims()[1];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
int block_dim_x = 0;
int block_dim_y = 0;
if (filter_height <= 4 && filter_width <= 4) {
......@@ -357,8 +396,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width,
padding_height, padding_width, output_height, output_width);
filter_height, filter_width, stride_height, stride_width, padding_up,
padding_left, output_height, output_width);
}
};
......
......@@ -74,8 +74,8 @@ class Im2ColFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width);
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right);
};
template <ColFormat Format, typename Place, typename T>
......@@ -83,7 +83,8 @@ class Col2ImFunctor {
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width);
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right);
};
} // namespace math
......
......@@ -35,6 +35,12 @@ void testIm2col() {
*
* output_ocf = [0, 1, 3, 4
* 1, 2, 4, 5]
*
* col2im_cfo = [0, 2, 2
* 3, 4, 5]
*
* col2im_ocf = [0, 2, 2
* 3, 4, 5]
*/
int input_height = 2;
int input_width = 3;
......@@ -59,7 +65,7 @@ void testIm2col() {
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
#else
PADDLE_THROW("no GPU support");
#endif // PADDLE_ONLY_CPU
#endif // PADDLE_WITH_CUDA
}
if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp;
......@@ -71,6 +77,7 @@ void testIm2col() {
output_ocf.mutable_data<float>(
{output_height, output_width, 1, filter_size, filter_size}, *place);
// Im2Col
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO, Place, float>
im2col;
......@@ -78,8 +85,13 @@ void testIm2col() {
paddle::operators::math::ColFormat::kOCF, Place, float>
im2col_ocf;
im2col(*context, input, output_cfo, stride, stride, padding, padding);
im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding);
im2col(*context, input, output_cfo, stride, stride, padding, padding, padding,
padding);
im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding,
padding, padding);
float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5};
float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5};
float* out_cfo_ptr;
if (paddle::platform::is_cpu_place(*place)) {
......@@ -88,14 +100,9 @@ void testIm2col() {
output_tmp.CopyFrom(output_cfo, paddle::platform::CPUPlace(), *context);
out_cfo_ptr = output_tmp.data<float>();
}
EXPECT_EQ(out_cfo_ptr[0], 0);
EXPECT_EQ(out_cfo_ptr[1], 1);
EXPECT_EQ(out_cfo_ptr[2], 1);
EXPECT_EQ(out_cfo_ptr[3], 2);
EXPECT_EQ(out_cfo_ptr[4], 3);
EXPECT_EQ(out_cfo_ptr[5], 4);
EXPECT_EQ(out_cfo_ptr[6], 4);
EXPECT_EQ(out_cfo_ptr[7], 5);
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(out_cfo_ptr[i], out_cfo_data[i]);
}
float* out_ocf_ptr;
if (paddle::platform::is_cpu_place(*place)) {
......@@ -104,14 +111,60 @@ void testIm2col() {
output_tmp.CopyFrom(output_ocf, paddle::platform::CPUPlace(), *context);
out_ocf_ptr = output_tmp.data<float>();
}
EXPECT_EQ(out_ocf_ptr[0], 0);
EXPECT_EQ(out_ocf_ptr[1], 1);
EXPECT_EQ(out_ocf_ptr[2], 3);
EXPECT_EQ(out_ocf_ptr[3], 4);
EXPECT_EQ(out_ocf_ptr[4], 1);
EXPECT_EQ(out_ocf_ptr[5], 2);
EXPECT_EQ(out_ocf_ptr[6], 4);
EXPECT_EQ(out_ocf_ptr[7], 5);
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(out_ocf_ptr[i], out_ocf_data[i]);
}
// Col2Im: kCFO
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kCFO, Place, float>
col2im;
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kOCF, Place, float>
col2im_ocf;
float col2im_data[] = {0, 2, 2, 3, 8, 5};
memset(input_ptr, 0, 6 * sizeof(float));
if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp;
} else {
input.CopyFrom(input_tmp, *place, *context);
}
col2im(*context, input, output_cfo, stride, stride, padding, padding, padding,
padding);
float* in_ptr;
if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>();
} else {
input_tmp.CopyFrom(input, paddle::platform::CPUPlace(), *context);
in_ptr = input_tmp.data<float>();
}
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(in_ptr[i], col2im_data[i]);
}
// Col2Im: kOCF
memset(input_ptr, 0, 6 * sizeof(float));
if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp;
} else {
input.CopyFrom(input_tmp, *place, *context);
}
col2im_ocf(*context, input, output_ocf, stride, stride, padding, padding,
padding, padding);
if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>();
} else {
input_tmp.CopyFrom(input, paddle::platform::CPUPlace(), *context);
in_ptr = input_tmp.data<float>();
}
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(in_ptr[i], col2im_data[i]);
}
}
TEST(math, im2col) {
......
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags)
nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc nccl.cc DEPS dynamic_loader)
nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc nccl.cc
DEPS dynamic_loader nccl)
......@@ -113,6 +113,10 @@ class Variable(object):
def lod_level(self):
return self.desc.lod_level()
@property
def type(self):
return self.desc.type()
@staticmethod
def _unique_var_name_():
uid = core.unique_integer() # unique during whole process.
......
from paddle.v2.framework.framework import Variable, OpProtoHolder, g_program, g_init_program
import paddle.v2.framework.core as core
import copy
import itertools
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import Variable, g_program, \
g_init_program
def unique_name(prefix):
uid = core.unique_integer() # unique during whole process.
......@@ -130,6 +133,9 @@ class LayerHelper(object):
dtype=dtype,
persistable=False)
def create_variable(self, *args, **kwargs):
return self.program.current_block().create_var(*args, **kwargs)
def create_global_variable(self, *args, **kwargs):
return self.program.global_block().create_var(
*args, persistable=False, **kwargs)
......
from paddle.v2.framework.layer_helper import LayerHelper
from paddle.v2.framework.layer_helper import LayerHelper, unique_name
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import OpProtoHolder, Variable
from paddle.v2.framework.framework import OpProtoHolder, Variable, Program
import re
__all__ = [
'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat'
'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat',
'StaticRNN'
]
......@@ -26,7 +27,9 @@ def fc(input,
mul_results = []
for input_var, param_attr in helper.iter_inputs_and_params():
input_shape = input_var.shape
param_shape = list(input_shape[num_flatten_dims:]) + [size]
param_shape = [
reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1)
] + [size]
w = helper.create_parameter(
attr=param_attr, shape=param_shape, dtype=dtype)
......@@ -38,10 +41,8 @@ def fc(input,
"Y": w,
},
outputs={"Out": tmp},
attrs={
'x_num_col_dims': num_flatten_dims,
'y_num_col_dims': len(input_shape) - num_flatten_dims
})
attrs={'x_num_col_dims': num_flatten_dims,
'y_num_col_dims': 1})
mul_results.append(tmp)
# sum
......@@ -273,3 +274,170 @@ def pool2d(input,
})
return pool_out
class BlockGuard(object):
"""
BlockGuard used to create sub-block in program by using Python `with`
keyword.
"""
def __init__(self, program):
if not isinstance(program, Program):
raise TypeError("BlockGuard takes a program")
self.program = program
def __enter__(self):
self.program.create_block()
def __exit__(self, exc_type, exc_val, exc_tb):
self.program.rollback()
if exc_type is not None:
return False # re-raise exception
return True
class StaticRNNGuard(BlockGuard):
def __init__(self, rnn):
if not isinstance(rnn, StaticRNN):
raise TypeError("StaticRNNGuard takes an StaticRNN")
super(StaticRNNGuard, self).__init__(rnn.helper.program)
self.rnn = rnn
def __enter__(self):
self.rnn.status = StaticRNN.IN_RNN_BLOCK
return super(StaticRNNGuard, self).__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
self.rnn.status = StaticRNN.AFTER_RNN_BLOCK
self.rnn.complete_rnn_op()
return super(StaticRNNGuard, self).__exit__(exc_type, exc_val, exc_tb)
class StaticRNNMemoryLink(object):
"""
:param init: the initial variable for Memory
:type init: Variable
:param pre_mem: the memory variable in previous time step
:type pre_mem: Variable
:param mem: the memory variable in current time step
:type mem: Variable
"""
def __init__(self, init, pre_mem, mem=None):
self.init = init
self.pre_mem = pre_mem
self.mem = mem
class StaticRNN(object):
BEFORE_RNN_BLOCK = 0
IN_RNN_BLOCK = 1
AFTER_RNN_BLOCK = 2
def __init__(self, name=None, program=None):
self.helper = LayerHelper("static_rnn", name=name, program=program)
self.memories = {} # memory map, from pre_mem.name --> MemoryLink
self.inputs = [] # input variable list in current block
self.outputs = [] # output variable list in parent block
self.status = StaticRNN.BEFORE_RNN_BLOCK # status flag.
# sequence length, since it is a static RNN, sequence length are fixed.
self.seq_len = None
def step(self):
return StaticRNNGuard(self)
def _assert_in_rnn_block_(self, method):
if self.status != StaticRNN.IN_RNN_BLOCK:
raise ValueError("You must invoke {0} in rnn block".format(method))
def memory(self, init=None, shape=None, dtype=None, init_value=0):
self._assert_in_rnn_block_('memory')
if init is None:
if shape is None or dtype is None:
raise ValueError(
"if init is None, memory at least need shape and dtype")
parent_block = self.parent_block()
var_name = unique_name("@".join([self.helper.name, "memory_boot"]))
boot_var = parent_block.create_var(
name=var_name, shape=shape, dtype=dtype, persistable=False)
parent_block.append_op(
type="fill_constant",
inputs={},
outputs={'Out': [boot_var]},
attrs={
'value': init_value,
'shape': boot_var.shape,
'data_type': boot_var.data_type
})
return self.memory(init=boot_var)
else:
pre_mem = self.helper.create_variable(
name=unique_name("@".join([self.helper.name, "mem"])),
dtype=init.data_type,
shape=init.shape)
self.memories[pre_mem.name] = StaticRNNMemoryLink(
init=init, pre_mem=pre_mem)
return pre_mem
def step_input(self, x):
self._assert_in_rnn_block_('step_input')
if not isinstance(x, Variable):
raise TypeError("step input takes a Variable")
if self.seq_len is None:
self.seq_len = x.shape[1]
elif self.seq_len != x.shape[1]:
raise ValueError("Static RNN only take fix seq_len input")
ipt = self.helper.create_variable(
name=x.name,
dtype=x.data_type,
shape=[-1] + list(x.shape[2:]),
type=x.type)
self.inputs.append(ipt)
return ipt
def step_output(self, o):
self._assert_in_rnn_block_('step_output')
if not isinstance(o, Variable):
raise TypeError("step output takes a Variable")
out_var = self.parent_block().create_var(
name=o.name,
shape=[-1, self.seq_len] + list(o.shape[1:]),
dtype=o.data_type)
self.outputs.append(out_var)
def output(self, *outputs):
for each in outputs:
self.step_output(each)
def update_memory(self, mem, var):
if not isinstance(mem, Variable) or not isinstance(var, Variable):
raise TypeError("update memory should take variables")
self.memories[mem.name].mem = var
def parent_block(self):
prog = self.helper.program
parent_idx = prog.current_block().parent_idx
assert parent_idx >= 0
parent_block = prog.block(parent_idx)
return parent_block
def __call__(self, *args, **kwargs):
if self.status != StaticRNN.AFTER_RNN_BLOCK:
raise ValueError("RNN output can only be retrieved after rnn block")
if len(self.outputs) == 0:
raise ValueError("RNN has no output")
elif len(self.outputs) == 1:
return self.outputs[0]
else:
return self.outputs
def complete_rnn_op(self):
# TODO(yuyang18): Create RNN Op here.
# Implement this method after RNN op complete.
pass
......@@ -179,7 +179,12 @@ def get_backward_op(scope, op, no_grad_set):
return backward_op
def get_gradient(scope, op, inputs, outputs, grad_name, place,
def get_gradient(scope,
op,
inputs,
outputs,
grad_names,
place,
no_grad_set=None):
ctx = core.DeviceContext.create(place)
......@@ -195,8 +200,10 @@ def get_gradient(scope, op, inputs, outputs, grad_name, place,
backward_op.run(scope, ctx)
out = np.array(scope.find_var(grad_name).get_tensor())
return out
return [
np.array(scope.find_var(grad_name).get_tensor())
for grad_name in grad_names
]
def append_input_output(block, op_proto, np_list, is_input):
......@@ -326,20 +333,31 @@ class OpTest(unittest.TestCase):
type(sub_out))
for sub_out_name, expect in sub_out:
idx = find_actual(sub_out_name, fetch_list)
actual = outs[idx]
actual_t = np.array(outs[idx])
expect_t = expect[0] \
if isinstance(expect, tuple) else expect
self.assertTrue(
np.allclose(
actual, expect, atol=atol),
actual_t, expect_t, atol=atol),
"Output (" + sub_out_name + ") has diff at " +
str(place))
if isinstance(expect, tuple):
self.assertListEqual(
actual_t.lod(), expect[1], "Output (" + sub_out_name
+ ") has different lod at " + str(place))
else:
idx = find_actual(out_name, fetch_list)
actual = outs[idx]
actual_t = outs[idx]
expect = self.outputs[out_name]
expect_t = expect[0] if isinstance(expect, tuple) else expect
self.assertTrue(
np.allclose(
actual, expect, atol=atol),
actual_t, expect_t, atol=atol),
"Output (" + out_name + ") has diff at " + str(place))
if isinstance(expect, tuple):
self.assertListEqual(actual_t.lod(), expect[1],
"Output (" + out_name +
") has different lod at " + str(place))
def check_output(self, atol=1e-5):
places = [core.CPUPlace()]
......@@ -399,11 +417,9 @@ class OpTest(unittest.TestCase):
]
cpu_place = core.CPUPlace()
cpu_analytic_grads = [
get_gradient(self.scope, self.op, self.inputs, self.outputs,
grad_name, cpu_place, no_grad_set)
for grad_name in grad_names
]
cpu_analytic_grads = get_gradient(self.scope, self.op, self.inputs,
self.outputs, grad_names, cpu_place,
no_grad_set)
self.__assert_is_close(numeric_grads, cpu_analytic_grads, grad_names,
max_relative_error,
......@@ -411,11 +427,9 @@ class OpTest(unittest.TestCase):
if core.is_compile_gpu() and self.op.support_gpu():
gpu_place = core.GPUPlace(0)
gpu_analytic_grads = [
get_gradient(self.scope, self.op, self.inputs, self.outputs,
grad_name, gpu_place, no_grad_set)
for grad_name in grad_names
]
gpu_analytic_grads = get_gradient(self.scope, self.op, self.inputs,
self.outputs, grad_names,
gpu_place, no_grad_set)
self.__assert_is_close(numeric_grads, gpu_analytic_grads,
grad_names, max_relative_error,
......
import unittest
import numpy as np
from op_test import OpTest
def conv2dtranspose_forward_naive(input_, filter_, conv2dtranspose_param):
# [2, 3, 5, 5]
in_n, in_c, in_h, in_w = input_.shape
# [3, 6, 3, 3]
f_c, out_c, f_h, f_w = filter_.shape
assert in_c == f_c
stride, pad = conv2dtranspose_param['stride'], conv2dtranspose_param['pad']
out_h = (in_h - 1) * stride[0] + f_h
out_w = (in_w - 1) * stride[1] + f_w
out = np.zeros((in_n, out_c, out_h, out_w))
for n in range(in_n):
for i in range(in_h):
for j in range(in_w):
input_masked = input_[n, :, i, j] # (c)
input_masked = np.reshape(input_masked, (in_c, 1, 1))
input_masked = np.tile(input_masked, (1, f_h, f_w))
for k in range(out_c):
tmp_out = np.sum(input_masked * filter_[:, k, :, :], axis=0)
i1, i2 = i * stride[0], i * stride[0] + f_h
j1, j2 = j * stride[0], j * stride[0] + f_w
out[n, k, i1:i2, j1:j2] += tmp_out
return out
class TestConv2dTransposeOp(OpTest):
def setUp(self):
# init as conv transpose
self.init_op_type()
# [2, 3, 5, 5] -> kernel [3, 6, 3, 3] -> output [2, 6, 7, 7]
self.init_test_case()
conv2dtranspose_param = {'stride': self.stride, 'pad': self.pad}
input_ = np.random.random(self.input_size).astype("float32")
filter_ = np.random.random(self.filter_size).astype("float32")
output = conv2dtranspose_forward_naive(input_, filter_,
conv2dtranspose_param)
# print 'deconv output py', output, output.shape
self.inputs = {'Input': input_, 'Filter': filter_}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
# 'dilations': self.dilations
}
self.outputs = {'Output': output}
def test_check_output(self):
print 'check output here'
self.check_output()
def test_check_grad(self):
self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.05)
def test_check_grad_no_filter(self):
self.check_grad(
['Input'],
'Output',
max_relative_error=0.05,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
self.check_grad(
['Filter'],
'Output',
max_relative_error=0.05,
no_grad_set=set(['Input']))
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_op_type(self):
self.op_type = "conv2dtranspose"
"""
class TestCudnn(TestConv2dOp):
def init_group(self):
self.groups = 1
def init_op_type(self):
self.op_type = "conv_cudnn"
"""
if __name__ == '__main__':
unittest.main()
......@@ -155,7 +155,11 @@ class TestLstmOp(OpTest):
'Weight': w,
'Bias': b
}
self.outputs = {'Hidden': h, 'Cell': c, 'BatchGate': g_sort}
self.outputs = {
'Hidden': (h, self.lod),
'Cell': (c, self.lod),
'BatchGate': g_sort
}
self.attrs = {
'usePeepholes': True,
'isReverse': self.is_reverse,
......
import unittest
from paddle.v2.framework.layers import *
from paddle.v2.framework.framework import g_program
class TestRNN(unittest.TestCase):
def test_rnn(self):
img = data(
shape=[
80, # sequence length
22, # image height
22
], # image width
data_type='float32',
name='image')
hidden = fc(input=img, size=100, act='sigmoid', num_flatten_dims=2)
self.assertEqual((-1, 80, 100), hidden.shape)
hidden = fc(input=hidden, size=100, act='sigmoid', num_flatten_dims=2)
self.assertEqual((-1, 80, 100), hidden.shape)
rnn = StaticRNN()
with rnn.step():
hidden = rnn.step_input(hidden)
self.assertEqual((-1, 100), hidden.shape)
memory = rnn.memory(shape=(-1, 32), dtype='float32', init_value=0.0)
rnn_out = fc(input=[hidden, memory], size=32, act='sigmoid')
self.assertEqual((-1, 32), rnn_out.shape)
rnn.update_memory(memory, rnn_out)
rnn.output(rnn_out)
out = rnn()
self.assertEqual((-1, 80, 32), out.shape)
print g_program
if __name__ == '__main__':
unittest.main()
......@@ -326,6 +326,17 @@ class Parameters(object):
self.set(name, arr.reshape(self.get_shape(name)))
def to_tar(self, f):
"""
Save parameters to a tar file.
WARNING: You should use `paddle.v2.trainer.SGD.save_parameter_to_tar(f)`
to save parameters most of the time. Otherwise, some settings such
as model average will not take effect.
:param f:
:type f: file
:return:
"""
tar = tarfile.TarFile(fileobj=f, mode='w')
for nm in self.names():
buf = cStringIO.StringIO()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册