diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst index cb330ea5e1b914587a725c9b90a33053f3fbbc3d..a4a843c610feb2b378c22c4b4097cd238ccd61ab 100644 --- a/doc/api/v2/config/layer.rst +++ b/doc/api/v2/config/layer.rst @@ -362,6 +362,11 @@ trans .. autoclass:: paddle.v2.layer.trans :noindex: +scale_shift +----------- +.. autoclass:: paddle.v2.layer.scale_shift + :noindex: + Sampling Layers =============== diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 68304c9fc8b8fa13cb1f99b82517abc87c71496c..2cdf323c53a8ba729ec74c1eacb9fa3ef272f44a 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -45,6 +45,7 @@ cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python backward sgd_op + gather_op add_op mul_op rowwise_add_op diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 9d30887224fe0020ff5665f362e7403bf5c724ee..bfda18724cc8ed23a40e0626ff07a290d26aa9d2 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -110,7 +110,7 @@ static std::unique_ptr BackwardRecursive( dup_output_ops[out].emplace_back(local_op_id); return false; }); - net->AddOp(std::move(bwd)); + net->AppendOp(std::move(bwd)); } // Get unique ID for this method. auto uid = uniq_id++; @@ -163,8 +163,9 @@ static std::unique_ptr BackwardRecursive( // If part of input gradient of that operator is not calculated, fill // zero variables to that input gradient. - net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {{"Src", {prefix}}}, - {{"Dst", {grad_input}}}, {})); + net->AppendOp(OpRegistry::CreateOp("fill_zeros_like", + {{"Src", {prefix}}}, + {{"Dst", {grad_input}}}, {})); } return false; }); @@ -195,7 +196,7 @@ static std::unique_ptr BackwardRecursive( if (net->ops_.empty()) { // Current no aux op is added to network return grad_op; } - net->AddOp(std::move(grad_op)); + net->AppendOp(std::move(grad_op)); } net->SetType("@GENERATED_BACKWARD@"); net->CompleteAddOp(); diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 2c5ec76dfeb8b8485951e4d94896b6758e0cb930..b93ab66f2f5b9cffa6d51b6e36afe552125970e4 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -75,13 +75,13 @@ class FcOp : public operators::NetOp { FcOp(const std::string &type, const VarNameMap &inputs, const VarNameMap &outputs, const AttributeMap &attrs) : NetOp(type, inputs, outputs, attrs) { - AddOp(OpRegistry::CreateOp("mul", - {{"X", {Input("X")}}, {"Y", {Input("W")}}}, - {{"Out", {Output("mul_result")}}}, {})); + AppendOp(OpRegistry::CreateOp("mul", + {{"X", {Input("X")}}, {"Y", {Input("W")}}}, + {{"Out", {Output("mul_result")}}}, {})); auto input_b = Inputs("b"); std::string before_act = "mul_result"; if (input_b.size() != 0) { - AddOp(OpRegistry::CreateOp( + AppendOp(OpRegistry::CreateOp( "rowwise_add", {{"X", {Output("mul_result")}}, {"b", {input_b[0]}}}, {{"Out", {Output("add_result")}}}, {})); before_act = "add_result"; @@ -92,8 +92,8 @@ class FcOp : public operators::NetOp { } } - AddOp(OpRegistry::CreateOp("sigmoid", {{"X", {Output(before_act)}}}, - {{"Out", {Output("Out")}}}, {})); + AppendOp(OpRegistry::CreateOp("sigmoid", {{"X", {Output(before_act)}}}, + {{"Out", {Output("Out")}}}, {})); CompleteAddOp(false); } }; @@ -234,13 +234,13 @@ TEST(Backward, net_fc_backward_not_have_b) { TEST(Backward, net_input_of_network_not_need_grad) { ops::NetOp net; - net.AddOp(f::OpRegistry::CreateOp( + net.AppendOp(f::OpRegistry::CreateOp( "fc", {{"X", {"x"}}, {"W", {"W1"}}, {"b", {"b1"}}}, {{"mul_result", {"mul_tmp_0"}}, {"add_result", {"add_tmp_0"}}, {"Out", {"hidden0"}}}, {})); - net.AddOp(f::OpRegistry::CreateOp( + net.AppendOp(f::OpRegistry::CreateOp( "fc", {{"X", {"hidden0"}}, {"W", {"W2"}}, {"b", {"b2"}}}, {{"mul_result", {"mul_tmp_1"}}, {"add_result", {"add_tmp_1"}}, @@ -273,10 +273,10 @@ TEST(Backward, net_input_of_network_not_need_grad) { TEST(Backward, net_shared_weight) { ops::NetOp net; - net.AddOp(f::OpRegistry::CreateOp("mul", {{"X", {"x"}}, {"Y", {"w"}}}, - {{"Out", {"out"}}}, {})); - net.AddOp(f::OpRegistry::CreateOp("mul", {{"X", {"out"}}, {"Y", {"w"}}}, - {{"Out", {"FinalOut"}}}, {})); + net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"x"}}, {"Y", {"w"}}}, + {{"Out", {"out"}}}, {})); + net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"out"}}, {"Y", {"w"}}}, + {{"Out", {"FinalOut"}}}, {})); net.CompleteAddOp(); auto bwd = f::Backward(net, {}); @@ -357,19 +357,19 @@ TEST(Backward, op_part_of_input_are_not_need) { TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ops::NetOp net; - net.AddOp(f::OpRegistry::CreateOp( + net.AppendOp(f::OpRegistry::CreateOp( "fc", {{"X", {"x1"}}, {"W", {"w1"}}, {"b", {"b1"}}}, {{"mul_result", {"mul_out1"}}, {"add_result", {"add_out1"}}, {"Out", {"out1"}}}, {})); - net.AddOp(f::OpRegistry::CreateOp( + net.AppendOp(f::OpRegistry::CreateOp( "fc", {{"X", {"out1"}}, {"W", {"w2"}}, {"b", {"b2"}}}, {{"mul_result", {"mul_out2"}}, {"add_result", {"tmp_out2"}}, {"Out", {"out2"}}}, {})); - net.AddOp(f::OpRegistry::CreateOp( + net.AppendOp(f::OpRegistry::CreateOp( "fc", {{"X", {"out2"}}, {"W", {"w3"}}, {"b", {"b3"}}}, {{"mul_result", {"mul_out3"}}, {"add_result", {"tmp_out3"}}, diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index f0114b9e4908d65b3fddb493230777f9e500b4e1..4539a1903eb430eb0d76a787adb32984342a468d 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -31,7 +31,7 @@ limitations under the License. */ namespace py = pybind11; USE_OP(add_two); -USE_CPU_ONLY_OP(onehot_cross_entropy); +USE_OP(onehot_cross_entropy); USE_OP(sgd); USE_OP(mul); USE_OP(mean); @@ -42,6 +42,7 @@ USE_OP(fill_zeros_like); USE_OP_ITSELF(recurrent_op); USE_OP(gaussian_random); USE_OP(uniform_random); +USE_CPU_ONLY_OP(gather); namespace paddle { namespace framework { @@ -222,8 +223,8 @@ All parameter, weight, gradient are variables in Paddle. retv->SetType("plain_net"); return retv; }) - .def("add_op", [](operators::NetOp &self, - const OperatorBase &op) { self.AddOp(op); }) + .def("append_op", [](operators::NetOp &self, + const OperatorBase &op) { self.AppendOp(op); }) .def("complete_add_op", &operators::NetOp::CompleteAddOp) .def("complete_add_op", [](std::shared_ptr &self) { self->CompleteAddOp(); diff --git a/paddle/function/GemmFunctor.cpp b/paddle/function/GemmFunctor.cpp index dc83278d8e3b70101962521b76b902046629c2dd..9e25ee58a12490a1454436b3fe4a89176478d5c0 100644 --- a/paddle/function/GemmFunctor.cpp +++ b/paddle/function/GemmFunctor.cpp @@ -84,7 +84,7 @@ struct BlasGemm { } }; -template class BlasGemm; -template class BlasGemm; +template struct BlasGemm; +template struct BlasGemm; } // namespace paddle diff --git a/paddle/gserver/gradientmachines/NeuralNetwork.cpp b/paddle/gserver/gradientmachines/NeuralNetwork.cpp index cfa80a89365af5111746eec9599d16e37532a9f7..26cff3e67710b2f38d93572c5d58849aa94a5135 100644 --- a/paddle/gserver/gradientmachines/NeuralNetwork.cpp +++ b/paddle/gserver/gradientmachines/NeuralNetwork.cpp @@ -202,7 +202,7 @@ void NeuralNetwork::prefetch(const std::vector& inArgs) { auto mat = dynamic_cast( para->getMat(PARAMETER_VALUE).get()); para->clearGradient(); - mat->clearIndices(); + if (mat) mat->clearIndices(); } } } diff --git a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp index f98bf95064fa539b990309dfe0bff10c1e99d096..157b1ab45163a94a81d859dbcb7a52ae8edae439 100644 --- a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp @@ -184,7 +184,7 @@ public: } void backward(const UpdateCallback& callback) override { - if (biases_) { + if (biases_ && biases_->getWGrad()) { backwardActivation(); biases_->getWGrad()->collectBias(*getOutputGrad(), 1); biases_->getParameterPtr()->incUpdate(callback); diff --git a/paddle/gserver/layers/ScaleShiftLayer.cpp b/paddle/gserver/layers/ScaleShiftLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..35fd038ab43a8a8b08bc328b3d1b08a7bbedd0a1 --- /dev/null +++ b/paddle/gserver/layers/ScaleShiftLayer.cpp @@ -0,0 +1,107 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Layer.h" + +namespace paddle { + +/** + * A layer applies a linear transformation to each element in each row of + * the input matrix. For each element, the layer first re-scale it and then + * adds a bias to it. + * + * \f[ + * y = wx + b + * \f] + * + * Here, w is the scale and b is the bias. Both w and b are trainable scalars. + * + */ + +class ScaleShiftLayer : public Layer { +protected: + std::unique_ptr scale_; + std::unique_ptr offset_; + +public: + explicit ScaleShiftLayer(const LayerConfig& config) : Layer(config) {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; +}; + +REGISTER_LAYER(scale_shift, ScaleShiftLayer); + +bool ScaleShiftLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + Layer::init(layerMap, parameterMap); + CHECK_EQ(inputLayers_.size(), 1U); + scale_.reset(new Weight(1, 1, parameters_[0])); + if (biasParameter_.get() != NULL) { + offset_ = std::unique_ptr(new Weight(1, 1, biasParameter_)); + } + return true; +} + +void ScaleShiftLayer::forward(PassType passType) { + Layer::forward(passType); + + MatrixPtr inV = getInputValue(0); + resetOutput(inV->getHeight(), inV->getWidth()); + MatrixPtr outV = getOutputValue(); + real scaleValue = scale_->getW()->getElement(0, 0); + outV->mulScalar(*inV, scaleValue); + if (offset_) { + real offsetValue = offset_->getW()->getElement(0, 0); + outV->add(offsetValue); + } +} + +void ScaleShiftLayer::backward(const UpdateCallback& callback) { + MatrixPtr inV = getInputValue(0); + MatrixPtr inG = getInputGrad(0); + MatrixPtr outV = getOutputValue(); + MatrixPtr outG = getOutputGrad(); + + /* Calculate the parameter gradient for the current layer */ + if (scale_->getWGrad()) { + MatrixPtr rowSumMtx; + Matrix::resizeOrCreate(rowSumMtx, outG->getHeight(), 1, false, useGpu_); + // this_i = scaleDest * this_i + scaleSum * \sum_j b_{ij} * c_{ij} + rowSumMtx->sumOfProducts( + /* b= */ *inV, /* c= */ *outG, /* scaleSum= */ 1, /* scaleDest= */ 0.); + // this_i = scaleDest * this_i + scaleSum * \sum_j b_{ji} + scale_->getWGrad()->sumCols( + /* b= */ *rowSumMtx, /* scaleSum= */ 1., /* scaleDest= */ 1.); + scale_->getParameterPtr()->incUpdate(callback); + } + if (offset_ && offset_->getWGrad()) { + MatrixPtr rowSumMtx; + Matrix::resizeOrCreate(rowSumMtx, outG->getHeight(), 1, false, useGpu_); + rowSumMtx->sumRows(*outG, 1., 0.); + offset_->getWGrad()->sumCols(*rowSumMtx, 1., 1.); + offset_->getParameterPtr()->incUpdate(callback); + } + + /* Calculate the input layers error */ + if (inG) { + real scaleValue = scale_->getW()->getElement(0, 0); + inG->add(*outG, scaleValue); + } +} + +} // namespace paddle diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index b3913d3a2808b646ee18ee22f787ef16b09b2753..9348c47bd4f0a54a345b794674f1daf80a9664f3 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -2027,6 +2027,21 @@ TEST(Layer, RowL2NormLayer) { } } +TEST(Layer, ScaleShiftLayer) { + const size_t batchSize = 16; + const size_t size = 32; + TestConfig config; + config.layerConfig.set_type("scale_shift"); + config.layerConfig.set_size(size); + config.biasSize = 1; + config.inputDefs.push_back( + {INPUT_DATA, "input", /* dim= */ size, /* paraSize= */ 1}); + config.layerConfig.add_inputs(); + for (auto useGpu : {false, true}) { + testLayerGrad(config, "scale_shift", batchSize, false, useGpu, false); + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index a7c89787e43df6173791bc54b3faffc034867f7d..ba1362e8bf38ef4735ffeea29bea12f6eff99982 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -43,6 +43,7 @@ endfunction() add_subdirectory(math) cc_test(gather_test SRCS gather_test.cc DEPS tensor) +op_library(gather_op SRCS gather_op.cc gather_op.cu) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index a623c551e1088365ade6f73bc6149977b6ef017e..ab1e1c101a10e09a81f7785d2f1514822e3bdf15 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -39,11 +39,10 @@ class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto X_grad = ctx.Output(framework::GradVarName("X")); + auto dX = ctx.Output(framework::GradVarName("X")); auto X = ctx.Input("X"); - // TODO(superjom) add enforce here after helper functions ready - X_grad->Resize(X->dims()); + dX->Resize(X->dims()); } }; @@ -70,9 +69,7 @@ namespace ops = paddle::operators; REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad, ops::OnehotCrossEntropyGradientOp); -REGISTER_OP_CPU_KERNEL( - onehot_cross_entropy, - ops::OnehotCrossEntropyOpKernel); -REGISTER_OP_CPU_KERNEL( - onehot_cross_entropy_grad, - ops::OnehotCrossEntropyGradientOpKernel); +REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, + ops::OnehotCrossEntropyOpKernel); +REGISTER_OP_CPU_KERNEL(onehot_cross_entropy_grad, + ops::OnehotCrossEntropyGradientOpKernel); diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 4bbc8f093a794d46737a16488684a6a0cc25e285..d999bfce58c8a6db5c811aad677c07094b881841 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -12,10 +12,122 @@ See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU -#include "paddle/operators/cross_entropy_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/platform/assert.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +__host__ __device__ T clipping_log(const T x) { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + T v = log(x); + if (v == INFINITY) { + return kApproInf; + } + if (v == -INFINITY) { + return -kApproInf; + } + return v; +} + +template +__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, + const int N, const int D) { + // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. + // CUDA_1D_KERNEL_LOOP(i, N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + PADDLE_ASSERT(label[i] >= 0 && label[i] < D); + Y[i] = -clipping_log(X[i * D + label[i]]); + } +} + +// TODO(qingqing): make zero setting an common function. +template +__global__ void zero(T* X, const int N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + X[i] = 0.0; + } +} + +template +__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X, + const int* label, const int N, + const int D) { + // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. + // CUDA_1D_KERNEL_LOOP(i, N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + int idx = i * D + label[i]; + dX[idx] = -dY[i] / X[idx]; + } +} + +template +class OnehotCrossEntropyOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto X = ctx.Input("X"); + const T* Xdata = X->data(); + const int* label_data = ctx.Input("label")->data(); + auto Y = ctx.Output("Y"); + Y->mutable_data(ctx.GetPlace()); + T* Ydata = Y->data(); + + int N = X->dims()[0]; + int D = X->dims()[1]; + int block = 512; + int grid = (N + block - 1) / block; + // TODO(qingqing) launch kernel on specified stream + // base on ExecutionContext. + CrossEntropyKernel<<>>(Ydata, Xdata, label_data, N, D); + } +}; + +template +class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto X = ctx.Input("X"); + auto dX = ctx.Output(framework::GradVarName("X")); + auto dY = ctx.Input(framework::GradVarName("Y")); + auto label = ctx.Input("label"); + + auto* dXdata = dX->template mutable_data(ctx.GetPlace()); + auto* dYdata = dY->template data(); + auto* Xdata = X->template data(); + auto* label_data = label->data(); + + int N = X->dims()[0]; + int D = X->dims()[1]; + int block = 512; + int grid = (N * D + block - 1) / block; + zero<<>>(dXdata, N * D); + + grid = (N + block - 1) / block; + // TODO(qingqing): launch kernel on specified stream + // base on ExecutionContext. + CrossEntropyGradientKernel<<>>(dXdata, dYdata, Xdata, + label_data, N, D); + } +}; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL( - onehot_cross_entropy, - ops::OnehotCrossEntropyOpKernel); +REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, + ops::OnehotCrossEntropyOpCUDAKernel); +REGISTER_OP_GPU_KERNEL(onehot_cross_entropy_grad, + ops::OnehotCrossEntropyGradientOpCUDAKernel); diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index b7df92c9a98ebf12b72a8d3d8e8e4e1a950f06c9..eb4d1348de1d940e2648c83c8ba94b289f10c5b2 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -21,7 +21,7 @@ namespace operators { using Tensor = framework::Tensor; template -T tolerable_value(T x) { +inline T tolerable_value(const T x) { static_assert(std::is_floating_point::value, "tolerable_value works only on float, " "double and double double."); @@ -39,10 +39,13 @@ T tolerable_value(T x) { return x; } -template +template class OnehotCrossEntropyOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + auto X = ctx.Input("X"); const T* Xdata = X->data(); const int* label_data = ctx.Input("label")->data(); @@ -62,10 +65,13 @@ class OnehotCrossEntropyOpKernel : public framework::OpKernel { } }; -template +template class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + auto X = ctx.Input("X"); auto dX = ctx.Output(framework::GradVarName("X")); auto dY = ctx.Input(framework::GradVarName("Y")); @@ -79,6 +85,8 @@ class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel { const int batch_size = X->dims()[0]; const int class_num = X->dims()[1]; + // TODO(qingqing): make zero setting an common function. + memset(dXdata, 0, sizeof(T) * batch_size * class_num); for (int i = 0; i < batch_size; ++i) { int index = i * class_num + label_data[i]; dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]); diff --git a/paddle/operators/gather.h b/paddle/operators/gather.h index d6e6990394e46ba06c4bacfe33ca522f3ff1413a..92fb51ec17709bc6f8abb2f516a9240fb5dc3a77 100644 --- a/paddle/operators/gather.h +++ b/paddle/operators/gather.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/framework/ddim.h" +#include "paddle/framework/eigen.h" #include "paddle/framework/tensor.h" #include "paddle/platform/place.h" @@ -25,13 +26,13 @@ namespace operators { // Implementation of CPU copy template -void CPUGather(const T* params, const int* indices, const int slice_size, +void CPUGather(const T* src, const int* indices, const int slice_size, const int index_size, T* output) { const size_t slice_bytes = slice_size * sizeof(T); for (int i = 0; i < index_size; ++i) { int index_ = indices[i]; - memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes); + memcpy(output + i * slice_size, src + index_ * slice_size, slice_bytes); } } @@ -55,7 +56,7 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src, int index_size = index->dims()[0]; auto src_dims = src->dims(); - paddle::framework::DDim output_dims(src_dims); + framework::DDim output_dims(src_dims); output_dims[0] = index_size; // slice size diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..123bed296c462c30bddd3bfbd530098fdbfe4856 --- /dev/null +++ b/paddle/operators/gather_op.cc @@ -0,0 +1,72 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/gather_op.h" +#include "paddle/framework/ddim.h" + +namespace paddle { +namespace operators { + +class GatherOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + int batch_size = ctx.Input("Index")->dims()[0]; + PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0"); + framework::DDim output_dims(ctx.Input("X")->dims()); + output_dims[0] = batch_size; + ctx.Output("Out")->Resize(output_dims); + } +}; + +class GatherGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto X_grad = ctx.Output(framework::GradVarName("X")); + auto X = ctx.Input("X"); + + X_grad->Resize(X->dims()); + } +}; + +class GatherOpMaker : public framework::OpProtoAndCheckerMaker { + public: + GatherOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The source input of gather op"); + AddInput("Index", "The index input of gather op"); + AddOutput("Out", "The output of add op"); + AddComment(R"DOC( +Gather Operator by selecting from the first axis, + +Out = X[Index] +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad, + ops::GatherGradOp); +REGISTER_OP_CPU_KERNEL(gather, + ops::GatherOpKernel); +REGISTER_OP_CPU_KERNEL( + gather_grad, + ops::GatherGradientOpKernel); diff --git a/paddle/operators/gather_op.cu b/paddle/operators/gather_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..3f04a7b3f8142106917975cd1e0413fa1633a298 --- /dev/null +++ b/paddle/operators/gather_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/gather_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(gather, + ops::GatherOpKernel); diff --git a/paddle/operators/gather_op.h b/paddle/operators/gather_op.h new file mode 100644 index 0000000000000000000000000000000000000000..381854f301870beadb72d9e9b4eb17ff199960fb --- /dev/null +++ b/paddle/operators/gather_op.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "gather.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "scatter.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class GatherOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *X = ctx.Input("X"); + auto *Index = ctx.Input("Index"); + auto *Y = ctx.Output("Out"); + + Y->mutable_data(ctx.GetPlace()); + Gather(ctx.GetPlace(), X, Index, Y); + } +}; + +template +class GatherGradientOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *Index = ctx.Input("Index"); + auto *dX = ctx.Output(framework::GradVarName("X")); + auto *dO = ctx.Input(framework::GradVarName("Out")); + + dX->mutable_data(ctx.GetPlace()); + ScatterUpdate(ctx.GetPlace(), dO, Index, dX); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index f30bbce9586d61063b4b61d98695bb568ef73c8d..a85363ad81d2a23e7267026c067f74f8c94c4786 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,25 +16,25 @@ namespace paddle { namespace operators { template -class GaussianRandomKernel : public framework::OpKernel { +class CPUGaussianRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { float mean = context.op_.GetAttr("mean"); float std = context.op_.GetAttr("std"); - auto* tensor = context.Output(0); + auto* tensor = context.Output("Out"); T* data = tensor->mutable_data(context.GetPlace()); - // TODO(dzh): attribute does not support unsigned int. - // And we need a global random seed configuration. - int seed = context.op_.GetAttr("seed"); + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); + std::minstd_rand engine; if (seed == 0) { seed = std::random_device()(); } - std::mt19937 g(seed); - std::normal_distribution distribution(mean, std); + engine.seed(seed); + std::normal_distribution dist(mean, std); ssize_t size = framework::product(tensor->dims()); - for (int i = 0; i < size; ++i) { - data[i] = distribution(g); + for (ssize_t i = 0; i < size; ++i) { + data[i] = dist(engine); } } }; @@ -48,7 +45,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext& context) const override { - auto* tensor = context.Output(0); + auto* tensor = context.Output("Out"); auto dims = GetAttr>("dims"); PADDLE_ENFORCE(dims.size() > 0UL, "dims can be one int or array. dims must be set."); @@ -68,8 +65,8 @@ Use to initialize tensor with gaussian random generator. )DOC"); AddAttr>("dims", "The dimension of random tensor."); - AddAttr("mean", "mean value of random.").SetDefault(.0f); - AddAttr("std", "minimum value of random value.").SetDefault(1.0f); + AddAttr("mean", "mean of random tensor.").SetDefault(.0f); + AddAttr("std", "std of random tensor.").SetDefault(1.0f); AddAttr("seed", "Random seed of generator." "0 means use system wide seed") @@ -83,4 +80,4 @@ Use to initialize tensor with gaussian random generator. namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker); -REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel); +REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel); diff --git a/paddle/operators/gaussian_random_op.cu b/paddle/operators/gaussian_random_op.cu index 1340b1e1e9f19fd96ced9e57fab75fe9d33bc84e..018a4bfcb26b9008c054000c91edf01e371fd82b 100644 --- a/paddle/operators/gaussian_random_op.cu +++ b/paddle/operators/gaussian_random_op.cu @@ -1,53 +1,65 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include "paddle/platform/dynload/curand.h" -#include "paddle/platform/gpu_info.h" - +#include +#include +#include +#include #include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" namespace paddle { namespace operators { template -class GaussianRandomKernel : public framework::OpKernel { +struct GaussianGenerator { + T mean_, std_; + unsigned int seed_; + + __host__ __device__ GaussianGenerator(T mean, T std, int seed) + : mean_(mean), std_(std), seed_(seed) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::normal_distribution dist(mean_, std_); + rng.discard(n); + return dist(rng); + } +}; + +template +class GPUGaussianRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - float mean = context.op_.GetAttr("mean"); - float std = context.op_.GetAttr("std"); - auto* tensor = context.Output(0); + auto* tensor = context.Output("Out"); T* data = tensor->mutable_data(context.GetPlace()); - - int seed = context.op_.GetAttr("seed"); + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); if (seed == 0) { std::random_device rd; seed = rd(); } - curandGenerator_t g; - PADDLE_ENFORCE(platform::dynload::curandCreateGenerator( - &g, CURAND_RNG_PSEUDO_DEFAULT)); - PADDLE_ENFORCE( - platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed)); - platform::dynload::curandGenerateNormal( - g, data, framework::product(tensor->dims()), mean, std); + T mean = static_cast(context.op_.GetAttr("mean")); + T std = static_cast(context.op_.GetAttr("std")); + thrust::counting_iterator index_sequence_begin(0); + ssize_t N = framework::product(tensor->dims()); + thrust::transform(index_sequence_begin, index_sequence_begin + N, + thrust::device_ptr(data), + GaussianGenerator(mean, std, seed)); } }; } // namespace operators } // namespace paddle -namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(gaussian_random, ops::GaussianRandomKernel); +REGISTER_OP_GPU_KERNEL(gaussian_random, + paddle::operators::GPUGaussianRandomKernel); diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 460e458ca4f7f40746f0dbf7e258a165faa88e1a..173cc3850ca9d97200e272ec59d1bd3fe09b5053 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -13,7 +13,6 @@ limitations under the License. */ #include "paddle/operators/mul_op.h" -#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index 885ac6eeca65998dea62c1db40b9261cceb97805..3d3f996ef52b6c1136425ca9de0f60e7e155458f 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -84,13 +84,14 @@ class NetOp : public framework::OperatorBase { return true; } - void AddOp(const framework::OperatorBase& op) { AddOp(op.Clone()); } + void AppendOp(const framework::OperatorBase& op) { AppendOp(op.Clone()); } /** * @brief Add an operator by ptr */ - void AddOp(std::unique_ptr op) { - PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); + void AppendOp(std::unique_ptr op) { + PADDLE_ENFORCE(!add_op_done_, + "Cannot AppendOp when this network is sealed"); PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op"); ops_.push_back(std::move(op)); } diff --git a/paddle/operators/net_op_test.cc b/paddle/operators/net_op_test.cc index e9598610c0a74e08a613a397109ad65994821498..99019754a965e5e7aeb74c6bfc10c9646289651b 100644 --- a/paddle/operators/net_op_test.cc +++ b/paddle/operators/net_op_test.cc @@ -38,10 +38,10 @@ TEST(OpKernel, all) { auto net = std::make_shared(); ASSERT_NE(net, nullptr); - net->AddOp(std::unique_ptr( + net->AppendOp(std::unique_ptr( new TestOp("test", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, {{"Out", {"y"}}}, {}))); - net->AddOp(std::unique_ptr( + net->AppendOp(std::unique_ptr( new TestOp("test", {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}}, {{"Out", {"z"}}}, {}))); @@ -61,7 +61,7 @@ TEST(NetOp, insert_op) { auto op1 = std::unique_ptr( new framework::NOP("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, {{"Out", {"y"}}}, {})); - net.AddOp(*op1); + net.AppendOp(*op1); net.InsertOp(0, *op1); ASSERT_EQ(2UL, net.ops_.size()); net.InsertOp(2, std::move(op1)); @@ -70,9 +70,9 @@ TEST(NetOp, insert_op) { TEST(NetOp, Clone) { NetOp net; - net.AddOp( + net.AppendOp( std::unique_ptr(new framework::NOP{"empty", {}, {}, {}})); - net.AddOp(std::unique_ptr( + net.AppendOp(std::unique_ptr( new framework::NOP{"empty2", {}, {}, {}})); net.CompleteAddOp(true); auto new_net_op = net.Clone(); diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index 232135c38de68d4002e044972b282b43a1374c72..1cbd8bb31ad90a32d8a4e3bb59617d0b5384e470 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" @@ -63,7 +63,7 @@ class RowwiseAddGradKernel : public framework::OpKernel { // https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html // colwise add - Eigen::array dims{{1}}; /* dimension to reduce */ + Eigen::array dims{{0}}; /* dimension to reduce */ EigenVector::Flatten(*db).device(place) = OutGrad.sum(dims); } }; diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index a0a0d4d914b37fca4250e5218a953f573611a086..29491137e6d8b4bfa2d0d07d48ffed1212a6131f 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -39,7 +36,8 @@ class CPUUniformRandomKernel : public framework::OpKernel { std::uniform_real_distribution dist( static_cast(context.op_.GetAttr("min")), static_cast(context.op_.GetAttr("max"))); - for (ssize_t i = 0; i < framework::product(tensor->dims()); ++i) { + ssize_t size = framework::product(tensor->dims()); + for (ssize_t i = 0; i < size; ++i) { data[i] = dist(engine); } } @@ -66,7 +64,6 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker { : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddOutput("Out", "The output tensor of uniform random op"); AddComment(R"DOC(Uniform random operator. - Used to initialize tensor with uniform random generator. )DOC"); AddAttr>("dims", "the dimension of random tensor"); diff --git a/paddle/operators/uniform_random_op.cu b/paddle/operators/uniform_random_op.cu index 7a243555b6385af690e9632dfa81bf96d70f925d..1d6709934cbbcf50265eabef87c857654f783ed8 100644 --- a/paddle/operators/uniform_random_op.cu +++ b/paddle/operators/uniform_random_op.cu @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/paddle/parameter/Parameter.h b/paddle/parameter/Parameter.h index e31cbc3dee6c57851c241e117dbbd9b701db9d2c..321f4275d8e68d7d3fbbc19acf0afacf689474e5 100644 --- a/paddle/parameter/Parameter.h +++ b/paddle/parameter/Parameter.h @@ -65,7 +65,10 @@ public: size_t getSize() const { return config_.size(); } bool isFullSize() const { - return this->getSize() == bufs_[PARAMETER_VALUE]->getSize(); + if (bufs_[PARAMETER_VALUE]) { + return this->getSize() == bufs_[PARAMETER_VALUE]->getSize(); + } + return false; } inline bool useGpu() const { return useGpu_; } diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index f92c15ae450e94de44d27e77763e791e6bae4426..ad212c5b2c47312743362db4926c80bf056e100d 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -114,9 +114,6 @@ CUDADeviceContext::~CUDADeviceContext() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); } - if (curand_generator_) { - PADDLE_ENFORCE(dynload::curandDestroyGenerator(curand_generator_)); - } eigen_stream_.reset(); eigen_device_.reset(); PADDLE_ENFORCE(cudaStreamDestroy(stream_)); @@ -152,19 +149,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() { cudaStream_t CUDADeviceContext::stream() { return stream_; } -curandGenerator_t CUDADeviceContext::curand_generator() { - if (!curand_generator_) { - SetDeviceId(place_.device); - PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_, - CURAND_RNG_PSEUDO_DEFAULT)); - PADDLE_ENFORCE( - dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_)); - - PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_)); - } - return curand_generator_; -} - #endif // PADDLE_ONLY_CPU } // namespace platform diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index c5042ae33e47e04521e59e0d91ddd8d4efffe50a..11528e1194e4516891034fa8febdac3ba6eed204 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -17,7 +17,6 @@ limitations under the License. */ #ifndef PADDLE_ONLY_CPU #include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cudnn.h" -#include "paddle/platform/dynload/curand.h" #include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif @@ -40,7 +39,7 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext { public: CPUDeviceContext(); - explicit CPUDeviceContext(CPUPlace); + explicit CPUDeviceContext(CPUPlace place); virtual ~CPUDeviceContext() {} Eigen::DefaultDevice* eigen_device() const; @@ -56,7 +55,7 @@ class EigenCudaStreamDevice; class CUDADeviceContext : public DeviceContext { public: - explicit CUDADeviceContext(GPUPlace); + explicit CUDADeviceContext(GPUPlace place); virtual ~CUDADeviceContext(); /*! \brief Wait for all operations completion in the stream. */ @@ -75,9 +74,6 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle(); - /*! \brief Return curand handle in the device context. */ - curandGenerator_t curand_generator(); - /*! \brief Return cuda stream in the device context. */ cudaStream_t stream(); // clang-format on @@ -85,18 +81,13 @@ class CUDADeviceContext : public DeviceContext { private: GPUPlace place_; - private: std::unique_ptr eigen_device_; std::unique_ptr eigen_stream_; - private: - uint64_t seed_; - // clang-format off cudaStream_t stream_{nullptr}; cudnnHandle_t cudnn_handle_{nullptr}; cublasHandle_t cublas_handle_{nullptr}; - curandGenerator_t curand_generator_{nullptr}; // clang-format on }; diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 8b764bdcd9d92e6b2203e45160acee35ec110538..5883a55272f0f24c94d48bc43c62ddb7bef15465 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -43,8 +43,6 @@ TEST(Device, CUDADeviceContext) { ASSERT_NE(nullptr, cudnn_handle); cublasHandle_t cublas_handle = device_context->cublas_handle(); ASSERT_NE(nullptr, cublas_handle); - curandGenerator_t curand_handle = device_context->curand_generator(); - ASSERT_NE(nullptr, curand_handle); ASSERT_NE(nullptr, device_context->stream()); delete device_context; } diff --git a/paddle/pserver/ParameterClient2.cpp b/paddle/pserver/ParameterClient2.cpp index f7e391f76324a09c203dfbbb449feb050caa8fb4..54063a809a4f9e558f8d364f5c437f2b6d98925b 100644 --- a/paddle/pserver/ParameterClient2.cpp +++ b/paddle/pserver/ParameterClient2.cpp @@ -65,7 +65,6 @@ void ParameterClient2::initThreads() { LOG(INFO) << "parallel_thread_num dosent need to set"; } syncThreadPool_.reset(new SyncThreadPool(threadNum_)); - startThreads(); } @@ -224,6 +223,14 @@ void ParameterClient2::prepareSendData( request.set_cost(cost); request.set_batch_status(batchStatus); CHECK_EQ(request.blocks_size(), 0); + VLOG(10) << "request: trainer_id: " << request.trainer_id() + << " update_mode" << request.update_mode() + << " send_back_parameter: " << request.send_back_parameter() + << " send_back_parameter_type: " + << request.send_back_parameter_type() + << " num_samples: " << request.num_samples() + << " cost: " << request.cost() + << " batch_status: " << request.batch_status(); } for (const auto& segments : parameterSegments) { const auto it = parameterMap_.find(segments.id); @@ -251,11 +258,17 @@ void ParameterClient2::prepareSendData( CHECK(sendMat != nullptr) << "sendMat is nullptr"; syncThreadPool_->exec([&](int tid, size_t numThreads) { + std::lock_guard guard(sparseAutoGrowthMutex_); const auto& localIndices = prefetchMat->getLocalIndices(); /// num of sparse rows size_t nLocalBlocks = localIndices.size(); uint64_t beginDim = 0; uint64_t endDim = 0; + + // FIXME(typhoonzero): let it resize first + prefetchMat->getLocalRow(nLocalBlocks + 1); + sendMat->getLocalRow(nLocalBlocks + 1); + for (size_t row = 0; row < nLocalBlocks; ++row) { int64_t blockId = localIndices[row]; // local row -> sparse row int serverId = std::abs((blockId + nameHash) % serviceNum_); @@ -275,7 +288,6 @@ void ParameterClient2::prepareSendData( block->set_begin_pos(row * blockSize); /// block len block->set_block_size(endDim - beginDim); - if (sendingPara) { sendJob->parallelInputIovs[serverId].push_back( {sendMat->getLocalRow(row), sizeof(real) * (size_t)blockSize}); diff --git a/paddle/pserver/ParameterClient2.h b/paddle/pserver/ParameterClient2.h index 89b3ddd502151e537b81bdbb09f171dd6e13ba26..29b9eeacddf2945dd22b7b17fc87c7c74b868896 100644 --- a/paddle/pserver/ParameterClient2.h +++ b/paddle/pserver/ParameterClient2.h @@ -583,6 +583,7 @@ protected: #ifndef PADDLE_DISABLE_TIMER uint64_t forwardbackwordTime_; #endif + std::mutex sparseAutoGrowthMutex_; /// map id to parameter used for decoding protobuf data std::unordered_map parameterMap_; diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 2d96901ed4c12e70090b560f40688898ab942eea..33a20afb183240c027b4b4a6663fd5ff8de6f79f 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -338,7 +338,8 @@ def RecurrentLayerGroupWithoutOutLinksBegin(name, in_links_count += 1 layer_name = MakeLayerNameInParentSubmodel(name) layer = g_layer_map[layer_name] - ScatterAgentLayer(name=name, size=layer.size) + ScatterAgentLayer( + name=name, size=layer.size, width=layer.width, height=layer.height) pair = g_current_submodel.in_links.add() pair.layer_name = layer_name @@ -2201,8 +2202,8 @@ class MaxOutLayer(LayerBase): maxout_conf = self.config.inputs[0].maxout_conf parse_maxout(self.inputs[0].maxout, input_layer.name, maxout_conf) out_channels = maxout_conf.image_conf.channels / maxout_conf.groups - self.set_cnn_layer(name, g_layer_map[input_layer.name].height, - g_layer_map[input_layer.name].width, out_channels) + self.set_cnn_layer(name, maxout_conf.image_conf.img_size_y, + maxout_conf.image_conf.img_size, out_channels) @config_layer('row_conv') @@ -2236,6 +2237,20 @@ class ClipLayer(LayerBase): self.config.inputs[0].clip_conf.max = max +@config_layer('scale_shift') +class ScaleShiftLayer(LayerBase): + def __init__(self, name, inputs, bias=True, **xargs): + super(ScaleShiftLayer, self).__init__( + name, 'scale_shift', 0, inputs=inputs, **xargs) + config_assert( + len(self.inputs) == 1, + 'ScaleShiftLayer must have one and only one input.') + input_layer = self.get_input_layer(0) + self.set_layer_size(input_layer.size) + self.create_input_parameter(0, 1, [1, 1]) + self.create_bias_parameter(bias, 1) + + # key: cost type # value: cost class g_cost_map = {} @@ -2395,9 +2410,11 @@ class GatherAgentLayer(LayerBase): @config_layer('scatter_agent') class ScatterAgentLayer(LayerBase): - def __init__(self, name, size, device=None): + def __init__(self, name, size, width=None, height=None, device=None): super(ScatterAgentLayer, self).__init__( name, 'scatter_agent', size, inputs=[], device=device) + if height and width: + self.set_layer_height_width(height, width) @config_layer('multiplex') diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index de7f31a20a823da57aa97ed1788b7f744eeb4e32..74b88cd4f876f7596916f4c46006fc8b57f98297 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -16,11 +16,13 @@ import functools import collections import inspect +import paddle.trainer.config_parser as cp from paddle.trainer.config_parser import * from .activations import LinearActivation, SigmoidActivation, TanhActivation, \ ReluActivation, IdentityActivation, SoftmaxActivation, BaseActivation from .evaluators import * -from .poolings import MaxPooling, AvgPooling, BasePoolingType +from .poolings import MaxPooling, AvgPooling, BasePoolingType, \ + CudnnAvgPooling, CudnnMaxPooling from .attrs import * from .default_decorators import * @@ -133,6 +135,7 @@ __all__ = [ 'clip_layer', 'slice_projection', 'kmax_sequence_score_layer', + 'scale_shift_layer', ] @@ -230,6 +233,7 @@ class LayerType(object): CLIP_LAYER = 'clip' KMAX_SEQ_SCORE = 'kmax_seq_score' + SCALE_SHIFT_LAYER = 'scale_shift' @staticmethod def is_layer_type(type_name): @@ -328,6 +332,14 @@ class LayerOutput(object): self.outputs = outputs self.reverse = reverse + @property + def width(self): + return cp.g_layer_map[self.full_name].width + + @property + def height(self): + return cp.g_layer_map[self.full_name].height + def set_input(self, input): """ Set the input for a memory layer. Can only be used for memory layer @@ -909,7 +921,13 @@ def data_layer(name, size, height=None, width=None, layer_attr=None): width=width, **ExtraLayerAttribute.to_kwargs(layer_attr)) - return LayerOutput(name, LayerType.DATA, size=size) + num_filters = None + if height is not None and width is not None: + num_filters = size / (width * height) + assert num_filters * width * height == size, \ + "size=%s width=%s height=%s" % (size, width, height) + + return LayerOutput(name, LayerType.DATA, size=size, num_filters=num_filters) @wrap_name_default("embedding") @@ -2588,6 +2606,10 @@ def img_pool_layer(input, assert input.num_filters is not None num_channels = input.num_filters + assert type(pool_type) in [AvgPooling, MaxPooling, CudnnAvgPooling, + CudnnMaxPooling], \ + "only (Cudnn)AvgPooling, (Cudnn)MaxPooling are supported" + if pool_type is None: pool_type = MaxPooling() elif isinstance(pool_type, AvgPooling): @@ -2597,7 +2619,6 @@ def img_pool_layer(input, if ( isinstance(pool_type, AvgPooling) or isinstance(pool_type, MaxPooling)) \ else pool_type.name - pool_size_y = pool_size if pool_size_y is None else pool_size_y stride_y = stride if stride_y is None else stride_y padding_y = padding if padding_y is None else padding_y @@ -4221,8 +4242,7 @@ def conv_operator(img, num_channels = img.num_filters assert isinstance(filter, LayerOutput) - if filter.size is not None: - filter.size = filter_size * filter_size_y * num_filters * num_channels + assert filter.size is not None opCls = ConvTransOperator if trans else ConvOperator @@ -4933,7 +4953,6 @@ def maxout_layer(input, groups, num_channels=None, name=None, layer_attr=None): :return: LayerOutput object. :rtype: LayerOutput """ - assert input.layer_type == LayerType.CONV_LAYER assert isinstance(input.activation, LinearActivation) assert groups > 1 if num_channels is None: @@ -6229,3 +6248,43 @@ def kmax_sequence_score_layer(input, name=None, beam_size=1): return LayerOutput( name, LayerType.KMAX_SEQ_SCORE, parents=[input], size=input.size) + + +@wrap_name_default("scale_shift") +@wrap_param_attr_default() +@wrap_bias_attr_default() +def scale_shift_layer(input, name=None, param_attr=None, bias_attr=None): + """ + A layer applies a linear transformation to each element in each row of + the input matrix. For each element, the layer first re-scale it and then + adds a bias to it. + + This layer is very like the SlopeInterceptLayer, except the scale and + bias are trainable. + + .. math:: + + y = w * x + b + + .. code-block:: python + + scale_shift = scale_shift_layer(input=input_layer, bias_attr=False) + + :param name: The Layer Name. + :type name: basestring + :param input: The input layer. + :type input: LayerOutput. + :param param_attr: The parameter attribute of scaling. + :type param_attr: ParameterAttribute + :param bias_attr: The parameter attribute of shifting. + :type bias_attr: ParameterAttribute + :return: LayerOutput object. + :rtype: LayerOutput + """ + Layer( + name=name, + type=LayerType.SCALE_SHIFT_LAYER, + inputs=Input(input.name, **param_attr.attr), + bias=ParamAttr.to_bias(bias_attr)) + return LayerOutput( + name, LayerType.SCALE_SHIFT_LAYER, parents=[input], size=input.size) diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index a61beb871ad064c617fa141451afcb2a5ac64854..3860699f6fb76871ee8ea100fce1b3118bbc041d 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -8,6 +8,6 @@ test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer -test_kmax_seq_socre_layer test_seq_select_layers) +test_kmax_seq_socre_layer test_seq_select_layers test_scale_shift_layer) export whole_configs=(test_split_datasource) diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_shift_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_shift_layer.protostr new file mode 100644 index 0000000000000000000000000000000000000000..35ade126a2586a8e3eee6f0ac3c7e49523c8f5c5 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_shift_layer.protostr @@ -0,0 +1,72 @@ +type: "nn" +layers { + name: "data" + type: "data" + size: 100 + active_type: "" +} +layers { + name: "__scale_shift_0__" + type: "scale_shift" + size: 100 + active_type: "" + inputs { + input_layer_name: "data" + input_parameter_name: "___scale_shift_0__.w0" + } +} +layers { + name: "__scale_shift_1__" + type: "scale_shift" + size: 100 + active_type: "" + inputs { + input_layer_name: "data" + input_parameter_name: "___scale_shift_1__.w0" + } + bias_parameter_name: "___scale_shift_1__.wbias" +} +parameters { + name: "___scale_shift_0__.w0" + size: 1 + initial_mean: 0.0 + initial_std: 1.0 + dims: 1 + dims: 1 + initial_strategy: 0 + initial_smart: true +} +parameters { + name: "___scale_shift_1__.w0" + size: 1 + initial_mean: 0.0 + initial_std: 1.0 + dims: 1 + dims: 1 + initial_strategy: 0 + initial_smart: true +} +parameters { + name: "___scale_shift_1__.wbias" + size: 1 + initial_mean: 0.0 + initial_std: 0.0 + dims: 1 + dims: 1 + initial_strategy: 0 + initial_smart: false +} +input_layer_names: "data" +output_layer_names: "__scale_shift_0__" +output_layer_names: "__scale_shift_1__" +sub_models { + name: "root" + layer_names: "data" + layer_names: "__scale_shift_0__" + layer_names: "__scale_shift_1__" + input_layer_names: "data" + output_layer_names: "__scale_shift_0__" + output_layer_names: "__scale_shift_1__" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_scale_shift_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_scale_shift_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..dd589116fa9932144ca066d3fa4c929d1433a7f1 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_scale_shift_layer.py @@ -0,0 +1,9 @@ +from paddle.trainer_config_helpers import * + +data = data_layer(name='data', size=100) + +scale = scale_shift_layer(input=data, bias_attr=False) + +scale_shift = scale_shift_layer(input=data) + +outputs(scale, scale_shift) diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index ce57a0713092723b6a99b2416e06ff1a436f043b..8a2b7c54d3ef481712bef1e1a39fb336b23eb1b2 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -13,6 +13,7 @@ py_test(test_add_two_op SRCS test_add_two_op.py) py_test(test_sigmoid_op SRCS test_sigmoid_op.py) py_test(test_softmax_op SRCS test_softmax_op.py) py_test(test_cross_entropy_op SRCS test_cross_entropy_op.py) +py_test(test_gather_op SRCS test_gather_op.py) py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py) py_test(gradient_checker SRCS gradient_checker.py) @@ -22,7 +23,7 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py) py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py) py_test(test_operator SRCS test_operator.py) -# py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) +py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py) py_test(test_sgd_op SRCS test_sgd_op.py) diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index dd65e0f2dc23d3f657ff16c55fb297dae210b2d7..3bc05a0feccbbd3d5e7852d85bd3dc8edaccfd07 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -64,7 +64,8 @@ class OpTestMeta(type): actual = numpy.array(scope.find_var(out_name).get_tensor()) expect = self.outputs[out_name] self.assertTrue( - numpy.allclose(actual, expect), + numpy.allclose( + actual, expect, atol=1e-05), "output name: " + out_name + "has diff") obj.test_all = test_all diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index 4815192e255c6e0429db3f50918a76a773b30131..d4277f2a42ce2e66e37405ccd3b2ee444d403d1a 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -8,9 +8,8 @@ class TestCrossEntropy(unittest.TestCase): __metaclass__ = OpTestMeta def setUp(self): - # TODO this unit test is not passed self.type = "onehot_cross_entropy" - batch_size = 100 + batch_size = 30 class_num = 10 X = numpy.random.random((batch_size, class_num)).astype("float32") label = 5 * numpy.ones(batch_size).astype("int32") @@ -22,9 +21,9 @@ class TestCrossEntropy(unittest.TestCase): class CrossEntropyGradOpTest(GradientChecker): - def test_softmax_grad(self): + def test_check_grad(self): op = create_op("onehot_cross_entropy") - batch_size = 100 + batch_size = 30 class_num = 10 inputs = { "X": numpy.random.uniform( diff --git a/python/paddle/v2/framework/tests/test_gather_op.py b/python/paddle/v2/framework/tests/test_gather_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e86898304252d08be718e40fed46c5e921596af7 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_gather_op.py @@ -0,0 +1,34 @@ +import unittest +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op +import numpy +import paddle.v2.framework.core as core +from paddle.v2.framework.op import Operator + + +class TestGatherOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "gather" + xnp = numpy.random.random((10, 20)).astype("float32") + self.inputs = { + 'X': xnp, + 'Index': numpy.array([1, 3, 5]).astype("int32") + } + self.outputs = {'Out': self.inputs['X'][self.inputs['Index']]} + + +class TestGatherGradOp(GradientChecker): + def test_gather_grad(self): + print 'creating op' + op = create_op("gather") + print 'creating op done' + xnp = numpy.random.random((10, 20)).astype("float32") + inputs = {'X': xnp, 'Index': numpy.array([1, 3, 5]).astype("int32")} + print 'correct before check gradient' + self.check_grad(op, inputs, set("X"), "Out") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_net.py b/python/paddle/v2/framework/tests/test_net.py index b42cadd11ab75abbc35763c8d12e8c27e995f0dc..9339cf28dabc95b46b958777200fb1db9dcf284f 100644 --- a/python/paddle/v2/framework/tests/test_net.py +++ b/python/paddle/v2/framework/tests/test_net.py @@ -6,8 +6,8 @@ import unittest def fc(X, W, Y): ret_v = core.Net.create() - ret_v.add_op(Operator("mul", X="X", Y="W", Out="pre_activation")) - ret_v.add_op(Operator("sigmoid", X="pre_activation", Y=Y)) + ret_v.append_op(Operator("mul", X="X", Y="W", Out="pre_activation")) + ret_v.append_op(Operator("sigmoid", X="pre_activation", Y=Y)) ret_v.complete_add_op(True) return ret_v @@ -16,12 +16,12 @@ class TestNet(unittest.TestCase): def test_net_all(self): net = core.Net.create() op1 = Operator("add_two", X="X", Y="Y", Out="Out") - net.add_op(op1) + net.append_op(op1) net2 = core.Net.create() - net2.add_op(fc(X="X", W="w", Y="fc.out")) + net2.append_op(fc(X="X", W="w", Y="fc.out")) net2.complete_add_op(True) - net.add_op(net2) + net.append_op(net2) net.complete_add_op(True) expected = ''' diff --git a/python/paddle/v2/framework/tests/test_recurrent_op.py b/python/paddle/v2/framework/tests/test_recurrent_op.py index 3d4a34d8d713ff1beeeba8ac48ad95176f7a29f2..d6000ab9f9d5b969f96128b183f48d49000c8a5e 100644 --- a/python/paddle/v2/framework/tests/test_recurrent_op.py +++ b/python/paddle/v2/framework/tests/test_recurrent_op.py @@ -150,7 +150,7 @@ class TestRecurrentOp(unittest.TestCase): sig_op = Operator("sigmoid", X="sum", Y="h@alias") for op in [x_fc_op, h_fc_op, sum_op, sig_op]: - stepnet.add_op(op) + stepnet.append_op(op) stepnet.complete_add_op(True) self.rnnop.set_stepnet(stepnet) diff --git a/python/paddle/v2/framework/tests/test_rowwise_add_op.py b/python/paddle/v2/framework/tests/test_rowwise_add_op.py index 29d72e850099734a9828ccceed47cc0a57fc3d6b..45d569da29d13cf8e2a3cb9d67c2d01e8b365453 100644 --- a/python/paddle/v2/framework/tests/test_rowwise_add_op.py +++ b/python/paddle/v2/framework/tests/test_rowwise_add_op.py @@ -20,7 +20,7 @@ class RowwiseAddGradOpTest(GradientChecker): def test_rowwise_add(self): op = create_op("rowwise_add") inputs = { - "X": np.random.uniform(0.1, 1, [10, 10]).astype("float32"), + "X": np.random.uniform(0.1, 1, [5, 10]).astype("float32"), "b": np.random.uniform(0.1, 1, [10]).astype("float32") } self.check_grad(op, inputs, set(["X", "b"]), "Out")