diff --git a/benchmark/paddle/image/provider.py b/benchmark/paddle/image/provider.py index 1ac47212b5a75667e8e9d4465b33f575516e2836..4703944c8722552d56ba80a8e0663de5fb4df53d 100644 --- a/benchmark/paddle/image/provider.py +++ b/benchmark/paddle/image/provider.py @@ -22,5 +22,5 @@ def initHook(settings, height, width, color, num_class, **kwargs): def process(settings, file_list): for i in xrange(1024): img = np.random.rand(1, settings.data_size).reshape(-1, 1).flatten() - lab = random.randint(0, settings.num_class) + lab = random.randint(0, settings.num_class - 1) yield img.astype('float32'), int(lab) diff --git a/benchmark/paddle/image/run_mkldnn.sh b/benchmark/paddle/image/run_mkldnn.sh new file mode 100755 index 0000000000000000000000000000000000000000..5b0a0373448e5b81ff0718db3465a4694690ec37 --- /dev/null +++ b/benchmark/paddle/image/run_mkldnn.sh @@ -0,0 +1,51 @@ +set -e + +unset OMP_NUM_THREADS MKL_NUM_THREADS +export OMP_DYNAMIC="FALSE" +export KMP_AFFINITY="granularity=fine,compact,0,0" + +function train() { + topology=$1 + bs=$2 + use_mkldnn=$3 + if [ $3 == "True" ]; then + use_mkldnn=$3 + thread=1 + log="logs/${topology}-mkldnn-${bs}.log" + elif [ $3 == "False" ]; then + use_mkldnn=$3 + thread=`nproc` + log="logs/${topology}-${thread}mklml-${bs}.log" + else + echo "Wrong input $3, use True or False." + fi + args="batch_size=${bs}" + config="${topology}.py" + paddle train --job=time \ + --config=$config \ + --use_mkldnn=$use_mkldnn \ + --use_gpu=False \ + --trainer_count=$thread \ + --log_period=10 \ + --test_period=100 \ + --config_args=$args \ + 2>&1 | tee ${log} +} + +if [ ! -d "train.list" ]; then + echo " " > train.list +fi +if [ ! -d "logs" ]; then + mkdir logs +fi + +#========= mkldnn =========# +# vgg +train vgg 64 True +train vgg 128 True +train vgg 256 True + +#========== mklml ===========# +train vgg 64 False +train vgg 128 False +train vgg 256 False diff --git a/benchmark/paddle/image/vgg.py b/benchmark/paddle/image/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..b8429975f5c83df6996e71478fe276b246e8b77b --- /dev/null +++ b/benchmark/paddle/image/vgg.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +from paddle.trainer_config_helpers import * + +height = 224 +width = 224 +num_class = 1000 +batch_size = get_config_arg('batch_size', int, 64) +layer_num = get_config_arg('layer_num', int, 19) + +args = {'height': height, 'width': width, 'color': True, 'num_class': num_class} +define_py_data_sources2( + "train.list", None, module="provider", obj="process", args=args) + +settings( + batch_size=batch_size, + learning_rate=0.01 / batch_size, + learning_method=MomentumOptimizer(0.9), + regularization=L2Regularization(0.0005 * batch_size)) + +img = data_layer(name='image', size=height * width * 3) + + +def vgg_network(vgg_num=3): + tmp = img_conv_group( + input=img, + num_channels=3, + conv_padding=1, + conv_num_filter=[64, 64], + conv_filter_size=3, + conv_act=ReluActivation(), + pool_size=2, + pool_stride=2, + pool_type=MaxPooling()) + + tmp = img_conv_group( + input=tmp, + conv_num_filter=[128, 128], + conv_padding=1, + conv_filter_size=3, + conv_act=ReluActivation(), + pool_stride=2, + pool_type=MaxPooling(), + pool_size=2) + + channels = [] + for i in range(vgg_num): + channels.append(256) + tmp = img_conv_group( + input=tmp, + conv_num_filter=channels, + conv_padding=1, + conv_filter_size=3, + conv_act=ReluActivation(), + pool_stride=2, + pool_type=MaxPooling(), + pool_size=2) + channels = [] + for i in range(vgg_num): + channels.append(512) + tmp = img_conv_group( + input=tmp, + conv_num_filter=channels, + conv_padding=1, + conv_filter_size=3, + conv_act=ReluActivation(), + pool_stride=2, + pool_type=MaxPooling(), + pool_size=2) + tmp = img_conv_group( + input=tmp, + conv_num_filter=channels, + conv_padding=1, + conv_filter_size=3, + conv_act=ReluActivation(), + pool_stride=2, + pool_type=MaxPooling(), + pool_size=2) + + tmp = fc_layer( + input=tmp, + size=4096, + act=ReluActivation(), + layer_attr=ExtraAttr(drop_rate=0.5)) + + tmp = fc_layer( + input=tmp, + size=4096, + act=ReluActivation(), + layer_attr=ExtraAttr(drop_rate=0.5)) + + return fc_layer(input=tmp, size=num_class, act=SoftmaxActivation()) + + +if layer_num == 16: + vgg = vgg_network(3) +elif layer_num == 19: + vgg = vgg_network(4) +else: + print("Wrong layer number.") + +lab = data_layer('label', num_class) +loss = cross_entropy(input=vgg, label=lab) +outputs(loss) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 0bbf92293168d4e3af2c1ed0e82b75e6a8d6c0cd..ff9868fc4e0d970b11e4763d2e0c8581f4f85907 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -253,7 +253,7 @@ function(nv_library TARGET_NAME) foreach(source_file ${nv_library_SRCS}) string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file}) if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) - list(APPEND cc_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + list(APPEND nv_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) endif() endforeach() add_style_check_target(${TARGET_NAME} ${nv_library_SRCS} ${nv_library_HEADERS}) diff --git a/cmake/util.cmake b/cmake/util.cmake index ac911052eb970c5a3e485e3178dd788b1517ca30..d1aee3e170a2d143ac06b438725e907e96f041c8 100644 --- a/cmake/util.cmake +++ b/cmake/util.cmake @@ -97,6 +97,10 @@ function(link_paddle_exe TARGET_NAME) target_link_libraries(${TARGET_NAME} log) endif(ANDROID) + if(WITH_MKLDNN AND WITH_MKLML AND MKLDNN_IOMP_DIR) + target_link_libraries(${TARGET_NAME} "-L${MKLDNN_IOMP_DIR} -liomp5 -Wl,--as-needed") + endif() + add_dependencies(${TARGET_NAME} ${external_project_dependencies}) endfunction() diff --git a/doc/getstarted/build_and_install/docker_install_cn.rst b/doc/getstarted/build_and_install/docker_install_cn.rst index 84e33177740ca1652efc09c8081c2519b4366906..30b144d849bec367cd0197b6082889e011193a9a 100644 --- a/doc/getstarted/build_and_install/docker_install_cn.rst +++ b/doc/getstarted/build_and_install/docker_install_cn.rst @@ -20,7 +20,7 @@ Docker使用入门 docker pull paddlepaddle/paddle:0.10.0 - 来下载Docker镜像,paddlepaddle/paddle是从官方镜像源Dockerhub.com下载的,推荐国内用户使用ocker.paddlepaddle.org/paddle下载。 + 来下载Docker镜像,paddlepaddle/paddle是从官方镜像源Dockerhub.com下载的,推荐国内用户使用docker.paddlepaddle.org/paddle下载。 - *容器*: 如果说一个Docker镜像就是一个程序,那容器就是这个程序运行时产生的“进程”。 实际上,一个容器就是一个操作系统的进程,但是是运行在独立的进程空间,文件系统以及网络之上。 diff --git a/doc/howto/dev/new_op_en.md b/doc/howto/dev/new_op_en.md new file mode 100644 index 0000000000000000000000000000000000000000..b7aa501db9e5c7378398fad48503f82bff893b60 --- /dev/null +++ b/doc/howto/dev/new_op_en.md @@ -0,0 +1,235 @@ +# How to write a new operator + + - [Background](#Background) + - [Implementing C++ Types](#Implementing_C++_Types) + - [Defining ProtoMaker](#Defining_ProtoMaker) + - [Defining Operator](#Defining_Operator) + - [Registering Operator](#Registering_Operator) + - [Compilation](#Compilation) + - [Python Binding](#Python_Binding) + - [Unit Tests](#Unit_Tests) + +## Background + +Here are the base types needed. For details, please refer to the design docs. + +- `framework::OperatorBase`: Operator (Op)base class. +- `framework::OpKernel`: Base class for Op computation. +- `framework::OperatorWithKernel`: Inherited from OperatorBase, describing an operator with computation. +- `class OpProtoAndCheckerMaker`: Describes an Operator's input, output, attributes and description, mainly used to interface with Python API. + +An operator can be differentiated by whether in has kernel methods. An operator with kernel inherits from `OperatorWithKernel` while the ones without inherit from `OperatorBase`. This tutorial focuses on implementing operators with kernels. In short, an operator includes the following information: + + + Information | Where is it defined +-------------- | :---------------------- +OpProtoMake definition | `.cc`files, Backward Op does not need an OpProtoMake interface. +Op definition | `.cc` files +Kernel implementation | The kernel methods shared between CPU and GPU are defined in `.h` files. CPU-specific kernels live in `.cc` files, while GPU-specific kernels are implemented in `.cu`files. +Registering the Op | Ops are registered in `.cc` files; For Kernel registration, `.cc` files contain the CPU implementation, while `.cu` files contain the GPU implementation. + + +New Operator implementations are added to the list [paddle/operators](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/operators), with file names in the format `*_op.h` (if applicable), `*_op.cc`, `*_op.cu` (if applicable).** The system will use the naming scheme to automatically build operators and their corresponding Python extensions. ** + + +Let's take matrix multiplication operator, [MulOp](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/mul_op.cc), as an example to introduce the writing of an Operator with Kernel. + + +## Implementing C++ Types + + +### 1. Defining Class ProtoMaker + +Matrix Multiplication can be written as $Out = X * Y$, meaning that the operation consists of two inputs and pne output. + +First, define `ProtoMaker` to describe the Operator's input, output, and additional comments: + +```cpp +class MulOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "(Tensor), 2D tensor of size (M x K)"); + AddInput("Y", "(Tensor), 2D tensor of size (K x N)"); + AddOutput("Out", "(Tensor), 2D tensor of size (M x N)"); + AddComment(R"DOC( +Two Element Mul Operator. +The equation is: Out = X * Y +)DOC"); + } +}; +``` + +[`MulOpMaker`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/mul_op.cc#L43)is inherited from`framework::OpProtoAndCheckerMaker`, consisting of 2 variables in the constructor: + + - `framework::OpProto` stores Operator input and variable attribute, used for generating Python API interfaces. + - `framework::OpAttrChecker` is used to validate variable attributes. + +The constructor utilizes `AddInput`, `AddOutput`, and `AddComment`, so that the corresponding information will be added to `OpProto`. + +The code above adds two inputs `X` and `Y` to `MulOp`, an output `Out`, and their corresponding descriptions, in accordance to Paddle's [naming convention](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/name_convention.md). + + +An additional example [`ScaleOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/scale_op.cc#L37) is implemented as follows: + +```cpp +template +class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input tensor of scale operator.").NotInGradient(); + AddOutput("Out", "The output tensor of scale operator.").NotInGradient(); + AddComment(R"DOC(Scale operator +The equation is: Out = scale*X +)DOC"); + AddAttr("scale", "scale of scale operator.").SetDefault(1.0); + } +}; +``` + +There are two changes in this example: + +- `AddInput("X","...").NotInGradient()` expresses that input `X` is not involved in `ScaleOp`'s corresponding computation. If an input to an operator is not participating in back-propagation, please explicitly set `.NotInGradient()`. + +- `AddAttr("scale", "...").SetDefault(1.0);` adds `scale`constant as an attribute, and sets the default value to 1.0. + + +### 2. Defining Operator + +The following code defines the interface for MulOp: + +```cpp +class MulOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto dim0 = ctx.Input("X")->dims(); + auto dim1 = ctx.Input("Y")->dims(); + PADDLE_ENFORCE_EQ(dim0.size(), 2, + "input X(%s) should be a tensor with 2 dims, a matrix", + ctx.op_.Input("X")); + PADDLE_ENFORCE_EQ(dim1.size(), 2, + "input Y(%s) should be a tensor with 2 dims, a matrix", + ctx.op_.Input("Y")); + PADDLE_ENFORCE_EQ( + dim0[1], dim1[0], + "First matrix's width must be equal with second matrix's height."); + ctx.Output("Out")->Resize({dim0[0], dim1[1]}); + } +}; +``` + +[`MulOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/mul_op.cc#L22) is inherited from `OperatorWithKernel`. Its `public` member + +```cpp +using framework::OperatorWithKernel::OperatorWithKernel; +``` + +expresses an operator constructor using base class `OperatorWithKernel`, alternatively written as + +```cpp +MulOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} +``` + +`InferShape` interface needs to be re-written.`InferShape` is a constant method and cannot modify Op's member variables, its constant member `const framework::InferShapeContext &ctx` can be used to extract input, output, and attributes. It functions to + + - 1). validate and error out early: it checks input data dimensions and types. + - 2). configures the tensor shape in the output. + +Usually `OpProtoMaker` and `Op`'s type definitions are written in `.cc` files, which also include the registration methods introduced later. + +### 3. Defining OpKernel + +`MulKernel` inherits `framework::OpKernel`, which includes the following templates: + +- `typename Place` denotes device type. When different devices, namely the CPU and the GPU, share the same kernel, this template needs to be added. If they don't share kernels, this must not be added. An example of a non-sharing kernel is [`OnehotCrossEntropyOpKernel`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/cross_entropy_op.h#L43). + +- `typename T` denotes data type, such as `float` or `double`. + +`MulKernel` types need to rewrite the interface for `Compute`. +- `Compute` takes one input variable `const framework::ExecutionContext& context`. +- Compared with `InferShapeContext`, `ExecutionContext` includes device types, and can similarly extract input, output, and attribute variables. +- `Compute` implements the computation logics of an `OpKernel`. + +`MulKernel`'s implementation of `Compute` is as follows: + + ```cpp + template + class MulKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Y = context.Input("Y"); + auto* Z = context.Output("Out"); + Z->mutable_data(context.GetPlace()); + auto* device_context = + const_cast(context.device_context_); + math::matmul(*X, false, *Y, false, 1, Z, 0, device_context); + } + }; + ``` + +Note that **different devices (CPU, GPU)share an Op definition; whether or not they share the same `OpKernel` depends on whether `Compute` calls functions that support both devices.** + +`MulOp`'s CPU and GPU share the same `Kernel`. A non-sharing `OpKernel` example can be seen in [`OnehotCrossEntropyOpKernel`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/cross_entropy_op.h#L43). + +To ease the writing of `OpKernel` compute, and for reusing code cross-device, `Eigen unsupported Tensor` module is used to implement `Compute` interface. To learn about how the Eigen library is used in PaddlePaddle, please see [usage document](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/use_eigen_cn.md). + + +This concludes the forward implementation of an operator. Next its operation and kernel need to be registered in a `.cc` file. + +The definition of its corresponding backward operator, if applicable, is similar to that of an forward operator. **Note that a backward operator does not include a `ProtoMaker`**. + +### 4. Registering Operator + +- In `.cc` files, register forward and backward operator classes and the CPU kernel. + + ```cpp + namespace ops = paddle::operators; + REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad); + REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel); + REGISTER_OP_CPU_KERNEL(mul_grad, + ops::MulGradKernel); + ``` + + In that code block, + + - `REGISTER_OP` registers the `ops::MulOp` class, type named `mul`, its type `ProtoMaker` is `ops::MulOpMaker`, registering `ops::MulOpGrad` as `mul_grad`. + - `REGISTER_OP_WITHOUT_GRADIENT` registers an operator without gradient. + - `REGISTER_OP_CPU_KERNEL` registers `ops::MulKernel` class and specialized template types `paddle::platform::CPUPlace` and `float`, which also registers `ops::MulKernel`. + + +- Registering GPU Kernel in `.cu` files + - Note that if GPU Kernel is implemented using the `Eigen unsupported` module, then on top of `.cu`, a macro definition `#define EIGEN_USE_GPU` is needed, such as + + ```cpp + // if use Eigen unsupported module before include head files + #define EIGEN_USE_GPU + + namespace ops = paddle::operators; + REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); + REGISTER_OP_GPU_KERNEL(mul_grad, + ops::MulGradKernel); + ``` + +### 5. Compilation + +Run the following commands to compile. + +``` +make mul_op +``` + +## Python Binding + +The system will automatically bind to Python and link it to a generated library. + +## Unit Tests + +Unit tests include comparing a forward operator's implementations on different devices, comparing a backward operator's implementation on different devices, and a scaling test for the backward operator. Here, we introduce the [unit tests for `MulOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/framework/tests/test_mul_op.py). diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc index fda89252e35c382468877e8cab148e5f91d77ac2..510dc28c57f642786e7c64d86961c76ac80014a8 100644 --- a/paddle/framework/attribute.cc +++ b/paddle/framework/attribute.cc @@ -28,47 +28,6 @@ ProgramDesc& GetProgramDesc() { return *g_program_desc; } -template <> -AttrType AttrTypeID() { - return BOOLEAN; -} -template <> -AttrType AttrTypeID() { - return INT; -} -template <> -AttrType AttrTypeID() { - return FLOAT; -} -template <> -AttrType AttrTypeID() { - return STRING; -} -template <> -AttrType AttrTypeID>() { - return BOOLEANS; -} -template <> -AttrType AttrTypeID>() { - return INTS; -} -template <> -AttrType AttrTypeID>() { - return FLOATS; -} -template <> -AttrType AttrTypeID>() { - return STRINGS; -} -template <> -AttrType AttrTypeID>>() { - return INT_PAIRS; -} -template <> -AttrType AttrTypeID() { - return BLOCK; -} - Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { switch (attr_desc.type()) { case framework::AttrType::BOOLEAN: { @@ -111,14 +70,6 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { } return val; } - case framework::AttrType::INT_PAIRS: { - std::vector> val(attr_desc.int_pairs_size()); - for (int i = 0; i < attr_desc.int_pairs_size(); ++i) { - val[i].first = attr_desc.int_pairs(i).first(); - val[i].second = attr_desc.int_pairs(i).second(); - } - return val; - } case framework::AttrType::BLOCK: { return GetProgramDesc().mutable_blocks(attr_desc.block_idx()); } diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 48b54b5422de8c45e15a1b7040b78373dce8fa3a..488fa38faf12ee51087643f79295f36bfd33ee22 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -27,10 +27,10 @@ limitations under the License. */ namespace paddle { namespace framework { -typedef boost::variant, std::vector, std::vector, - std::vector, - std::vector>, BlockDesc*> +// The order should be as same as framework.proto +typedef boost::variant, + std::vector, std::vector, bool, + std::vector, BlockDesc*> Attribute; typedef std::unordered_map AttributeMap; @@ -38,7 +38,10 @@ typedef std::unordered_map AttributeMap; ProgramDesc& GetProgramDesc(); template -AttrType AttrTypeID(); +inline AttrType AttrTypeID() { + Attribute tmp = T(); + return static_cast(tmp.which() - 1); +} Attribute GetAttrValue(const OpDesc::Attr& attr_desc); diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index 6fcfe6de25737b66a2ea6c1a438636f072a513bb..951c7afbc14e2d9119169c1351d38ff0b67bdc5b 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -22,17 +22,11 @@ enum AttrType { INTS = 3; FLOATS = 4; STRINGS = 5; - INT_PAIRS = 6; - BOOLEAN = 7; - BOOLEANS = 8; - BLOCK = 9; + BOOLEAN = 6; + BOOLEANS = 7; + BLOCK = 8; } -message IntPair { - required int32 first = 1; - required int32 second = 2; -}; - // OpDesc describes an instance of a C++ framework::OperatorBase // derived class type. message OpDesc { @@ -46,7 +40,6 @@ message OpDesc { repeated int32 ints = 6; repeated float floats = 7; repeated string strings = 8; - repeated IntPair int_pairs = 9; optional bool b = 10; repeated bool bools = 11; optional int32 block_idx = 12; diff --git a/paddle/gserver/activations/MKLDNNActivation.h b/paddle/gserver/activations/MKLDNNActivation.h index 86ffe387366409d81a91740cc8cea886e618f7e2..40dd8c618aa2b70d410130e12efc54520218afea 100644 --- a/paddle/gserver/activations/MKLDNNActivation.h +++ b/paddle/gserver/activations/MKLDNNActivation.h @@ -100,6 +100,7 @@ public: if (cnt_ == act.value->getElementCnt()) { return; } + VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward"; cnt_ = act.value->getElementCnt(); stream_.reset(new MKLDNNStream()); auto eng = CPUEngine::Instance().getEngine(); @@ -110,7 +111,6 @@ public: float alpha = getAlpha(); float beta = getBeta(); - /// forward pipelineFwd_.clear(); val_ = std::dynamic_pointer_cast(act.value); if (val_ == nullptr) { @@ -152,6 +152,7 @@ public: if (!needResetBwd_) { return; } + VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward"; needResetBwd_ = false; mkldnn::algorithm algo = getAlgo(this->getName()); float alpha = getBwdAlpha(); diff --git a/paddle/gserver/layers/MKLDNNConvLayer.cpp b/paddle/gserver/layers/MKLDNNConvLayer.cpp index 88b047c89bd40aba1afc456c22a2870c62989c1c..9a0abd291ae8fae43b0e95c7371f3ce35d1261ec 100644 --- a/paddle/gserver/layers/MKLDNNConvLayer.cpp +++ b/paddle/gserver/layers/MKLDNNConvLayer.cpp @@ -64,7 +64,7 @@ bool MKLDNNConvLayer::init(const LayerMap& layerMap, // create biases if (biasParameter_.get() != NULL) { - biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_)); + biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_, 0)); } return true; } @@ -251,22 +251,31 @@ void MKLDNNConvLayer::resetInValue( // create buffer and reorder if input value do not match cpuInVal_ = nullptr; cvtInVal_ = nullptr; - if (inputIsOnlyMKLDNN()) { - MKLDNNMatrixPtr dnnIn = std::dynamic_pointer_cast(inMat); - CHECK(dnnIn) << "Input should be MKLDNNMatrix"; - if (dnnIn->getPrimitiveDesc() != in->getPrimitiveDesc()) { - CHECK_EQ(dnnIn->getFormat(), format::nc); + + MKLDNNMatrixPtr dnnIn = std::dynamic_pointer_cast(inMat); + CHECK_EQ(inputIsOnlyMKLDNN(), dnnIn != nullptr); + if (dnnIn != nullptr && dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()) { + in = dnnIn; + return; + } + if (dnnIn) { + if (dnnIn->getFormat() == format::nc) { CHECK(ih_ == 1 && iw_ == 1) << "when input is nc format"; // create a new one with nchw format and same data memory::dims inDims = memory::dims{bs_, ic_, 1, 1}; dnnIn = MKLDNNMatrix::create(inMat, inDims, format::nchw, engine_); - CHECK(dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()); } - in = dnnIn; + if (dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()) { + in = dnnIn; + return; + } + cpuInVal_ = dnnIn; + in = MKLDNNMatrix::create(nullptr, pd->src_primitive_desc()); + cvtInVal_ = MKLDNNMatrix::createReorder(cpuInVal_, in); + CHECK(cvtInVal_) << "should not be emptry"; } else { - const MatrixPtr& cpuIn = getInputValue(0, CPU_DEVICE); memory::dims inDims = memory::dims{bs_, ic_, ih_, iw_}; - cpuInVal_ = MKLDNNMatrix::create(cpuIn, inDims, format::nchw, engine_); + cpuInVal_ = MKLDNNMatrix::create(inMat, inDims, format::nchw, engine_); if (cpuInVal_->getPrimitiveDesc() != in->getPrimitiveDesc()) { // create new mkldnn matrix in = MKLDNNMatrix::create(nullptr, pd->src_primitive_desc()); @@ -535,7 +544,7 @@ void MKLDNNConvLayer::resetWgtValBwdData( } else { wgtValBwdData_ = wgtVal_; } - VLOG(MKLDNN_FMTS) << "weight value format for backward data" + VLOG(MKLDNN_FMTS) << "weight value format for backward data: " << wgtValBwdData_->getFormat(); } diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp index afd092666bf8b8a3389b36aa1f0edb256a9968e6..8cbfbd0d2b9f2149f7c959aec5c4ae1de952f903 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.cpp +++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp @@ -49,7 +49,7 @@ bool MKLDNNFcLayer::init(const LayerMap& layerMap, // create biases if (biasParameter_.get() != NULL) { - biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_)); + biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_, 0)); } return true; } @@ -161,9 +161,16 @@ void MKLDNNFcLayer::resetInValue(MKLDNNMatrixPtr& in) { void MKLDNNFcLayer::resetWgtBiasValue(MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias) { + format wgtFmt = format::oihw; + if (inVal_->getFormat() == format::nChw8c) { + wgtFmt = format::oIhw8i; + } else if (inVal_->getFormat() == format::nChw16c) { + wgtFmt = format::oIhw16i; + } wgt = MKLDNNMatrix::create( - weight_->getW(), {oc_, ic_, ih_, iw_}, format::oihw, engine_); + weight_->getW(), {oc_, ic_, ih_, iw_}, wgtFmt, engine_); wgt->downSpatial(); + VLOG(MKLDNN_FMTS) << "Weight value format: " << wgt->getFormat(); bias = (biases_ && biases_->getW()) ? MKLDNNMatrix::create(biases_->getW(), {oc_}, format::x, engine_) diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index d8555a833187ddf64b096135e920e5be2b3a8c2f..c09fd89462ef4fdaeaae3e122f96b0cc6ce373ea 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -115,6 +115,7 @@ public: copySeqInfoToOutputs(); size_t elemenCnt = inputLayers_[0]->getOutput().value->getElementCnt(); if (inputElemenCnt_ != elemenCnt) { + VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward"; // reset when input total sizes changed, not only the batchsize inputElemenCnt_ = elemenCnt; reshape(bs_, ic_, ih_, iw_, oc_, oh_, ow_); @@ -142,6 +143,7 @@ public: void backward(const UpdateCallback& callback) override { if (needResetBwd_) { + VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward"; resetBwd(pipelineBwd_, inGrad_, wgtGrad_, biasGrad_, outGrad_); needResetBwd_ = false; } diff --git a/paddle/operators/lstm_unit_op.cc b/paddle/operators/lstm_unit_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3600f199770c4b8c9a6561b4c270a91bc8b20c0b --- /dev/null +++ b/paddle/operators/lstm_unit_op.cc @@ -0,0 +1,103 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/lstm_unit_op.h" + +namespace paddle { +namespace operators { + +class LstmUnitOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of LSTM should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("C_prev"), + "Input(C_prev) of LSTM should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("C"), + "Output(C) of LSTM should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("H"), + "Output(H) of LSTM should not be null."); + + auto *x = ctx.Input("X"); + auto *c_prev = ctx.Input("C_prev"); + + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); + PADDLE_ENFORCE(x->dims()[0] == c_prev->dims()[0], + "Batch size of inputs and states must be equal"); + PADDLE_ENFORCE(x->dims()[1] == c_prev->dims()[1] * 4, + "Dimension of FC should equal to prev state * 4"); + + int b_size = c_prev->dims()[0]; // batch size + int s_dim = c_prev->dims()[1]; // state dim + ctx.Output("C")->Resize({b_size, s_dim}); + ctx.Output("H")->Resize({b_size, s_dim}); + } +}; + +template +class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LstmUnitOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "FC input before the non-linear activation."); + AddInput( + "C_prev", + "The cell state tensor of last time-step in the Lstm Unit operator."); + AddOutput("C", "The cell tensor of Lstm Unit operator."); + AddOutput("H", "The hidden state tensor of Lstm Unit operator."); + + AddComment(R"DOC(Lstm-Unit Operator + +Equation: + i, f, o, j = split(X) + C = C_prev * sigm(f + forget_bias) + sigm(i) * tanh(j) + H = C * sigm(o) + +)DOC"); + AddAttr("forget_bias", "The forget bias of Lstm Unit.") + .SetDefault(0.0); + } +}; + +class LstmUnitGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("C")), + "Input(C@GRAD) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("H")), + "Input(H@GRAD) should not be null"); + ctx.Output(framework::GradVarName("X")) + ->Resize(ctx.Input("X")->dims()); + ctx.Output(framework::GradVarName("C_prev")) + ->Resize(ctx.Input("C_prev")->dims()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(lstm_unit, ops::LstmUnitOp, ops::LstmUnitOpMaker, + lstm_unit_grad, ops::LstmUnitGradOp); +REGISTER_OP_CPU_KERNEL(lstm_unit, + ops::LstmUnitKernel); +REGISTER_OP_CPU_KERNEL( + lstm_unit_grad, ops::LstmUnitGradKernel); diff --git a/paddle/operators/lstm_unit_op.cu b/paddle/operators/lstm_unit_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..6e5e4978994c281416a65af5f8ffdec688768d63 --- /dev/null +++ b/paddle/operators/lstm_unit_op.cu @@ -0,0 +1,173 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/cross_entropy_op.h" +#include "paddle/platform/assert.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__device__ Dtype cuda_sigmoid(const Dtype x) { + return Dtype(1) / (Dtype(1) + exp(-x)); +} + +template +__device__ Dtype cuda_tanh(const Dtype x) { + return Dtype(1 - exp(-2. * x)) / (Dtype(1) + exp(-2. * x)); +} + +template +__global__ void LSTMUnitKernel(const int nthreads, const int dim, + const T* C_prev, const T* X, T* C, T* H, + const T forget_bias) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + const int n = index / dim; + const int d = index % dim; + + const T* X_offset = X + 4 * dim * n; + const T i = cuda_sigmoid(X_offset[d]); + const T f = cuda_sigmoid(X_offset[1 * dim + d] + forget_bias); + const T o = cuda_sigmoid(X_offset[2 * dim + d]); + const T g = cuda_tanh(X_offset[3 * dim + d]); + const T c_prev = C_prev[index]; + const T c = f * c_prev + i * g; + C[index] = c; + const T tanh_c = cuda_tanh(c); + H[index] = o * tanh_c; + } +} + +template +__global__ void LSTMUnitGradientKernel(const int nthreads, const int dim, + const T* C_prev, const T* X, const T* C, + const T* H, const T* C_diff, + const T* H_diff, T* C_prev_diff, + T* X_diff, const T forget_bias) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + const int n = index / dim; + const int d = index % dim; + const T* X_offset = X + 4 * dim * n; + T* c_prev_diff = C_prev_diff + index; + T* X_diff_offset = X_diff + 4 * dim * n; + T* i_diff = X_diff_offset + d; + T* f_diff = X_diff_offset + 1 * dim + d; + T* o_diff = X_diff_offset + 2 * dim + d; + T* g_diff = X_diff_offset + 3 * dim + d; + + const T i = cuda_sigmoid(X_offset[d]); + const T f = cuda_sigmoid(X_offset[1 * dim + d] + forget_bias); + const T o = cuda_sigmoid(X_offset[2 * dim + d]); + const T g = cuda_tanh(X_offset[3 * dim + d]); + const T c_prev = C_prev[index]; + const T c = C[index]; + const T tanh_c = cuda_tanh(c); + const T c_term_diff = + C_diff[index] + H_diff[index] * o * (1 - tanh_c * tanh_c); + *c_prev_diff = c_term_diff * f; + *i_diff = c_term_diff * g * i * (1 - i); + *f_diff = c_term_diff * c_prev * f * (1 - f); + *o_diff = H_diff[index] * tanh_c * o * (1 - o); + *g_diff = c_term_diff * i * (1 - g * g); + } +} + +template +class LstmUnitOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto* x_tensor = ctx.Input("X"); + auto* c_prev_tensor = ctx.Input("C_prev"); + auto* c_tensor = ctx.Output("C"); + auto* h_tensor = ctx.Output("H"); + + auto forget_bias = static_cast(ctx.Attr("forget_bias")); + + int b_size = c_tensor->dims()[0]; + int D = c_tensor->dims()[1]; + + const T* X = x_tensor->data(); + const T* C_prev = c_prev_tensor->data(); + + T* C = c_tensor->mutable_data(ctx.GetPlace()); + T* H = h_tensor->mutable_data(ctx.GetPlace()); + + int block = 512; + int n = b_size * D; + int grid = (n + block - 1) / block; + + LSTMUnitKernel<<>>(n, D, C_prev, X, C, H, forget_bias); + } +}; + +template +class LstmUnitGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto x_tensor = ctx.Input("X"); + auto c_prev_tensor = ctx.Input("C_prev"); + auto c_tensor = ctx.Input("C"); + auto h_tensor = ctx.Input("H"); + + auto hdiff_tensor = ctx.Input(framework::GradVarName("H")); + auto cdiff_tensor = ctx.Input(framework::GradVarName("C")); + + auto xdiff_tensor = ctx.Output(framework::GradVarName("X")); + auto c_prev_diff_tensor = + ctx.Output(framework::GradVarName("C_prev")); + + auto* X = x_tensor->data(); + auto* C_prev = c_prev_tensor->data(); + auto* C = c_tensor->data(); + auto* H = h_tensor->data(); + + auto* H_diff = hdiff_tensor->data(); + auto* C_diff = cdiff_tensor->data(); + + auto* C_prev_diff = c_prev_diff_tensor->mutable_data(ctx.GetPlace()); + auto* X_diff = xdiff_tensor->mutable_data(ctx.GetPlace()); + + int N = c_tensor->dims()[0]; + int D = c_tensor->dims()[1]; + + auto forget_bias = static_cast(ctx.Attr("forget_bias")); + + int block = 512; + int n = N * D; + int grid = (n + block - 1) / block; + + LSTMUnitGradientKernel<<>>(n, D, C_prev, X, C, H, C_diff, + H_diff, C_prev_diff, X_diff, + forget_bias); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(lstm_unit, ops::LstmUnitOpCUDAKernel); +REGISTER_OP_GPU_KERNEL(lstm_unit_grad, ops::LstmUnitGradOpCUDAKernel); diff --git a/paddle/operators/lstm_unit_op.h b/paddle/operators/lstm_unit_op.h new file mode 100644 index 0000000000000000000000000000000000000000..683034fe15df8cabfdff5e856adb5c0467055064 --- /dev/null +++ b/paddle/operators/lstm_unit_op.h @@ -0,0 +1,148 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "glog/logging.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::LoDTensor; +using framework::Tensor; + +template +inline T sigmoid(T x) { + return 1. / (1. + exp(-x)); +} + +template +inline T tanh(T x) { + return 2. * sigmoid(2. * x) - 1.; +} + +template +class LstmUnitKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto* x_tensor = ctx.Input("X"); + auto* c_prev_tensor = ctx.Input("C_prev"); + auto* c_tensor = ctx.Output("C"); + auto* h_tensor = ctx.Output("H"); + + auto forget_bias = static_cast(ctx.Attr("forget_bias")); + + int b_size = c_tensor->dims()[0]; + int D = c_tensor->dims()[1]; + + T* C = c_tensor->mutable_data(ctx.GetPlace()); + T* H = h_tensor->mutable_data(ctx.GetPlace()); + + const T* X = x_tensor->data(); + const T* C_prev = c_prev_tensor->data(); + + for (int n = 0; n < b_size; ++n) { + for (int d = 0; d < D; ++d) { + const T i = sigmoid(X[d]); + const T f = sigmoid(X[1 * D + d] + forget_bias); + const T o = sigmoid(X[2 * D + d]); + const T g = tanh(X[3 * D + d]); + const T c_prev = C_prev[d]; + const T c = f * c_prev + i * g; + C[d] = c; + const T tanh_c = tanh(c); + H[d] = o * tanh_c; + } + C_prev += D; + X += 4 * D; + C += D; + H += D; + } + } +}; + +template +class LstmUnitGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto x_tensor = ctx.Input("X"); + auto c_prev_tensor = ctx.Input("C_prev"); + auto c_tensor = ctx.Input("C"); + auto h_tensor = ctx.Input("H"); + + auto hdiff_tensor = ctx.Input(framework::GradVarName("H")); + auto cdiff_tensor = ctx.Input(framework::GradVarName("C")); + + auto xdiff_tensor = ctx.Output(framework::GradVarName("X")); + auto c_prev_diff_tensor = + ctx.Output(framework::GradVarName("C_prev")); + + auto* X = x_tensor->data(); + auto* C_prev = c_prev_tensor->data(); + auto* C = c_tensor->data(); + auto* H = h_tensor->data(); + + auto* H_diff = hdiff_tensor->data(); + auto* C_diff = cdiff_tensor->data(); + + auto* C_prev_diff = c_prev_diff_tensor->mutable_data(ctx.GetPlace()); + auto* X_diff = xdiff_tensor->mutable_data(ctx.GetPlace()); + + int N = c_tensor->dims()[0]; + int D = c_tensor->dims()[1]; + + auto forget_bias = static_cast(ctx.Attr("forget_bias")); + + for (int n = 0; n < N; ++n) { + for (int d = 0; d < D; ++d) { + T* c_prev_diff = C_prev_diff + d; + T* i_diff = X_diff + d; + T* f_diff = X_diff + 1 * D + d; + T* o_diff = X_diff + 2 * D + d; + T* g_diff = X_diff + 3 * D + d; + + const T i = sigmoid(X[d]); + const T f = sigmoid(X[1 * D + d] + forget_bias); + const T o = sigmoid(X[2 * D + d]); + const T g = tanh(X[3 * D + d]); + const T c_prev = C_prev[d]; + const T c = C[d]; + const T tanh_c = tanh(c); + const T c_term_diff = C_diff[d] + H_diff[d] * o * (1 - tanh_c * tanh_c); + *c_prev_diff = c_term_diff * f; + *i_diff = c_term_diff * g * i * (1 - i); + *f_diff = c_term_diff * c_prev * f * (1 - f); + *o_diff = H_diff[d] * tanh_c * o * (1 - o); + *g_diff = c_term_diff * i * (1 - g * g); + } + C_prev += D; + X += 4 * D; + C += D; + H += D; + C_diff += D; + H_diff += D; + X_diff += 4 * D; + C_prev_diff += D; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/parameter/FirstOrderOptimizer.h b/paddle/parameter/FirstOrderOptimizer.h index caa78acd98ea4b35fc69643689cfce23026275e0..895e8d6a63d1fad0ee7a6f5647402435d418b2f1 100644 --- a/paddle/parameter/FirstOrderOptimizer.h +++ b/paddle/parameter/FirstOrderOptimizer.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "ParameterOptimizer.h" +#include "ParameterUpdateFunctions.h" #include "Regularizer.h" namespace paddle { @@ -37,6 +38,15 @@ public: real torch_learningRate = optConfig_.learning_method() == "torch_momentum" ? 1.0 - paraConfig.momentum() : 1.0; +#ifdef PADDLE_USE_MKLDNN + sgdUpdate(learningRate_ * paraConfig.learning_rate() * + (firstTime_ ? 1.0 : torch_learningRate), + paraConfig.momentum(), + applyDecay_ ? paraConfig.decay_rate() : 0, + vecs[PARAMETER_VALUE].get(), + vecs[PARAMETER_GRADIENT].get(), + vecs[PARAMETER_MOMENTUM].get()); +#else vecs[PARAMETER_VALUE]->sgdUpdate( *vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM], @@ -44,6 +54,7 @@ public: (firstTime_ ? 1.0 : torch_learningRate), paraConfig.momentum(), applyDecay_ ? paraConfig.decay_rate() : 0); +#endif } virtual void finishBatch() { firstTime_ = false; } }; diff --git a/paddle/parameter/ParameterUpdateFunctions.cpp b/paddle/parameter/ParameterUpdateFunctions.cpp index c8af7105c78dcbf9f625a348b7f38efcf278469e..8b3be062b654a52e667626199be8c8bb4a2a96d7 100644 --- a/paddle/parameter/ParameterUpdateFunctions.cpp +++ b/paddle/parameter/ParameterUpdateFunctions.cpp @@ -30,6 +30,9 @@ void sgdUpdateCpu(real learningRate, const real* grad, real* momentumVec) { decayRate *= learningRate; +#ifdef PADDLE_USE_MKLDNN +#pragma omp parallel for +#endif for (size_t i = 0; i < size; ++i) { momentumVec[i] = momentum * momentumVec[i] - learningRate * grad[i] - decayRate * value[i]; diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index a106592e454e21c46cd2f87f1bbf6694955d6e23..f6a39a8e26c301296aac0af7f4e8b2c6c97ece24 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -34,13 +34,14 @@ class DeviceContext { template DeviceType* get_eigen_device() const; + + virtual void Wait() const {} }; class CPUDeviceContext : public DeviceContext { public: CPUDeviceContext(); explicit CPUDeviceContext(CPUPlace place); - virtual ~CPUDeviceContext() {} Eigen::DefaultDevice* eigen_device() const; @@ -59,7 +60,7 @@ class CUDADeviceContext : public DeviceContext { virtual ~CUDADeviceContext(); /*! \brief Wait for all operations completion in the stream. */ - void Wait() const; + void Wait() const override; /*! \brief Return place in the device context. */ Place GetPlace() const override; diff --git a/paddle/platform/gpu_info.h b/paddle/platform/gpu_info.h index ed2420b8740e583d307f6836a70fe7e1c780e28b..f0c825bd9b0bc41396b8fdb95f0b4337cbe3db02 100644 --- a/paddle/platform/gpu_info.h +++ b/paddle/platform/gpu_info.h @@ -36,7 +36,7 @@ int GetCurrentDeviceId(); //! Set the GPU device id for next execution. void SetDeviceId(int device_id); -//!Get the memory usage of current GPU device. +//! Get the memory usage of current GPU device. void GpuMemoryUsage(size_t &available, size_t &total); //! Get the maximum allocation size of current GPU device. diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index fbe074188e5870de4b00fa4fff733035739974ea..25e290ffbb94354da3393ca0b769aff512d74a41 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -237,7 +237,13 @@ All parameter, weight, gradient are variables in Paddle. return Backward(forwardOp, no_grad_vars).release(); }) .def("infer_shape", &OperatorBase::InferShape) - .def("run", &OperatorBase::Run) + .def("run", + [](OperatorBase &self, + const Scope &scope, + const platform::DeviceContext &dev_ctx) { + self.Run(scope, dev_ctx); + dev_ctx.Wait(); + }) .def("type", [](const OperatorBase &op) -> std::string { return op.Type(); }) .def("outputs", diff --git a/paddle/trainer/tests/CMakeLists.txt b/paddle/trainer/tests/CMakeLists.txt index f01ad4142d4fe7c7f7d7aac60d967ea114b93e56..066837ca959e46dbe3b39c661aa1bab11cbf2734 100644 --- a/paddle/trainer/tests/CMakeLists.txt +++ b/paddle/trainer/tests/CMakeLists.txt @@ -37,6 +37,19 @@ add_test(NAME test_CompareTwoNets --config_file_a=trainer/tests/sample_trainer_config_qb_rnn.conf --config_file_b=trainer/tests/sample_trainer_config_rnn.conf WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/) +################ test_CompareMKLDNNandCPU ###################### +if(WITH_MKLDNN) + add_unittest_without_exec(test_CompareMKLDNNandCPU + test_CompareTwoNets.cpp) + add_test(NAME test_CompareMKLDNNandCPU + COMMAND ${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d ${PADDLE_SOURCE_DIR}/python/ + ${CMAKE_CURRENT_BINARY_DIR}/test_CompareMKLDNNandCPU + --config_file_a=trainer/tests/sample_trainer_config_simple_net.conf --use_mkldnn_a=True + --config_file_b=trainer/tests/sample_trainer_config_simple_net.conf --use_mkldnn_b=False + --use_gpu=False + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/) +endif() + ############### test_CompareTwoOpts ################### add_unittest_without_exec(test_CompareTwoOpts test_CompareTwoOpts.cpp) diff --git a/paddle/trainer/tests/sample_trainer_config_simple_net.conf b/paddle/trainer/tests/sample_trainer_config_simple_net.conf new file mode 100644 index 0000000000000000000000000000000000000000..77f78161535c49da4ef7fc1563cff58c021aecef --- /dev/null +++ b/paddle/trainer/tests/sample_trainer_config_simple_net.conf @@ -0,0 +1,63 @@ +# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +################################### Data Configuration ################################### +TrainData(ProtoData(files = "trainer/tests/mnist.list")) +################################### Algorithm Configuration ################################### +settings(batch_size = 1000, + learning_method = MomentumOptimizer(momentum=0.5, sparse=False)) +################################### Network Configuration ################################### +data = data_layer(name ="input", size=784) + +tmp = img_conv_layer(input=data, + num_channels=1, + filter_size=3, + num_filters=32, + padding=1, + shared_biases=True, + act=ReluActivation()) + +tmp = img_pool_layer(input=tmp, + pool_size=3, + stride=2, + padding=1, + pool_type=AvgPooling()) + +tmp = img_conv_layer(input=tmp, + filter_size=3, + num_filters=64, + padding=1, + shared_biases=True, + act=ReluActivation()) + +tmp = img_pool_layer(input=tmp, + pool_size=3, + stride=2, + padding=1, + pool_type=MaxPooling()) + +tmp = fc_layer(input=tmp, size=64, + bias_attr=True, + act=ReluActivation()) + +output = fc_layer(input=tmp, size=10, + bias_attr=True, + act=SoftmaxActivation()) + +lbl = data_layer(name ="label", size=10) + +cost = classification_cost(input=output, label=lbl) +outputs(cost) diff --git a/paddle/trainer/tests/test_CompareTwoNets.cpp b/paddle/trainer/tests/test_CompareTwoNets.cpp index 94f65e545d116c802fb4877dc14f07aaaf83a4fb..307645d2c3d21d954371fcedb5f95a2536a0183e 100644 --- a/paddle/trainer/tests/test_CompareTwoNets.cpp +++ b/paddle/trainer/tests/test_CompareTwoNets.cpp @@ -26,12 +26,15 @@ DECLARE_int32(gpu_id); DECLARE_bool(local); DECLARE_bool(use_gpu); +DECLARE_bool(use_mkldnn); DECLARE_string(config); DECLARE_string(nics); DEFINE_string(config_file_a, "", "config of one network to compare"); DEFINE_string(config_file_b, "", "config of another network to compare"); +DEFINE_bool(use_mkldnn_a, false, "whether to use mkldnn to run config_file_a"); +DEFINE_bool(use_mkldnn_b, false, "whether to use mkldnn to run config_file_b"); DEFINE_bool(need_high_accuracy, false, "whether need to run in double accuracy"); @@ -128,6 +131,12 @@ void compareGradient(ComData& comDataA, ComData& comDataB) { matA.getWidth()); } + if (FLAGS_use_mkldnn_a || FLAGS_use_mkldnn_b) { + // some format of mkldnn parameter is different with cpu + // test_MKLDNN will check the parameters + return; + } + vector& parametersA = comDataA.parameters; vector& parametersB = comDataB.parameters; @@ -167,10 +176,12 @@ void compareGradient(ComData& comDataA, ComData& comDataB) { TEST(Trainer, create) { ComData dataA; + FLAGS_use_mkldnn = FLAGS_use_mkldnn_a; calcGradient(dataA, FLAGS_config_file_a); LOG(INFO) << "\n\nforwardBackward of Network A is finished\n\n"; ComData dataB; + FLAGS_use_mkldnn = FLAGS_use_mkldnn_b; calcGradient(dataB, FLAGS_config_file_b); LOG(INFO) << "\n\nforwardBackward of the Network B is finished\n\n"; diff --git a/python/paddle/v2/framework/tests/test_lstm_unit_op.py b/python/paddle/v2/framework/tests/test_lstm_unit_op.py new file mode 100644 index 0000000000000000000000000000000000000000..8ce65bfc31d9fa2d3988759a197e2f497b8161b1 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_lstm_unit_op.py @@ -0,0 +1,38 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def sigmoid_np(x): + return 1. / (1. + np.exp(-x)) + + +def tanh_np(x): + return 2 * sigmoid_np(2. * x) - 1. + + +class LstmUnitTest(OpTest): + def setUp(self): + self.op_type = "lstm_unit" + x_np = np.random.normal(size=(5, 16)).astype("float32") + c_np = np.random.normal(size=(5, 4)).astype("float32") + i_np, f_np, o_np, j_np = np.split(x_np, 4, axis=1) + forget_bias_np = 0. + self.attrs = {'forget_bias': 0.} + + new_c = c_np * sigmoid_np(f_np + forget_bias_np) + sigmoid_np( + i_np) * tanh_np(j_np) + new_h = tanh_np(new_c) * sigmoid_np(o_np) + + self.inputs = {'X': x_np, 'C_prev': c_np} + self.outputs = {'C': new_c, 'H': new_h} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X', 'C_prev'], ['C', 'H'], max_relative_error=0.01) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_prelu_op.py b/python/paddle/v2/framework/tests/test_prelu_op.py index 2b6b7db36808a4b68c55328a1eb9ac212c18b678..676fd9f7c555fd5c8544e760345ab954cd137dc5 100644 --- a/python/paddle/v2/framework/tests/test_prelu_op.py +++ b/python/paddle/v2/framework/tests/test_prelu_op.py @@ -7,6 +7,14 @@ class PReluTest(OpTest): def setUp(self): self.op_type = "prelu" x_np = np.random.normal(size=(10, 10)).astype("float32") + + for pos, val in np.ndenumerate(x_np): + # Since zero point in prelu is not differentiable, avoid randomize + # zero. + while abs(val) < 1e-3: + x_np[pos] = np.random.normal() + val = x_np[pos] + x_np_sign = np.sign(x_np) x_np = x_np_sign * np.maximum(x_np, .005) alpha_np = np.array([.1])