diff --git a/Dockerfile b/Dockerfile index 8cfb16928c95dcbfac08383d32562ff67933d873..5dd9b0be4f7e0a304108abfdfb089fea4faa4d38 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,7 +27,7 @@ RUN apt-get update && \ git python-pip python-dev openssh-server bison \ wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \ curl sed grep graphviz libjpeg-dev zlib1g-dev \ - python-numpy python-matplotlib gcc g++ \ + python-numpy python-matplotlib gcc-4.8 g++-4.8 \ automake locales clang-format-3.8 swig doxygen cmake \ liblapack-dev liblapacke-dev libboost-dev \ clang-3.8 llvm-3.8 libclang-3.8-dev \ diff --git a/cmake/flags.cmake b/cmake/flags.cmake index ef31c252038ce18655913c0f41343fe6dc7dbb86..d00a9bb3a30cfb16623e073414088059481c3e1a 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -9,6 +9,11 @@ function(CheckCompilerCXX11Flag) if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8) message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.") endif() + # TODO(qijun) gcc 4.9 or later versions raise SEGV due to the optimization problem. + # Use Debug mode instead for now. + if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9 OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL 4.9) + set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "" FORCE) + endif() elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") # cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang" # Apple Clang is a different compiler than upstream Clang which havs different version numbers. diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst index ec7f1446cfb74842af7d0c7152bebf58619f3861..372272a53c12c314fc80eebbce5eae9fcabc55ba 100644 --- a/doc/api/v2/config/layer.rst +++ b/doc/api/v2/config/layer.rst @@ -104,6 +104,11 @@ cross_channel_norm ------------------ .. autoclass:: paddle.v2.layer.cross_channel_norm :noindex: + +row_l2_norm +----------- +.. autoclass:: paddle.v2.layer.row_l2_norm + :noindex: Recurrent Layers ================ @@ -320,6 +325,11 @@ scaling .. autoclass:: paddle.v2.layer.scaling :noindex: +clip +---- +.. autoclass:: paddle.v2.layer.clip + :noindex: + slope_intercept --------------- .. autoclass:: paddle.v2.layer.slope_intercept diff --git a/paddle/cuda/src/hl_cuda_cudnn.cc b/paddle/cuda/src/hl_cuda_cudnn.cc index c53a5636829cab9d575f58cc2326cb3efe383e1c..7ad8a39768a064140a08c912a5a467bc24a12adf 100644 --- a/paddle/cuda/src/hl_cuda_cudnn.cc +++ b/paddle/cuda/src/hl_cuda_cudnn.cc @@ -1022,6 +1022,15 @@ void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc, real alpha = 1.0f; real beta = 1.0f; cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL; + + int batch_size = ((cudnn_tensor_descriptor)inputDesc)->batch_size; + if (batch_size > 1024 && g_cudnn_lib_version < 6000) { + LOG(INFO) << " To process current batch data with size " << batch_size + << " (>1024), cudnnBatchNorm requires cuDNN version >= 6000." + << " If there is an error complaining CUDNN_STATUS_NOT_SUPPORTED," + << " just recompile PaddlePaddle with cuDNN >= 6000, replacing" + << " current version " << g_cudnn_lib_version; + } CHECK_CUDNN( dynload::cudnnBatchNormalizationForwardInference(t_resource.cudnn_handle, mode, diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index b73426eaacdf2eaf115b4ac02d58e02d24cc753d..f8f9bae12d42ccbc52b1046900d239ae0cde6940 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -38,7 +38,7 @@ cc_library(backward SRCS backward.cc DEPS net) cc_test(backward_test SRCS backward_test.cc DEPS backward) cc_library(paddle_pybind SHARED SRCS pybind.cc - DEPS pybind python + DEPS pybind python backward fc_op sgd_op add_op diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index f10c9297981a4c6aefc6c2072d0ac2b8e562a7a0..3e72e391266066de9e4114e68b43b066c15254db 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -400,6 +400,14 @@ class GradOpRegisterHelper { return 0; \ } +/** + * Macro to Forbid user register Gradient Operator. + */ +#define NO_GRADIENT(__op_type) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_gradient_op__##__op_type##__op_type##_grad, \ + "NO_GRADIENT must be in global namespace") + /** * Macro to Register OperatorKernel. */ diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index cfe9cba308556475ef64b45e7178dfc418761598..cb86e6be2be3624bf54ee28193ca5d4c7bafa0eb 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -20,16 +20,16 @@ namespace paddle { namespace framework { template <> -Eigen::DefaultDevice* ExecutionContext::GetEigenDevice< +Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< platform::CPUPlace, Eigen::DefaultDevice>() const { - return device_context_.get_eigen_device<Eigen::DefaultDevice>(); + return *device_context_.get_eigen_device<Eigen::DefaultDevice>(); } #ifndef PADDLE_ONLY_CPU template <> -Eigen::GpuDevice* +Eigen::GpuDevice& ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { - return device_context_.get_eigen_device<Eigen::GpuDevice>(); + return *device_context_.get_eigen_device<Eigen::GpuDevice>(); } #endif diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 0832a663dd01fe2921366d70599bc867e73af47c..55435103489ace11868eed61c38018d8ba357e65 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -253,7 +253,7 @@ class ExecutionContext : public OperatorContext { template <typename PlaceType, typename DeviceType = typename EigenDeviceConverter<PlaceType>::EigenDeviceType> - DeviceType* GetEigenDevice() const; + DeviceType& GetEigenDevice() const; platform::Place GetPlace() const { return device_context_.GetPlace(); } diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index a735cc2ad51aaf3eaa2ad05f2ab757448b31ed49..cc47469b4db53458f6a4314f4339b58a9527637e 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 +http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -16,11 +16,14 @@ limitations under the License. */ #include <fstream> #include <vector> +#include "paddle/framework/backward.h" #include "paddle/framework/net.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/operator.h" #include "paddle/framework/scope.h" #include "paddle/framework/tensor_bind.h" +#include "paddle/platform/enforce.h" +#include "paddle/platform/place.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" @@ -43,6 +46,10 @@ template <typename ClassType> void ExposeOperator(ClassType &m) { m.def("infer_shape", &ClassType::type::InferShape) .def("run", &ClassType::type::Run) + .def("type", + [](const typename ClassType::type &op) -> std::string { + return op.type_; + }) .def("outputs", [](const typename ClassType::type &op) -> std::vector<std::string> { return op.outputs_; @@ -55,6 +62,14 @@ static size_t UniqueIntegerGenerator() { return generator.fetch_add(1); } +bool IsCompileGPU() { +#ifdef PADDLE_ONLY_CPU + return false; +#else + return true; +#endif +} + PYBIND11_PLUGIN(core) { py::module m("core", "C++ core of PaddlePaddle"); @@ -68,16 +83,29 @@ PYBIND11_PLUGIN(core) { self.Resize(make_ddim(dim)); }) .def("alloc_float", - [](Tensor &self) { - self.mutable_data<float>(paddle::platform::CPUPlace()); + [](pd::Tensor &self, paddle::platform::GPUPlace &place) { + self.mutable_data<float>(place); + }) + .def("alloc_float", + [](pd::Tensor &self, paddle::platform::CPUPlace &place) { + self.mutable_data<float>(place); }) .def("alloc_int", - [](Tensor &self) { - self.mutable_data<int>(paddle::platform::CPUPlace()); + [](pd::Tensor &self, paddle::platform::CPUPlace &place) { + self.mutable_data<int>(place); }) - .def("set", PyTensorSetFromArray<float>) - .def("set", PyTensorSetFromArray<int>) - .def("shape", [](Tensor &self) { return vectorize(self.dims()); }); + .def("alloc_int", + [](pd::Tensor &self, paddle::platform::GPUPlace &place) { + self.mutable_data<int>(place); + }) + .def("set", paddle::pybind::PyCPUTensorSetFromArray<float>) + .def("set", paddle::pybind::PyCPUTensorSetFromArray<int>) +#ifndef PADDLE_ONLY_CPU + .def("set", paddle::pybind::PyCUDATensorSetFromArray<float>) + .def("set", paddle::pybind::PyCUDATensorSetFromArray<int>) +#endif + .def("shape", + [](pd::Tensor &self) { return pd::vectorize(self.dims()); }); py::class_<Variable>(m, "Variable", R"DOC(Variable Class. @@ -124,13 +152,29 @@ All parameter, weight, gradient are variables in Paddle. m.def_submodule( "var_names", "The module will return special predefined variable name in Paddle") - .def("empty", OperatorBase::EMPTY_VAR_NAME) - .def("temp", OperatorBase::TMP_VAR_NAME); - + .def("empty", pd::OperatorBase::EMPTY_VAR_NAME) + .def("temp", pd::OperatorBase::TMP_VAR_NAME); + // clang-format off py::class_<paddle::platform::DeviceContext>(m, "DeviceContext") - .def_static("cpu_context", []() -> paddle::platform::DeviceContext * { - return new paddle::platform::CPUDeviceContext(); - }); + .def_static("create", + [](paddle::platform::CPUPlace& place) + -> paddle::platform::DeviceContext* { + return new paddle::platform::CPUDeviceContext(); + }) + .def_static("create", + [](paddle::platform::GPUPlace& place) + -> paddle::platform::DeviceContext* { +#ifdef PADDLE_ONLY_CPU + PADDLE_THROW("GPUPlace is not supported in CPU device."); +#else + return new paddle::platform::CUDADeviceContext(place); +#endif + }); + // clang-format on + + py::class_<paddle::platform::GPUPlace>(m, "GPUPlace").def(py::init<int>()); + + py::class_<paddle::platform::CPUPlace>(m, "CPUPlace").def(py::init<>()); py::class_<OperatorBase, std::shared_ptr<OperatorBase>> operator_base( m, "Operator"); @@ -144,6 +188,13 @@ All parameter, weight, gradient are variables in Paddle. desc.InitializationErrorString()); return OpRegistry::CreateOp(desc); }); + + operator_base.def("backward", + [](const pd::OperatorBase &forwardOp, + const std::unordered_set<std::string> &no_grad_vars) { + return pd::Backward(forwardOp, no_grad_vars); + }); + ExposeOperator(operator_base); py::class_<NetOp, std::shared_ptr<NetOp>> net(m, "Net"); @@ -166,6 +217,8 @@ All parameter, weight, gradient are variables in Paddle. m.def("unique_integer", UniqueIntegerGenerator); + m.def("is_compile_gpu", IsCompileGPU); + return m.ptr(); } } // namespace framework diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index c3e9a914f1d389b058d380b893441e34249f4293..4c3b14b83d841e88683a13634c93f51c012128b6 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -165,4 +165,4 @@ class Tensor { } // namespace framework } // namespace paddle -#include "paddle/framework/detail/tensor-inl.h" +#include "paddle/framework/tensor_impl.h" diff --git a/paddle/framework/tensor_bind.h b/paddle/framework/tensor_bind.h index 530b640f7051db2334c873bf4cd9608fcc0e88f1..4e1ab77b157fe1adaeac55c271c056236f2d40de 100644 --- a/paddle/framework/tensor_bind.h +++ b/paddle/framework/tensor_bind.h @@ -13,9 +13,11 @@ limitations under the License. */ #pragma once -#include <paddle/framework/tensor.h> -#include <pybind11/numpy.h> -#include <pybind11/pybind11.h> +#include <string> +#include "paddle/framework/tensor.h" +#include "paddle/memory/memcpy.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" namespace py = pybind11; @@ -40,9 +42,6 @@ template <size_t I, typename... ARGS> struct CastToPyBufferImpl<true, I, ARGS...> { using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type; py::buffer_info operator()(framework::Tensor &tensor) { - PADDLE_ENFORCE(paddle::platform::is_cpu_place(tensor.holder_->place()), - "Only CPU tensor can cast to numpy array"); - if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) { auto dim_vec = framework::vectorize(tensor.dims()); std::vector<size_t> dims_outside; @@ -56,11 +55,16 @@ struct CastToPyBufferImpl<true, I, ARGS...> { strides[i - 1] = sizeof(CUR_TYPE) * prod; prod *= dims_outside[i - 1]; } - + framework::Tensor dst_tensor; + if (paddle::platform::is_gpu_place(tensor.holder_->place())) { + dst_tensor.CopyFrom<CUR_TYPE>(tensor, platform::CPUPlace()); + } else if (paddle::platform::is_cpu_place(tensor.holder_->place())) { + dst_tensor = tensor; + } return py::buffer_info( - tensor.mutable_data<CUR_TYPE>(tensor.holder_->place()), + dst_tensor.mutable_data<CUR_TYPE>(dst_tensor.holder_->place()), sizeof(CUR_TYPE), py::format_descriptor<CUR_TYPE>::format(), - (size_t)framework::arity(tensor.dims()), dims_outside, strides); + (size_t)framework::arity(dst_tensor.dims()), dims_outside, strides); } else { constexpr bool less = I + 1 < std::tuple_size<std::tuple<ARGS...>>::value; return CastToPyBufferImpl<less, I + 1, ARGS...>()(tensor); @@ -74,9 +78,10 @@ inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) { } template <typename T> -void PyTensorSetFromArray( +void PyCPUTensorSetFromArray( framework::Tensor &self, - py::array_t<T, py::array::c_style | py::array::forcecast> array) { + py::array_t<T, py::array::c_style | py::array::forcecast> array, + paddle::platform::CPUPlace &place) { std::vector<int> dims; dims.reserve(array.ndim()); for (size_t i = 0; i < array.ndim(); ++i) { @@ -84,9 +89,28 @@ void PyTensorSetFromArray( } self.Resize(framework::make_ddim(dims)); - auto *dst = self.mutable_data<T>(paddle::platform::CPUPlace()); + auto *dst = self.mutable_data<T>(place); std::memcpy(dst, array.data(), sizeof(T) * array.size()); } +#ifndef PADDLE_ONLY_CPU +template <typename T> +void PyCUDATensorSetFromArray( + framework::Tensor &self, + py::array_t<T, py::array::c_style | py::array::forcecast> array, + paddle::platform::GPUPlace &place) { + std::vector<int> dims; + dims.reserve(array.ndim()); + for (size_t i = 0; i < array.ndim(); ++i) { + dims.push_back((int)array.shape()[i]); + } + + self.Resize(framework::make_ddim(dims)); + auto *dst = self.mutable_data<T>(place); + paddle::platform::GpuMemcpySync(dst, array.data(), sizeof(T) * array.size(), + cudaMemcpyHostToDevice); +} +#endif + } // namespace pybind } // namespace paddle diff --git a/paddle/framework/detail/tensor-inl.h b/paddle/framework/tensor_impl.h similarity index 97% rename from paddle/framework/detail/tensor-inl.h rename to paddle/framework/tensor_impl.h index e7ff09dd5c954378afeca299e901277c3ebdb96a..92621f8c18ec0d03160a23c462830d14272c7f64 100644 --- a/paddle/framework/detail/tensor-inl.h +++ b/paddle/framework/tensor_impl.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once - #include "paddle/memory/memcpy.h" namespace paddle { @@ -62,9 +61,11 @@ inline T* Tensor::mutable_data(platform::Place place) { if (platform::is_cpu_place(place)) { holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>( boost::get<platform::CPUPlace>(place), size)); + } else if (platform::is_gpu_place(place)) { +#ifdef PADDLE_ONLY_CPU + PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); } -#ifndef PADDLE_ONLY_CPU - else if (platform::is_gpu_place(place)) { +#else holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>( boost::get<platform::GPUPlace>(place), size)); } diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h index bb4f48364b9b454af7d37fe4d3c340666e53285c..baf78bc6c88d0d294f4457b81c52b22e425d9fdb 100644 --- a/paddle/function/ConvOp.h +++ b/paddle/function/ConvOp.h @@ -109,6 +109,13 @@ protected: return filter[filter.ndims() - 1]; } + // determine whether im2col needs to be performed + inline bool isNeedIm2col(const TensorShape& filter) const { + return !(getFilterHeight(filter) == 1 && getFilterWidth(filter) == 1 && + strideH() == 1 && strideW() == 1 && paddingH() == 0 && + paddingW() == 0); + } + std::vector<size_t> strides_; std::vector<size_t> paddings_; diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 9deb2739fcfff935a98a0b5b31b5d11819d81227..0ada4d70a0c7d13f9b5fb1a42eac07fc4c775a87 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -66,16 +66,23 @@ public: real* inputData = inputs[0].data<real>(); real* filterData = inputs[1].data<real>(); real* outputData = outputs[0].data<real>(); + bool needIm2col = isNeedIm2col(filter); + TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); - TensorShape colShape = TensorShape({inputChannels / groups_, - filterHeight, - filterWidth, - outputHeight, - outputWidth}); - resizeBuffer<Device>(colShape.getElements()); - real* colData = reinterpret_cast<real*>(memory_->getBuf()); + TensorShape colShape; + real* colData = NULL; + + if (needIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + resizeBuffer<Device>(colShape.getElements()); + colData = reinterpret_cast<real*>(memory_->getBuf()); + } Im2ColFunctor<kCFO, Device, real> im2col; GemmFunctor<Device, real> gemm; @@ -86,15 +93,18 @@ public: for (size_t i = 0; i < batchSize; i++) { for (size_t g = 0; g < groups_; g++) { - im2col(inputData + g * inputOffset, - imShape, - colData, - colShape, - strideH(), - strideW(), - paddingH(), - paddingW()); - + if (needIm2col) { + im2col(inputData + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + } else { + colData = inputData + g * inputOffset; + } int M = outputChannels / groups_; int N = outputHeight * outputWidth; int K = inputChannels / groups_ * filterHeight * filterWidth; @@ -159,19 +169,27 @@ public: real* outputGrad = inputs[0].data<real>(); real* filterData = inputs[1].data<real>(); real* inputGrad = outputs[0].data<real>(); + bool needIm2col = isNeedIm2col(filter); + TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); - TensorShape colShape = TensorShape({inputChannels / groups_, - filterHeight, - filterWidth, - outputHeight, - outputWidth}); - resizeBuffer<Device>(colShape.getElements()); - real* colData = reinterpret_cast<real*>(memory_->getBuf()); + TensorShape colShape; + real* colData = NULL; + + if (needIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + resizeBuffer<Device>(colShape.getElements()); + colData = reinterpret_cast<real*>(memory_->getBuf()); + } Col2ImFunctor<kCFO, Device, real> col2im; GemmFunctor<Device, real> gemm; + size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; @@ -182,6 +200,11 @@ public: int K = outputChannels / groups_; int N = outputHeight * outputWidth; int M = inputChannels / groups_ * filterHeight * filterWidth; + real scale = 0.0f; + if (!needIm2col) { + colData = inputGrad + g * inputOffset; + scale = 1.0f; + } gemm(CblasTrans, CblasNoTrans, M, @@ -192,17 +215,19 @@ public: M, outputGrad + g * outputOffset, N, - 0.0f, + scale, colData, N); - col2im(inputGrad + g * inputOffset, - imShape, - colData, - colShape, - strideH(), - strideW(), - paddingH(), - paddingW()); + if (needIm2col) { + col2im(inputGrad + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + } } inputGrad += inputChannels * inputHeight * inputWidth; outputGrad += outputChannels * outputHeight * outputWidth; @@ -255,16 +280,23 @@ public: real* outputGrad = inputs[0].data<real>(); real* inputData = inputs[1].data<real>(); real* filterGrad = outputs[0].data<real>(); + bool needIm2col = isNeedIm2col(filter); + TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); - TensorShape colShape = TensorShape({inputChannels / groups_, - filterHeight, - filterWidth, - outputHeight, - outputWidth}); - resizeBuffer<Device>(colShape.getElements()); - real* colData = reinterpret_cast<real*>(memory_->getBuf()); + TensorShape colShape; + real* colData = NULL; + + if (needIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + resizeBuffer<Device>(colShape.getElements()); + colData = reinterpret_cast<real*>(memory_->getBuf()); + } Im2ColFunctor<kCFO, Device, real> im2col; GemmFunctor<Device, real> gemm; @@ -274,15 +306,18 @@ public: size_t filterOffset = filter.getElements() / groups_; for (size_t i = 0; i < batchSize; i++) { for (size_t g = 0; g < groups_; g++) { - im2col(inputData + g * inputOffset, - imShape, - colData, - colShape, - strideH(), - strideW(), - paddingH(), - paddingW()); - + if (needIm2col) { + im2col(inputData + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + } else { + colData = inputData + g * inputOffset; + } int M = outputChannels / groups_; int K = outputHeight * outputWidth; int N = inputChannels / groups_ * filterHeight * filterWidth; diff --git a/paddle/gserver/layers/ClipLayer.cpp b/paddle/gserver/layers/ClipLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..13f16c953793b82183237188b56eb61d76ecd2fd --- /dev/null +++ b/paddle/gserver/layers/ClipLayer.cpp @@ -0,0 +1,79 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Layer.h" + +namespace paddle { + +/** + * A layer for clipping the input value by the threshold. + * \f[ + * out[i] = \min\left(\max\left(in[i],p_{1}\right),p_{2}\right) + * \f] + */ + +class ClipLayer : public Layer { +protected: + double min_; + double max_; + +public: + explicit ClipLayer(const LayerConfig& config) : Layer(config) {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; +}; + +REGISTER_LAYER(clip, ClipLayer); + +bool ClipLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + Layer::init(layerMap, parameterMap); + + CHECK_EQ(inputLayers_.size(), 1U); + auto layerConf = config_.inputs(0).clip_conf(); + min_ = layerConf.min(); + max_ = layerConf.max(); + CHECK_LT(min_, max_); + return true; +} + +void ClipLayer::forward(PassType passType) { + Layer::forward(passType); + + MatrixPtr inV = getInputValue(0); + resetOutput(inV->getHeight(), inV->getWidth()); + MatrixPtr outV = getOutputValue(); + outV->copyFrom(*inV); + outV->clip(min_, max_); +} + +void ClipLayer::backward(const UpdateCallback& callback) { + MatrixPtr inV = getInputValue(0); + MatrixPtr inG = getInputGrad(0); + if (inG) { + MatrixPtr outV = getOutputValue(); + MatrixPtr outG = getOutputGrad(); + MatrixPtr tmpMtx; + Matrix::resizeOrCreate( + tmpMtx, outG->getHeight(), outG->getWidth(), false, useGpu_); + tmpMtx->clipDerivative(*inV, min_, max_); + inG->addDotMul(*outG, *tmpMtx, 1, 1); + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/RowL2NormLayer.cpp b/paddle/gserver/layers/RowL2NormLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0d609be43b73a86d0d0f7b60be993836e2ea6fff --- /dev/null +++ b/paddle/gserver/layers/RowL2NormLayer.cpp @@ -0,0 +1,98 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Layer.h" + +namespace paddle { + +/** + * A layer for L2 normalization in each row, + * \f[ + * out[i] = \frac{in[i]}{\sqrt{\sum_{k=1}^N in[k]^{2}}} + * \f] + * where the size of \f$in\f$ is (batchSize x dataDim), + * and the size of \f$out\f$ is (batchSize x dataDim). + */ + +class RowL2NormLayer : public Layer { +protected: + MatrixPtr inSquare_; + MatrixPtr l2NormReciprocal_; + MatrixPtr dotSum_; + +public: + explicit RowL2NormLayer(const LayerConfig& config) : Layer(config) {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; +}; + +REGISTER_LAYER(row_l2_norm, RowL2NormLayer); + +bool RowL2NormLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + Layer::init(layerMap, parameterMap); + + CHECK_EQ(inputLayers_.size(), 1U); + + return true; +} + +void RowL2NormLayer::forward(PassType passType) { + Layer::forward(passType); + + MatrixPtr inV = getInputValue(0); + + /* malloc memory for the output_ if necessary */ + size_t batchSize = inV->getHeight(); + size_t dataDim = getSize(); + CHECK_EQ(dataDim, inV->getWidth()); + resetOutput(batchSize, dataDim); + MatrixPtr outV = getOutputValue(); + + Matrix::resizeOrCreate(inSquare_, batchSize, dataDim, false, useGpu_); + inV->square2(*inSquare_); + Matrix::resizeOrCreate(l2NormReciprocal_, batchSize, 1, false, useGpu_); + inSquare_->rowSum(*l2NormReciprocal_); + l2NormReciprocal_->sqrt2(*l2NormReciprocal_); + l2NormReciprocal_->scalarDiv(*l2NormReciprocal_, 1.0); + outV->rowScale(0, *inV, *l2NormReciprocal_); +} + +void RowL2NormLayer::backward(const UpdateCallback& callback) { + MatrixPtr inV = getInputValue(0); + MatrixPtr inG = getInputGrad(0); + MatrixPtr outV = getOutputValue(); + MatrixPtr outG = getOutputGrad(); + size_t batchSize = inV->getHeight(); + + // inG[ij] += outG[ij] / l2NormReciprocal + // inG[ij] += -inV[ij] * l2NormReciprocal * l2NormReciprocal * DotMul(outG[i], + // inV[i]) + if (inG) { + Matrix::resizeOrCreate(dotSum_, batchSize, 1, false, useGpu_); + dotSum_->zeroMem(); + dotSum_->rowDotMul(0, *outG, *outV); + dotSum_->dotMul(*dotSum_, *l2NormReciprocal_); + dotSum_->dotMul(*dotSum_, *l2NormReciprocal_); + inSquare_->rowScale(0, *inV, *dotSum_); + inG->sub(*inSquare_); + inG->addRowScale(0, *outG, *l2NormReciprocal_); + } +} + +} // namespace paddle diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 8ce8600c6743779899b2685c1c12053922265411..fe11278f41c0118ee0bdb34f17fbf9602e0fa76b 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1899,6 +1899,36 @@ TEST(Layer, CropLayer) { } } +TEST(Layer, ClipLayer) { + const size_t batchSize = 128; + const size_t size = 512; + TestConfig config; + config.layerConfig.set_type("clip"); + config.inputDefs.push_back({INPUT_DATA, "input", size, 0}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + ClipConfig* layerConf = input->mutable_clip_conf(); + double p1 = std::rand() / (double)RAND_MAX; + double p2 = std::rand() / (double)RAND_MAX; + layerConf->set_min(std::min(p1, p2)); + layerConf->set_max(std::max(p1, p2)); + for (auto useGpu : {false, true}) { + testLayerGrad(config, "clip", batchSize, false, useGpu, false); + } +} + +TEST(Layer, RowL2NormLayer) { + const size_t batchSize = 128; + const size_t size = 512; + TestConfig config; + config.layerConfig.set_type("row_l2_norm"); + config.layerConfig.set_size(size); + config.inputDefs.push_back({INPUT_DATA, "input", size, 0}); + config.layerConfig.add_inputs(); + for (auto useGpu : {false, true}) { + testLayerGrad(config, "row_l2_norm", batchSize, false, useGpu, false); + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/paddle/math/BaseMatrix.cu b/paddle/math/BaseMatrix.cu index de48b6fac9c7d8125a552022c52353ef6bcef995..6db5965789b3750f46731f157167150583130d0a 100644 --- a/paddle/math/BaseMatrix.cu +++ b/paddle/math/BaseMatrix.cu @@ -442,6 +442,12 @@ DEFINE_MATRIX_UNARY_PARAMETER_OP(Clip, TWO_PARAMETER, template<class T> void BaseMatrixT<T>::clip(T p1, T p2) { applyUnary(unary::Clip<T>(p1, p2)); } +DEFINE_MATRIX_BINARY_PARAMETER_OP(ClipDerivative, TWO_PARAMETER, a = b < p1 ? 0 : (b > p2 ? 0 : 1)); +template<class T> +void BaseMatrixT<T>::clipDerivative(BaseMatrixT& b, T p1, T p2) { + applyBinary(binary::ClipDerivative<T>(p1, p2), b); +} + DEFINE_MATRIX_UNARY_PARAMETER_OP(BiggerThanScalar, ONE_PARAMETER, a = a > p ? 1.0f : 0.0f); template<class T> diff --git a/paddle/math/BaseMatrix.h b/paddle/math/BaseMatrix.h index 120d69f718b954925438fbd2119d69f0be13b3e9..12ad2d45a0bbff182e78da6efb3c5ff4c6b59b55 100644 --- a/paddle/math/BaseMatrix.h +++ b/paddle/math/BaseMatrix.h @@ -488,6 +488,13 @@ public: */ void clip(T p1, T p2); + /** + * this = b < low ? 0 : 1 + * + * this = b > high ? 0 : 1 + */ + void clipDerivative(BaseMatrixT& b, T p1, T p2); + /** * @code * a = a > p ? 1.0f : 0.0f diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index b910bee836ed488aeb34f28d0503b5efba396583..6465deeec93100f0238ac850b92f7f7c5a60b795 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -60,10 +60,5 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op softmax_op net) - -op_library(recurrent_network_op - SRCS recurrent_network_op.cc - DEPS op_desc tensor net) -cc_test(recurrent_network_op_test - SRCS recurrent_network_op_test.cc - DEPS recurrent_network_op mul_op add_op) +op_library(recurrent_op SRCS recurrent_op.cc DEPS op_desc tensor op_registry operator net) +cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op) diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 3a43dbfbada87e458109d8ca22effdb4407b4c1d..85269a5f7445a1745d9be68417789e33eb725d5c 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -50,10 +50,6 @@ The equation is: Out = X + Y class AddOpGrad : public OperatorWithKernel { protected: void InferShape(const InferShapeContext &ctx) const override {} - std::string DebugString() const override { - LOG(INFO) << "AddOpGrad"; - return ""; - } }; } // namespace operators diff --git a/paddle/operators/add_op.cu b/paddle/operators/add_op.cu index 79d8de6cd46e1c72b14b0554c7be7b4eee281f4c..f961b37565f400b5c26844b9e7a3cff5e682340b 100644 --- a/paddle/operators/add_op.cu +++ b/paddle/operators/add_op.cu @@ -1,3 +1,4 @@ +#define EIGEN_USE_GPU #include "paddle/framework/op_registry.h" #include "paddle/operators/add_op.h" diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index d2b649fcbd1e5cac1c8cfcfd4e522e41135f7d1f..54d2231425293f6cfb3adc9cb34d903a75fcdcd0 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -28,10 +28,13 @@ public: output->mutable_data<T>(context.GetPlace()); - EigenVector<T>::Flatten(*output).device( - *(context.GetEigenDevice<Place>())) = - framework::EigenVector<T>::Flatten(*input0) + - framework::EigenVector<T>::Flatten(*input1); + auto X = EigenVector<T>::Flatten(*input0); + auto Y = EigenVector<T>::Flatten(*input1); + auto Z = EigenVector<T>::Flatten(*output); + + auto place = context.GetEigenDevice<Place>(); + + Z.device(place) = X + Y; } }; diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 19e4b74596a0f59edd04db830ec6f6f481373465..926a0c616b957d8e542c1f3dee227a718fb29f07 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -1,3 +1,4 @@ +#define EIGEN_USE_GPU #include "paddle/operators/cross_entropy_op.h" REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index fe34d6ad4015620cac520146850e10563d4c50e0..78131b26808b183ee107313374493ae870f1b641 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -33,13 +33,23 @@ public: MeanOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of mean op"); - AddOutput("Out", "The output of mean op"); + AddOutput("Out", "The output of mean op").IgnoreGradient(); AddComment("Mean Operator"); } }; +class MeanGradOp : public OperatorWithKernel { +protected: + void InferShape(const InferShapeContext &ctx) const override { + ctx.Output<Tensor>("X" + GRAD_VAR_SUFFIX()) + ->Resize(ctx.Input<Tensor>("X")->dims()); + } +}; + } // namespace operators } // namespace paddle REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker); REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel<ops::CPUPlace, float>); +REGISTER_GRADIENT_OP(mean, mean_grad, ops::MeanGradOp); +REGISTER_OP_CPU_KERNEL(mean_grad, ops::MeanGradKernel<ops::CPUPlace, float>); diff --git a/paddle/operators/mean_op.cu b/paddle/operators/mean_op.cu index 740157cbc57a64cafcf109186c630691620f542b..e15de2fd0dd84e4015ee0e3b5343d7651b027a88 100644 --- a/paddle/operators/mean_op.cu +++ b/paddle/operators/mean_op.cu @@ -3,3 +3,4 @@ #include "paddle/operators/mean_op.h" REGISTER_OP_GPU_KERNEL(mean, ops::MeanKernel<ops::GPUPlace, float>); +REGISTER_OP_GPU_KERNEL(mean_grad, ops::MeanGradKernel<ops::GPUPlace, float>); \ No newline at end of file diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h index 5f7d443751d1cdd7de3b67b0de2758ba1d566fb3..a89cb422f9b296dba6eb5358043f73d00aefc5d3 100644 --- a/paddle/operators/mean_op.h +++ b/paddle/operators/mean_op.h @@ -27,8 +27,28 @@ public: output->mutable_data<T>(context.GetPlace()); - EigenScalar<T>::From(*output).device(*(context.GetEigenDevice<Place>())) = - EigenVector<T>::Flatten(*input).mean(); + auto X = EigenVector<T>::Flatten(*input); + auto y = EigenScalar<T>::From(*output); + auto place = context.GetEigenDevice<Place>(); + + y.device(place) = X.mean(); + } +}; + +template <typename Place, typename T> +class MeanGradKernel : public OpKernel { +public: + void Compute(const ExecutionContext& context) const override { + auto OG = context.Input<Tensor>("Out" + OperatorBase::GRAD_VAR_SUFFIX()); + PADDLE_ENFORCE(framework::product(OG->dims()) == 1, + "Mean Gradient should be scalar"); + auto IG = context.Output<Tensor>("X" + OperatorBase::GRAD_VAR_SUFFIX()); + IG->mutable_data<T>(context.GetPlace()); + + T ig_size = (T)framework::product(IG->dims()); + + EigenVector<T>::Flatten(*IG).device(*(context.GetEigenDevice<Place>())) = + EigenScalar<T>::From(*OG) / ig_size; } }; diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu index c27fc886ce7238a13c8ef86bce673a2b54949a9d..dc9236701627dc9335b844d2a82e18eb1f7dfd42 100644 --- a/paddle/operators/mul_op.cu +++ b/paddle/operators/mul_op.cu @@ -12,6 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ +#define EIGEN_USE_GPU #include "paddle/operators/mul_op.h" REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<ops::GPUPlace, float>); \ No newline at end of file diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index eef72ab293e13a9d05ce0013be41ec4bb75d6077..c7b78ad39045d25d73bfc2c930063c255a514864 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -26,13 +26,18 @@ public: Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = { {Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}}; + auto input0 = context.Input<Tensor>("X"); + auto input1 = context.Input<Tensor>("Y"); auto output = context.Output<Tensor>(0); + output->mutable_data<T>(context.GetPlace()); - EigenMatrix<T>::From(*output).device(*(context.GetEigenDevice<Place>())) = - EigenMatrix<T>::From(*context.Input<Tensor>("X")) - .contract(EigenMatrix<T>::From(*context.Input<Tensor>("Y")), - dim_pair); + auto X = EigenMatrix<T>::From(*input0); + auto Y = EigenMatrix<T>::From(*input1); + auto Z = EigenMatrix<T>::From(*output); + auto place = context.GetEigenDevice<Place>(); + + Z.device(place) = X.contract(Y, dim_pair); } }; } // namespace operators diff --git a/paddle/operators/recurrent_network_op.cc b/paddle/operators/recurrent_op.cc similarity index 67% rename from paddle/operators/recurrent_network_op.cc rename to paddle/operators/recurrent_op.cc index 60d065fc4789f76370840328870165579aa73b67..e5b76e3724b5b0287071c90d26235b8e1a1d80cf 100644 --- a/paddle/operators/recurrent_network_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/recurrent_network_op.h" +#include "paddle/operators/recurrent_op.h" #include <glog/logging.h> #include <cstring> @@ -29,11 +29,15 @@ namespace rnn { void SegmentInputs(const std::vector<Scope*>& step_scopes, const std::vector<Link>& inlinks, - const size_t seq_len) { + const size_t seq_len, + bool infer_shape_mode) { PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided."); for (size_t i = 0; i < inlinks.size(); ++i) { - Tensor* input = - step_scopes[0]->FindVar(inlinks[i].external)->GetMutable<Tensor>(); + auto input_var = step_scopes[0]->FindVar(inlinks[i].external); + PADDLE_ENFORCE(input_var != nullptr, + "input link [%s] is not in scope.", + inlinks[i].external); + Tensor* input = input_var->GetMutable<Tensor>(); DDim dims = input->dims(); PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len, "all the inlinks must have same length"); @@ -41,7 +45,9 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes, for (size_t j = 0; j < seq_len; j++) { Tensor* step_input = step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>(); - *step_input = input->Slice<float>(j, j + 1); + if (!infer_shape_mode) { + *step_input = input->Slice<float>(j, j + 1); + } step_input->Resize(step_dims); } } @@ -49,36 +55,41 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes, void ConcatOutputs(const std::vector<Scope*>& step_scopes, const std::vector<Link>& outlinks, - const size_t seq_len) { + const size_t seq_len, + bool infer_shape_mode) { for (size_t i = 0; i < outlinks.size(); i++) { - Tensor* output = - step_scopes[0]->FindVar(outlinks[i].external)->GetMutable<Tensor>(); - - // TODO(qingiqng) remove following code after adding - // InferShape in RecurrentGradientOp - DDim step_dims = step_scopes[0] - ->FindVar(outlinks[i].internal) - ->GetMutable<Tensor>() - ->dims(); - std::vector<int> dims_vec = vectorize(step_dims); - dims_vec.insert(dims_vec.begin(), seq_len); - output->mutable_data<float>(make_ddim(dims_vec), platform::CPUPlace()); - - for (size_t j = 0; j < seq_len; j++) { - Tensor* step_output = - step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable<Tensor>(); - // TODO(luotao02) data type and platform::DeviceContext() should set - // correctly - (output->Slice<float>(j, j + 1)) - .CopyFrom<float>(*step_output, platform::CPUPlace()); + auto output_var = step_scopes[0]->FindVar(outlinks[i].external); + PADDLE_ENFORCE(output_var != nullptr, + "output link [%s] is not in scope.", + outlinks[i].external); + Tensor* output = output_var->GetMutable<Tensor>(); + if (infer_shape_mode) { + DDim step_dims = step_scopes[0] + ->FindVar(outlinks[i].internal) + ->GetMutable<Tensor>() + ->dims(); + std::vector<int> dims_vec = vectorize(step_dims); + dims_vec.insert(dims_vec.begin(), seq_len); + output->Resize(make_ddim(dims_vec)); + } else { + output->mutable_data<float>(platform::CPUPlace()); + for (size_t j = 0; j < seq_len; j++) { + Tensor* step_output = + step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable<Tensor>(); + // TODO(luotao02) data type and platform::DeviceContext() should set + // correctly + (output->Slice<float>(j, j + 1)) + .CopyFrom<float>(*step_output, platform::CPUPlace()); + } } } } void LinkMemories(const std::vector<Scope*>& scopes, const std::vector<rnn::MemoryAttr>& memories, - size_t step_id, - int offset) { + const size_t step_id, + const int offset, + bool infer_shape_mode) { PADDLE_ENFORCE(step_id < scopes.size(), "step [%d] is out of range of step scopes' size [%d]", step_id, @@ -95,18 +106,13 @@ void LinkMemories(const std::vector<Scope*>& scopes, auto scope = scopes[step_id]; auto linked_scope = scopes[step_id + offset]; for (auto& attr : memories) { - auto mem = scope->NewVar(attr.pre_var)->GetMutable<Tensor>(); - // maybe share variable is better? + auto mem = scope->FindVar(attr.pre_var)->GetMutable<Tensor>(); auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<Tensor>(); - mem->ShareDataWith<float>(*linked_mem); - - // TODO(qingqing) remove following code - // the memory of current step should be allocated in step net - auto m = scope->NewVar(attr.var)->GetMutable<Tensor>(); - // for unit test, as addOp and mulOp are null currently, if not - // mutable_data, mem.data() in output will be error. We will - // remove this line after merge the correct addOp and mulOp. - m->mutable_data<float>(mem->dims(), platform::CPUPlace()); + if (infer_shape_mode) { + mem->Resize(linked_mem->dims()); + } else { + mem->ShareDataWith<float>(*linked_mem); + } } } @@ -175,60 +181,39 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const { ->dims()[0]; CreateScopes(scope); auto step_scopes = GetStepScopes(scope); - - // SegmentInputs is called in InferShape. The input must hold memory in - // SegmentInputs. But the other op only set dimension for the output in - // InferShape. That's a problem. Wether the RNN op needs InferShape or not? - // Wether the following functions (SegmentInputs, InitMemories, ...) need - // to rewrite for RNN op? - rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); - - InitMemories(step_scopes[0]); - - PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr, - "stepnet [%s] is not in scope.", - arg_->step_net); + rnn::SegmentInputs( + step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/); + InitMemories(step_scopes[0], true /*infer_shape_mode*/); Variable* net = scope.FindVar(arg_->step_net); PADDLE_ENFORCE(net != nullptr, "failed to get step net"); - // If the InferShape is called in OperatorBase's run function, - // the rnn op only needs to do InferShape for the first time step for (size_t i = 0; i < seq_len_; i++) { if (i > 0) { - rnn::LinkMemories(step_scopes, arg_->memories, i, -1); + rnn::LinkMemories( + step_scopes, arg_->memories, i, -1, true /*infer_shape_mode*/); } net->GetMutable<NetOp>()->InferShape(*step_scopes[i]); } - - auto outlinks = arg_->outlinks; - for (size_t i = 0; i < outlinks.size(); i++) { - DDim step_dims = step_scopes[0] - ->FindVar(outlinks[i].internal) - ->GetMutable<Tensor>() - ->dims(); - std::vector<int> dims_vec = vectorize(step_dims); - // now only support fixed length - dims_vec.insert(dims_vec.begin(), seq_len_); - Tensor* output = - step_scopes[0]->FindVar(outlinks[i].external)->GetMutable<Tensor>(); - output->Resize(make_ddim(dims_vec)); - } + rnn::ConcatOutputs( + step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/); } void RecurrentAlgorithm::Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const { auto step_scopes = GetStepScopes(scope); - + rnn::SegmentInputs( + step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/); + InitMemories(step_scopes[0], false /*infer_shape_mode*/); Variable* net = scope.FindVar(arg_->step_net); + for (size_t step_id = 0; step_id < seq_len_; step_id++) { - // the link memory is done in InferShape - // maybe remove following code after testing if (step_id > 0) { - rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1); + rnn::LinkMemories( + step_scopes, arg_->memories, step_id, -1, false /*infer_shape_mode*/); } net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx); } - - rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); + rnn::ConcatOutputs( + step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/); } void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { @@ -244,18 +229,19 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { // Now all variables in scope must be created outside of op. auto net_op = scope.FindVar(arg_->step_net)->GetMutable<NetOp>(); for (auto& input : net_op->inputs_) { + // the weight are located in parent scope if (!step_scope.FindVar(input)) step_scope.NewVar(input); } for (auto& output : net_op->outputs_) { step_scope.NewVar(output); } - step_scopes->emplace_back(&step_scope); } } } -void RecurrentAlgorithm::InitMemories(Scope* step_scope) const { +void RecurrentAlgorithm::InitMemories(Scope* step_scope, + bool infer_shape_mode) const { for (auto& attr : arg_->memories) { Tensor* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable<Tensor>(); PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr, @@ -263,13 +249,11 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope) const { attr.var, attr.boot_var); Tensor* boot_mem = step_scope->FindVar(attr.boot_var)->GetMutable<Tensor>(); - pre_mem->ShareDataWith<float>(*boot_mem); - - // TODO(qingqing) remove following code - // the memory of current step should be allocated in step net - // here for unit test - auto cur_step_mem = step_scope->NewVar(attr.var)->GetMutable<Tensor>(); - cur_step_mem->mutable_data<float>(boot_mem->dims(), platform::CPUPlace()); + if (infer_shape_mode) { + pre_mem->Resize(boot_mem->dims()); + } else { + pre_mem->ShareDataWith<float>(*boot_mem); + } } } @@ -307,13 +291,14 @@ public: : OpProtoAndCheckerMaker(proto, op_checker) { const auto& name = RecurrentOp::kArgName; // inputs and outputs stored in proto - AddInput(name.inlinks, "the input that need to be segmented for each step.") + AddInput(name.inlinks, + "the inputs that need to be segmented for each step.") .SetMultiple(); AddInput(name.boot_memories, "variables to initialize memories.") .SetMultiple(); AddInput(name.step_net, "network shared by all steps."); - AddOutput(name.outlinks, "the output that need to concated for all steps.") + AddOutput(name.outlinks, "the outputs that need to concated for all steps.") .SetMultiple(); AddOutput(name.step_scopes, "step scopes"); @@ -331,34 +316,39 @@ public: void RecurrentGradientAlgorithm::Run( const Scope& scope, const platform::DeviceContext& dev_ctx) const { auto step_scopes = GetStepScopes(scope); - rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); - PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr, - "step net is not in scope."); + rnn::SegmentInputs( + step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/); Variable* net = scope.FindVar(arg_->step_net); PADDLE_ENFORCE(net != nullptr, "failed to get step net"); for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { if (static_cast<size_t>(step_id) != seq_len_ - 1) { - rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1); + rnn::LinkMemories( + step_scopes, arg_->memories, step_id, 1, false /*infer_shape_mode*/); } net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx); } - LinkBootMemoryGradients(step_scopes[0]); - rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); + LinkBootMemoryGradients(step_scopes[0], false); + rnn::ConcatOutputs( + step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/); } void RecurrentGradientAlgorithm::LinkBootMemoryGradients( - Scope* step_scope) const { + Scope* step_scope, bool infer_shape_mode) const { for (auto& attr : arg_->memories) { - Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable<Tensor>(); - PADDLE_ENFORCE(mem_grad != nullptr, - "boot_tensor should be retrieved before"); + PADDLE_ENFORCE(step_scope->FindVar(attr.var) != nullptr, + "memory variable [%s] does not exists", + attr.var); PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr, - "memory [%s]'s boot variable [%s] not exists", - attr.var, + "boot variable [%s] does not exists", attr.boot_var); + Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable<Tensor>(); Tensor* boot_mem_grad = step_scope->NewVar(attr.boot_var)->GetMutable<Tensor>(); - boot_mem_grad->ShareDataWith<float>(*mem_grad); + if (infer_shape_mode) { + boot_mem_grad->Resize(mem_grad->dims()); + } else { + boot_mem_grad->ShareDataWith<float>(*mem_grad); + } } } @@ -367,34 +357,20 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { ->GetMutable<Tensor>() ->dims()[0]; auto step_scopes = GetStepScopes(scope); - rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); - - PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr, - "step net is not in scope."); + rnn::SegmentInputs( + step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/); Variable* net = scope.FindVar(arg_->step_net); PADDLE_ENFORCE(net != nullptr, "failed to get step net"); - for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { if (static_cast<size_t>(step_id) != seq_len_ - 1) { - rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1); + rnn::LinkMemories( + step_scopes, arg_->memories, step_id, 1, true /*infer_shape_mode*/); } net->GetMutable<NetOp>()->InferShape(*step_scopes[step_id]); } - - auto outlinks = arg_->outlinks; - for (size_t i = 0; i < outlinks.size(); i++) { - DDim step_dims = step_scopes[0] - ->FindVar(outlinks[i].internal) - ->GetMutable<Tensor>() - ->dims(); - std::vector<int> dims_vec = vectorize(step_dims); - // now only support fixed length - dims_vec.insert(dims_vec.begin(), seq_len_); - Tensor* output = - step_scopes[0]->FindVar(outlinks[i].external)->GetMutable<Tensor>(); - output->Resize(make_ddim(dims_vec)); - } - LinkBootMemoryGradients(step_scopes[0]); + rnn::ConcatOutputs( + step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/); + LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/); } void RecurrentGradientOp::Init() { diff --git a/paddle/operators/recurrent_network_op.h b/paddle/operators/recurrent_op.h similarity index 92% rename from paddle/operators/recurrent_network_op.h rename to paddle/operators/recurrent_op.h index d57a1a2e51cbed22549ab6ebce79223e2d4e3bcf..2a0964fff326500b6215dd4afac63c75d64c4a06 100644 --- a/paddle/operators/recurrent_network_op.h +++ b/paddle/operators/recurrent_op.h @@ -72,19 +72,22 @@ struct ArgumentName { */ void SegmentInputs(const std::vector<Scope*>& step_scopes, const std::vector<Link>& inlinks, - const size_t seq_len); + const size_t seq_len, + bool infer_shape_mode); /** * Process outputs of step nets and merge to variables. */ void ConcatOutputs(const std::vector<Scope*>& step_scopes, const std::vector<Link>& outlinks, - const size_t seq_len); + const size_t seq_len, + bool infer_shape_mode); void LinkMemories(const std::vector<Scope*>& step_scopes, const std::vector<MemoryAttr>& memories, - size_t step_id, - int offset); + const size_t step_id, + const int offset, + bool infer_shape_mode); void InitArgument(const ArgumentName& name, Argument* arg); @@ -122,7 +125,7 @@ protected: return *scope.FindVar(arg_->step_scopes)->GetMutable<std::vector<Scope*>>(); } - void InitMemories(Scope* step_scopes) const; + void InitMemories(Scope* step_scopes, bool infer_shape_mode) const; private: std::unique_ptr<rnn::Argument> arg_; @@ -145,7 +148,7 @@ public: void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const; - void LinkBootMemoryGradients(Scope* step_scopes) const; + void LinkBootMemoryGradients(Scope* step_scopes, bool infer_shape_mode) const; /** * InferShape must be called before Run. diff --git a/paddle/operators/recurrent_network_op_test.cc b/paddle/operators/recurrent_op_test.cc similarity index 90% rename from paddle/operators/recurrent_network_op_test.cc rename to paddle/operators/recurrent_op_test.cc index b0e61fbee611744adb85b498b1c3540f059afc8c..91f2972ca49953fd7a627289fa37db32916d85cd 100644 --- a/paddle/operators/recurrent_network_op_test.cc +++ b/paddle/operators/recurrent_op_test.cc @@ -18,7 +18,7 @@ #include "paddle/framework/op_registry.h" #include "paddle/framework/operator.h" #include "paddle/framework/tensor.h" -#include "paddle/operators/recurrent_network_op.h" +#include "paddle/operators/recurrent_op.h" namespace paddle { namespace operators { @@ -55,7 +55,7 @@ protected: w->GetMutable<Tensor>()->mutable_data<float>( make_ddim(std::vector<int>{30, 30}), platform::CPUPlace()); - for (auto boot : std::vector<std::string>{"x_boot", "h_boot"}) { + for (auto boot : std::vector<std::string>{"h_boot"}) { LOG(INFO) << "create global variable " << boot; Variable* h_boot = scope_.NewVar(boot); h_boot->GetMutable<Tensor>()->mutable_data<float>( @@ -79,7 +79,6 @@ protected: op_desc.add_inputs("x0"); op_desc.add_inputs("x1"); // boot_memories 3 - op_desc.add_inputs("x_boot"); op_desc.add_inputs("h_boot"); // step net 5 op_desc.add_inputs("step_net"); @@ -91,7 +90,7 @@ protected: auto _input_format = std::vector<int>{ 0, // in_link 3, // memories - 5 // step_net + 4 // step_net }; auto input_format = op_desc.add_attrs(); input_format->set_name("input_format"); @@ -129,12 +128,11 @@ protected: inlink_alias->add_strings(item); } // pre memories - for (const auto& item : - std::vector<std::string>{"rnn/x@pre", "rnn/h@pre"}) { + for (const auto& item : std::vector<std::string>{"rnn/h@pre"}) { pre_memories->add_strings(item); } // memories - for (const auto& item : std::vector<std::string>{"rnn/x", "rnn/h"}) { + for (const auto& item : std::vector<std::string>{"rnn/h"}) { memories->add_strings(item); } // output alias @@ -151,14 +149,11 @@ protected: LOG(INFO) << "create variable step_net"; Variable* var = scope_.NewVar("step_net"); auto net = var->GetMutable<NetOp>(); - // rnn/s is net's input or output? - net->inputs_ = {"rnn/h@pre", "rnn/w", "rnn/x"}; - net->inputs_ = {"rnn/s", "rnn/h"}; net->AddOp( OpRegistry::CreateOp("mul", {"rnn/h@pre", "rnn/w"}, {"rnn/s"}, {})); net->AddOp( - OpRegistry::CreateOp("add_two", {"rnn/x", "rnn/s"}, {"rnn/h"}, {})); + OpRegistry::CreateOp("add_two", {"x@alias", "rnn/s"}, {"rnn/h"}, {})); net->CompleteAddOp(); } @@ -297,7 +292,10 @@ protected: inlink.internal = "rnn/x"; auto step_scopes = scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>(); - rnn::SegmentInputs(*step_scopes, std::vector<rnn::Link>{inlink}, 10); + rnn::SegmentInputs(*step_scopes, + std::vector<rnn::Link>{inlink}, + 10, + true /*infer_shape_mode*/); } void LinkeMemories() { @@ -311,7 +309,8 @@ protected: auto step_scopes = scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>(); for (int i = 1; i < 10; ++i) { - rnn::LinkMemories(*step_scopes, memories, i, -1); + rnn::LinkMemories( + *step_scopes, memories, i, -1, true /*infer_shape_mode*/); } } @@ -333,14 +332,14 @@ TEST(RecurrentOp, LinkMemories) { using namespace paddle::operators; // create and init step scopes - int len = 10; + size_t len = 10; std::vector<Scope*> step_scopes; - for (int i = 0; i < len; ++i) { + for (size_t i = 0; i < len; ++i) { auto scope = new Scope(); scope->NewVar("pre_h"); auto tensor = scope->NewVar("h")->GetMutable<Tensor>(); float* data = tensor->mutable_data<float>({15, 20}, CPUPlace()); - for (int j = 0; j < 15 * 20; ++j) { + for (size_t j = 0; j < 15 * 20; ++j) { data[j] = rand() * (1. / (double)RAND_MAX); } step_scopes.push_back(scope); @@ -354,24 +353,24 @@ TEST(RecurrentOp, LinkMemories) { std::vector<rnn::MemoryAttr> memories; memories.push_back(mem_attr); - for (int i = 1; i < len; ++i) { - rnn::LinkMemories(step_scopes, memories, i, -1); + for (size_t i = 1; i < len; ++i) { + rnn::LinkMemories(step_scopes, memories, i, -1, false /*infer_shape_mode*/); } // check - for (int i = 0; i < len - 1; ++i) { + for (size_t i = 0; i < len - 1; ++i) { const float* a = step_scopes[i]->FindVar("h")->GetMutable<Tensor>()->data<float>(); const float* b = step_scopes[i + 1] ->FindVar("pre_h") ->GetMutable<Tensor>() ->data<float>(); - for (size_t i = 0; i < 15 * 20; ++i) { - ASSERT_FLOAT_EQ(a[i], b[i]); + for (size_t j = 0; j < 15 * 20; ++j) { + ASSERT_FLOAT_EQ(a[j], b[j]); } } for (int i = len - 2; i >= 0; --i) { - rnn::LinkMemories(step_scopes, memories, i, 1); + rnn::LinkMemories(step_scopes, memories, i, 1, false /*infer_shape_mode*/); } // check for (int i = len - 2; i >= 0; --i) { @@ -379,8 +378,8 @@ TEST(RecurrentOp, LinkMemories) { step_scopes[i]->FindVar("pre_h")->GetMutable<Tensor>()->data<float>(); const float* b = step_scopes[i + 1]->FindVar("h")->GetMutable<Tensor>()->data<float>(); - for (size_t i = 0; i < 15 * 20; ++i) { - ASSERT_FLOAT_EQ(a[i], b[i]); + for (size_t j = 0; j < 15 * 20; ++j) { + ASSERT_FLOAT_EQ(a[j], b[j]); } } @@ -391,9 +390,3 @@ TEST(RecurrentOp, LinkMemories) { USE_OP(add_two); USE_OP(mul); - -// int main() { -// //! TODO(yuyang18): Temporary disable this unit-test because implementation -// //! error. -// return 0; -//} \ No newline at end of file diff --git a/paddle/operators/rowwise_add_op.cu b/paddle/operators/rowwise_add_op.cu index 4b33e38ebabe853e179fe70ef7fde0a80b9050e2..82338ceccc06653791b26472e18d804f62735649 100644 --- a/paddle/operators/rowwise_add_op.cu +++ b/paddle/operators/rowwise_add_op.cu @@ -1,3 +1,4 @@ +#define EIGEN_USE_GPU #include "paddle/operators/rowwise_add_op.h" REGISTER_OP_GPU_KERNEL(rowwise_add, diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index b86dd5463436bf521f9939b1c421b39f11102769..bd4d1128955fb718d3a84dfd96d8c68d7196e9cc 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -33,7 +33,7 @@ public: const int rest_size = input.size() / bias_size; Eigen::DSizes<int, 1> one_d(input.size()); Eigen::DSizes<int, 1> bcast(rest_size); - output.reshape(one_d).device(*(context.GetEigenDevice<Place>())) = + output.reshape(one_d).device(context.GetEigenDevice<Place>()) = input.reshape(one_d) + bias.broadcast(bcast).reshape(one_d); } }; diff --git a/paddle/operators/sgd_op.cu b/paddle/operators/sgd_op.cu index f8f5b90cab460b4457cfb0a88bfc012bafe0fbc2..d79258cbf13c699cfb2afaee229cf96a3e377b5e 100644 --- a/paddle/operators/sgd_op.cu +++ b/paddle/operators/sgd_op.cu @@ -1,3 +1,4 @@ +#define EIGEN_USE_GPU #include "paddle/operators/sgd_op.h" REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel<ops::GPUPlace, float>); \ No newline at end of file diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index af1dfdd756ceb9991bee6b85c3281c05f0fb5a9f..0c3a240f9a4a5fc7bc4898e82786810cee2f7010 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -29,8 +29,12 @@ public: param_out->mutable_data<T>(ctx.GetPlace()); - EigenVector<T>::Flatten(*param_out).device(*(ctx.GetEigenDevice<Place>())) = - EigenVector<T>::Flatten(*param) - lr * EigenVector<T>::Flatten(*grad); + auto p = EigenVector<T>::Flatten(*param); + auto g = EigenVector<T>::Flatten(*grad); + auto o = EigenVector<T>::Flatten(*param_out); + auto place = ctx.GetEigenDevice<Place>(); + + o.device(place) = p - lr * g; } }; diff --git a/paddle/operators/sigmoid_op.cu b/paddle/operators/sigmoid_op.cu index f679b20418f04eff4310efe4e121963ce5a235e0..c9d11a2e1f9dcc563765c9e8cc1bae6beff57f18 100644 --- a/paddle/operators/sigmoid_op.cu +++ b/paddle/operators/sigmoid_op.cu @@ -1,3 +1,4 @@ +#define EIGEN_USE_GPU #include "paddle/operators/sigmoid_op.h" REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::GPUPlace, float>); diff --git a/paddle/operators/sigmoid_op.h b/paddle/operators/sigmoid_op.h index 3dd23a9ebc7ac0972d6ee07b9ac051d59e66f62f..1412e4398440c8e946d3ab434a50e978079637ab 100644 --- a/paddle/operators/sigmoid_op.h +++ b/paddle/operators/sigmoid_op.h @@ -27,9 +27,11 @@ public: auto output = context.Output<Tensor>(0); output->mutable_data<T>(context.GetPlace()); - EigenVector<T>::Flatten(*output).device( - *(context.GetEigenDevice<Place>())) = - 1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(*input)).exp()); + auto X = EigenVector<T>::Flatten(*input); + auto Y = EigenVector<T>::Flatten(*output); + auto place = context.GetEigenDevice<Place>(); + + Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp()); } }; } // namespace operators diff --git a/paddle/operators/softmax_op.cu b/paddle/operators/softmax_op.cu index a1f6944a369fe5148ffcfeabf3bf7063dcbc2664..ddf8f6e913ccf450185f377f531bf978f69ed1fc 100644 --- a/paddle/operators/softmax_op.cu +++ b/paddle/operators/softmax_op.cu @@ -1,3 +1,4 @@ +#define EIGEN_USE_GPU #include "paddle/framework/op_registry.h" #include "paddle/operators/softmax_op.h" diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index a5c19c5fc7c6f5909dbb355aff09bf15405b6957..75c5197697dada58e09f4cda41cea13af56e79a3 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -46,9 +46,9 @@ public: .reshape(batch_by_one) .broadcast(one_by_class)); - softmax.device(*(context.GetEigenDevice<Place>())) = shifted_logits.exp(); + softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp(); - softmax.device(*(context.GetEigenDevice<Place>())) = + softmax.device(context.GetEigenDevice<Place>()) = (softmax * softmax.sum(along_class) .inverse() diff --git a/paddle/operators/type_alias.h b/paddle/operators/type_alias.h index 93b62cddc819e0d1fd48323e474a294ff0d327e1..9049ffda1da5408411687474c5ed0c76c2394623 100644 --- a/paddle/operators/type_alias.h +++ b/paddle/operators/type_alias.h @@ -51,6 +51,7 @@ using CPUPlace = platform::CPUPlace; using GPUPlace = platform::GPUPlace; using NetOp = framework::NetOp; using OpRegistry = framework::OpRegistry; +using OperatorBase = framework::OperatorBase; } // namespace operators } // namespace paddle diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 26c8eb78e614a68ec9728aad727d8fe3e08547ae..60a42c777d1c2ebbc22fdb77b1100cc6fcf7ff35 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -144,12 +144,12 @@ inline void throw_on_error(T e) { throw_on_error(e, ""); } -#define PADDLE_THROW(...) \ - do { \ - throw ::paddle::platform::EnforceNotMet( \ - std::make_exception_ptr( \ - std::runtime_error(string::Sprintf(__VA_ARGS__))), \ - __FILE__, __LINE__); \ +#define PADDLE_THROW(...) \ + do { \ + throw ::paddle::platform::EnforceNotMet( \ + std::make_exception_ptr( \ + std::runtime_error(paddle::string::Sprintf(__VA_ARGS__))), \ + __FILE__, __LINE__); \ } while (0) #define PADDLE_ENFORCE(...) \ diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..29dd0ded0ac75893da7e244d92725cd5e285efce --- /dev/null +++ b/paddle/pybind/CMakeLists.txt @@ -0,0 +1,9 @@ +cc_library(paddle_pybind SHARED + SRCS pybind.cc + DEPS pybind python backward + fc_op + sgd_op + add_op + mean_op + cross_entropy_op + recurrent_op) diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 3860facb099950a5287d3f6b89c3de38f588f568..69ae0ea2d72c199a8e17c0595693e5e0b2f79ee1 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -148,7 +148,7 @@ cat >> /paddle/build/Dockerfile <<EOF ADD *.deb / # run paddle version to install python packages first RUN apt-get update &&\ - apt-get install -y python-pip && pip install -U pip && \ + apt-get install -y wget python-pip && pip install -U pip && \ dpkg -i /*.deb ; apt-get install -f -y && \ apt-get clean -y && \ rm -f /*.deb && \ diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 3bee5b572ae42750332b69e28af980ae325532da..b50b73c7e169f3e8ae75322d9a0a3cad5072a9c7 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -298,6 +298,11 @@ message DetectionOutputConfig { optional uint32 width = 9 [default = 1]; } +message ClipConfig { + required double min = 1; + required double max = 2; +} + message LayerInputConfig { required string input_layer_name = 1; optional string input_parameter_name = 2; @@ -318,6 +323,7 @@ message LayerInputConfig { optional RowConvConfig row_conv_conf = 15; optional MultiBoxLossConfig multibox_loss_conf = 16; optional DetectionOutputConfig detection_output_conf = 17; + optional ClipConfig clip_conf = 18; } message LayerConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index f71fefffb59d4a53dda092ff83a61d9eec4b601f..9ea69fc5e57636c22fb20d5d97de760b9cc3bcde 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -2198,6 +2198,20 @@ class RowConvLayer(LayerBase): self.create_input_parameter(0, psize, dims) +@config_layer('clip') +class ClipLayer(LayerBase): + def __init__(self, name, inputs, min, max, **xargs): + super(ClipLayer, self).__init__(name, 'clip', 0, inputs=inputs, **xargs) + config_assert( + len(self.inputs) == 1, + 'ClipLayer must have one and only one input.') + config_assert(min < max, 'min must be less than max.') + input_layer = self.get_input_layer(0) + self.set_layer_size(input_layer.size) + self.config.inputs[0].clip_conf.min = min + self.config.inputs[0].clip_conf.max = max + + # key: cost type # value: cost class g_cost_map = {} @@ -2754,6 +2768,16 @@ class SumToOneNormLayer(LayerBase): self.set_layer_size(input_layer0.size) +@config_layer('row_l2_norm') +class RowL2NormLayer(LayerBase): + def __init__(self, name, inputs, **xargs): + super(RowL2NormLayer, self).__init__( + name, 'row_l2_norm', 0, inputs=inputs, **xargs) + config_assert(len(self.inputs) == 1, 'RowL2NormLayer must have 1 input') + input_layer = self.get_input_layer(0) + self.set_layer_size(input_layer.size) + + @config_layer('cos_vm') class CosSimVecMatLayer(LayerBase): def __init__(self, name, size, inputs, cos_scale=1.0, device=None): diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 965874ddf632a83d00065c2d40037930a6e604a8..ea5fdcc50f6abbc67fb61b7fd56c100d9f9811d0 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -76,6 +76,7 @@ __all__ = [ 'trans_layer', 'rotate_layer', 'sum_to_one_norm_layer', + 'row_l2_norm_layer', 'get_output_layer', 'LayerType', 'context_projection', @@ -128,6 +129,7 @@ __all__ = [ 'prelu_layer', 'gated_unit_layer', 'crop_layer', + 'clip_layer', 'slice_projection', ] @@ -160,6 +162,7 @@ class LayerType(object): BATCH_NORM_LAYER = 'batch_norm' NORM_LAYER = 'norm' SUM_TO_ONE_NORM_LAYER = 'sum_to_one_norm' + ROW_L2_NORM_LAYER = 'row_l2_norm' ADDTO_LAYER = 'addto' CONCAT_LAYER = 'concat' @@ -221,6 +224,7 @@ class LayerType(object): PRELU = 'prelu' CROP_LAYER = 'crop' + CLIP_LAYER = 'clip' @staticmethod def is_layer_type(type_name): @@ -2889,6 +2893,42 @@ def sum_to_one_norm_layer(input, name=None, layer_attr=None): name, LayerType.SUM_TO_ONE_NORM_LAYER, parents=[input], size=input.size) +@wrap_name_default() +@layer_support() +def row_l2_norm_layer(input, name=None, layer_attr=None): + """ + A layer for L2-normalization in each row. + + .. math:: + out[i] = \frac{in[i]}{\sqrt{\sum_{k=1}^N in[k]^{2}}} + + where the size of :math:`in` is (batchSize x dataDim) , + and the size of :math:`out` is a (batchSize x dataDim) . + + The example usage is: + + .. code-block:: python + + row_l2_norm_layer = row_l2_norm_layer(input=layer) + + :param input: Input layer. + :type input: LayerOutput + :param name: Layer name. + :type name: basestring + :param layer_attr: extra layer attributes. + :type layer_attr: ExtraLayerAttribute. + :return: LayerOutput object. + :rtype: LayerOutput + """ + Layer( + name=name, + type=LayerType.ROW_L2_NORM_LAYER, + inputs=[input.name], + **ExtraAttr.to_kwargs(layer_attr)) + return LayerOutput( + name, LayerType.ROW_L2_NORM_LAYER, parents=[input], size=input.size) + + @wrap_name_default("addto") @wrap_act_default(act=LinearActivation()) @wrap_bias_attr_default(has_bias=False) @@ -6046,3 +6086,36 @@ def crop_layer(input, offset, axis=2, shape=None, name=None, layer_attr=None): layer_type=LayerType.CROP_LAYER, parents=input, size=l.config.size) + + +@wrap_name_default("clip") +def clip_layer(input, min, max, name=None): + """ + A layer for clipping the input value by the threshold. + + .. math:: + + out[i] = \min\left(\max\left(in[i],p_{1}\right),p_{2}\right) + + .. code-block:: python + + clip = clip_layer(input=input_layer, min=-10, max=10) + + :param name: The Layer Name. + :type name: basestring + :param input: The input layer. + :type input: LayerOutput. + :param min: The lower threshold for clipping. + :type min: double + :param max: The upper threshold for clipping. + :type max: double + :return: LayerOutput + """ + Layer( + name=name, + type=LayerType.CLIP_LAYER, + inputs=[input.name], + min=min, + max=max) + return LayerOutput( + name, LayerType.CLIP_LAYER, parents=[input], size=input.size) diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index cdf9b2eab733adb173cf33cd6a93ef7b5abefc50..0ffa58bc1e2088f75e7cd25c7ecdffbe270825a4 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -7,6 +7,6 @@ test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer -test_recursive_topology test_gated_unit_layer) +test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer) export whole_configs=(test_split_datasource) diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_clip_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_clip_layer.protostr new file mode 100644 index 0000000000000000000000000000000000000000..4b9578a0c050ef74f186485fec3f6c1f7a0f0814 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_clip_layer.protostr @@ -0,0 +1,31 @@ +type: "nn" +layers { + name: "input" + type: "data" + size: 300 + active_type: "" +} +layers { + name: "__clip_0__" + type: "clip" + size: 300 + active_type: "" + inputs { + input_layer_name: "input" + clip_conf { + min: -10 + max: 10 + } + } +} +input_layer_names: "input" +output_layer_names: "__clip_0__" +sub_models { + name: "root" + layer_names: "input" + layer_names: "__clip_0__" + input_layer_names: "input" + output_layer_names: "__clip_0__" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_row_l2_norm_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_row_l2_norm_layer.protostr new file mode 100644 index 0000000000000000000000000000000000000000..c2786ff55c7023d856d739face5e747cc5fee870 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_row_l2_norm_layer.protostr @@ -0,0 +1,27 @@ +type: "nn" +layers { + name: "input" + type: "data" + size: 300 + active_type: "" +} +layers { + name: "__row_l2_norm_layer_0__" + type: "row_l2_norm" + size: 300 + active_type: "" + inputs { + input_layer_name: "input" + } +} +input_layer_names: "input" +output_layer_names: "__row_l2_norm_layer_0__" +sub_models { + name: "root" + layer_names: "input" + layer_names: "__row_l2_norm_layer_0__" + input_layer_names: "input" + output_layer_names: "__row_l2_norm_layer_0__" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_clip_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_clip_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..f066fe1fb30877bf40bb6299d35546f7427989a5 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_clip_layer.py @@ -0,0 +1,6 @@ +from paddle.trainer_config_helpers import * + +data = data_layer(name='input', size=300) +clip = clip_layer(input=data, min=-10, max=10) + +outputs(clip) diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_row_l2_norm_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_row_l2_norm_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..ac8badb26a40e96e75225e6f61aa536cd28e9098 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_row_l2_norm_layer.py @@ -0,0 +1,6 @@ +from paddle.trainer_config_helpers import * + +data = data_layer(name='input', size=300) +row_l2_norm = row_l2_norm_layer(input=data) + +outputs(row_l2_norm) diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 540636a0e8100fbf97231bd548dbc1176b07daca..4619b0edc3dd7e253e01f7fee5e6a8641340d291 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -8,7 +8,6 @@ add_python_test(test_framework test_fc_op.py test_add_two_op.py test_sgd_op.py - test_cross_entropy_op.py test_mul_op.py test_mean_op.py test_sigmoid_op.py diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index 99085c367221150c8386a24e8d90d58fd63894c4..98fae1b975ad6243b20e5c19ec6ff68d5536cd74 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -26,40 +26,45 @@ class OpTestMeta(type): scope = core.Scope() kwargs = dict() + places = [] + places.append(core.CPUPlace()) + if core.is_compile_gpu(): + places.append(core.GPUPlace(0)) - for in_name in func.all_input_args: - if hasattr(self, in_name): - kwargs[in_name] = in_name - var = scope.new_var(in_name).get_tensor() - arr = getattr(self, in_name) - var.set_dims(arr.shape) - var.set(arr) - else: - kwargs[in_name] = "@EMPTY@" + for place in places: + for in_name in func.all_input_args: + if hasattr(self, in_name): + kwargs[in_name] = in_name + var = scope.new_var(in_name).get_tensor() + arr = getattr(self, in_name) + var.set_dims(arr.shape) + var.set(arr, place) + else: + kwargs[in_name] = "@EMPTY@" - for out_name in func.all_output_args: - if hasattr(self, out_name): - kwargs[out_name] = out_name - scope.new_var(out_name).get_tensor() + for out_name in func.all_output_args: + if hasattr(self, out_name): + kwargs[out_name] = out_name + scope.new_var(out_name).get_tensor() - for attr_name in func.all_attr_args: - if hasattr(self, attr_name): - kwargs[attr_name] = getattr(self, attr_name) + for attr_name in func.all_attr_args: + if hasattr(self, attr_name): + kwargs[attr_name] = getattr(self, attr_name) - op = func(**kwargs) + op = func(**kwargs) - op.infer_shape(scope) + op.infer_shape(scope) - ctx = core.DeviceContext.cpu_context() - op.run(scope, ctx) + ctx = core.DeviceContext.create(place) + op.run(scope, ctx) - for out_name in func.all_output_args: - actual = numpy.array(scope.find_var(out_name).get_tensor()) - expect = getattr(self, out_name) - # TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul - # has some diff, and could not pass unittest. So I set decimal 3 here. - # And I will check this in future. - numpy.testing.assert_almost_equal(actual, expect, decimal=3) + for out_name in func.all_output_args: + actual = numpy.array(scope.find_var(out_name).get_tensor()) + expect = getattr(self, out_name) + # TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul + # has some diff, and could not pass unittest. So I set decimal 3 here. + # And I will check this in future. + numpy.testing.assert_almost_equal(actual, expect, decimal=3) obj.test_all = test_all return obj diff --git a/python/paddle/v2/framework/tests/test_add_two_op.py b/python/paddle/v2/framework/tests/test_add_two_op.py index a06d7a78ecf838a49e5f2808d3686c6b92faa8ce..6e6643201bf361fce1bad7de10b2562f0525e00a 100644 --- a/python/paddle/v2/framework/tests/test_add_two_op.py +++ b/python/paddle/v2/framework/tests/test_add_two_op.py @@ -1,6 +1,10 @@ import unittest -from op_test_util import OpTestMeta + import numpy +import paddle.v2.framework.core as core +import paddle.v2.framework.create_op_creation_methods as creation + +from op_test_util import OpTestMeta class TestAddOp(unittest.TestCase): @@ -8,10 +12,19 @@ class TestAddOp(unittest.TestCase): def setUp(self): self.type = "add_two" - self.X = numpy.random.random((342, 345)).astype("float32") - self.Y = numpy.random.random((342, 345)).astype("float32") + self.X = numpy.random.random((102, 105)).astype("float32") + self.Y = numpy.random.random((102, 105)).astype("float32") self.Out = self.X + self.Y +class TestAddGradOp(unittest.TestCase): + def test_add_grad(self): + op = creation.op_creations.add_two(X="X", Y="Y", Out="Out") + backward_op = core.Operator.backward(op, set()) + self.assertEqual(backward_op.type(), "add_two_grad") + expected = '''Op(add_two_grad), inputs:(X, Y, Out, Out@GRAD), outputs:(X@GRAD, Y@GRAD).''' + self.assertEqual(expected, str(backward_op)) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_fc_op.py b/python/paddle/v2/framework/tests/test_fc_op.py index 43931aac406cd93beede008066aa1c0c00eba6ea..00dc4399aaf59e6382692c3a4356f89a7e79a0c5 100644 --- a/python/paddle/v2/framework/tests/test_fc_op.py +++ b/python/paddle/v2/framework/tests/test_fc_op.py @@ -7,17 +7,19 @@ import paddle.v2.framework.create_op_creation_methods as creation class TestFc(unittest.TestCase): def test_fc(self): scope = core.Scope() + place = core.CPUPlace() x = scope.new_var("X") + x_tensor = x.get_tensor() x_tensor.set_dims([1000, 784]) - x_tensor.alloc_float() + x_tensor.alloc_float(place) w = scope.new_var("W") w_tensor = w.get_tensor() w_tensor.set_dims([784, 100]) - w_tensor.alloc_float() + w_tensor.alloc_float(place) - w_tensor.set(numpy.random.random((784, 100)).astype("float32")) + w_tensor.set(numpy.random.random((784, 100)).astype("float32"), place) # Set a real numpy array here. # x_tensor.set(numpy.array([])) @@ -32,7 +34,7 @@ class TestFc(unittest.TestCase): op.infer_shape(scope) self.assertEqual([1000, 100], tensor.shape()) - ctx = core.DeviceContext.cpu_context() + ctx = core.DeviceContext.create(place) op.run(scope, ctx) diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py index 0a87e66cd03af1bf84be8ffe111e4a8c3a24d6dc..e1ac66d3a4d23d617f7c5a4d97d070b2660954c8 100644 --- a/python/paddle/v2/framework/tests/test_mul_op.py +++ b/python/paddle/v2/framework/tests/test_mul_op.py @@ -8,8 +8,8 @@ class TestMulOp(unittest.TestCase): def setUp(self): self.type = "mul" - self.X = np.random.random((32, 784)).astype("float32") - self.Y = np.random.random((784, 100)).astype("float32") + self.X = np.random.random((32, 84)).astype("float32") + self.Y = np.random.random((84, 100)).astype("float32") self.Out = np.dot(self.X, self.Y) diff --git a/python/paddle/v2/framework/tests/test_rowwise_add_op.py b/python/paddle/v2/framework/tests/test_rowwise_add_op.py index ef1514983c03f822f84b85437d1cfe653b6a1a2e..04abc14ee198fe4e2307e009c696a2b40ec271b6 100644 --- a/python/paddle/v2/framework/tests/test_rowwise_add_op.py +++ b/python/paddle/v2/framework/tests/test_rowwise_add_op.py @@ -8,8 +8,8 @@ class TestRowwiseAddOp(unittest.TestCase): def setUp(self): self.type = "rowwise_add" - self.X = np.random.random((32, 784)).astype("float32") - self.b = np.random.random(784).astype("float32") + self.X = np.random.random((32, 84)).astype("float32") + self.b = np.random.random(84).astype("float32") self.Out = np.add(self.X, self.b) diff --git a/python/paddle/v2/framework/tests/test_sgd_op.py b/python/paddle/v2/framework/tests/test_sgd_op.py index 405d73b224fa153e50b4ec408a921f2bdaab46aa..ca03cc11abe2ceb31b33a87797aa752943dd2a7d 100644 --- a/python/paddle/v2/framework/tests/test_sgd_op.py +++ b/python/paddle/v2/framework/tests/test_sgd_op.py @@ -8,8 +8,8 @@ class TestSGD(unittest.TestCase): def setUp(self): self.type = "sgd" - self.param = numpy.random.random((342, 345)).astype("float32") - self.grad = numpy.random.random((342, 345)).astype("float32") + self.param = numpy.random.random((102, 105)).astype("float32") + self.grad = numpy.random.random((102, 105)).astype("float32") self.learning_rate = 0.1 self.param_out = self.param - self.learning_rate * self.grad diff --git a/python/paddle/v2/framework/tests/test_tensor.py b/python/paddle/v2/framework/tests/test_tensor.py index 6d59863cea29832f648139e07a134050e22bfa21..1af39818a305215b45219b8c5f0a10630fd64279 100644 --- a/python/paddle/v2/framework/tests/test_tensor.py +++ b/python/paddle/v2/framework/tests/test_tensor.py @@ -7,16 +7,17 @@ class TestScope(unittest.TestCase): def test_int_tensor(self): scope = core.Scope() var = scope.new_var("test_tensor") + place = core.CPUPlace() + tensor = var.get_tensor() tensor.set_dims([1000, 784]) - tensor.alloc_int() - + tensor.alloc_int(place) tensor_array = numpy.array(tensor) self.assertEqual((1000, 784), tensor_array.shape) tensor_array[3, 9] = 1 tensor_array[19, 11] = 2 - tensor.set(tensor_array) + tensor.set(tensor_array, place) tensor_array_2 = numpy.array(tensor) self.assertEqual(1.0, tensor_array_2[3, 9]) @@ -25,16 +26,18 @@ class TestScope(unittest.TestCase): def test_float_tensor(self): scope = core.Scope() var = scope.new_var("test_tensor") + place = core.CPUPlace() + tensor = var.get_tensor() tensor.set_dims([1000, 784]) - tensor.alloc_float() + tensor.alloc_float(place) tensor_array = numpy.array(tensor) self.assertEqual((1000, 784), tensor_array.shape) tensor_array[3, 9] = 1.0 tensor_array[19, 11] = 2.0 - tensor.set(tensor_array) + tensor.set(tensor_array, place) tensor_array_2 = numpy.array(tensor) self.assertAlmostEqual(1.0, tensor_array_2[3, 9])