diff --git a/.gitignore b/.gitignore index 4f21fefda9f64a0392881971a715b97c234030e3..351b8204100dfd71e94cb3efa2e946b44b9e4285 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ CMakeFiles cmake_install.cmake paddle/.timestamp python/paddlepaddle.egg-info/ +paddle/pybind/pybind.h diff --git a/doc/design/refactorization.md b/doc/design/refactorization.md new file mode 100644 index 0000000000000000000000000000000000000000..e105861e926411a269b0b52dd4688744912c9ab3 --- /dev/null +++ b/doc/design/refactorization.md @@ -0,0 +1,253 @@ +# Design Doc: Refactorization Overview + +The goal of refactorizaiton include: + +1. Make it easy for external contributors to write new elementory computaiton operations. +1. Make the codebase clean and readable. +1. Introduce a new design of computation representation -- a computation graph of operators and variables. +1. The graph representation helps implementing auto-scalable and auto fault recoverable distributed computing. + +## Computation Graphs + +1. PaddlePaddle represent the computation, training and inference of DL models, by computation graphs. + + 1. Please dig into [computation graphs](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/graph.md) for a solid example. + +1. Users write Python programs to describe the graphs and run it (locally or remotely). + +1. A graph is composed of *variabels* and *operators*. + +1. The description of graphs must be able to be serialized/deserialized, so it + + 1. could to be sent to the cloud for distributed execution, and + 1. be sent to clients for mobile or enterprise deployment. + +1. The Python program do + + 1. *compilation*: runs a Python program to generate a protobuf message representation of the graph and send it to + 1. the C++ library `libpaddle.so` for local execution, + 1. the master process of a distributed training job for training, or + 1. the server process of a Kubernetes serving job for distributed serving. + 1. *execution*: according to the protobuf message, constructs instances of class `Variable` and `OperatorBase`, and run them. + +## Description and Realization + +At compile time, the Python program generates protobuf message representation of the graph, or the description of the graph. + +At runtime, the C++ program realizes the graph and run it. + +| | Representation (protobuf messages) | Realization (C++ class objects) | +|---|---|---| +|Data|[VarDesc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L107)|[Variable](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/variable.h#L24)| +|Operation|[OpDesc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L35)|[Operator](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h#L64)| +|Block|BlockDesc|Block| + +The word *graph* is exchangable with *block* in this document. A graph represent computation steps and local variables as a C++/Java program block, or a pair of { and }. + +## Compilation and Execution + +1. Run an applicaton Python program to describe the graph. In particular, + + 1. create VarDesc to represent local/intermediate variables, + 1. create operators and set attributes, + 1. validate attribute values, + 1. inference the type and the shape of variables, + 1. plan for memory-reuse for variables, + 1. generate backward and optimization part of the Graph. + 1. possiblly split the graph for distributed training. + +1. The invocation of `train` or `infer` in the application Python program: + + 1. create a new Scope instance in the [scope hierarchy](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/scope.md) for each run of a block, + 1. realize local variables defined in the BlockDesc message in the new scope, + 1. a scope is similar to the stack frame in programming languages, + + 1. create an instance of class `Block`, in which, + 1. realize operators in the BlockDesc message, + + 1. run the Block by calling + 1. `Block::Eval(vector* targets)` for forward and backward computations, or + 1. `Block::Eval(vector* targets)` for optimization. + + +## Intermediate Representation (IR) + +```text +Compile Time -> IR -> Runtime +``` + +### Benefit + +- Optimization + ```text + Compile Time -> IR -> Optimized IR -> Runtime + ``` +- Send automatically partitioned IR to different nodes. + - Automatic data parallel + ```text + Compile Time + |-> Single GPU IR + |-> [trainer-IR-0, trainer-IR-1, pserver-IR] + |-> Node-0 (runs trainer-IR-0) + |-> Node-1 (runs trainer-IR-1) + |-> Node-2 (runs pserver-IR) + ``` + - Automatic model parallel (planned for future) + +--- + +# Operator/OpWithKernel/OpKernel + +![class_diagram](http://api.paddlepaddle.org/graphviz?dot=https://gist.githubusercontent.com/reyoung/53df507f6749762675dff3e7ce53372f/raw/49caf1fb70820fb4a6c217634317c9306f361f36/op_op_with_kern_class_diagram.dot) + +--- + +# Operator +![class_diagram](http://api.paddlepaddle.org/graphviz?dot=https://gist.githubusercontent.com/reyoung/53df507f6749762675dff3e7ce53372f/raw/dd598e8f1976f5759f58af5e5ef94738a6b2e661/op.dot) + +* `Operator` is the fundamental building block as the user interface. + * Operator stores input/output variable name, and attributes. + * The `InferShape` interface is used to infer output variable shapes by its input shapes. + * Use `Run` to compute `input variables` to `output variables`. + +--- + +# OpWithKernel/Kernel + +![class_diagram](http://api.paddlepaddle.org/graphviz?dot=https://gist.githubusercontent.com/reyoung/53df507f6749762675dff3e7ce53372f/raw/9d7f4eba185cf41c8e2fbfb40ae21890dbddcd39/op_with_kernel.dot) + +* `OpWithKernel` inherits `Operator`. +* `OpWithKernel` contains a Kernel map. + * `OpWithKernel::Run` get device's kernel, and invoke `OpKernel::Compute`. + * `OpKernelKey` is the map key. Only device place now, but may be data type later. + +--- + +# Why separate Kernel and Operator + +* Separate GPU and CPU code. + * Make Paddle can run without GPU. +* Make one operator (which is user interface) can contain many implementations. + * Same mul op, different FP16, FP32 Kernel. different MKL, eigen kernel. +--- + +# Libraries for Kernel development + +* `Eigen::Tensor` contains basic math and element-wise functions. + * Note that `Eigen::Tensor` has broadcast implementation. + * Limit number of `tensor.device(dev) = ` in your code. +* `thrust::tranform` and `std::transform`. + * `thrust` has the same API as C++ standard library. Using `transform` can quickly implement a customized elementwise kernel. + * `thrust` has more complex API, like `scan`, `reduce`, `reduce_by_key`. +* Hand-writing `GPUKernel` and `CPU` code + * Do not write `.h`. CPU Kernel should be in `.cc`. CPU kernel should be in `.cu`. (`GCC` cannot compile GPU code.) +--- +# Operator Register + +## Why register is necessary? +We need a method to build mappings between Op type names and Op classes. + +## How to do the register? + +Maintain a map, whose key is the type name and value is corresponding Op constructor. + +--- +# The Registry Map + +### `OpInfoMap` + +`op_type(string)` -> `OpInfo` + +`OpInfo`: + +- **`creator`**: The Op constructor. +- **`grad_op_type`**: The type of the gradient Op. +- **`proto`**: The Op's Protobuf, including inputs, outputs and required attributes. +- **`checker`**: Used to check attributes. + +--- +# Related Concepts + +### Op_Maker +It's constructor takes `proto` and `checker`. They are compeleted during Op_Maker's construction. ([ScaleOpMaker](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/scale_op.cc#L37)) + +### Register Macros +```cpp +REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, grad_op_class) +REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) +``` + +### `USE` Macros +make sure the registration process is executed and linked. + +--- +# Register Process +1. Write Op class, as well as its gradient Op class if there is. +2. Write Op maker class. In the constructor, describe its inputs, outputs, and attributes. +3. Invoke macro `REGISTER_OP`. The macro will + 1. call maker class to complete `proto` and `checker` + 2. with the completed `proto` and `checker`, build a new key-value pair in the `OpInfoMap` + +4. Invoke `USE` macro in where the Op is used to make sure it is linked. + +--- +# Backward Module (1/2) +### Create Backward Operator +- Mapping from forwarding Op to backward Op +![backward](https://gist.githubusercontent.com/dzhwinter/a6fbd4623ee76c459f7f94591fd1abf0/raw/61026ab6e518e66bde66a889bc42557a1fccff33/backward.png) + +--- +# Backward Module (2/2) +### Build Backward Network +- **Input** graph of forwarding operators +- **Output** graph of backward operators +- **corner case in construction** + - shared variable => insert `Add` operator + - no gradient => insert `fill_zero_grad` operator + - recursive netOp => call `Backward` recursively + - RNN Op => recursively call `Backward` on stepnet + + +--- +# Scope, Variable, Tensor + +* `Tensor` is an n-dimension array with type. + * Only dims and data pointers are stored in `Tensor`. + * All operators on `Tensor` is written in `Operator` or global functions. + * variable length Tensor design [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md) +* `Variable` is the inputs and outputs of an operator. Not just `Tensor`. + * step_scopes in RNN is a variable and not a tensor. +* `Scope` is where variables store at. + * map + * `Scope` has a hierarchical structure. The local scope can get variable from its parent scope. + +--- +# Block (in design) +## the difference with original RNNOp +- as an operator is more intuitive than `RNNOp`, +- offers new interface `Eval(targets)` to deduce the minimal block to `Run`, +- fits the compile-time/ runtime separation design. + - during the compilation, `SymbolTable` stores `VarDesc`s and `OpDesc`s and serialize to a `BlockDesc` + - when graph executes, a Block with `BlockDesc` passed in creates `Op` and `Var` then `Run` + +--- +# Milestone +- take Paddle/books as the main line, the requirement of the models motivates framework refactoring, +- model migration + - framework development gives **priority support** to model migration, for example, + - the MNIST demo needs a Python interface, + - the RNN models require the framework to support `LoDTensor`. + - determine some timelines, + - heavily-relied Ops need to be migrated first, + - different models can be migrated parallelly. +- improve the framework at the same time +- accept imperfection, concentrated on solving the specific problem at the right price. + +--- +# Control the migration quality +- compare the performance of migrated models with old ones. +- follow google C style +- build the automatic workflow of generating Python/C++ documentations + - the documentation of layers and ops should be written inside the code + - take the documentation quality into account when doing PR + - preview the documentations, read and improve them from users' perspective diff --git a/paddle/cuda/include/hl_cuda_cudnn.h b/paddle/cuda/include/hl_cuda_cudnn.h index 3f68c62de6d9b3aaadc9180d86159089dc728ea9..b44b071bd1b3b6e9e5539d5dc0c2b155c524fd57 100644 --- a/paddle/cuda/include/hl_cuda_cudnn.h +++ b/paddle/cuda/include/hl_cuda_cudnn.h @@ -22,10 +22,10 @@ limitations under the License. */ */ typedef enum { HL_POOLING_MAX = 0, - // average includes padded values - HL_POOLING_AVERAGE = 1, // average does not include padded values - HL_POOLING_AVERAGE_EXCLUDE_PADDING = 2, + HL_POOLING_AVERAGE = 1, + // average includes padded values + HL_POOLING_AVERAGE_INCLUDE_PADDING = 2, HL_POOLING_END } hl_pooling_mode_t; diff --git a/paddle/cuda/include/hl_tensor_ops.h b/paddle/cuda/include/hl_tensor_ops.h index 93d38b7d2299d994cde0934213668a525bffa80c..b2bf334dab9799153fe1d4fe2c74cce9d57168b9 100644 --- a/paddle/cuda/include/hl_tensor_ops.h +++ b/paddle/cuda/include/hl_tensor_ops.h @@ -461,7 +461,7 @@ class add { public: INLINE float32x4_t operator()(const float32x4_t a, const float32x4_t b) const { - return vmulq_f32(a, b); + return vaddq_f32(a, b); } }; diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index 9ba3d142617537c0160f6dccb86ddca43ada15a5..58674febdc4a094c95ff03701e4586c32729847d 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -211,13 +211,11 @@ __global__ void KeAvgPoolForward(const int nthreads, int hstart = ph * strideH - padH; int wstart = pw * strideW - padW; - int hend = min(hstart + sizeY, height + padH); - int wend = min(wstart + sizeX, width + padW); - int pool_size = (hend - hstart) * (wend - wstart); + int hend = min(hstart + sizeY, height); + int wend = min(wstart + sizeX, width); hstart = max(hstart, 0); wstart = max(wstart, 0); - hend = min(hend, height); - wend = min(wend, width); + int pool_size = (hend - hstart) * (wend - wstart); real aveval = 0; inputData += (frameNum * channels + c) * height * width; @@ -299,12 +297,14 @@ __global__ void KeAvgPoolBackward(const int nthreads, outGrad += (frameNum * outStride + offsetC * pooledH * pooledW); for (int ph = phstart; ph < phend; ++ph) { + int hstart = ph * strideH - padH; + int hend = min(hstart + sizeY, height); + hstart = max(hstart, 0); for (int pw = pwstart; pw < pwend; ++pw) { // figure out the pooling size - int hstart = ph * strideH - padH; int wstart = pw * strideW - padW; - int hend = min(hstart + sizeY, height + padH); - int wend = min(wstart + sizeX, width + padW); + int wend = min(wstart + sizeX, width); + wstart = max(wstart, 0); int poolsize = (hend - hstart) * (wend - wstart); gradient += outGrad[ph * pooledW + pw] / poolsize; } @@ -600,16 +600,13 @@ __global__ void KeAvgPool3DForward(const int nthreads, int dstart = pd * strideD - padD; int hstart = ph * strideH - padH; int wstart = pw * strideW - padW; - int dend = min(dstart + sizeZ, depth + padD); - int hend = min(hstart + sizeY, height + padH); - int wend = min(wstart + sizeX, width + padW); - int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); + int dend = min(dstart + sizeZ, depth); + int hend = min(hstart + sizeY, height); + int wend = min(wstart + sizeX, width); dstart = max(dstart, 0); hstart = max(hstart, 0); wstart = max(wstart, 0); - dend = min(dend, depth); - hend = min(hend, height); - wend = min(wend, width); + int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); real aveval = 0; inputData += (frameNum * channels + c) * depth * height * width; @@ -712,15 +709,18 @@ __global__ void KeAvgPool3DBackward(const int nthreads, outGrad += (frameNum * channels + offsetC) * pooledD * pooledH * pooledW; for (int pd = pdstart; pd < pdend; ++pd) { + int dstart = pd * strideD - padD; + int dend = min(dstart + sizeZ, depth); + dstart = max(dstart, 0); for (int ph = phstart; ph < phend; ++ph) { + int hstart = ph * strideH - padH; + int hend = min(hstart + sizeY, height); + hstart = max(hstart, 0); for (int pw = pwstart; pw < pwend; ++pw) { // figure out the pooling size - int dstart = pd * strideD - padD; - int hstart = ph * strideH - padH; int wstart = pw * strideW - padW; - int dend = min(dstart + sizeZ, depth + padD); - int hend = min(hstart + sizeY, height + padH); - int wend = min(wstart + sizeX, width + padW); + int wend = min(wstart + sizeX, width); + wstart = max(wstart, 0); int poolsize = (dend - dstart) * (hend - hstart) * (wend - wstart); gradient += outGrad[(pd * pooledH + ph) * pooledW + pw] / poolsize; } diff --git a/paddle/cuda/src/hl_cuda_cudnn.cc b/paddle/cuda/src/hl_cuda_cudnn.cc index f38ef692558b908ed65d2c84821bbb7c3b439742..b8caf48f9c06094e85765f7aa5a3f4195d0ca931 100644 --- a/paddle/cuda/src/hl_cuda_cudnn.cc +++ b/paddle/cuda/src/hl_cuda_cudnn.cc @@ -432,11 +432,11 @@ void hl_create_pooling_descriptor(hl_pooling_descriptor* pooling_desc, cudnn_mode = CUDNN_POOLING_MAX; break; case HL_POOLING_AVERAGE: - cudnn_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; - break; - case HL_POOLING_AVERAGE_EXCLUDE_PADDING: cudnn_mode = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; break; + case HL_POOLING_AVERAGE_INCLUDE_PADDING: + cudnn_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + break; default: LOG(FATAL) << "parameter mode error"; } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index c57537be4bf67a8db6a49669ab8d2ed1b1324bdc..f8a64a786611ef872dbbfced10919e00c4d46715 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -22,14 +22,14 @@ namespace framework { template <> Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< platform::CPUPlace, Eigen::DefaultDevice>() const { - return *device_context_->get_eigen_device(); + return *device_context_.get_eigen_device(); } #ifndef PADDLE_ONLY_CPU template <> Eigen::GpuDevice& ExecutionContext::GetEigenDevice() const { - return *device_context_->get_eigen_device(); + return *device_context_.get_eigen_device(); } #endif diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index adae7bfc3d7d31b1ed0373f01db4ef80343a08f7..b7c9c39402d57daf0aec97d98535ac8a8d9c0150 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -366,7 +366,7 @@ struct EigenDeviceConverter { class ExecutionContext : public InferShapeContext { public: ExecutionContext(const OperatorBase& op, const Scope& scope, - const platform::DeviceContext* device_context) + const platform::DeviceContext& device_context) : InferShapeContext(op, scope), device_context_(device_context) {} template ::EigenDeviceType> DeviceType& GetEigenDevice() const; - platform::Place GetPlace() const { return device_context_->GetPlace(); } + platform::Place GetPlace() const { return device_context_.GetPlace(); } - const platform::DeviceContext* device_context() const { + const platform::DeviceContext& device_context() const { return device_context_; } @@ -401,7 +401,8 @@ class ExecutionContext : public InferShapeContext { return res; } - const platform::DeviceContext* device_context_; + private: + const platform::DeviceContext& device_context_; }; template <> @@ -461,7 +462,7 @@ class OperatorWithKernel : public OperatorBase { void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const final { auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); - opKernel->Compute(ExecutionContext(*this, scope, &dev_ctx)); + opKernel->Compute(ExecutionContext(*this, scope, dev_ctx)); } static std::unordered_map& diff --git a/paddle/gserver/layers/CudnnPoolLayer.cpp b/paddle/gserver/layers/CudnnPoolLayer.cpp index 4adb2d4709e585a6fec052435c33714d6e3a3f0e..810a1af2d09c63c3787a1ac225c2c7de4238d609 100644 --- a/paddle/gserver/layers/CudnnPoolLayer.cpp +++ b/paddle/gserver/layers/CudnnPoolLayer.cpp @@ -29,9 +29,9 @@ bool CudnnPoolLayer::typeCheck(const std::string &poolType, if (mode) { *mode = HL_POOLING_AVERAGE; } - } else if (poolType == "cudnn-avg-excl-pad-pool") { + } else if (poolType == "cudnn-avg-incl-pad-pool") { if (mode) { - *mode = HL_POOLING_AVERAGE_EXCLUDE_PADDING; + *mode = HL_POOLING_AVERAGE_INCLUDE_PADDING; } } else { return false; diff --git a/paddle/gserver/layers/DetectionOutputLayer.cpp b/paddle/gserver/layers/DetectionOutputLayer.cpp index 0cf0a92bf4bd8f9b8eba2016b2377d9dfb18c70a..f9040f7ae746f9ae1736cd477d3a69a2c49e9d34 100644 --- a/paddle/gserver/layers/DetectionOutputLayer.cpp +++ b/paddle/gserver/layers/DetectionOutputLayer.cpp @@ -143,7 +143,7 @@ void DetectionOutputLayer::forward(PassType passType) { resetOutput(numKept, 7); } else { MatrixPtr outV = getOutputValue(); - outV = NULL; + if (outV) outV->resize(0, 0); return; } MatrixPtr outV = getOutputValue(); diff --git a/paddle/gserver/layers/MKLDNNPoolLayer.cpp b/paddle/gserver/layers/MKLDNNPoolLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..48b2f5a4cb37f6a9c4b1fdc6178c914b46c76e63 --- /dev/null +++ b/paddle/gserver/layers/MKLDNNPoolLayer.cpp @@ -0,0 +1,277 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "MKLDNNPoolLayer.h" +#include "paddle/math/MathUtils.h" +#include "paddle/utils/Logging.h" + +using namespace mkldnn; // NOLINT +typedef memory::format format; + +namespace paddle { + +REGISTER_LAYER(mkldnn_pool, MKLDNNPoolLayer); + +bool MKLDNNPoolLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + if (!MKLDNNLayer::init(layerMap, parameterMap)) { + return false; + } + + /* the size of inputs for pool-layer is 1 */ + CHECK_EQ(config_.inputs_size(), 1); + const PoolConfig& conf = config_.inputs(0).pool_conf(); + ic_ = conf.channels(); + ih_ = conf.img_size_y(); + iw_ = conf.img_size(); + oc_ = ic_; + oh_ = conf.output_y(); + ow_ = conf.output_x(); + fh_ = conf.size_y(); + fw_ = conf.size_x(); + ph_ = conf.padding_y(); + pw_ = conf.padding(); + sh_ = conf.stride_y(); + sw_ = conf.stride(); + + const std::string& type = conf.pool_type(); + if (type == "max-projection") { + poolAlgo_ = algorithm::pooling_max; + } else if (type == "avg-projection") { + // paddle only use exclude_padding + poolAlgo_ = algorithm::pooling_avg_exclude_padding; + } else { + LOG(FATAL) << "unknow pooling type!"; + } + return true; +} + +void MKLDNNPoolLayer::reshape( + int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) { + reshapeInput(bs, ih, iw); + // ic_ and oc can not be changed + CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic) + << "Input channel can not be changed"; + + // cal output sizes + // paddle used false caffeMode for pooling + oh = outputSize(ih, fh_, ph_, sh_, false); + ow = outputSize(iw, fw_, pw_, sw_, false); + reshapeOutput(oh, ow); + + resizeOutput(bs, oc * oh * ow); + + printSizeInfo(); +} + +void MKLDNNPoolLayer::resetFwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + resetFwdBuffers(in, out); + + resetFwdPD(fwdPD_, in, out); + + resetFwdPipeline(pipeline, fwdPD_, in, out); + + printValueFormatFlow(); +} + +void MKLDNNPoolLayer::resetBwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) { + std::shared_ptr pd; + + resetBwdBuffers(in, out); + + resetBwdPD(pd, in, out); + + resetBwdPipeline(pipeline, pd, in, out); + + printGradFormatFlow(); +} + +void MKLDNNPoolLayer::updateInputData() { + inVal_->setData(getInputValue(0, CPU_DEVICE)->getData()); +} + +void MKLDNNPoolLayer::resetFwdBuffers(MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& out) { + resetInValue(in); + + resetOutValue(out); +} + +void MKLDNNPoolLayer::resetInValue(MKLDNNMatrixPtr& in) { + if (inputIsOnlyMKLDNN()) { + const MatrixPtr& dnnIn = getInputValue(0); + in = std::dynamic_pointer_cast(dnnIn); + CHECK(in) << "Input should be MKLDNNMatrix"; + } else { + CHECK_EQ(getPrev(0)->getDeviceId(), CPU_DEVICE) << "Only support CPU yet"; + const MatrixPtr& cpuIn = getInputValue(0, CPU_DEVICE); + in = MKLDNNMatrix::create( + cpuIn, {bs_, ic_, ih_, iw_}, format::nchw, engine_); + } +} + +void MKLDNNPoolLayer::resetOutValue(MKLDNNMatrixPtr& out) { + CHECK(inVal_) << "Should reset input value first"; + memory::dims outDims = memory::dims{bs_, oc_, oh_, ow_}; + out = MKLDNNMatrix::create( + output_.value, outDims, inVal_->getFormat(), engine_); + output_.value = std::dynamic_pointer_cast(out); + + // create reorder if output value has cpu device and pd do not match + cpuOutVal_ = nullptr; + cvtOutVal_ = nullptr; + if (!outputIsOnlyMKLDNN()) { + const MatrixPtr& cpuOut = getOutput(CPU_DEVICE).value; + cpuOutVal_ = MKLDNNMatrix::create(cpuOut, outDims, format::nchw, engine_); + if (cpuOutVal_->getPrimitiveDesc() != out->getPrimitiveDesc()) { + cvtOutVal_ = MKLDNNMatrix::createReorder(out, cpuOutVal_); + CHECK(cvtOutVal_) << "should not be emptry"; + } else { + // CPU output share the same data of MKLDNN output + cpuOut->setData(out->getData()); + cpuOutVal_ = out; + } + } +} + +void MKLDNNPoolLayer::resetFwdPD(std::shared_ptr& pd, + MKLDNNMatrixPtr in, + MKLDNNMatrixPtr out) { + memory::dims inDims = memory::dims{bs_, ic_, ih_, iw_}; + memory::dims outDims = memory::dims{bs_, oc_, oh_, ow_}; + memory::dims kernels = memory::dims{fh_, fw_}; + memory::dims strides = memory::dims{sh_, sw_}; + memory::dims padL = memory::dims{ph_, pw_}; + memory::dims padR = getPaddingR(); + padding_kind padKind = padding_kind::zero; + prop_kind pk = passType_ == PASS_TEST ? prop_kind::forward_scoring + : prop_kind::forward_training; + auto fwdDesc = pool_fwd::desc(pk, + poolAlgo_, + in->getMemoryDesc(), + out->getMemoryDesc(), + strides, + kernels, + padL, + padR, + padKind); + pd.reset(new pool_fwd::primitive_desc(fwdDesc, engine_)); + + // prepare workspace if necessary + workspace_ = + (passType_ != PASS_TEST && poolAlgo_ == algorithm::pooling_max) + ? std::make_shared(memory(pd->workspace_primitive_desc())) + : nullptr; +} + +void MKLDNNPoolLayer::resetFwdPipeline( + std::vector& pipeline, + std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& out) { + pipeline.clear(); + fwd_ = workspace_ + ? std::make_shared(pool_fwd(*pd, *in, *out, *workspace_)) + : std::make_shared(pool_fwd(*pd, *in, *out)); + pipeline.push_back(*fwd_); + + if (cvtOutVal_) { + pipeline.push_back(*cvtOutVal_); + } +} + +void MKLDNNPoolLayer::resetBwdBuffers(MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& out) { + resetOutGrad(out); + + resetInGrad(in); +} +void MKLDNNPoolLayer::resetOutGrad(MKLDNNMatrixPtr& out) { + CHECK(outVal_) << "Should have output value"; + out = MKLDNNMatrix::create(output_.grad, outVal_->getPrimitiveDesc()); + + // create reorder if output value has cpu device and pd do not match + cpuOutGrad_ = nullptr; + cvtOutGrad_ = nullptr; + if (!outputIsOnlyMKLDNN()) { + const MatrixPtr& cpuOut = getOutput(CPU_DEVICE).grad; + cpuOutGrad_ = MKLDNNMatrix::create( + cpuOut, memory::dims{bs_, oc_, oh_, ow_}, format::nchw, engine_); + if (cpuOutGrad_->getPrimitiveDesc() != out->getPrimitiveDesc()) { + cvtOutGrad_ = MKLDNNMatrix::createReorder(cpuOutGrad_, out); + CHECK(cvtOutGrad_) << "should not be emptry"; + } else { + // share the same data of CPU output + output_.grad->setData(cpuOut->getData()); + out = cpuOutGrad_; + } + } +} + +void MKLDNNPoolLayer::resetInGrad(MKLDNNMatrixPtr& in) { + in = nullptr; + const MatrixPtr& inGrad = inputLayers_[0]->getOutput().grad; + if (inGrad == nullptr) { + return; + } + CHECK(inVal_); + in = MKLDNNMatrix::create(inGrad, inVal_->getPrimitiveDesc()); +} + +void MKLDNNPoolLayer::resetBwdPD(std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& out) { + memory::dims kernels = memory::dims{fh_, fw_}; + memory::dims strides = memory::dims{sh_, sw_}; + memory::dims padL = memory::dims{ph_, pw_}; + memory::dims padR = getPaddingR(); + CHECK(in); + CHECK(out); + auto bwdDesc = pool_bwd::desc(poolAlgo_, + in->getMemoryDesc(), + out->getMemoryDesc(), + strides, + kernels, + padL, + padR, + padding_kind::zero); + pd.reset(new pool_bwd::primitive_desc(bwdDesc, engine_, *fwdPD_)); +} + +void MKLDNNPoolLayer::resetBwdPipeline( + std::vector& pipeline, + std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& out) { + pipeline.clear(); + if (cvtOutGrad_) { + pipeline.push_back(*cvtOutGrad_); + } + + bwdData_ = + workspace_ + ? std::make_shared(pool_bwd(*pd, *out, *workspace_, *in)) + : std::make_shared(pool_bwd(*pd, *out, *in)); + pipeline.push_back(*bwdData_); +} + +} // namespace paddle diff --git a/paddle/gserver/layers/MKLDNNPoolLayer.h b/paddle/gserver/layers/MKLDNNPoolLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..891e15a7efcdd2e54f61352efc1ba7345b91c76b --- /dev/null +++ b/paddle/gserver/layers/MKLDNNPoolLayer.h @@ -0,0 +1,138 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "MKLDNNLayer.h" +#include "mkldnn.hpp" + +namespace paddle { +typedef mkldnn::pooling_forward pool_fwd; +typedef mkldnn::pooling_backward pool_bwd; + +/** + * @brief A subclass of MKLDNNLayer pool layer. + * + * The config file api is mkldnn_pool + */ +class MKLDNNPoolLayer : public MKLDNNLayer { +protected: + // padding height and width + int ph_, pw_; + // stride height and width + int sh_, sw_; + // filter(kenerl) height and width + int fh_, fw_; + + // pooling_avg or pooling_max + mkldnn::algorithm poolAlgo_; + + // MKLDNNMatrixPtr which should be created from CPU Device + MKLDNNMatrixPtr cpuOutVal_; + MKLDNNMatrixPtr cpuOutGrad_; + // convert handle between CPU device and MKLDNN device + std::shared_ptr cvtOutVal_; + std::shared_ptr cvtOutGrad_; + + // save forward primitive_desc, which can be used backward + std::shared_ptr fwdPD_; + // according to https://github.com/01org/mkl-dnn/blob/master/tests/gtests/ + // test_pooling_forward.cpp, pool need workspace for backward + std::shared_ptr workspace_; + +public: + explicit MKLDNNPoolLayer(const LayerConfig& config) : MKLDNNLayer(config) {} + + ~MKLDNNPoolLayer() {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void reshape( + int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override; + + void resetFwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) override; + + void resetBwd(std::vector& pipeline, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& wgt, + MKLDNNMatrixPtr& bias, + MKLDNNMatrixPtr& out) override; + + void updateInputData() override; + + void printSizeInfo() override { + MKLDNNLayer::printSizeInfo(); + VLOG(MKLDNN_SIZES) << getName() << ": fh: " << fh_ << ", fw: " << fw_ + << ": ph: " << ph_ << ", pw: " << pw_ << ", sh: " << sh_ + << ", sw: " << sw_; + } + +protected: + /** + * Forward functions: reset buffers(input, output), + * reset primitive descriptor, + * reset pipeline. + */ + void resetFwdBuffers(MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& out); + void resetInValue(MKLDNNMatrixPtr& in); + void resetOutValue(MKLDNNMatrixPtr& out); + void resetFwdPD(std::shared_ptr& pd, + MKLDNNMatrixPtr in, + MKLDNNMatrixPtr out); + void resetFwdPipeline(std::vector& pipeline, + std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& out); + + /** + * Backward functions: reset buffers(input, output), + * reset primitive descriptor, + * reset pipeline. + */ + void resetBwdBuffers(MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& out); + void resetOutGrad(MKLDNNMatrixPtr& out); + void resetInGrad(MKLDNNMatrixPtr& in); + void resetBwdPD(std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& out); + void resetBwdPipeline(std::vector& pipeline, + std::shared_ptr& pd, + MKLDNNMatrixPtr& in, + MKLDNNMatrixPtr& out); + + /** + * get padding_r according to + * https://github.com/01org/mkl-dnn/blob/master/tests/gtests/ + * test_pooling_forward.cpp + */ + mkldnn::memory::dims getPaddingR() const { + mkldnn::memory::dims padR = {ph_, pw_}; + for (int i = 0; i < 2; ++i) { + if ((ih_ + ph_ + padR[0] - fh_) / sh_ + 1 < oh_) { + ++padR[0]; + } + if ((iw_ + pw_ + padR[1] - fw_) / sw_ + 1 < ow_) { + ++padR[1]; + } + } + return padR; + } +}; + +} // namespace paddle diff --git a/paddle/gserver/tests/test_MKLDNN.cpp b/paddle/gserver/tests/test_MKLDNN.cpp index e70802881e3f22160a87b7a4babda07ffbcf9d6f..b593f65fe49ef2271ad7cd0f609c9b828be03037 100644 --- a/paddle/gserver/tests/test_MKLDNN.cpp +++ b/paddle/gserver/tests/test_MKLDNN.cpp @@ -141,6 +141,68 @@ TEST(MKLDNNLayer, ConvLayer) { testConvLayer({4, 4, 16, 3, 3, 16, 3, 3, 3, 3, 1, 1, 1, 1, 1, 1}); } +struct testPoolDesc { + int bs, ch; // input channel and output channel are the same + int ih, iw; + int oh, ow; + int fh, fw; + int ph, pw; + int sh, sw; +}; + +void testPoolLayer(const testPoolDesc& pm) { + const std::string compareTypes[] = {"mkldnn_pool", "pool"}; + TestConfig cfg; + cfg.layerConfig.set_type(compareTypes[0]); + cfg.layerConfig.set_size(pm.ch * pm.oh * pm.ow); + cfg.inputDefs.push_back( + {INPUT_DATA, + "layer_0", + /* size of input layer= */ size_t(pm.ch * pm.ih * pm.iw), + 0}); + LayerInputConfig* input = cfg.layerConfig.add_inputs(); + PoolConfig* pool = input->mutable_pool_conf(); + // pool->set_pool_type(poolType); + pool->set_channels(pm.ch); + pool->set_img_size(pm.iw); + pool->set_img_size_y(pm.ih); + pool->set_output_x(pm.ow); + pool->set_output_y(pm.oh); + pool->set_size_x(pm.fw); + pool->set_size_y(pm.fh); + pool->set_padding(pm.pw); + pool->set_padding_y(pm.ph); + pool->set_stride(pm.sw); + pool->set_stride_y(pm.sh); + + int oh = outputSize(pm.ih, pm.fh, pm.ph, pm.sh, false); + int ow = outputSize(pm.iw, pm.fw, pm.pw, pm.sw, false); + CHECK_EQ(ow, pm.ow) << "output size check failed"; + CHECK_EQ(oh, pm.oh) << "output size check failed"; + + MKLDNNTester tester; + for (auto type : {"max-projection", "avg-projection"}) { + pool->set_pool_type(type); + TestConfig ref = cfg; + ref.layerConfig.set_type(compareTypes[1]); + for (auto bs : {pm.bs, 1}) { + tester.run(cfg, ref, bs, pm.ih, pm.iw); + } + } +} + +TEST(MkldnnLayer, PoolLayer) { + /* bs, ch, ih, iw, oh, ow, fh, fw, ph, pw, sh, sw*/ + testPoolLayer({2, 1, 4, 4, 2, 2, 3, 3, 0, 0, 2, 2}); + testPoolLayer({10, 8, 16, 16, 8, 8, 2, 2, 0, 0, 2, 2}); + testPoolLayer({4, 2, 5, 5, 3, 3, 3, 3, 1, 1, 2, 2}); + testPoolLayer({8, 16, 56, 56, 28, 28, 3, 3, 0, 0, 2, 2}); + testPoolLayer({8, 16, 14, 14, 7, 7, 3, 3, 0, 0, 2, 2}); + testPoolLayer({4, 16, 7, 7, 1, 1, 7, 7, 0, 0, 1, 1}); + testPoolLayer({4, 2, 5, 5, 3, 3, 5, 5, 1, 1, 1, 1}); + testPoolLayer({2, 8, 56, 56, 29, 29, 3, 3, 1, 1, 2, 2}); +} + // TODO(TJ): add branch test int main(int argc, char** argv) { diff --git a/paddle/math/BaseMatrix.cu b/paddle/math/BaseMatrix.cu index 5435808fb7f70fdf1ac98815f7fe8890fb85527c..53dd5383601782231e6e742784007d1c9154dc6b 100644 --- a/paddle/math/BaseMatrix.cu +++ b/paddle/math/BaseMatrix.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "BaseMatrix.h" #include "MathFunctions.h" +#include "NEONFunctions.h" #include "SIMDFunctions.h" #include "hl_matrix_apply.cuh" #include "hl_matrix_base.cuh" @@ -666,6 +667,13 @@ void BaseMatrixT::relu(BaseMatrixT& b) { applyBinary(binary::Relu(), b); } +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +template <> +void BaseMatrixT::relu(BaseMatrixT& b) { + neon::relu(data_, b.data_, height_ * width_); +} +#endif + DEFINE_MATRIX_BINARY_OP(ReluDerivative, a *= (b > 0.0f ? 1.0f : 0.0f)); template void BaseMatrixT::reluDerivative(BaseMatrixT& b) { diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 4a2132c8d1bfa329ced575f9b78052bdbfe3e4d5..0023b4d0f5da500f380ecb836b7c54e050b13d67 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1033,17 +1033,15 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat, real* inputData = inputMat.getData(); size_t frameNum = inputMat.getHeight(); - size_t width = imgSizeW; - size_t height = imgSizeH; - CHECK(height * width * channels == inputMat.getWidth()); + CHECK(imgSizeH * imgSizeW * channels == inputMat.getWidth()); CHECK(height_ == inputMat.getHeight()); CHECK(width_ == outputH * outputW * channels); hl_maxpool_forward(frameNum, inputData, channels, - height, - width, + imgSizeH, + imgSizeW, outputH, outputW, sizeX, @@ -1080,11 +1078,8 @@ void GpuMatrix::maxPoolBackward(Matrix& inputMat, real* outDiff = outGrad.getData(); size_t frameNum = inputMat.getHeight(); size_t channels = outV.getWidth() / outputH / outputW; - size_t width = imgSizeW; - size_t height = imgSizeH; - CHECK(height * width * channels == inputMat.getWidth()); + CHECK(imgSizeH * imgSizeW * channels == inputMat.getWidth()); CHECK(height_ == inputMat.getHeight()); - CHECK(width_ == width * height * channels); CHECK(outGrad.getHeight() == outV.getHeight() && outGrad.getWidth() == outV.getWidth()); @@ -1093,8 +1088,8 @@ void GpuMatrix::maxPoolBackward(Matrix& inputMat, outData, outDiff, channels, - height, - width, + imgSizeH, + imgSizeW, outputH, outputW, sizeX, @@ -1125,17 +1120,15 @@ void GpuMatrix::avgPoolForward(Matrix& inputMat, real* inputData = inputMat.getData(); size_t frameNum = inputMat.getHeight(); - size_t height = imgSizeH; - size_t width = imgSizeW; - CHECK(height * width * channels == inputMat.getWidth()); + CHECK(imgSizeH * imgSizeW * channels == inputMat.getWidth()); CHECK(height_ == inputMat.getHeight()); CHECK(width_ == outputH * outputW * channels); hl_avgpool_forward(frameNum, inputData, channels, - height, - width, + imgSizeH, + imgSizeW, outputH, outputW, sizeX, @@ -1166,17 +1159,15 @@ void GpuMatrix::avgPoolBackward(Matrix& outGrad, real* outDiff = outGrad.getData(); size_t frameNum = outGrad.getHeight(); size_t channels = outGrad.getWidth() / outputH / outputW; - size_t height = imgSizeH; - size_t width = imgSizeW; - CHECK(height * width * channels == width_); + CHECK(imgSizeH * imgSizeW * channels == width_); CHECK(height_ == outGrad.getHeight()); CHECK(outGrad.getWidth() == outputH * outputW * channels); hl_avgpool_backward(frameNum, outDiff, channels, - height, - width, + imgSizeH, + imgSizeW, outputH, outputW, sizeX, @@ -1214,19 +1205,16 @@ void GpuMatrix::maxPool3DForward(Matrix& inputMat, real* inputData = inputMat.getData(); real* maxPoolIdxData = maxPoolIdx.getData(); size_t num = inputMat.getHeight(); - size_t width = imgSizeW; - size_t height = imgSizeH; - size_t depth = imgSizeD; - CHECK(depth * height * width * channels == inputMat.getWidth()); + CHECK(imgSizeD * imgSizeH * imgSizeW * channels == inputMat.getWidth()); CHECK(height_ == inputMat.getHeight()); CHECK(width_ == outputD * outputH * outputW * channels); hl_maxpool3D_forward(num, inputData, channels, - depth, - height, - width, + imgSizeD, + imgSizeH, + imgSizeW, outputD, outputH, outputW, @@ -1269,20 +1257,16 @@ void GpuMatrix::maxPool3DBackward(Matrix& outGrad, real* maxPoolIdxData = maxPoolIdx.getData(); size_t frameNum = getHeight(); size_t channels = outGrad.getWidth() / outputD / outputH / outputW; - size_t width = imgSizeW; - size_t height = imgSizeH; - size_t depth = imgSizeD; - CHECK(depth * height * width * channels == getWidth()); - CHECK(width_ == depth * width * height * channels); + CHECK(imgSizeD * imgSizeH * imgSizeW * channels == getWidth()); CHECK(outGrad.getHeight() == maxPoolIdx.getHeight() && outGrad.getWidth() == maxPoolIdx.getWidth()); hl_maxpool3D_backward(frameNum, outDiff, channels, - depth, - height, - width, + imgSizeD, + imgSizeH, + imgSizeW, outputD, outputH, outputW, @@ -1323,19 +1307,16 @@ void GpuMatrix::avgPool3DForward(Matrix& inputMat, real* inputData = inputMat.getData(); size_t frameNum = inputMat.getHeight(); - size_t height = imgSizeH; - size_t width = imgSizeW; - size_t depth = imgSizeD; - CHECK(depth * height * width * channels == inputMat.getWidth()); + CHECK(imgSizeD * imgSizeH * imgSizeW * channels == inputMat.getWidth()); CHECK(height_ == inputMat.getHeight()); CHECK(width_ == outputD * outputH * outputW * channels); hl_avgpool3D_forward(frameNum, inputData, channels, - depth, - height, - width, + imgSizeD, + imgSizeH, + imgSizeW, outputD, outputH, outputW, @@ -1375,19 +1356,16 @@ void GpuMatrix::avgPool3DBackward(Matrix& outGrad, real* outDiff = outGrad.getData(); size_t frameNum = outGrad.getHeight(); size_t channels = outGrad.getWidth() / outputD / outputH / outputW; - size_t height = imgSizeH; - size_t width = imgSizeW; - size_t depth = imgSizeD; - CHECK(depth * height * width * channels == width_); + CHECK(imgSizeD * imgSizeH * imgSizeW * channels == width_); CHECK(height_ == outGrad.getHeight()); CHECK(outGrad.getWidth() == outputD * outputH * outputW * channels); hl_avgpool3D_backward(frameNum, outDiff, channels, - depth, - height, - width, + imgSizeD, + imgSizeH, + imgSizeW, outputD, outputH, outputW, @@ -1999,11 +1977,11 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, real* inputData = inputMat.getData(); real* outData = data_; size_t num = inputMat.getHeight(); - size_t inWidth = imgSizeW; - size_t inHeight = imgSizeH; - CHECK(inHeight * inWidth == inputMat.getWidth() / channels); + size_t inLength = imgSizeH * imgSizeW; + size_t outLength = outputH * outputW; + CHECK(inLength == inputMat.getWidth() / channels); CHECK_EQ(num, this->getHeight()); - CHECK_EQ(channels * outputH * outputW, this->getWidth()); + CHECK_EQ(channels * outLength, this->getWidth()); size_t outStride = getStride(); /* initialize the data_ */ @@ -2020,24 +1998,24 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, } for (size_t c = 0; c < channels; ++c) { // channel by channel for (size_t ph = 0; ph < outputH; ++ph) { + int hstart = ph * strideH - paddingH; + int hend = std::min(hstart + sizeY, imgSizeH); + hstart = std::max(hstart, 0); for (size_t pw = 0; pw < outputW; ++pw) { - int hstart = ph * strideH - paddingH; int wstart = pw * strideW - paddingW; - int hend = std::min(hstart + sizeY, inHeight); - int wend = std::min(wstart + sizeX, inWidth); - hstart = std::max(hstart, 0); + int wend = std::min(wstart + sizeX, imgSizeW); wstart = std::max(wstart, 0); for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - outData[ph * outputW + pw] = std::max(outData[ph * outputW + pw], - inputData[h * inWidth + w]); + outData[ph * outputW + pw] = std::max( + outData[ph * outputW + pw], inputData[h * imgSizeW + w]); } } } } // compute offset - inputData += inHeight * inWidth; - outData += outputH * outputW; + inputData += inLength; + outData += outLength; } } } @@ -2058,8 +2036,10 @@ void CpuMatrix::maxPoolBackward(Matrix& image, size_t paddingH, size_t paddingW) { size_t num = image.getHeight(); - size_t channels = size_t(width_ / imgSizeH / imgSizeW); - CHECK(image.getWidth() == imgSizeH * imgSizeW * channels); + size_t inLength = imgSizeH * imgSizeW; + size_t outLength = outputH * outputW; + size_t channels = size_t(width_ / inLength); + CHECK(image.getWidth() == inLength * channels); CHECK(image.getHeight() == height_ && image.getWidth() == width_); CHECK(outV.getHeight() == outGrad.getHeight() && outV.getWidth() == outGrad.getWidth()); @@ -2080,12 +2060,12 @@ void CpuMatrix::maxPoolBackward(Matrix& image, } for (size_t c = 0; c < channels; ++c) { for (size_t ph = 0; ph < outputH; ++ph) { + int hstart = ph * strideH - paddingH; + int hend = std::min(hstart + sizeY, imgSizeH); + hstart = std::max(hstart, 0); for (size_t pw = 0; pw < outputW; ++pw) { - int hstart = ph * strideH - paddingH; int wstart = pw * strideW - paddingW; - int hend = std::min(hstart + sizeY, imgSizeH); int wend = std::min(wstart + sizeX, imgSizeW); - hstart = std::max(hstart, 0); wstart = std::max(wstart, 0); for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { @@ -2098,10 +2078,10 @@ void CpuMatrix::maxPoolBackward(Matrix& image, } } // offset - inData += imgSizeH * imgSizeW; - tgtGrad += imgSizeH * imgSizeW; - otData += outputH * outputW; - otGrad += outputH * outputW; + inData += inLength; + tgtGrad += inLength; + otData += outLength; + otGrad += outLength; } } } @@ -2120,10 +2100,10 @@ void CpuMatrix::avgPoolForward(Matrix& input, size_t paddingW) { // The main loop size_t num = input.getHeight(); - size_t inHeight = imgSizeH; - size_t inWidth = imgSizeW; - CHECK(inHeight * inWidth * channels == input.getWidth()); - CHECK(outputH * outputW * channels * num == height_ * width_); + size_t inLength = imgSizeH * imgSizeW; + size_t outLength = outputH * outputW; + CHECK(inLength * channels == input.getWidth()); + CHECK(outLength * channels * num == height_ * width_); real* tgtData = data_; real* inData = input.getData(); @@ -2133,30 +2113,27 @@ void CpuMatrix::avgPoolForward(Matrix& input, } for (size_t c = 0; c < channels; ++c) { for (size_t ph = 0; ph < outputH; ++ph) { + int hstart = ph * strideH - paddingH; + int hend = std::min(hstart + sizeY, imgSizeH); + hstart = std::max(hstart, 0); for (size_t pw = 0; pw < outputW; ++pw) { - int hstart = ph * strideH - paddingH; int wstart = pw * strideW - paddingW; - int hend = std::min(hstart + sizeY, inHeight + paddingH); - int wend = std::min(wstart + sizeX, inWidth + paddingW); - int poolSize = (hend - hstart) * (wend - wstart); - hstart = std::max(hstart, 0); + int wend = std::min(wstart + sizeX, imgSizeW); wstart = std::max(wstart, 0); - hend = std::min(hend, static_cast(inHeight)); - wend = std::min(wend, static_cast(inWidth)); - - CHECK(poolSize); tgtData[ph * outputW + pw] = 0; // clear for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - tgtData[ph * outputW + pw] += inData[h * inWidth + w]; + tgtData[ph * outputW + pw] += inData[h * imgSizeW + w]; } } + int poolSize = (hend - hstart) * (wend - wstart); + CHECK(poolSize); tgtData[ph * outputW + pw] /= poolSize; } } // compute offset - inData += inHeight * inWidth; - tgtData += outputH * outputW; + inData += inLength; + tgtData += outLength; } } } @@ -2176,7 +2153,9 @@ void CpuMatrix::avgPoolBackward(Matrix& input, size_t paddingW) { size_t num = input.getHeight(); size_t channels = input.getWidth() / outputH / outputW; - CHECK(imgSizeH * imgSizeW * channels == getWidth()); + size_t inLength = imgSizeH * imgSizeW; + size_t outLength = outputH * outputW; + CHECK(inLength * channels == getWidth()); real* inData = input.getData(); real* outData = getData(); @@ -2186,16 +2165,14 @@ void CpuMatrix::avgPoolBackward(Matrix& input, } for (size_t c = 0; c < channels; ++c) { for (size_t ph = 0; ph < outputH; ++ph) { + int hstart = ph * strideH - paddingH; + int hend = std::min(hstart + sizeY, imgSizeH); + hstart = std::max(hstart, 0); for (size_t pw = 0; pw < outputW; ++pw) { - int hstart = ph * strideH - paddingH; int wstart = pw * strideW - paddingW; - int hend = std::min(hstart + sizeY, imgSizeH + paddingH); - int wend = std::min(wstart + sizeX, imgSizeW + paddingW); - int poolSize = (hend - hstart) * (wend - wstart); - hstart = std::max(hstart, 0); + int wend = std::min(wstart + sizeX, imgSizeW); wstart = std::max(wstart, 0); - hend = std::min(hend, static_cast(imgSizeH)); - wend = std::min(wend, static_cast(imgSizeW)); + int poolSize = (hend - hstart) * (wend - wstart); CHECK(poolSize); for (int h = hstart; h < hend; ++h) { @@ -2206,8 +2183,8 @@ void CpuMatrix::avgPoolBackward(Matrix& input, } } // offset - outData += imgSizeH * imgSizeW; - inData += outputH * outputW; + outData += inLength; + inData += outLength; } } } @@ -2234,12 +2211,11 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat, real* outData = getData(); real* maxPoolIdxData = maxPoolIdx.getData(); size_t num = inputMat.getHeight(); - size_t inWidth = imgSizeW; - size_t inHeight = imgSizeH; - size_t inDepth = imgSizeD; - CHECK(inHeight * inWidth * inDepth == inputMat.getWidth() / channels); + size_t inLength = imgSizeH * imgSizeW * imgSizeD; + size_t outLength = outputH * outputW * outputD; + CHECK(inLength == inputMat.getWidth() / channels); CHECK_EQ(num, this->getHeight()); - CHECK_EQ(channels * outputH * outputW * outputD, this->getWidth()); + CHECK_EQ(channels * outLength, this->getWidth()); size_t outStride = getStride(); /* initialize the data_ */ @@ -2258,16 +2234,16 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat, } for (size_t c = 0; c < channels; ++c) { // channel by channel for (size_t pd = 0; pd < outputD; ++pd) { + int dstart = pd * strideD - paddingD; + int dend = std::min(dstart + sizeZ, imgSizeD); + dstart = std::max(dstart, 0); for (size_t ph = 0; ph < outputH; ++ph) { + int hstart = ph * strideH - paddingH; + int hend = std::min(hstart + sizeY, imgSizeH); + hstart = std::max(hstart, 0); for (size_t pw = 0; pw < outputW; ++pw) { - int dstart = pd * strideD - paddingD; - int hstart = ph * strideH - paddingH; int wstart = pw * strideW - paddingW; - int dend = std::min(dstart + sizeZ, inDepth); - int hend = std::min(hstart + sizeY, inHeight); - int wend = std::min(wstart + sizeX, inWidth); - dstart = std::max(dstart, 0); - hstart = std::max(hstart, 0); + int wend = std::min(wstart + sizeX, imgSizeW); wstart = std::max(wstart, 0); int maxIdx = -1; real maxOutData = outData[(pd * outputH + ph) * outputW + pw]; @@ -2275,9 +2251,9 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat, for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { if (maxOutData < - inputData[(d * inHeight + h) * inWidth + w]) { - maxOutData = inputData[(d * inHeight + h) * inWidth + w]; - maxIdx = (d * inHeight + h) * inWidth + w; + inputData[(d * imgSizeH + h) * imgSizeW + w]) { + maxOutData = inputData[(d * imgSizeH + h) * imgSizeW + w]; + maxIdx = (d * imgSizeH + h) * imgSizeW + w; } } } @@ -2288,9 +2264,9 @@ void CpuMatrix::maxPool3DForward(Matrix& inputMat, } } // compute offset - inputData += inDepth * inHeight * inWidth; - outData += outputD * outputH * outputW; - maxPoolIdxData += outputD * outputH * outputW; + inputData += inLength; + outData += outLength; + maxPoolIdxData += outLength; } } } @@ -2315,7 +2291,9 @@ void CpuMatrix::maxPool3DBackward(Matrix& outGrad, real scaleTargets, real scaleOutput) { size_t num = getHeight(); - size_t channels = size_t(width_ / imgSizeD / imgSizeH / imgSizeW); + size_t inLength = imgSizeH * imgSizeW * imgSizeD; + size_t outLength = outputH * outputW * outputD; + size_t channels = size_t(width_ / inLength); CHECK(maxPoolIdx.getHeight() == outGrad.getHeight() && maxPoolIdx.getWidth() == outGrad.getWidth()); @@ -2341,9 +2319,9 @@ void CpuMatrix::maxPool3DBackward(Matrix& outGrad, } } // offset - tgtGrad += imgSizeD * imgSizeH * imgSizeW; - otGrad += outputD * outputH * outputW; - maxPoolIdxData += outputD * outputH * outputW; + tgtGrad += inLength; + otGrad += outLength; + maxPoolIdxData += outLength; } } } @@ -2367,11 +2345,10 @@ void CpuMatrix::avgPool3DForward(Matrix& input, size_t paddingW) { // The main loop size_t num = input.getHeight(); - size_t inDepth = imgSizeD; - size_t inHeight = imgSizeH; - size_t inWidth = imgSizeW; - CHECK(inDepth * inHeight * inWidth * channels == input.getWidth()); - CHECK(outputD * outputH * outputW * channels * num == height_ * width_); + size_t inLength = imgSizeH * imgSizeW * imgSizeD; + size_t outLength = outputH * outputW * outputD; + CHECK(inLength * channels == input.getWidth()); + CHECK(outLength * channels * num == height_ * width_); real* tgtData = getData(); real* inData = input.getData(); @@ -2381,39 +2358,36 @@ void CpuMatrix::avgPool3DForward(Matrix& input, } for (size_t c = 0; c < channels; ++c) { for (size_t pd = 0; pd < outputD; ++pd) { + int dstart = pd * strideD - paddingD; + int dend = std::min(dstart + sizeZ, imgSizeD); + dstart = std::max(dstart, 0); for (size_t ph = 0; ph < outputH; ++ph) { + int hstart = ph * strideH - paddingH; + int hend = std::min(hstart + sizeY, imgSizeH); + hstart = std::max(hstart, 0); for (size_t pw = 0; pw < outputW; ++pw) { - int dstart = pd * strideD - paddingD; - int hstart = ph * strideH - paddingH; int wstart = pw * strideW - paddingW; - int dend = std::min(dstart + sizeZ, inDepth + paddingD); - int hend = std::min(hstart + sizeY, inHeight + paddingH); - int wend = std::min(wstart + sizeX, inWidth + paddingW); - int poolSize = (dend - dstart) * (hend - hstart) * (wend - wstart); - dstart = std::max(dstart, 0); - hstart = std::max(hstart, 0); + int wend = std::min(wstart + sizeX, imgSizeW); wstart = std::max(wstart, 0); - dend = std::min(dend, static_cast(inDepth)); - hend = std::min(hend, static_cast(inHeight)); - wend = std::min(wend, static_cast(inWidth)); - CHECK(poolSize); tgtData[(pd * outputH + ph) * outputW + pw] = 0; // clear for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { tgtData[(pd * outputH + ph) * outputW + pw] += - inData[(d * inHeight + h) * inWidth + w]; + inData[(d * imgSizeH + h) * imgSizeW + w]; } } } + int poolSize = (dend - dstart) * (hend - hstart) * (wend - wstart); + CHECK(poolSize); tgtData[(pd * outputH + ph) * outputW + pw] /= poolSize; } } } // compute offset - inData += inDepth * inHeight * inWidth; - tgtData += outputD * outputH * outputW; + inData += inLength; + tgtData += outLength; } } } @@ -2437,8 +2411,10 @@ void CpuMatrix::avgPool3DBackward(Matrix& input, real scaleTargets, real scaleOutput) { size_t num = input.getHeight(); - size_t channels = input.getWidth() / outputD / outputH / outputW; - CHECK(imgSizeD * imgSizeH * imgSizeW * channels == getWidth()); + size_t inLength = imgSizeH * imgSizeW * imgSizeD; + size_t outLength = outputH * outputW * outputD; + size_t channels = input.getWidth() / outLength; + CHECK(inLength * channels == getWidth()); real* inData = input.getData(); real* outData = getData(); @@ -2448,21 +2424,18 @@ void CpuMatrix::avgPool3DBackward(Matrix& input, } for (size_t c = 0; c < channels; ++c) { for (size_t pd = 0; pd < outputD; ++pd) { + int dstart = pd * strideD - paddingD; + int dend = std::min(dstart + sizeZ, imgSizeD); + dstart = std::max(dstart, 0); for (size_t ph = 0; ph < outputH; ++ph) { + int hstart = ph * strideH - paddingH; + int hend = std::min(hstart + sizeY, imgSizeH); + hstart = std::max(hstart, 0); for (size_t pw = 0; pw < outputW; ++pw) { - int dstart = pd * strideD - paddingD; - int hstart = ph * strideH - paddingH; int wstart = pw * strideW - paddingW; - int dend = std::min(dstart + sizeZ, imgSizeD + paddingD); - int hend = std::min(hstart + sizeY, imgSizeH + paddingH); - int wend = std::min(wstart + sizeX, imgSizeW + paddingW); - int poolSize = (dend - dstart) * (hend - hstart) * (wend - wstart); - dstart = std::max(dstart, 0); - hstart = std::max(hstart, 0); + int wend = std::min(wstart + sizeX, imgSizeW); wstart = std::max(wstart, 0); - dend = std::min(dend, static_cast(imgSizeD)); - hend = std::min(hend, static_cast(imgSizeH)); - wend = std::min(wend, static_cast(imgSizeW)); + int poolSize = (dend - dstart) * (hend - hstart) * (wend - wstart); CHECK(poolSize); for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { @@ -2476,8 +2449,8 @@ void CpuMatrix::avgPool3DBackward(Matrix& input, } } // offset - outData += imgSizeD * imgSizeH * imgSizeW; - inData += outputD * outputH * outputW; + outData += inLength; + inData += outLength; } } } diff --git a/paddle/math/NEONFunctions.cpp b/paddle/math/NEONFunctions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3bf47901f1069ac228fa1b877e29848d8cc130e8 --- /dev/null +++ b/paddle/math/NEONFunctions.cpp @@ -0,0 +1,55 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + +#include "NEONFunctions.h" +#include + +namespace paddle { +namespace neon { + +// b[i] = a[i] > 0.0f ? a[i] : 0.0f +void relu(const float* a, float* b, int len) { + int offset = len % 16; + float32x4_t ma0, ma1, ma2, ma3; + float32x4_t mb0, mb1, mb2, mb3; + + float32x4_t zero = vdupq_n_f32(0.f); + for (int k = 0; k < len / 16; k++, a += 16, b += 16) { + ma0 = vld1q_f32(a); + ma1 = vld1q_f32(a + 4); + ma2 = vld1q_f32(a + 8); + ma3 = vld1q_f32(a + 12); + + mb0 = vmaxq_f32(ma0, zero); + mb1 = vmaxq_f32(ma1, zero); + mb2 = vmaxq_f32(ma2, zero); + mb3 = vmaxq_f32(ma3, zero); + + vst1q_f32(b, mb0); + vst1q_f32(b + 4, mb1); + vst1q_f32(b + 8, mb2); + vst1q_f32(b + 12, mb3); + } + + for (int i = 0; i < offset; i++) { + b[i] = a[i] > 0.0f ? a[i] : 0.0f; + } +} + +} // namespace neon +} // namespace paddle + +#endif diff --git a/paddle/math/NEONFunctions.h b/paddle/math/NEONFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..69085e333547a31a341fbfde247f1e30adb957ee --- /dev/null +++ b/paddle/math/NEONFunctions.h @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace neon { + +void relu(const float* a, float* b, int len); + +} // namespace neon +} // namespace paddle diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 103f06acc57d7a23f019f5e713f6cacf2179e9e0..061fb22e3fd744d9d9895fd1008089e4a6ce6a0f 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -825,9 +825,8 @@ void testMaxPoolFwdBwd(int numSamples, int strideW, int padH, int padW) { - int outH = 0, outW = 0; - outH = (imgSizeH - ksizeH + 2 * padH + strideH - 1) / strideH + 1; - outW = (imgSizeW - ksizeW + 2 * padW + strideW - 1) / strideW + 1; + int outH = outputSize(imgSizeH, ksizeH, padH, strideH, true); + int outW = outputSize(imgSizeW, ksizeW, padW, strideW, true); int inWidth = imgSizeH * imgSizeW * channels; MatrixPtr input = CpuMatrix::create(numSamples, inWidth, false, false); @@ -927,9 +926,8 @@ void testAvgPoolFwdBwd(int numSamples, int strideW, int padH, int padW) { - int outH = 0, outW = 0; - outH = (imgSizeH - ksizeH + 2 * padH + strideH - 1) / strideH + 1; - outW = (imgSizeW - ksizeW + 2 * padW + strideW - 1) / strideW + 1; + int outH = outputSize(imgSizeH, ksizeH, padH, strideH, true); + int outW = outputSize(imgSizeW, ksizeW, padW, strideW, true); int inWidth = imgSizeH * imgSizeW * channels; MatrixPtr input = CpuMatrix::create(numSamples, inWidth, false, false); diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..953367eb8bcd1282ab6c7e1189d778f0ce3da541 --- /dev/null +++ b/paddle/operators/cross_entropy_op.cc @@ -0,0 +1,147 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/cross_entropy_op.h" + +namespace paddle { +namespace operators { + +using framework::LoDTensor; + +class CrossEntropyOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), + "Input(Label) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), "Output(Y) must not be null."); + + auto x = ctx.Input("X"); + auto label = ctx.Input("Label"); + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); + PADDLE_ENFORCE_EQ(label->dims().size(), 2, + "Input(Label)'s rank must be 2."); + // TODO(xinghai-sun): remove this check after swtiching to bool + PADDLE_ENFORCE(ctx.Attr("soft_label") == 0 || + ctx.Attr("soft_label") == 1); + PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], + "The 1st dimension of Input(X) and Input(Label) must " + "be equal."); + if (ctx.Attr("soft_label") == 1) { + PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], + "If Attr(soft_label) == 1, The 2nd dimension of " + "Input(X) and Input(Label) must be equal."); + } else { + PADDLE_ENFORCE_EQ(label->dims()[1], 1, + "If Attr(soft_label) == 0, The 2nd dimension of " + "Input(Label) must be 1."); + } + + ctx.Output("Y")->Resize({x->dims()[0], 1}); + } +}; + +class CrossEntropyGradientOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), + "Input(Label) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), + "Input(Y@GRAD) must not be null."); + + auto x = ctx.Input("X"); + auto label = ctx.Input("Label"); + auto dy = ctx.Input(framework::GradVarName("Y")); + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); + PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2."); + PADDLE_ENFORCE_EQ(label->dims().size(), 2, + "Input(Label)'s rank must be 2."); + // TODO(xinghai-sun): remove this check after swtiching to bool + PADDLE_ENFORCE(ctx.Attr("soft_label") == 0 || + ctx.Attr("soft_label") == 1); + PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], + "The 1st dimension of Input(X) and Input(Label) must " + "be equal."); + PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0], + "The 1st dimension of Input(X) and Input(Y@Grad) must " + "be equal."); + PADDLE_ENFORCE_EQ(dy->dims()[1], 1, + "The 2nd dimension of Input(Y@Grad) must be 1."); + if (ctx.Attr("soft_label") == 1) { + PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], + "If Attr(soft_label) == 1, The 2nd dimension of " + "Input(X) and Input(Label) must be equal."); + } else { + PADDLE_ENFORCE_EQ(label->dims()[1], 1, + "If Attr(soft_label) == 0, The 2nd dimension of " + "Input(Label) must be 1."); + } + + auto dx = ctx.Output(framework::GradVarName("X")); + dx->Resize(x->dims()); + } +}; + +class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { + public: + CrossEntropyOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The first input of CrossEntropyOp"); + AddInput("Label", "The second input of CrossEntropyOp"); + AddOutput("Y", "The output of CrossEntropyOp"); + AddAttr("soft_label", "Is soft label. Default zero.").SetDefault(0); + + AddComment(R"DOC( +CrossEntropy Operator. + +It supports both standard cross-entropy and soft-label cross-entropy loss +computation. +1) One-hot cross-entropy: + soft_label = 0, Label[i, 0] indicates the class index for sample i: + + Y[i] = -log(X[i, Label[i]]) + +2) Soft-label cross-entropy: + soft_label = 1, Label[i, j] indicates the soft label of class j + for sample i: + + Y[i] = \sum_j{-Label[i, j] * log(X[i, j])} + + Please make sure that in this case the summuation of each row of Label + equals one. + +3) One-hot cross-entropy with vecterized Input(Label): + As a special case of 2), when each row of Input(Label) has only one + non-zero element (equals 1), soft-label cross-entropy degenerates to a + one-hot cross-entropy with one-hot label representation. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker, + cross_entropy_grad, ops::CrossEntropyGradientOp); +REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel); +REGISTER_OP_CPU_KERNEL(cross_entropy_grad, + ops::CrossEntropyGradientOpKernel); diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..ab6ad0e062269483948bf70e492c9431991221fb --- /dev/null +++ b/paddle/operators/cross_entropy_op.cu @@ -0,0 +1,158 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/cross_entropy_op.h" +#include "paddle/platform/assert.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +template +__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, + const int N, const int D) { + // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. + // CUDA_1D_KERNEL_LOOP(i, N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + PADDLE_ASSERT(label[i] >= 0 && label[i] < D); + Y[i] = -tolerable_value(log(X[i * D + label[i]])); + } +} + +template +__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, + const int N, const int D) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + T sum = static_cast(0); + for (int j = 0; j < D; j++) { + sum += label[i * D + j] * tolerable_value(log(X[i * D + j])); + } + Y[i] = -sum; + } +} + +// TODO(qingqing): make zero setting an common function. +template +__global__ void zero(T* X, const int N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + X[i] = 0.0; + } +} + +template +__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X, + const int* label, const int N, + const int D) { + // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. + // CUDA_1D_KERNEL_LOOP(i, N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + int idx = i * D + label[i]; + dX[idx] = -dY[i] / X[idx]; + } +} + +template +__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X, + const T* label, const int N, + const int D) { + // TOOD(qingqing): optimize for this kernel + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + for (int j = 0; j < D; ++j) { + int idx = i * D + j; + dX[idx] = -label[idx] * dY[i] / X[idx]; + } + } +} + +template +class CrossEntropyOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto x = ctx.Input("X"); + auto y = ctx.Output("Y"); + auto label = ctx.Input("Label"); + + auto* x_data = x->data(); + y->mutable_data(ctx.GetPlace()); + auto* y_data = y->data(); + + int n = x->dims()[0]; + int d = x->dims()[1]; + int block = 512; + int grid = (n + block - 1) / block; + // TODO(qingqing) launch kernel on specified stream + // base on ExecutionContext. + if (ctx.Attr("soft_label") == 1) { + auto* label_data = ctx.Input("Label")->data(); + SoftCrossEntropyKernel<<>>(y_data, x_data, label_data, n, + d); + } else { + auto* label_data = ctx.Input("Label")->data(); + CrossEntropyKernel<<>>(y_data, x_data, label_data, n, d); + } + } +}; + +template +class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto x = ctx.Input("X"); + auto dx = ctx.Output(framework::GradVarName("X")); + auto dy = ctx.Input(framework::GradVarName("Y")); + auto label = ctx.Input("Label"); + + auto* dx_data = dx->mutable_data(ctx.GetPlace()); + auto* dy_data = dy->data(); + auto* x_data = x->data(); + + int n = x->dims()[0]; + int d = x->dims()[1]; + int block = 512; + int grid = (n * d + block - 1) / block; + zero<<>>(dx_data, n * d); + grid = (n + block - 1) / block; + // TODO(qingqing): launch kernel on specified stream + // base on ExecutionContext. + if (ctx.Attr("soft_label") == 1) { + auto* label_data = label->data(); + SoftCrossEntropyGradientKernel<<>>( + dx_data, dy_data, x_data, label_data, n, d); + } else { + auto* label_data = label->data(); + CrossEntropyGradientKernel<<>>(dx_data, dy_data, x_data, + label_data, n, d); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel); +REGISTER_OP_GPU_KERNEL(cross_entropy_grad, + ops::CrossEntropyGradientOpCUDAKernel); diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h new file mode 100644 index 0000000000000000000000000000000000000000..1b4b23ac2029138afadef0168262203ac2e20430 --- /dev/null +++ b/paddle/operators/cross_entropy_op.h @@ -0,0 +1,117 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +HOSTDEVICE T tolerable_value(const T x) { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + if (x == INFINITY) { + return kApproInf; + } + if (x == -INFINITY) { + return -kApproInf; + } + return x; +} + +template +class CrossEntropyOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto x = ctx.Input("X"); + auto y = ctx.Output("Y"); + + auto* x_data = x->data(); + y->mutable_data(ctx.GetPlace()); + auto* y_data = y->data(); + + int batch_size = x->dims()[0]; + int class_num = x->dims()[1]; + + if (ctx.Attr("soft_label") == 1) { + auto* label_data = ctx.Input("Label")->data(); + int index = 0; + for (int i = 0; i < batch_size; ++i) { + T sum = static_cast(0); + for (int j = 0; j < class_num; ++j) { + sum += label_data[index] * tolerable_value(std::log(x_data[index])); + y_data[i] = -sum; + index++; + } + } + } else { + auto* label_data = ctx.Input("Label")->data(); + for (int i = 0; i < batch_size; ++i) { + int index = i * class_num + label_data[i]; + y_data[i] = -tolerable_value(std::log(x_data[index])); + } + } + } +}; + +template +class CrossEntropyGradientOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto x = ctx.Input("X"); + auto dx = ctx.Output(framework::GradVarName("X")); + auto dy = ctx.Input(framework::GradVarName("Y")); + auto label = ctx.Input("Label"); + + auto* dx_data = dx->mutable_data(ctx.GetPlace()); + auto* dy_data = dy->data(); + auto* x_data = x->data(); + + int batch_size = x->dims()[0]; + int class_num = x->dims()[1]; + + // TODO(qingqing): make zero setting an common function. + if (ctx.Attr("soft_label") == 1) { + auto* label_data = ctx.Input("Label")->data(); + int index = 0; + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < class_num; ++j) { + dx_data[index] = -label_data[index] * dy_data[i] / x_data[index]; + index++; + } + } + } else { + auto* label_data = label->data(); + memset(dx_data, 0, sizeof(T) * batch_size * class_num); + for (int i = 0; i < batch_size; ++i) { + PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num); + int index = i * class_num + label_data[i]; + dx_data[index] = -dy_data[i] / x_data[index]; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/dropout_op.cc b/paddle/operators/dropout_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b111b9fccb2310bd5fb92bda878a497c51f62ce0 --- /dev/null +++ b/paddle/operators/dropout_op.cc @@ -0,0 +1,113 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/dropout_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; +using framework::LoDTensor; + +class DropoutOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_GE(ctx.Attr("dropout_prob"), 0); + PADDLE_ENFORCE_LE(ctx.Attr("dropout_prob"), 1); + // TODO(xinghai-sun): remove this check after swtiching to bool + PADDLE_ENFORCE(ctx.Attr("is_training") == 0 || + ctx.Attr("is_training") == 1); + + auto dims = ctx.Input("X")->dims(); + ctx.Output("Out")->Resize(dims); + if (ctx.Attr("is_training") == 1) { + ctx.Output("Mask")->Resize(dims); + } + } +}; + +template +class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { + public: + DropoutOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("dropout_prob", "Probability of setting units to zero.") + .SetDefault(.5f); + // TODO(xinghai-sun): use bool for is_training after bool is supported. + AddAttr("is_training", "Whether in training phase.").SetDefault(1); + AddAttr("seed", "Dropout random seed.").SetDefault(0); + AddInput("X", "The input of dropout op."); + AddOutput("Out", "The output of dropout op."); + AddOutput("Mask", "The random sampled dropout mask.").AsIntermediate(); + + AddComment(R"DOC( +Dropout Operator. + +"Dropout" refers to randomly dropping out units in a nerual network. It is a +regularization technique for reducing overfitting by preventing neuron +co-adaption during training. The dropout operator randomly set (according to +the given dropout probability) the outputs of some units to zero, while others +being set to their inputs. +)DOC"); + } +}; + +template +class DropoutOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_EQ(ctx.Attr("is_training"), 1, + "GradOp is only callable when is_training is true"); + + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Mask"), "Mask must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) must not be null."); + + PADDLE_ENFORCE_GE(ctx.Attr("dropout_prob"), 0); + PADDLE_ENFORCE_LE(ctx.Attr("dropout_prob"), 1); + // TODO(xinghai-sun): remove this check after swtiching to bool + PADDLE_ENFORCE(ctx.Attr("is_training") == 0 || + ctx.Attr("is_training") == 1); + auto x_dims = ctx.Input("X")->dims(); + auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); + PADDLE_ENFORCE_EQ(x_dims, out_dims, + "Dimensions of Input(X) and Out@Grad must be the same."); + auto mask_dims = ctx.Input("Mask")->dims(); + PADDLE_ENFORCE_EQ(x_dims, mask_dims, + "Dimensions of Input(X) and Mask must be the same."); + + auto *x_grad = ctx.Output(framework::GradVarName("X")); + x_grad->Resize(x_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad, + ops::DropoutOpGrad); +REGISTER_OP_CPU_KERNEL( + dropout, ops::CPUDropoutKernel); +REGISTER_OP_CPU_KERNEL( + dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/operators/dropout_op.cu b/paddle/operators/dropout_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..186237fb238add37f32403309a0f7e8a9846d335 --- /dev/null +++ b/paddle/operators/dropout_op.cu @@ -0,0 +1,86 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include +#include +#include +#include +#include "paddle/operators/dropout_op.h" + +namespace paddle { +namespace operators { + +template +struct MaskGenerator { + AttrType dropout_prob; + int seed; + + __host__ __device__ MaskGenerator(AttrType dropout_prob, int seed) + : dropout_prob(dropout_prob), seed(seed) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed); + thrust::uniform_real_distribution dist(0, 1); + rng.discard(n); + if (dist(rng) < dropout_prob) { + return static_cast(0); + } else { + return static_cast(1); + } + } +}; + +// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. +// Use std::random and thrust::random(thrust is a std library in CUDA) to +// implement uniform random. +template +class GPUDropoutKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* y = context.Output("Out"); + y->mutable_data(context.GetPlace()); + AttrType dropout_prob = context.Attr("dropout_prob"); + + auto X = EigenMatrix::Reshape(*x, 1); + auto Y = EigenMatrix::Reshape(*y, 1); + + auto place = context.GetEigenDevice(); + if (context.Attr("is_training") == 1) { + auto* mask = context.Output("Mask"); + auto* mask_data = mask->mutable_data(context.GetPlace()); + int size = framework::product(mask->dims()); + int seed = context.Attr("seed"); + thrust::counting_iterator index_sequence_begin(0); + thrust::transform(index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(mask_data), + MaskGenerator(dropout_prob, seed)); + auto M = EigenMatrix::Reshape(*mask, 1); + Y.device(place) = X * M; + } else { + Y.device(place) = X * dropout_prob; + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + dropout, ops::GPUDropoutKernel); +REGISTER_OP_GPU_KERNEL( + dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/operators/dropout_op.h b/paddle/operators/dropout_op.h new file mode 100644 index 0000000000000000000000000000000000000000..82eafee0e0e7db7b4b4ae5405f37146d061aefd5 --- /dev/null +++ b/paddle/operators/dropout_op.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +template +class CPUDropoutKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* y = context.Output("Out"); + const auto* x_data = x->data(); + auto* y_data = y->mutable_data(context.GetPlace()); + AttrType dropout_prob = context.Attr("dropout_prob"); + + if (context.Attr("is_training") == 1) { + auto* mask = context.Output("Mask"); + auto* mask_data = mask->mutable_data(context.GetPlace()); + int seed = context.Attr("seed"); + std::minstd_rand engine; + engine.seed(seed); + std::uniform_real_distribution dist(0, 1); + size_t size = framework::product(mask->dims()); + for (size_t i = 0; i < size; ++i) { + if (dist(engine) < dropout_prob) { + mask_data[i] = 0; + y_data[i] = 0; + } else { + mask_data[i] = 1; + y_data[i] = x_data[i]; + } + } + } else { + auto X = EigenMatrix::Reshape(*x, 1); + auto Y = EigenMatrix::Reshape(*y, 1); + auto place = context.GetEigenDevice(); + Y.device(place) = X * dropout_prob; + } + } +}; + +template +class DropoutGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE_EQ(context.Attr("is_training"), 1, + "GradOp is only callable when is_training is true"); + + auto* grad_x = context.Output(framework::GradVarName("X")); + auto* grad_y = context.Input(framework::GradVarName("Out")); + auto* mask = context.Input("Mask"); + grad_x->mutable_data(context.GetPlace()); + + auto M = EigenMatrix::Reshape(*mask, 1); + auto dX = EigenMatrix::Reshape(*grad_x, 1); + auto dY = EigenMatrix::Reshape(*grad_y, 1); + + auto place = context.GetEigenDevice(); + dX.device(place) = dY * M; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 1e86fc3d166077265e0f433a6712b0665ea5a152..def4b01da098fc960ce7c0e497732fbcc2579945 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -19,12 +19,13 @@ namespace operators { namespace math { template <> -void gemm(const CBLAS_TRANSPOSE transA, +void gemm(const platform::DeviceContext& context, + const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const float alpha, const float* A, - const float* B, const float beta, float* C, - platform::DeviceContext* context) { + const float* B, const float beta, + float* C) { int lda = (transA == CblasNoTrans) ? K : M; int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; @@ -33,13 +34,13 @@ void gemm(const CBLAS_TRANSPOSE transA, } template <> -void gemm(const CBLAS_TRANSPOSE transA, +void gemm(const platform::DeviceContext& context, + const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const double alpha, const double* A, const double* B, const double beta, - double* C, - platform::DeviceContext* context) { + double* C) { int lda = (transA == CblasNoTrans) ? K : M; int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; @@ -48,13 +49,10 @@ void gemm(const CBLAS_TRANSPOSE transA, } template <> -void matmul(const framework::Tensor& matrix_a, - bool trans_a, - const framework::Tensor& matrix_b, - bool trans_b, float alpha, - framework::Tensor* matrix_out, - float beta, - platform::DeviceContext* context) { +void matmul( + const platform::DeviceContext& context, const framework::Tensor& matrix_a, + bool trans_a, const framework::Tensor& matrix_b, bool trans_b, float alpha, + framework::Tensor* matrix_out, float beta) { auto dim_a = matrix_a.dims(); auto dim_b = matrix_b.dims(); auto dim_out = matrix_out->dims(); @@ -74,18 +72,15 @@ void matmul(const framework::Tensor& matrix_a, CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; gemm( - transA, transB, M, N, K, alpha, matrix_a.data(), - matrix_b.data(), beta, matrix_out->data(), context); + context, transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data()); } template <> -void matmul(const framework::Tensor& matrix_a, - bool trans_a, - const framework::Tensor& matrix_b, - bool trans_b, double alpha, - framework::Tensor* matrix_out, - double beta, - platform::DeviceContext* context) { +void matmul( + const platform::DeviceContext& context, const framework::Tensor& matrix_a, + bool trans_a, const framework::Tensor& matrix_b, bool trans_b, double alpha, + framework::Tensor* matrix_out, double beta) { auto dim_a = matrix_a.dims(); auto dim_b = matrix_b.dims(); auto dim_out = matrix_out->dims(); @@ -105,8 +100,8 @@ void matmul(const framework::Tensor& matrix_a, CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; gemm( - transA, transB, M, N, K, alpha, matrix_a.data(), - matrix_b.data(), beta, matrix_out->data(), context); + context, transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data()); } } // namespace math diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index da40b27c948918e4997f4a046d2145552296158b..71563b77b4b262c3f1e17ae7c4381da56ba780a3 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -19,12 +19,13 @@ namespace operators { namespace math { template <> -void gemm(const CBLAS_TRANSPOSE transA, +void gemm(const platform::DeviceContext& context, + const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const float alpha, const float* A, - const float* B, const float beta, float* C, - platform::DeviceContext* context) { + const float* B, const float beta, + float* C) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. int lda = (transA == CblasNoTrans) ? K : M; @@ -35,18 +36,19 @@ void gemm(const CBLAS_TRANSPOSE transA, (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; PADDLE_ENFORCE(platform::dynload::cublasSgemm( - reinterpret_cast(context)->cublas_handle(), + reinterpret_cast(context) + .cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); } template <> -void gemm(const CBLAS_TRANSPOSE transA, +void gemm(const platform::DeviceContext& context, + const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const double alpha, const double* A, const double* B, const double beta, - double* C, - platform::DeviceContext* context) { + double* C) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. int lda = (transA == CblasNoTrans) ? K : M; @@ -56,18 +58,16 @@ void gemm(const CBLAS_TRANSPOSE transA, cublasOperation_t cuTransB = (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; PADDLE_ENFORCE(platform::dynload::cublasDgemm( - reinterpret_cast(context)->cublas_handle(), + reinterpret_cast(context) + .cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); } template <> -void matmul(const framework::Tensor& matrix_a, - bool trans_a, - const framework::Tensor& matrix_b, - bool trans_b, float alpha, - framework::Tensor* matrix_out, - float beta, - platform::DeviceContext* context) { +void matmul( + const platform::DeviceContext& context, const framework::Tensor& matrix_a, + bool trans_a, const framework::Tensor& matrix_b, bool trans_b, float alpha, + framework::Tensor* matrix_out, float beta) { auto dim_a = matrix_a.dims(); auto dim_b = matrix_b.dims(); auto dim_out = matrix_out->dims(); @@ -87,18 +87,15 @@ void matmul(const framework::Tensor& matrix_a, CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; gemm( - transA, transB, M, N, K, alpha, matrix_a.data(), - matrix_b.data(), beta, matrix_out->data(), context); + context, transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data()); } template <> -void matmul(const framework::Tensor& matrix_a, - bool trans_a, - const framework::Tensor& matrix_b, - bool trans_b, double alpha, - framework::Tensor* matrix_out, - double beta, - platform::DeviceContext* context) { +void matmul( + const platform::DeviceContext& context, const framework::Tensor& matrix_a, + bool trans_a, const framework::Tensor& matrix_b, bool trans_b, double alpha, + framework::Tensor* matrix_out, double beta) { auto dim_a = matrix_a.dims(); auto dim_b = matrix_b.dims(); auto dim_out = matrix_out->dims(); @@ -118,8 +115,8 @@ void matmul(const framework::Tensor& matrix_a, CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; gemm( - transA, transB, M, N, K, alpha, matrix_a.data(), - matrix_b.data(), beta, matrix_out->data(), context); + context, transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data()); } } // namespace math diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 155589fadb3ed9f59160a750d546dd8093a56cbe..d8518e77fa7b4abdbcf08b7983013c24806e14ca 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -66,16 +66,16 @@ namespace math { // For more detailed info, please refer to // http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html template -void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, - const int M, const int N, const int K, const T alpha, const T* A, - const T* B, const T beta, T* C, platform::DeviceContext* context); +void gemm(const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, + const T alpha, const T* A, const T* B, const T beta, T* C); // matrix multiply with continuous memory template -void matmul(const framework::Tensor& matrix_a, bool trans_a, +void matmul(const platform::DeviceContext& context, + const framework::Tensor& matrix_a, bool trans_a, const framework::Tensor& matrix_b, bool trans_b, T alpha, - framework::Tensor* matrix_out, T beta, - platform::DeviceContext* context); + framework::Tensor* matrix_out, T beta); } // namespace math } // namespace operators diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc index 6c020c4ff7285b43bc5836d80c173d3a068e72b3..7e339457f7f08ff16162f399064a4b4dca594d7f 100644 --- a/paddle/operators/math/math_function_test.cc +++ b/paddle/operators/math/math_function_test.cc @@ -15,8 +15,7 @@ TEST(math_function, notrans_mul_trans) { memcpy(input1_ptr, arr, 6 * sizeof(float)); auto* gpu_place = new paddle::platform::GPUPlace(0); - paddle::platform::DeviceContext* context = - new paddle::platform::CUDADeviceContext(*gpu_place); + paddle::platform::CUDADeviceContext context(*gpu_place); input1_gpu.CopyFrom(input1, *gpu_place); input2_gpu.CopyFrom(input1, *gpu_place); @@ -24,7 +23,7 @@ TEST(math_function, notrans_mul_trans) { out_gpu.mutable_data({2, 2}, *gpu_place); paddle::operators::math::matmul( - input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0, context); + context, input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0); out.CopyFrom(out_gpu, *cpu_place); @@ -33,6 +32,7 @@ TEST(math_function, notrans_mul_trans) { EXPECT_EQ(out_ptr[1], 14); EXPECT_EQ(out_ptr[2], 14); EXPECT_EQ(out_ptr[3], 50); + delete gpu_place; } TEST(math_function, trans_mul_notrans) { @@ -48,8 +48,7 @@ TEST(math_function, trans_mul_notrans) { memcpy(input1_ptr, arr, 6 * sizeof(float)); auto* gpu_place = new paddle::platform::GPUPlace(0); - paddle::platform::DeviceContext* context = - new paddle::platform::CUDADeviceContext(*gpu_place); + paddle::platform::CUDADeviceContext context(*gpu_place); input1_gpu.CopyFrom(input1, *gpu_place); input2_gpu.CopyFrom(input1, *gpu_place); @@ -57,7 +56,7 @@ TEST(math_function, trans_mul_notrans) { out_gpu.mutable_data({3, 3}, *gpu_place); paddle::operators::math::matmul( - input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0, context); + context, input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0); out.CopyFrom(out_gpu, *cpu_place); @@ -71,5 +70,6 @@ TEST(math_function, trans_mul_notrans) { EXPECT_EQ(out_ptr[6], 15); EXPECT_EQ(out_ptr[7], 22); EXPECT_EQ(out_ptr[8], 29); + delete gpu_place; } #endif diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index 3c01f868bda8cba488b3403df456d63d6b082fa6..ac7136a76933d1f3ead86518c65d589747227631 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -46,10 +46,8 @@ class MulKernel : public framework::OpKernel { : *y; z->mutable_data(context.GetPlace()); - auto* device_context = - const_cast(context.device_context_); - math::matmul(x_matrix, false, y_matrix, false, 1, z, 0, - device_context); + math::matmul(context.device_context(), x_matrix, false, y_matrix, + false, 1, z, 0); } }; @@ -71,16 +69,14 @@ class MulGradKernel : public framework::OpKernel { Tensor* dx = ctx.Output(framework::GradVarName("X")); Tensor* dy = ctx.Output(framework::GradVarName("Y")); - auto* device_context = - const_cast(ctx.device_context_); if (dx) { dx->mutable_data(ctx.GetPlace()); Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix( *dx, x_num_col_dims) : *dx; // dx = dout * y'. dx: M x K, dout : M x N, y : K x N - math::matmul(*dout, false, y_matrix, true, 1, &dx_matrix, 0, - device_context); + math::matmul(ctx.device_context(), *dout, false, y_matrix, true, + 1, &dx_matrix, 0); } if (dy) { dy->mutable_data(ctx.GetPlace()); @@ -88,8 +84,8 @@ class MulGradKernel : public framework::OpKernel { *dy, y_num_col_dims) : *dy; // dy = x' * dout. dy K x N, dout : M x N, x : M x K - math::matmul(x_matrix, true, *dout, false, 1, &dy_matrix, 0, - device_context); + math::matmul(ctx.device_context(), x_matrix, true, *dout, false, + 1, &dy_matrix, 0); } } }; diff --git a/paddle/operators/onehot_cross_entropy_op.cc b/paddle/operators/onehot_cross_entropy_op.cc deleted file mode 100644 index f38be3549f3c5d2443f61739fc32cdca74197649..0000000000000000000000000000000000000000 --- a/paddle/operators/onehot_cross_entropy_op.cc +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/operators/onehot_cross_entropy_op.h" - -namespace paddle { -namespace operators { - -class OnehotCrossEntropyOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL( - ctx.InputVar("X"), - "Input(X) of OnehotCrossEntropyOp should not be null."); - PADDLE_ENFORCE_NOT_NULL( - ctx.InputVar("label"), - "Input(label) of OnehotCrossEntropyOp should not be null."); - PADDLE_ENFORCE_NOT_NULL( - ctx.OutputVar("Y"), - "Output(Y) of OnehotCrossEntropyOp should not be null."); - - auto *X = ctx.Input("X"); - auto *label = ctx.Input("label"); - - PADDLE_ENFORCE_EQ(X->dims().size(), 2, "X's dimension must be 2."); - PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label's dimension must be 1."); - PADDLE_ENFORCE_EQ(X->dims()[0], label->dims()[0]); - ctx.Output("Y")->Resize({X->dims()[0], 1}); - } -}; - -class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - auto dX = ctx.Output(framework::GradVarName("X")); - auto X = ctx.Input("X"); - - dX->Resize(X->dims()); - } -}; - -class OnehotCrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { - public: - OnehotCrossEntropyOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The first input of OnehotCrossEntropyOp"); - AddInput("label", "The second input of OnehotCrossEntropyOp"); - AddOutput("Y", "The output of OnehotCrossEntropyOp"); - AddComment(R"DOC( -OnehotCrossEntropy Operator. - - Y[i] = -log(X[i][j]) - -)DOC"); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, - ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad, - ops::OnehotCrossEntropyGradientOp); -REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, - ops::OnehotCrossEntropyOpKernel); -REGISTER_OP_CPU_KERNEL(onehot_cross_entropy_grad, - ops::OnehotCrossEntropyGradientOpKernel); diff --git a/paddle/operators/onehot_cross_entropy_op.cu b/paddle/operators/onehot_cross_entropy_op.cu deleted file mode 100644 index d999bfce58c8a6db5c811aad677c07094b881841..0000000000000000000000000000000000000000 --- a/paddle/operators/onehot_cross_entropy_op.cu +++ /dev/null @@ -1,133 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/framework/op_registry.h" -#include "paddle/platform/assert.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -__host__ __device__ T clipping_log(const T x) { - PADDLE_ASSERT(std::is_floating_point::value); - const T kApproInf = 1e20; - T v = log(x); - if (v == INFINITY) { - return kApproInf; - } - if (v == -INFINITY) { - return -kApproInf; - } - return v; -} - -template -__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, - const int N, const int D) { - // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. - // CUDA_1D_KERNEL_LOOP(i, N) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - PADDLE_ASSERT(label[i] >= 0 && label[i] < D); - Y[i] = -clipping_log(X[i * D + label[i]]); - } -} - -// TODO(qingqing): make zero setting an common function. -template -__global__ void zero(T* X, const int N) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - X[i] = 0.0; - } -} - -template -__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X, - const int* label, const int N, - const int D) { - // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. - // CUDA_1D_KERNEL_LOOP(i, N) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - int idx = i * D + label[i]; - dX[idx] = -dY[i] / X[idx]; - } -} - -template -class OnehotCrossEntropyOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use GPUPlace."); - - auto X = ctx.Input("X"); - const T* Xdata = X->data(); - const int* label_data = ctx.Input("label")->data(); - auto Y = ctx.Output("Y"); - Y->mutable_data(ctx.GetPlace()); - T* Ydata = Y->data(); - - int N = X->dims()[0]; - int D = X->dims()[1]; - int block = 512; - int grid = (N + block - 1) / block; - // TODO(qingqing) launch kernel on specified stream - // base on ExecutionContext. - CrossEntropyKernel<<>>(Ydata, Xdata, label_data, N, D); - } -}; - -template -class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use GPUPlace."); - - auto X = ctx.Input("X"); - auto dX = ctx.Output(framework::GradVarName("X")); - auto dY = ctx.Input(framework::GradVarName("Y")); - auto label = ctx.Input("label"); - - auto* dXdata = dX->template mutable_data(ctx.GetPlace()); - auto* dYdata = dY->template data(); - auto* Xdata = X->template data(); - auto* label_data = label->data(); - - int N = X->dims()[0]; - int D = X->dims()[1]; - int block = 512; - int grid = (N * D + block - 1) / block; - zero<<>>(dXdata, N * D); - - grid = (N + block - 1) / block; - // TODO(qingqing): launch kernel on specified stream - // base on ExecutionContext. - CrossEntropyGradientKernel<<>>(dXdata, dYdata, Xdata, - label_data, N, D); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, - ops::OnehotCrossEntropyOpCUDAKernel); -REGISTER_OP_GPU_KERNEL(onehot_cross_entropy_grad, - ops::OnehotCrossEntropyGradientOpCUDAKernel); diff --git a/paddle/operators/onehot_cross_entropy_op.h b/paddle/operators/onehot_cross_entropy_op.h deleted file mode 100644 index eb4d1348de1d940e2648c83c8ba94b289f10c5b2..0000000000000000000000000000000000000000 --- a/paddle/operators/onehot_cross_entropy_op.h +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -inline T tolerable_value(const T x) { - static_assert(std::is_floating_point::value, - "tolerable_value works only on float, " - "double and double double."); - - const T kApproInf = 1e20; - - if (x == INFINITY) { - return kApproInf; - } - - if (x == -INFINITY) { - return -kApproInf; - } - - return x; -} - -template -class OnehotCrossEntropyOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), - "It must use CPUPlace."); - - auto X = ctx.Input("X"); - const T* Xdata = X->data(); - const int* label_data = ctx.Input("label")->data(); - auto Y = ctx.Output("Y"); - - Y->mutable_data(ctx.GetPlace()); - - T* Ydata = Y->data(); - - int batch_size = X->dims()[0]; - int class_num = X->dims()[1]; - - for (int i = 0; i < batch_size; ++i) { - int index = i * class_num + label_data[i]; - Ydata[i] = -tolerable_value(std::log(Xdata[index])); - } - } -}; - -template -class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), - "It must use CPUPlace."); - - auto X = ctx.Input("X"); - auto dX = ctx.Output(framework::GradVarName("X")); - auto dY = ctx.Input(framework::GradVarName("Y")); - auto label = ctx.Input("label"); - - auto* dXdata = dX->template mutable_data(ctx.GetPlace()); - auto* dYdata = dY->template data(); - auto* Xdata = X->template data(); - auto* label_data = label->data(); - - const int batch_size = X->dims()[0]; - const int class_num = X->dims()[1]; - - // TODO(qingqing): make zero setting an common function. - memset(dXdata, 0, sizeof(T) * batch_size * class_num); - for (int i = 0; i < batch_size; ++i) { - int index = i * class_num + label_data[i]; - dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/prelu_op.cc b/paddle/operators/prelu_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..7ae80b296850f2f433c89d904ebf32355b2a29c7 --- /dev/null +++ b/paddle/operators/prelu_op.cc @@ -0,0 +1,94 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/prelu_op.h" +#include "paddle/operators/net_op.h" + +namespace paddle { +namespace operators { + +class PReluOp : public framework::OperatorWithKernel { + public: + PReluOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + auto *in = ctx.Input("X"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Alpha"), + "Input(Alpha) should not be null"); + auto *alpha = ctx.Input("Alpha"); + PADDLE_ENFORCE(alpha->numel() == 1, "Size of weight Alpha must be one."); + + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) should not be null"); + auto *out = ctx.Output("Out"); + out->Resize(in->dims()); + } +}; + +class PReluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + PReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input tensor of prelu operator."); + AddInput("Alpha", "The alpha weight of PRelu operator."); + AddOutput("Out", "The output tensor of PRelu operator."); + AddComment(R"DOC(PRelu operator + +The equation is: + + f(x) = alpha * x , for x < 0 + f(x) = x , for x >= 0 + +)DOC"); + } +}; + +// The operator to calculate gradients of a prelu operator. +class PReluGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto *dx = ctx.Output(framework::GradVarName("X")); + auto *x = ctx.Input("X"); + + auto *dalpha = + ctx.Output(framework::GradVarName("Alpha")); + auto *alpha = ctx.Input("Alpha"); + + dx->Resize(x->dims()); + dalpha->Resize(alpha->dims()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP(prelu, ops::PReluOp, ops::PReluOpMaker, prelu_grad, + ops::PReluGradOp); +REGISTER_OP_CPU_KERNEL(prelu, + ops::PReluKernel); +REGISTER_OP_CPU_KERNEL(prelu_grad, + ops::PReluGradKernel); diff --git a/paddle/operators/prelu_op.cu b/paddle/operators/prelu_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..9e391dabae735cc8a740b46b50d31d271f99b65d --- /dev/null +++ b/paddle/operators/prelu_op.cu @@ -0,0 +1,21 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/prelu_op.h" + +REGISTER_OP_GPU_KERNEL( + prelu, paddle::operators::PReluKernel); +REGISTER_OP_GPU_KERNEL( + prelu_grad, + paddle::operators::PReluGradKernel); diff --git a/paddle/operators/prelu_op.h b/paddle/operators/prelu_op.h new file mode 100644 index 0000000000000000000000000000000000000000..63031c25cc3570cf40440726ea76976953d5417a --- /dev/null +++ b/paddle/operators/prelu_op.h @@ -0,0 +1,102 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/platform/transform.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using platform::Transform; + +template +class PReluFunctor { + public: + explicit PReluFunctor(const T* alpha) : alpha_(alpha) {} + + HOSTDEVICE T operator()(const T& x) const { + if (x > 0) + return x; + else + return x * (*alpha_); + } + + private: + const T* alpha_; +}; + +template +class PReluKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* alpha = context.Input("Alpha"); + auto* out = context.Output("Out"); + + const T* x_ptr = x->data(); + T* o_ptr = out->mutable_data(context.GetPlace()); + + auto* alpha_ptr = alpha->data(); + + int numel = x->numel(); + + Transform(context.device_context(), x_ptr, x_ptr + numel, o_ptr, + PReluFunctor(alpha_ptr)); + } +}; + +template +class PReluGradFunctor { + public: + explicit PReluGradFunctor(const T* alpha) : alpha_(alpha) {} + + HOSTDEVICE T operator()(const T& out, const T& dout) const { + if (out > 0) + return dout; + else + return dout * (*alpha_); + } + + private: + const T* alpha_; +}; + +template +class PReluGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* dx = context.Output(framework::GradVarName("X")); + auto* dout = context.Input(framework::GradVarName("Out")); + + auto* out = context.Input("Out"); + auto* alpha = context.Input("Alpha"); + auto* alpha_ptr = alpha->data(); + + T* dx_ptr = dx->mutable_data(context.GetPlace()); + const T* dout_ptr = dout->data(); + const T* out_ptr = out->data(); + int numel = dx->numel(); + + Transform(context.device_context(), out_ptr, out_ptr + numel, dout_ptr, + dx_ptr, PReluGradFunctor(alpha_ptr)); + + // TODO (Zhuoyuan): add dalpha upgrade when GPU kernels ready + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/split_op.cc b/paddle/operators/split_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..61296f5c8122fdce7083e9a91dc313482875c805 --- /dev/null +++ b/paddle/operators/split_op.cc @@ -0,0 +1,118 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/split_op.h" +#include "paddle/operators/net_op.h" + +namespace paddle { +namespace operators { +using framework::Tensor; + +class SplitOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + // infershape + auto *in = ctx.Input("X"); + auto outs = ctx.MultiOutput("Out"); + size_t axis = static_cast(ctx.Attr("axis")); + size_t num = static_cast(ctx.Attr("num")); + std::vector sections = + static_cast>(ctx.Attr>("sections")); + const size_t n = outs.size(); + + if (num > 0) { + int64_t in_axis_dim = in->dims()[axis]; + PADDLE_ENFORCE_EQ(in_axis_dim % num, 0, + "tensor split does not result" + " in an equal division"); + size_t out_axis_dim = in_axis_dim / num; + for (size_t i = 0; i < n; ++i) { + auto dim = in->dims(); + dim[axis] = out_axis_dim; + outs[i]->Resize(dim); + } + } else if (sections.size() > 0) { + PADDLE_ENFORCE_EQ(sections.size(), n, + "tensor split sections size" + "should be equal to output size."); + for (size_t i = 0; i < n; ++i) { + auto dim = in->dims(); + dim[axis] = sections[i]; + outs[i]->Resize(dim); + } + } else { + PADDLE_ENFORCE_NOT_NULL(nullptr, "split operator should", + " specify indices or sections."); + } + } +}; + +class SplitOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SplitOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "the input tensor of split operator."); + AddOutput("Out", "the output tensors of split operator.").AsDuplicable(); + AddComment(R"DOC( + Split the input tensor into multiple sub-tensors. + Example: + Input = [[1,2], + [3,4], + [5,6]] + sections = [2,1] + axis = 0 + Output[0] = [[1,2], + [3,4]] + Output[1] = [[5,6]] + + )DOC"); + AddAttr>("sections", + "the length for each" + "output along with the specify axis.") + .SetDefault(std::vector{}); + AddAttr("num", + "number of the sub-tensors, it must evenly divide " + "Input.dims()[axis]") + .SetDefault(0); + AddAttr("axis", "The axis which the input will be splited on.") + .SetDefault(0); + } +}; + +class SplitOpGrad : public NetOp { + public: + SplitOpGrad(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : NetOp(type, inputs, outputs, attrs) { + auto out_grad = Inputs(framework::GradVarName("Out")); + auto x_grad = Output(framework::GradVarName("X")); + AppendOp(framework::OpRegistry::CreateOp("concat", {{"X", out_grad}}, + {{"Out", {x_grad}}}, attrs)); + CompleteAddOp(false); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +USE_CPU_ONLY_OP(concat); +REGISTER_OP(split, ops::SplitOp, ops::SplitOpMaker, split_grad, + ops::SplitOpGrad); +REGISTER_OP_CPU_KERNEL(split, + ops::SplitKernel); diff --git a/paddle/operators/split_op.h b/paddle/operators/split_op.h new file mode 100644 index 0000000000000000000000000000000000000000..860690ee895075fda9ddef08776a2102642efff9 --- /dev/null +++ b/paddle/operators/split_op.h @@ -0,0 +1,62 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class SplitKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto outs = ctx.MultiOutput("Out"); + int64_t axis = static_cast(ctx.Attr("axis")); + size_t before = 1, after = 1; + const size_t n = outs.size(); + size_t input_axis_dim = in->dims()[axis]; + + for (int64_t i = 0; i < in->dims().size(); ++i) { + if (i == axis) { + continue; + } + if (i < axis) { + before *= in->dims()[i]; + } else { + after *= in->dims()[i]; + } + } + size_t input_offset = 0; + for (size_t i = 0; i < n; i++) { + auto& out = outs[i]; + size_t axis_dim = out->dims()[axis]; + for (size_t j = 0; j < before; j++) { + size_t len = axis_dim * after * sizeof(T); + T* dest = + out->mutable_data(platform::CPUPlace()) + axis_dim * after * j; + const T* src = + in->data() + input_offset + input_axis_dim * after * j; + memcpy(dest, src, len); + } + input_offset += axis_dim * after; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 8b605e51c3f4ea38fc358ce054bb36fcc82063c4..daf519b91d623d4369774dc4e37dcb7b1733666b 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -24,4 +24,4 @@ cc_library(device_context SRCS device_context.cc DEPS memory buddy_allocator nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info) nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda) -nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place) +nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place device_context) diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index ad212c5b2c47312743362db4926c80bf056e100d..93b472b41c8a4c3a2bfada9d4fbf0e9e1b0cc736 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -101,19 +101,17 @@ CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_->Reinitialize(&stream_, place); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); + PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); + PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_)); + PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); + PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_)); } CUDADeviceContext::~CUDADeviceContext() { SetDeviceId(place_.device); Wait(); - if (cublas_handle_) { - PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); - } - - if (cudnn_handle_) { - PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); - } - + PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); + PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); eigen_stream_.reset(); eigen_device_.reset(); PADDLE_ENFORCE(cudaStreamDestroy(stream_)); @@ -129,25 +127,13 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { return eigen_device_.get(); } -cublasHandle_t CUDADeviceContext::cublas_handle() { - if (!cublas_handle_) { - SetDeviceId(place_.device); - PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); - PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_)); - } +cublasHandle_t CUDADeviceContext::cublas_handle() const { return cublas_handle_; } -cudnnHandle_t CUDADeviceContext::cudnn_handle() { - if (!cudnn_handle_) { - SetDeviceId(place_.device); - PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); - PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_)); - } - return cudnn_handle_; -} +cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } -cudaStream_t CUDADeviceContext::stream() { return stream_; } +cudaStream_t CUDADeviceContext::stream() const { return stream_; } #endif // PADDLE_ONLY_CPU diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 11528e1194e4516891034fa8febdac3ba6eed204..a106592e454e21c46cd2f87f1bbf6694955d6e23 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -67,16 +67,14 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return eigen device in the device context. */ Eigen::GpuDevice* eigen_device() const; - // clang-format off /*! \brief Return cublas handle in the device context. */ - cublasHandle_t cublas_handle(); + cublasHandle_t cublas_handle() const; /*! \brief Return cudnn handle in the device context. */ - cudnnHandle_t cudnn_handle(); + cudnnHandle_t cudnn_handle() const; /*! \brief Return cuda stream in the device context. */ - cudaStream_t stream(); - // clang-format on + cudaStream_t stream() const; private: GPUPlace place_; @@ -84,11 +82,9 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr eigen_device_; std::unique_ptr eigen_stream_; - // clang-format off - cudaStream_t stream_{nullptr}; - cudnnHandle_t cudnn_handle_{nullptr}; - cublasHandle_t cublas_handle_{nullptr}; - // clang-format on + cudaStream_t stream_; + cudnnHandle_t cudnn_handle_; + cublasHandle_t cublas_handle_; }; #endif diff --git a/paddle/platform/transform.h b/paddle/platform/transform.h index 3ee4acd29660f201d318ce6d39baa6f3999ae274..8eaab047fd4daa386f5ebdbb99a4caeed5fe2fbf 100644 --- a/paddle/platform/transform.h +++ b/paddle/platform/transform.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/platform/device_context.h" #include "paddle/platform/enforce.h" #include "paddle/platform/hostdevice.h" #include "paddle/platform/place.h" @@ -21,6 +22,7 @@ #include #include #ifdef __NVCC__ +#include #include #include "paddle/platform/details/device_ptr_cast.h" #endif @@ -28,34 +30,39 @@ namespace paddle { namespace platform { // Transform on host or device. It provides the same API in std library. -template -void Transform(Place place, InputIter first, InputIter last, OutputIter result, - UnaryOperation op) { +template +void Transform(const DeviceContext& context, InputIter first, InputIter last, + OutputIter result, UnaryOperation op) { + auto place = context.GetPlace(); if (is_cpu_place(place)) { std::transform(first, last, result, op); } else { #ifdef __NVCC__ + auto& ctx = reinterpret_cast(context); using namespace details; - thrust::transform(DevPtrCast(first), DevPtrCast(last), DevPtrCast(result), - op); + thrust::transform(thrust::cuda::par.on(ctx.stream()), DevPtrCast(first), + DevPtrCast(last), DevPtrCast(result), op); #else PADDLE_THROW("Do not invoke `Transform` in .cc file"); #endif } } -template -void Transform(Place place, InputIter1 first1, InputIter1 last1, - InputIter2 first2, OutputIter result, BinaryOperation op) { +template +void Transform(const DeviceContext& context, InputIter1 first1, + InputIter1 last1, InputIter2 first2, OutputIter result, + BinaryOperation op) { + auto place = context.GetPlace(); if (is_cpu_place(place)) { std::transform(first1, last1, first2, result, op); } else { #ifdef __NVCC__ + auto& ctx = reinterpret_cast(context); using namespace details; - thrust::transform(DevPtrCast(first1), DevPtrCast(last1), DevPtrCast(first2), - DevPtrCast(result), op); + thrust::transform(thrust::cuda::par.on(ctx.stream()), DevPtrCast(first1), + DevPtrCast(last1), DevPtrCast(first2), DevPtrCast(result), + op); #else PADDLE_THROW("Do not invoke `Transform` in .cc file"); #endif diff --git a/paddle/platform/transform_test.cu b/paddle/platform/transform_test.cu index 600fed8f45077a6fee91f295aa854153c9cf9c01..b8a6200bb03c9a40b67be8d113012856e2a407e9 100644 --- a/paddle/platform/transform_test.cu +++ b/paddle/platform/transform_test.cu @@ -36,8 +36,9 @@ class Multiply { TEST(Transform, CPUUnary) { using namespace paddle::platform; + CPUDeviceContext ctx; float buf[4] = {0.1, 0.2, 0.3, 0.4}; - Transform(CPUPlace(), buf, buf + 4, buf, Scale(10)); + Transform(ctx, buf, buf + 4, buf, Scale(10)); for (int i = 0; i < 4; ++i) { ASSERT_NEAR(buf[i], static_cast(i + 1), 1e-5); } @@ -47,10 +48,12 @@ TEST(Transform, GPUUnary) { using namespace paddle::platform; using namespace paddle::memory; GPUPlace gpu0(0); + CUDADeviceContext ctx(gpu0); float cpu_buf[4] = {0.1, 0.2, 0.3, 0.4}; float* gpu_buf = static_cast(Alloc(gpu0, sizeof(float) * 4)); Copy(gpu0, gpu_buf, CPUPlace(), cpu_buf, sizeof(cpu_buf)); - Transform(gpu0, gpu_buf, gpu_buf + 4, gpu_buf, Scale(10)); + Transform(ctx, gpu_buf, gpu_buf + 4, gpu_buf, Scale(10)); + ctx.Wait(); Copy(CPUPlace(), cpu_buf, gpu0, gpu_buf, sizeof(cpu_buf)); Free(gpu0, gpu_buf); for (int i = 0; i < 4; ++i) { @@ -62,7 +65,7 @@ TEST(Transform, CPUBinary) { using namespace paddle::platform; using namespace paddle::memory; int buf[4] = {1, 2, 3, 4}; - Transform(CPUPlace(), buf, buf + 4, buf, buf, Multiply()); + Transform(CPUDeviceContext(), buf, buf + 4, buf, buf, Multiply()); for (int i = 0; i < 4; ++i) { ASSERT_EQ((i + 1) * (i + 1), buf[i]); } @@ -73,9 +76,11 @@ TEST(Transform, GPUBinary) { using namespace paddle::memory; int buf[4] = {1, 2, 3, 4}; GPUPlace gpu0(0); + CUDADeviceContext ctx(gpu0); int* gpu_buf = static_cast(Alloc(gpu0, sizeof(buf))); Copy(gpu0, gpu_buf, CPUPlace(), buf, sizeof(buf)); - Transform(gpu0, gpu_buf, gpu_buf + 4, gpu_buf, gpu_buf, Multiply()); + Transform(ctx, gpu_buf, gpu_buf + 4, gpu_buf, gpu_buf, Multiply()); + ctx.Wait(); Copy(CPUPlace(), buf, gpu0, gpu_buf, sizeof(buf)); Free(gpu0, gpu_buf); for (int i = 0; i < 4; ++i) { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index a9e1d6d2e06d56f837690ec95fa8f8d41a90725f..7c32eb0069f4075d72cd4c3654c83e3d5c98fb1c 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -2286,8 +2286,15 @@ class NormLayer(LayerBase): @config_layer('pool') class PoolLayer(LayerBase): + layer_type = 'pool' + def __init__(self, name, inputs, ceil_mode=True, **xargs): - super(PoolLayer, self).__init__(name, 'pool', 0, inputs=inputs, **xargs) + use_mkldnn = int(g_command_config_args.get("use_mkldnn", 0)) + if self.layer_type == "mkldnn_pool": + config_assert(use_mkldnn, "mkldnn_pool only support MKLDNN") + self.layer_type = 'mkldnn_pool' if use_mkldnn else 'pool' + super(PoolLayer, self).__init__( + name, self.layer_type, 0, inputs=inputs, **xargs) for input_index in xrange(len(self.inputs)): input_layer = self.get_input_layer(input_index) pool_conf = self.config.inputs[input_index].pool_conf @@ -2297,6 +2304,11 @@ class PoolLayer(LayerBase): pool_conf.channels) +@config_layer('mkldnn_pool') +class MKLDNNPoolLayer(PoolLayer): + layer_type = 'mkldnn_pool' + + @config_layer('pool3d') class Pool3DLayer(LayerBase): def __init__(self, name, inputs, ceil_mode=True, **xargs): diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 8c7d1738ad9753eb7afb27e893f979f8bce70a0d..9bdcca1716fca73e87adee861e91b6b90d1ef70d 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -1224,8 +1224,8 @@ def detection_output_layer(input_loc, name=None): """ Apply the NMS to the output of network and compute the predict bounding - box location. The output of this layer could be None if there is no valid - bounding box. + box location. The output's shape of this layer could be zero if there is + no valid bounding box. :param name: The Layer Name. :type name: basestring diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index 75b689982a5b797dd8b5b9ee868b2b3676278f4e..0a5673868c547d9e184e8ce05346c3ebabe06892 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -192,25 +192,29 @@ class OpTest(unittest.TestCase): self.op.run(self.scope, ctx) for out_name, out_dup in Operator.get_op_outputs(self.op.type()): + if out_name not in self.outputs: + continue + if out_dup: sub_out = self.outputs[out_name] - for sub_out_name, sub_out_array in sub_out: + if not isinstance(sub_out, list): + raise AssertionError("sub_out type %s is not list", + type(sub_out)) + + for sub_out_name, expect in sub_out: actual = np.array( self.scope.find_var(sub_out_name).get_tensor()) - expect = sub_out_array self.assertTrue( np.allclose( actual, expect, atol=1e-05), "output name: " + out_name + " has diff") else: - var = self.scope.find_var(out_name) - if var is not None: - actual = np.array(var.get_tensor()) - expect = self.outputs[out_name] - self.assertTrue( - np.allclose( - actual, expect, atol=1e-05), - "output name: " + out_name + " has diff") + actual = np.array(self.scope.find_var(out_name).get_tensor()) + expect = self.outputs[out_name] + self.assertTrue( + np.allclose( + actual, expect, atol=1e-05), + "output name: " + out_name + " has diff") def check_output(self): places = [core.CPUPlace()] diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py new file mode 100644 index 0000000000000000000000000000000000000000..0206ca064be87afe204aa99021979b7ddc3c5d63 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -0,0 +1,89 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestCrossEntropyOp1(OpTest): + """Test standard cross-entropy, with index representation of labels. + """ + + def setUp(self): + self.op_type = "cross_entropy" + batch_size = 30 + class_num = 10 + X = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + label = np.random.randint(0, class_num, (batch_size, 1), dtype="int32") + cross_entropy = np.asmatrix( + [[-np.log(X[i][label[i][0]])] for i in range(X.shape[0])], + dtype="float32") + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": cross_entropy} + self.attrs = {'soft_label': 0} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Y") + + +class TestCrossEntropyOp2(OpTest): + """Test soft-label cross-entropy, with vecterized soft labels. + """ + + def setUp(self): + self.op_type = "cross_entropy" + batch_size = 10 + class_num = 5 + X = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + label = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + label /= label.sum(axis=1, keepdims=True) + cross_entropy = (-label * np.log(X)).sum( + axis=1, keepdims=True).astype("float32") + self.inputs = {'X': X, 'Label': label} + self.outputs = {'Y': cross_entropy} + self.attrs = {'soft_label': 1} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y') + + +class TestCrossEntropyOp3(OpTest): + """Test one-hot cross-entropy, with vecterized one-hot representation of + labels. + """ + + def setUp(self): + self.op_type = "cross_entropy" + batch_size = 30 + class_num = 10 + X = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + label_index = np.random.randint( + 0, class_num, (batch_size), dtype="int32") + label = np.zeros(X.shape) + label[np.arange(batch_size), label_index] = 1 + cross_entropy = np.asmatrix( + [[-np.log(X[i][label_index[i]])] for i in range(X.shape[0])], + dtype="float32") + cross_entropy2 = (-label * np.log(X)).sum( + axis=1, keepdims=True).astype("float32") + self.inputs = {'X': X, 'Label': label} + self.outputs = {'Y': cross_entropy} + self.attrs = {'soft_label': 1} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_dropout_op.py b/python/paddle/v2/framework/tests/test_dropout_op.py new file mode 100644 index 0000000000000000000000000000000000000000..3638fee1a1c26195791bc1f5a46dd749da0aee95 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_dropout_op.py @@ -0,0 +1,59 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestDropoutOp(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = {'dropout_prob': 0.0, 'is_training': 1} + self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64))} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out', max_relative_error=0.05) + + +class TestDropoutOp2(TestDropoutOp): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = {'dropout_prob': 1.0, 'is_training': 1} + self.outputs = {'Out': np.zeros((32, 64)), 'Mask': np.zeros((32, 64))} + + +class TestDropoutOp3(TestDropoutOp): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")} + self.attrs = {'dropout_prob': 0.0, 'is_training': 1} + self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64, 2))} + + +class TestDropoutOp4(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = {'dropout_prob': 0.35, 'is_training': 0} + self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']} + + def test_check_output(self): + self.check_output() + + +class TestDropoutOp5(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")} + self.attrs = {'dropout_prob': 0.75, 'is_training': 0} + self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']} + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_mnist.py b/python/paddle/v2/framework/tests/test_mnist.py index f6f8f49b797fb6e5016a5e309f12f192d5096431..66452cb3965d28fd15e814833079621410775c17 100644 --- a/python/paddle/v2/framework/tests/test_mnist.py +++ b/python/paddle/v2/framework/tests/test_mnist.py @@ -128,7 +128,7 @@ def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None): def cross_entropy_layer(net, input, label): cost_name = "cross_entropy_%d" % uniq_id() cross_entropy_op = Operator( - "onehot_cross_entropy", X=input, label=label, Y=cost_name) + "cross_entropy", X=input, Label=label, Y=cost_name) net.append_op(cross_entropy_op) scope.new_var(cost_name) net.infer_shape(scope) @@ -181,7 +181,7 @@ def error_rate(predict, label): images = data_layer(name="pixel", dims=[BATCH_SIZE, 784]) -labels = data_layer(name="label", dims=[BATCH_SIZE]) +labels = data_layer(name="label", dims=[BATCH_SIZE, 1]) fc1 = fc_layer(net=forward_net, input=images, size=100, act="sigmoid") fc2 = fc_layer(net=forward_net, input=fc1, size=100, act="sigmoid") predict = fc_layer(net=forward_net, input=fc2, size=10, act="softmax") @@ -215,6 +215,7 @@ def test(cost_name): for data in test_reader(): image_data = numpy.array(map(lambda x: x[0], data)).astype("float32") label_data = numpy.array(map(lambda x: x[1], data)).astype("int32") + label_data = numpy.expand_dims(label_data, axis=1) feed_data(images, image_data) feed_data(labels, label_data) @@ -235,6 +236,7 @@ for pass_id in range(PASS_NUM): for data in train_reader(): image_data = numpy.array(map(lambda x: x[0], data)).astype("float32") label_data = numpy.array(map(lambda x: x[1], data)).astype("int32") + label_data = numpy.expand_dims(label_data, axis=1) feed_data(images, image_data) feed_data(labels, label_data) diff --git a/python/paddle/v2/framework/tests/test_onehot_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_onehot_cross_entropy_op.py deleted file mode 100644 index fd3cbdb80374865ccf113768856096bf49dce643..0000000000000000000000000000000000000000 --- a/python/paddle/v2/framework/tests/test_onehot_cross_entropy_op.py +++ /dev/null @@ -1,30 +0,0 @@ -import unittest -import numpy -from op_test import OpTest - - -class TestOnehotCrossEntropyOp(OpTest): - def setUp(self): - self.op_type = "onehot_cross_entropy" - batch_size = 30 - class_num = 10 - - X = numpy.random.uniform(0.1, 1.0, - [batch_size, class_num]).astype("float32") - labels = numpy.random.randint(0, class_num, batch_size, dtype="int32") - - cross_entropy = numpy.asmatrix( - [[-numpy.log(X[i][labels[i]])] for i in range(X.shape[0])], - dtype="float32") - self.inputs = {"X": X, "label": labels} - self.outputs = {"Y": cross_entropy} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(["X"], "Y") - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/v2/framework/tests/test_prelu_op.py b/python/paddle/v2/framework/tests/test_prelu_op.py new file mode 100644 index 0000000000000000000000000000000000000000..76d1f1d5a418b7a2a91b36360a79317d063a72e7 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_prelu_op.py @@ -0,0 +1,28 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class PReluTest(OpTest): + def setUp(self): + self.op_type = "prelu" + x_np = np.random.normal(size=(10, 10)).astype("float32") + x_np_sign = np.sign(x_np) + x_np = x_np_sign * np.maximum(x_np, .005) + alpha_np = np.array([.1]) + self.inputs = {'X': x_np, 'Alpha': alpha_np} + out_np = np.maximum(self.inputs['X'], 0.) + out_np = out_np + np.minimum(self.inputs['X'], + 0.) * self.inputs['Alpha'] + assert out_np is not self.inputs['X'] + self.outputs = {'Out': out_np} + + def not_test_check_output(self): + self.check_output() + + def not_test_check_grad(self): + self.check_grad(['X'], 'Out') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_split_op.py b/python/paddle/v2/framework/tests/test_split_op.py new file mode 100644 index 0000000000000000000000000000000000000000..b4420db9d71b99556e305104ac17ef5e4b4bd0f2 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_split_op.py @@ -0,0 +1,26 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestSplitOp(OpTest): + def setUp(self): + self.op_type = "split" + axis = 0 + num = 2 + x = np.random.random((4, 2)).astype('float32') + out = np.split(x, num, axis) + self.inputs = {'X': x} + self.attrs = {'axis': axis, 'num': num} + self.outputs = {'Out': [('out%d' % i, out[i]) \ + for i in xrange(len(out))]} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], ['out0', 'out1']) + + +if __name__ == '__main__': + unittest.main()