diff --git a/Dockerfile b/Dockerfile index 8cfb16928c95dcbfac08383d32562ff67933d873..5dd9b0be4f7e0a304108abfdfb089fea4faa4d38 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,7 +27,7 @@ RUN apt-get update && \ git python-pip python-dev openssh-server bison \ wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \ curl sed grep graphviz libjpeg-dev zlib1g-dev \ - python-numpy python-matplotlib gcc g++ \ + python-numpy python-matplotlib gcc-4.8 g++-4.8 \ automake locales clang-format-3.8 swig doxygen cmake \ liblapack-dev liblapacke-dev libboost-dev \ clang-3.8 llvm-3.8 libclang-3.8-dev \ diff --git a/cmake/flags.cmake b/cmake/flags.cmake index ef31c252038ce18655913c0f41343fe6dc7dbb86..d00a9bb3a30cfb16623e073414088059481c3e1a 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -9,6 +9,11 @@ function(CheckCompilerCXX11Flag) if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8) message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.") endif() + # TODO(qijun) gcc 4.9 or later versions raise SEGV due to the optimization problem. + # Use Debug mode instead for now. + if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9 OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL 4.9) + set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "" FORCE) + endif() elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") # cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang" # Apple Clang is a different compiler than upstream Clang which havs different version numbers. diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst index ec7f1446cfb74842af7d0c7152bebf58619f3861..372272a53c12c314fc80eebbce5eae9fcabc55ba 100644 --- a/doc/api/v2/config/layer.rst +++ b/doc/api/v2/config/layer.rst @@ -104,6 +104,11 @@ cross_channel_norm ------------------ .. autoclass:: paddle.v2.layer.cross_channel_norm :noindex: + +row_l2_norm +----------- +.. autoclass:: paddle.v2.layer.row_l2_norm + :noindex: Recurrent Layers ================ @@ -320,6 +325,11 @@ scaling .. autoclass:: paddle.v2.layer.scaling :noindex: +clip +---- +.. autoclass:: paddle.v2.layer.clip + :noindex: + slope_intercept --------------- .. autoclass:: paddle.v2.layer.slope_intercept diff --git a/paddle/cuda/src/hl_cuda_cudnn.cc b/paddle/cuda/src/hl_cuda_cudnn.cc index c53a5636829cab9d575f58cc2326cb3efe383e1c..7ad8a39768a064140a08c912a5a467bc24a12adf 100644 --- a/paddle/cuda/src/hl_cuda_cudnn.cc +++ b/paddle/cuda/src/hl_cuda_cudnn.cc @@ -1022,6 +1022,15 @@ void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc, real alpha = 1.0f; real beta = 1.0f; cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL; + + int batch_size = ((cudnn_tensor_descriptor)inputDesc)->batch_size; + if (batch_size > 1024 && g_cudnn_lib_version < 6000) { + LOG(INFO) << " To process current batch data with size " << batch_size + << " (>1024), cudnnBatchNorm requires cuDNN version >= 6000." + << " If there is an error complaining CUDNN_STATUS_NOT_SUPPORTED," + << " just recompile PaddlePaddle with cuDNN >= 6000, replacing" + << " current version " << g_cudnn_lib_version; + } CHECK_CUDNN( dynload::cudnnBatchNormalizationForwardInference(t_resource.cudnn_handle, mode, diff --git a/paddle/framework/detail/tensor-inl.h b/paddle/framework/detail/tensor-inl.h index e7ff09dd5c954378afeca299e901277c3ebdb96a..92621f8c18ec0d03160a23c462830d14272c7f64 100644 --- a/paddle/framework/detail/tensor-inl.h +++ b/paddle/framework/detail/tensor-inl.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once - #include "paddle/memory/memcpy.h" namespace paddle { @@ -62,9 +61,11 @@ inline T* Tensor::mutable_data(platform::Place place) { if (platform::is_cpu_place(place)) { holder_.reset(new PlaceholderImpl( boost::get(place), size)); + } else if (platform::is_gpu_place(place)) { +#ifdef PADDLE_ONLY_CPU + PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); } -#ifndef PADDLE_ONLY_CPU - else if (platform::is_gpu_place(place)) { +#else holder_.reset(new PlaceholderImpl( boost::get(place), size)); } diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index f10c9297981a4c6aefc6c2072d0ac2b8e562a7a0..3e72e391266066de9e4114e68b43b066c15254db 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -400,6 +400,14 @@ class GradOpRegisterHelper { return 0; \ } +/** + * Macro to Forbid user register Gradient Operator. + */ +#define NO_GRADIENT(__op_type) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_gradient_op__##__op_type##__op_type##_grad, \ + "NO_GRADIENT must be in global namespace") + /** * Macro to Register OperatorKernel. */ diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index cfe9cba308556475ef64b45e7178dfc418761598..cb86e6be2be3624bf54ee28193ca5d4c7bafa0eb 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -20,16 +20,16 @@ namespace paddle { namespace framework { template <> -Eigen::DefaultDevice* ExecutionContext::GetEigenDevice< +Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< platform::CPUPlace, Eigen::DefaultDevice>() const { - return device_context_.get_eigen_device(); + return *device_context_.get_eigen_device(); } #ifndef PADDLE_ONLY_CPU template <> -Eigen::GpuDevice* +Eigen::GpuDevice& ExecutionContext::GetEigenDevice() const { - return device_context_.get_eigen_device(); + return *device_context_.get_eigen_device(); } #endif diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 0832a663dd01fe2921366d70599bc867e73af47c..55435103489ace11868eed61c38018d8ba357e65 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -253,7 +253,7 @@ class ExecutionContext : public OperatorContext { template ::EigenDeviceType> - DeviceType* GetEigenDevice() const; + DeviceType& GetEigenDevice() const; platform::Place GetPlace() const { return device_context_.GetPlace(); } diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h index bb4f48364b9b454af7d37fe4d3c340666e53285c..baf78bc6c88d0d294f4457b81c52b22e425d9fdb 100644 --- a/paddle/function/ConvOp.h +++ b/paddle/function/ConvOp.h @@ -109,6 +109,13 @@ protected: return filter[filter.ndims() - 1]; } + // determine whether im2col needs to be performed + inline bool isNeedIm2col(const TensorShape& filter) const { + return !(getFilterHeight(filter) == 1 && getFilterWidth(filter) == 1 && + strideH() == 1 && strideW() == 1 && paddingH() == 0 && + paddingW() == 0); + } + std::vector strides_; std::vector paddings_; diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 9deb2739fcfff935a98a0b5b31b5d11819d81227..0ada4d70a0c7d13f9b5fb1a42eac07fc4c775a87 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -66,16 +66,23 @@ public: real* inputData = inputs[0].data(); real* filterData = inputs[1].data(); real* outputData = outputs[0].data(); + bool needIm2col = isNeedIm2col(filter); + TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); - TensorShape colShape = TensorShape({inputChannels / groups_, - filterHeight, - filterWidth, - outputHeight, - outputWidth}); - resizeBuffer(colShape.getElements()); - real* colData = reinterpret_cast(memory_->getBuf()); + TensorShape colShape; + real* colData = NULL; + + if (needIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + resizeBuffer(colShape.getElements()); + colData = reinterpret_cast(memory_->getBuf()); + } Im2ColFunctor im2col; GemmFunctor gemm; @@ -86,15 +93,18 @@ public: for (size_t i = 0; i < batchSize; i++) { for (size_t g = 0; g < groups_; g++) { - im2col(inputData + g * inputOffset, - imShape, - colData, - colShape, - strideH(), - strideW(), - paddingH(), - paddingW()); - + if (needIm2col) { + im2col(inputData + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + } else { + colData = inputData + g * inputOffset; + } int M = outputChannels / groups_; int N = outputHeight * outputWidth; int K = inputChannels / groups_ * filterHeight * filterWidth; @@ -159,19 +169,27 @@ public: real* outputGrad = inputs[0].data(); real* filterData = inputs[1].data(); real* inputGrad = outputs[0].data(); + bool needIm2col = isNeedIm2col(filter); + TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); - TensorShape colShape = TensorShape({inputChannels / groups_, - filterHeight, - filterWidth, - outputHeight, - outputWidth}); - resizeBuffer(colShape.getElements()); - real* colData = reinterpret_cast(memory_->getBuf()); + TensorShape colShape; + real* colData = NULL; + + if (needIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + resizeBuffer(colShape.getElements()); + colData = reinterpret_cast(memory_->getBuf()); + } Col2ImFunctor col2im; GemmFunctor gemm; + size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; @@ -182,6 +200,11 @@ public: int K = outputChannels / groups_; int N = outputHeight * outputWidth; int M = inputChannels / groups_ * filterHeight * filterWidth; + real scale = 0.0f; + if (!needIm2col) { + colData = inputGrad + g * inputOffset; + scale = 1.0f; + } gemm(CblasTrans, CblasNoTrans, M, @@ -192,17 +215,19 @@ public: M, outputGrad + g * outputOffset, N, - 0.0f, + scale, colData, N); - col2im(inputGrad + g * inputOffset, - imShape, - colData, - colShape, - strideH(), - strideW(), - paddingH(), - paddingW()); + if (needIm2col) { + col2im(inputGrad + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + } } inputGrad += inputChannels * inputHeight * inputWidth; outputGrad += outputChannels * outputHeight * outputWidth; @@ -255,16 +280,23 @@ public: real* outputGrad = inputs[0].data(); real* inputData = inputs[1].data(); real* filterGrad = outputs[0].data(); + bool needIm2col = isNeedIm2col(filter); + TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); - TensorShape colShape = TensorShape({inputChannels / groups_, - filterHeight, - filterWidth, - outputHeight, - outputWidth}); - resizeBuffer(colShape.getElements()); - real* colData = reinterpret_cast(memory_->getBuf()); + TensorShape colShape; + real* colData = NULL; + + if (needIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + resizeBuffer(colShape.getElements()); + colData = reinterpret_cast(memory_->getBuf()); + } Im2ColFunctor im2col; GemmFunctor gemm; @@ -274,15 +306,18 @@ public: size_t filterOffset = filter.getElements() / groups_; for (size_t i = 0; i < batchSize; i++) { for (size_t g = 0; g < groups_; g++) { - im2col(inputData + g * inputOffset, - imShape, - colData, - colShape, - strideH(), - strideW(), - paddingH(), - paddingW()); - + if (needIm2col) { + im2col(inputData + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + } else { + colData = inputData + g * inputOffset; + } int M = outputChannels / groups_; int K = outputHeight * outputWidth; int N = inputChannels / groups_ * filterHeight * filterWidth; diff --git a/paddle/gserver/layers/ClipLayer.cpp b/paddle/gserver/layers/ClipLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..13f16c953793b82183237188b56eb61d76ecd2fd --- /dev/null +++ b/paddle/gserver/layers/ClipLayer.cpp @@ -0,0 +1,79 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Layer.h" + +namespace paddle { + +/** + * A layer for clipping the input value by the threshold. + * \f[ + * out[i] = \min\left(\max\left(in[i],p_{1}\right),p_{2}\right) + * \f] + */ + +class ClipLayer : public Layer { +protected: + double min_; + double max_; + +public: + explicit ClipLayer(const LayerConfig& config) : Layer(config) {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; +}; + +REGISTER_LAYER(clip, ClipLayer); + +bool ClipLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + Layer::init(layerMap, parameterMap); + + CHECK_EQ(inputLayers_.size(), 1U); + auto layerConf = config_.inputs(0).clip_conf(); + min_ = layerConf.min(); + max_ = layerConf.max(); + CHECK_LT(min_, max_); + return true; +} + +void ClipLayer::forward(PassType passType) { + Layer::forward(passType); + + MatrixPtr inV = getInputValue(0); + resetOutput(inV->getHeight(), inV->getWidth()); + MatrixPtr outV = getOutputValue(); + outV->copyFrom(*inV); + outV->clip(min_, max_); +} + +void ClipLayer::backward(const UpdateCallback& callback) { + MatrixPtr inV = getInputValue(0); + MatrixPtr inG = getInputGrad(0); + if (inG) { + MatrixPtr outV = getOutputValue(); + MatrixPtr outG = getOutputGrad(); + MatrixPtr tmpMtx; + Matrix::resizeOrCreate( + tmpMtx, outG->getHeight(), outG->getWidth(), false, useGpu_); + tmpMtx->clipDerivative(*inV, min_, max_); + inG->addDotMul(*outG, *tmpMtx, 1, 1); + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/RowL2NormLayer.cpp b/paddle/gserver/layers/RowL2NormLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0d609be43b73a86d0d0f7b60be993836e2ea6fff --- /dev/null +++ b/paddle/gserver/layers/RowL2NormLayer.cpp @@ -0,0 +1,98 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Layer.h" + +namespace paddle { + +/** + * A layer for L2 normalization in each row, + * \f[ + * out[i] = \frac{in[i]}{\sqrt{\sum_{k=1}^N in[k]^{2}}} + * \f] + * where the size of \f$in\f$ is (batchSize x dataDim), + * and the size of \f$out\f$ is (batchSize x dataDim). + */ + +class RowL2NormLayer : public Layer { +protected: + MatrixPtr inSquare_; + MatrixPtr l2NormReciprocal_; + MatrixPtr dotSum_; + +public: + explicit RowL2NormLayer(const LayerConfig& config) : Layer(config) {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; +}; + +REGISTER_LAYER(row_l2_norm, RowL2NormLayer); + +bool RowL2NormLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + Layer::init(layerMap, parameterMap); + + CHECK_EQ(inputLayers_.size(), 1U); + + return true; +} + +void RowL2NormLayer::forward(PassType passType) { + Layer::forward(passType); + + MatrixPtr inV = getInputValue(0); + + /* malloc memory for the output_ if necessary */ + size_t batchSize = inV->getHeight(); + size_t dataDim = getSize(); + CHECK_EQ(dataDim, inV->getWidth()); + resetOutput(batchSize, dataDim); + MatrixPtr outV = getOutputValue(); + + Matrix::resizeOrCreate(inSquare_, batchSize, dataDim, false, useGpu_); + inV->square2(*inSquare_); + Matrix::resizeOrCreate(l2NormReciprocal_, batchSize, 1, false, useGpu_); + inSquare_->rowSum(*l2NormReciprocal_); + l2NormReciprocal_->sqrt2(*l2NormReciprocal_); + l2NormReciprocal_->scalarDiv(*l2NormReciprocal_, 1.0); + outV->rowScale(0, *inV, *l2NormReciprocal_); +} + +void RowL2NormLayer::backward(const UpdateCallback& callback) { + MatrixPtr inV = getInputValue(0); + MatrixPtr inG = getInputGrad(0); + MatrixPtr outV = getOutputValue(); + MatrixPtr outG = getOutputGrad(); + size_t batchSize = inV->getHeight(); + + // inG[ij] += outG[ij] / l2NormReciprocal + // inG[ij] += -inV[ij] * l2NormReciprocal * l2NormReciprocal * DotMul(outG[i], + // inV[i]) + if (inG) { + Matrix::resizeOrCreate(dotSum_, batchSize, 1, false, useGpu_); + dotSum_->zeroMem(); + dotSum_->rowDotMul(0, *outG, *outV); + dotSum_->dotMul(*dotSum_, *l2NormReciprocal_); + dotSum_->dotMul(*dotSum_, *l2NormReciprocal_); + inSquare_->rowScale(0, *inV, *dotSum_); + inG->sub(*inSquare_); + inG->addRowScale(0, *outG, *l2NormReciprocal_); + } +} + +} // namespace paddle diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 8ce8600c6743779899b2685c1c12053922265411..fe11278f41c0118ee0bdb34f17fbf9602e0fa76b 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1899,6 +1899,36 @@ TEST(Layer, CropLayer) { } } +TEST(Layer, ClipLayer) { + const size_t batchSize = 128; + const size_t size = 512; + TestConfig config; + config.layerConfig.set_type("clip"); + config.inputDefs.push_back({INPUT_DATA, "input", size, 0}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + ClipConfig* layerConf = input->mutable_clip_conf(); + double p1 = std::rand() / (double)RAND_MAX; + double p2 = std::rand() / (double)RAND_MAX; + layerConf->set_min(std::min(p1, p2)); + layerConf->set_max(std::max(p1, p2)); + for (auto useGpu : {false, true}) { + testLayerGrad(config, "clip", batchSize, false, useGpu, false); + } +} + +TEST(Layer, RowL2NormLayer) { + const size_t batchSize = 128; + const size_t size = 512; + TestConfig config; + config.layerConfig.set_type("row_l2_norm"); + config.layerConfig.set_size(size); + config.inputDefs.push_back({INPUT_DATA, "input", size, 0}); + config.layerConfig.add_inputs(); + for (auto useGpu : {false, true}) { + testLayerGrad(config, "row_l2_norm", batchSize, false, useGpu, false); + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/paddle/math/BaseMatrix.cu b/paddle/math/BaseMatrix.cu index de48b6fac9c7d8125a552022c52353ef6bcef995..6db5965789b3750f46731f157167150583130d0a 100644 --- a/paddle/math/BaseMatrix.cu +++ b/paddle/math/BaseMatrix.cu @@ -442,6 +442,12 @@ DEFINE_MATRIX_UNARY_PARAMETER_OP(Clip, TWO_PARAMETER, template void BaseMatrixT::clip(T p1, T p2) { applyUnary(unary::Clip(p1, p2)); } +DEFINE_MATRIX_BINARY_PARAMETER_OP(ClipDerivative, TWO_PARAMETER, a = b < p1 ? 0 : (b > p2 ? 0 : 1)); +template +void BaseMatrixT::clipDerivative(BaseMatrixT& b, T p1, T p2) { + applyBinary(binary::ClipDerivative(p1, p2), b); +} + DEFINE_MATRIX_UNARY_PARAMETER_OP(BiggerThanScalar, ONE_PARAMETER, a = a > p ? 1.0f : 0.0f); template diff --git a/paddle/math/BaseMatrix.h b/paddle/math/BaseMatrix.h index 120d69f718b954925438fbd2119d69f0be13b3e9..12ad2d45a0bbff182e78da6efb3c5ff4c6b59b55 100644 --- a/paddle/math/BaseMatrix.h +++ b/paddle/math/BaseMatrix.h @@ -488,6 +488,13 @@ public: */ void clip(T p1, T p2); + /** + * this = b < low ? 0 : 1 + * + * this = b > high ? 0 : 1 + */ + void clipDerivative(BaseMatrixT& b, T p1, T p2); + /** * @code * a = a > p ? 1.0f : 0.0f diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 3a43dbfbada87e458109d8ca22effdb4407b4c1d..85269a5f7445a1745d9be68417789e33eb725d5c 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -50,10 +50,6 @@ The equation is: Out = X + Y class AddOpGrad : public OperatorWithKernel { protected: void InferShape(const InferShapeContext &ctx) const override {} - std::string DebugString() const override { - LOG(INFO) << "AddOpGrad"; - return ""; - } }; } // namespace operators diff --git a/paddle/operators/add_op.cu b/paddle/operators/add_op.cu index 79d8de6cd46e1c72b14b0554c7be7b4eee281f4c..f961b37565f400b5c26844b9e7a3cff5e682340b 100644 --- a/paddle/operators/add_op.cu +++ b/paddle/operators/add_op.cu @@ -1,3 +1,4 @@ +#define EIGEN_USE_GPU #include "paddle/framework/op_registry.h" #include "paddle/operators/add_op.h" diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index d2b649fcbd1e5cac1c8cfcfd4e522e41135f7d1f..54d2231425293f6cfb3adc9cb34d903a75fcdcd0 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -28,10 +28,13 @@ public: output->mutable_data(context.GetPlace()); - EigenVector::Flatten(*output).device( - *(context.GetEigenDevice())) = - framework::EigenVector::Flatten(*input0) + - framework::EigenVector::Flatten(*input1); + auto X = EigenVector::Flatten(*input0); + auto Y = EigenVector::Flatten(*input1); + auto Z = EigenVector::Flatten(*output); + + auto place = context.GetEigenDevice(); + + Z.device(place) = X + Y; } }; diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 19e4b74596a0f59edd04db830ec6f6f481373465..926a0c616b957d8e542c1f3dee227a718fb29f07 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -1,3 +1,4 @@ +#define EIGEN_USE_GPU #include "paddle/operators/cross_entropy_op.h" REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h index 5f7d443751d1cdd7de3b67b0de2758ba1d566fb3..5c339bffbf8e39f36ee9b4f857ab380cbac82879 100644 --- a/paddle/operators/mean_op.h +++ b/paddle/operators/mean_op.h @@ -27,8 +27,11 @@ public: output->mutable_data(context.GetPlace()); - EigenScalar::From(*output).device(*(context.GetEigenDevice())) = - EigenVector::Flatten(*input).mean(); + auto X = EigenVector::Flatten(*input); + auto y = EigenScalar::From(*output); + auto place = context.GetEigenDevice(); + + y.device(place) = X.mean(); } }; diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu index c27fc886ce7238a13c8ef86bce673a2b54949a9d..dc9236701627dc9335b844d2a82e18eb1f7dfd42 100644 --- a/paddle/operators/mul_op.cu +++ b/paddle/operators/mul_op.cu @@ -12,6 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ +#define EIGEN_USE_GPU #include "paddle/operators/mul_op.h" REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); \ No newline at end of file diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index eef72ab293e13a9d05ce0013be41ec4bb75d6077..c7b78ad39045d25d73bfc2c930063c255a514864 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -26,13 +26,18 @@ public: Eigen::array, 1> dim_pair = { {Eigen::IndexPair(1, 0)}}; + auto input0 = context.Input("X"); + auto input1 = context.Input("Y"); auto output = context.Output(0); + output->mutable_data(context.GetPlace()); - EigenMatrix::From(*output).device(*(context.GetEigenDevice())) = - EigenMatrix::From(*context.Input("X")) - .contract(EigenMatrix::From(*context.Input("Y")), - dim_pair); + auto X = EigenMatrix::From(*input0); + auto Y = EigenMatrix::From(*input1); + auto Z = EigenMatrix::From(*output); + auto place = context.GetEigenDevice(); + + Z.device(place) = X.contract(Y, dim_pair); } }; } // namespace operators diff --git a/paddle/operators/rowwise_add_op.cu b/paddle/operators/rowwise_add_op.cu index 4b33e38ebabe853e179fe70ef7fde0a80b9050e2..82338ceccc06653791b26472e18d804f62735649 100644 --- a/paddle/operators/rowwise_add_op.cu +++ b/paddle/operators/rowwise_add_op.cu @@ -1,3 +1,4 @@ +#define EIGEN_USE_GPU #include "paddle/operators/rowwise_add_op.h" REGISTER_OP_GPU_KERNEL(rowwise_add, diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index b86dd5463436bf521f9939b1c421b39f11102769..bd4d1128955fb718d3a84dfd96d8c68d7196e9cc 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -33,7 +33,7 @@ public: const int rest_size = input.size() / bias_size; Eigen::DSizes one_d(input.size()); Eigen::DSizes bcast(rest_size); - output.reshape(one_d).device(*(context.GetEigenDevice())) = + output.reshape(one_d).device(context.GetEigenDevice()) = input.reshape(one_d) + bias.broadcast(bcast).reshape(one_d); } }; diff --git a/paddle/operators/sgd_op.cu b/paddle/operators/sgd_op.cu index f8f5b90cab460b4457cfb0a88bfc012bafe0fbc2..d79258cbf13c699cfb2afaee229cf96a3e377b5e 100644 --- a/paddle/operators/sgd_op.cu +++ b/paddle/operators/sgd_op.cu @@ -1,3 +1,4 @@ +#define EIGEN_USE_GPU #include "paddle/operators/sgd_op.h" REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel); \ No newline at end of file diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index af1dfdd756ceb9991bee6b85c3281c05f0fb5a9f..0c3a240f9a4a5fc7bc4898e82786810cee2f7010 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -29,8 +29,12 @@ public: param_out->mutable_data(ctx.GetPlace()); - EigenVector::Flatten(*param_out).device(*(ctx.GetEigenDevice())) = - EigenVector::Flatten(*param) - lr * EigenVector::Flatten(*grad); + auto p = EigenVector::Flatten(*param); + auto g = EigenVector::Flatten(*grad); + auto o = EigenVector::Flatten(*param_out); + auto place = ctx.GetEigenDevice(); + + o.device(place) = p - lr * g; } }; diff --git a/paddle/operators/sigmoid_op.cu b/paddle/operators/sigmoid_op.cu index f679b20418f04eff4310efe4e121963ce5a235e0..c9d11a2e1f9dcc563765c9e8cc1bae6beff57f18 100644 --- a/paddle/operators/sigmoid_op.cu +++ b/paddle/operators/sigmoid_op.cu @@ -1,3 +1,4 @@ +#define EIGEN_USE_GPU #include "paddle/operators/sigmoid_op.h" REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel); diff --git a/paddle/operators/sigmoid_op.h b/paddle/operators/sigmoid_op.h index 3dd23a9ebc7ac0972d6ee07b9ac051d59e66f62f..1412e4398440c8e946d3ab434a50e978079637ab 100644 --- a/paddle/operators/sigmoid_op.h +++ b/paddle/operators/sigmoid_op.h @@ -27,9 +27,11 @@ public: auto output = context.Output(0); output->mutable_data(context.GetPlace()); - EigenVector::Flatten(*output).device( - *(context.GetEigenDevice())) = - 1.0 / (1.0 + (-1.0 * EigenVector::Flatten(*input)).exp()); + auto X = EigenVector::Flatten(*input); + auto Y = EigenVector::Flatten(*output); + auto place = context.GetEigenDevice(); + + Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp()); } }; } // namespace operators diff --git a/paddle/operators/softmax_op.cu b/paddle/operators/softmax_op.cu index a1f6944a369fe5148ffcfeabf3bf7063dcbc2664..ddf8f6e913ccf450185f377f531bf978f69ed1fc 100644 --- a/paddle/operators/softmax_op.cu +++ b/paddle/operators/softmax_op.cu @@ -1,3 +1,4 @@ +#define EIGEN_USE_GPU #include "paddle/framework/op_registry.h" #include "paddle/operators/softmax_op.h" diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index a5c19c5fc7c6f5909dbb355aff09bf15405b6957..75c5197697dada58e09f4cda41cea13af56e79a3 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -46,9 +46,9 @@ public: .reshape(batch_by_one) .broadcast(one_by_class)); - softmax.device(*(context.GetEigenDevice())) = shifted_logits.exp(); + softmax.device(context.GetEigenDevice()) = shifted_logits.exp(); - softmax.device(*(context.GetEigenDevice())) = + softmax.device(context.GetEigenDevice()) = (softmax * softmax.sum(along_class) .inverse() diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 26c8eb78e614a68ec9728aad727d8fe3e08547ae..60a42c777d1c2ebbc22fdb77b1100cc6fcf7ff35 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -144,12 +144,12 @@ inline void throw_on_error(T e) { throw_on_error(e, ""); } -#define PADDLE_THROW(...) \ - do { \ - throw ::paddle::platform::EnforceNotMet( \ - std::make_exception_ptr( \ - std::runtime_error(string::Sprintf(__VA_ARGS__))), \ - __FILE__, __LINE__); \ +#define PADDLE_THROW(...) \ + do { \ + throw ::paddle::platform::EnforceNotMet( \ + std::make_exception_ptr( \ + std::runtime_error(paddle::string::Sprintf(__VA_ARGS__))), \ + __FILE__, __LINE__); \ } while (0) #define PADDLE_ENFORCE(...) \ diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 845589dcb1997b662b5175e5cce320eec4be4a8d..ac12b504b5cbee778c7c0a74a84a7729f210e01e 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,6 +1,6 @@ cc_library(paddle_pybind SHARED SRCS pybind.cc - DEPS pybind python + DEPS pybind python backward fc_op sgd_op add_op diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index dc6f29d026b62dc819fe3d0fdf301e1a0d2886a9..4df02d40fe08290d1165e032f0ad013c5986feaa 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -16,10 +16,13 @@ limitations under the License. */ #include #include +#include "paddle/framework/backward.h" #include "paddle/framework/net.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/operator.h" #include "paddle/framework/scope.h" +#include "paddle/platform/enforce.h" +#include "paddle/platform/place.h" #include "paddle/pybind/tensor_bind.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" @@ -43,6 +46,10 @@ template void ExposeOperator(ClassType& m) { m.def("infer_shape", &ClassType::type::InferShape) .def("run", &ClassType::type::Run) + .def("type", + [](const typename ClassType::type& op) -> std::string { + return op.type_; + }) .def("outputs", [](const typename ClassType::type& op) -> std::vector { return op.outputs_; @@ -55,6 +62,14 @@ static size_t UniqueIntegerGenerator() { return generator.fetch_add(1); } +bool IsCompileGPU() { +#ifdef PADDLE_ONLY_CPU + return false; +#else + return true; +#endif +} + PYBIND11_PLUGIN(core) { py::module m("core", "C++ core of PaddlePaddle"); @@ -69,15 +84,27 @@ PYBIND11_PLUGIN(core) { self.Resize(pd::make_ddim(dim)); }) .def("alloc_float", - [](pd::Tensor& self) { - self.mutable_data(paddle::platform::CPUPlace()); + [](pd::Tensor& self, paddle::platform::GPUPlace& place) { + self.mutable_data(place); + }) + .def("alloc_float", + [](pd::Tensor& self, paddle::platform::CPUPlace& place) { + self.mutable_data(place); + }) + .def("alloc_int", + [](pd::Tensor& self, paddle::platform::CPUPlace& place) { + self.mutable_data(place); }) .def("alloc_int", - [](pd::Tensor& self) { - self.mutable_data(paddle::platform::CPUPlace()); + [](pd::Tensor& self, paddle::platform::GPUPlace& place) { + self.mutable_data(place); }) - .def("set", paddle::pybind::PyTensorSetFromArray) - .def("set", paddle::pybind::PyTensorSetFromArray) + .def("set", paddle::pybind::PyCPUTensorSetFromArray) + .def("set", paddle::pybind::PyCPUTensorSetFromArray) +#ifndef PADDLE_ONLY_CPU + .def("set", paddle::pybind::PyCUDATensorSetFromArray) + .def("set", paddle::pybind::PyCUDATensorSetFromArray) +#endif .def("shape", [](pd::Tensor& self) { return pd::vectorize(self.dims()); }) .def("set_float_element", [](pd::Tensor& self, size_t offset, float f) { @@ -144,11 +171,27 @@ All parameter, weight, gradient are variables in Paddle. "The module will return special predefined variable name in Paddle") .def("empty", pd::OperatorBase::EMPTY_VAR_NAME) .def("temp", pd::OperatorBase::TMP_VAR_NAME); - + // clang-format off py::class_(m, "DeviceContext") - .def_static("cpu_context", []() -> paddle::platform::DeviceContext* { - return new paddle::platform::CPUDeviceContext(); - }); + .def_static("create", + [](paddle::platform::CPUPlace& place) + -> paddle::platform::DeviceContext* { + return new paddle::platform::CPUDeviceContext(); + }) + .def_static("create", + [](paddle::platform::GPUPlace& place) + -> paddle::platform::DeviceContext* { +#ifdef PADDLE_ONLY_CPU + PADDLE_THROW("GPUPlace is not supported in CPU device."); +#else + return new paddle::platform::CUDADeviceContext(place); +#endif + }); + // clang-format on + + py::class_(m, "GPUPlace").def(py::init()); + + py::class_(m, "CPUPlace").def(py::init<>()); py::class_> operator_base( m, "Operator"); @@ -162,6 +205,13 @@ All parameter, weight, gradient are variables in Paddle. desc.InitializationErrorString()); return pd::OpRegistry::CreateOp(desc); }); + + operator_base.def("backward", + [](const pd::OperatorBase& forwardOp, + const std::unordered_set& no_grad_vars) { + return pd::Backward(forwardOp, no_grad_vars); + }); + ExposeOperator(operator_base); py::class_> net(m, "Net"); @@ -184,5 +234,7 @@ All parameter, weight, gradient are variables in Paddle. m.def("unique_integer", UniqueIntegerGenerator); + m.def("is_compile_gpu", IsCompileGPU); + return m.ptr(); } diff --git a/paddle/pybind/tensor_bind.h b/paddle/pybind/tensor_bind.h index 995e102bf9d342e1604f5ae704288d6cf68d97a4..def37219ccefd5435f1212c4e4daac5a351d76f4 100644 --- a/paddle/pybind/tensor_bind.h +++ b/paddle/pybind/tensor_bind.h @@ -13,9 +13,11 @@ limitations under the License. */ #pragma once -#include -#include -#include +#include +#include "paddle/framework/tensor.h" +#include "paddle/memory/memcpy.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" namespace py = pybind11; @@ -40,9 +42,6 @@ template struct CastToPyBufferImpl { using CUR_TYPE = typename std::tuple_element>::type; py::buffer_info operator()(framework::Tensor &tensor) { - PADDLE_ENFORCE(paddle::platform::is_cpu_place(tensor.holder_->place()), - "Only CPU tensor can cast to numpy array"); - if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) { auto dim_vec = framework::vectorize(tensor.dims()); std::vector dims_outside; @@ -56,12 +55,17 @@ struct CastToPyBufferImpl { strides[i - 1] = sizeof(CUR_TYPE) * prod; prod *= dims_outside[i - 1]; } - + framework::Tensor dst_tensor; + if (paddle::platform::is_gpu_place(tensor.holder_->place())) { + dst_tensor.CopyFrom(tensor, platform::CPUPlace()); + } else if (paddle::platform::is_cpu_place(tensor.holder_->place())) { + dst_tensor = tensor; + } return py::buffer_info( - tensor.mutable_data(tensor.holder_->place()), + dst_tensor.mutable_data(dst_tensor.holder_->place()), sizeof(CUR_TYPE), py::format_descriptor::format(), - (size_t)framework::arity(tensor.dims()), + (size_t)framework::arity(dst_tensor.dims()), dims_outside, strides); } else { @@ -77,9 +81,10 @@ inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) { } template -void PyTensorSetFromArray( +void PyCPUTensorSetFromArray( framework::Tensor &self, - py::array_t array) { + py::array_t array, + paddle::platform::CPUPlace &place) { std::vector dims; dims.reserve(array.ndim()); for (size_t i = 0; i < array.ndim(); ++i) { @@ -87,9 +92,28 @@ void PyTensorSetFromArray( } self.Resize(framework::make_ddim(dims)); - auto *dst = self.mutable_data(paddle::platform::CPUPlace()); + auto *dst = self.mutable_data(place); std::memcpy(dst, array.data(), sizeof(T) * array.size()); } +#ifndef PADDLE_ONLY_CPU +template +void PyCUDATensorSetFromArray( + framework::Tensor &self, + py::array_t array, + paddle::platform::GPUPlace &place) { + std::vector dims; + dims.reserve(array.ndim()); + for (size_t i = 0; i < array.ndim(); ++i) { + dims.push_back((int)array.shape()[i]); + } + + self.Resize(framework::make_ddim(dims)); + auto *dst = self.mutable_data(place); + paddle::platform::GpuMemcpySync( + dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice); +} +#endif + } // namespace pybind } // namespace paddle diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 3860facb099950a5287d3f6b89c3de38f588f568..69ae0ea2d72c199a8e17c0595693e5e0b2f79ee1 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -148,7 +148,7 @@ cat >> /paddle/build/Dockerfile <