diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..953367eb8bcd1282ab6c7e1189d778f0ce3da541 --- /dev/null +++ b/paddle/operators/cross_entropy_op.cc @@ -0,0 +1,147 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/cross_entropy_op.h" + +namespace paddle { +namespace operators { + +using framework::LoDTensor; + +class CrossEntropyOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), + "Input(Label) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), "Output(Y) must not be null."); + + auto x = ctx.Input("X"); + auto label = ctx.Input("Label"); + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); + PADDLE_ENFORCE_EQ(label->dims().size(), 2, + "Input(Label)'s rank must be 2."); + // TODO(xinghai-sun): remove this check after swtiching to bool + PADDLE_ENFORCE(ctx.Attr("soft_label") == 0 || + ctx.Attr("soft_label") == 1); + PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], + "The 1st dimension of Input(X) and Input(Label) must " + "be equal."); + if (ctx.Attr("soft_label") == 1) { + PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], + "If Attr(soft_label) == 1, The 2nd dimension of " + "Input(X) and Input(Label) must be equal."); + } else { + PADDLE_ENFORCE_EQ(label->dims()[1], 1, + "If Attr(soft_label) == 0, The 2nd dimension of " + "Input(Label) must be 1."); + } + + ctx.Output("Y")->Resize({x->dims()[0], 1}); + } +}; + +class CrossEntropyGradientOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), + "Input(Label) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), + "Input(Y@GRAD) must not be null."); + + auto x = ctx.Input("X"); + auto label = ctx.Input("Label"); + auto dy = ctx.Input(framework::GradVarName("Y")); + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); + PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2."); + PADDLE_ENFORCE_EQ(label->dims().size(), 2, + "Input(Label)'s rank must be 2."); + // TODO(xinghai-sun): remove this check after swtiching to bool + PADDLE_ENFORCE(ctx.Attr("soft_label") == 0 || + ctx.Attr("soft_label") == 1); + PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], + "The 1st dimension of Input(X) and Input(Label) must " + "be equal."); + PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0], + "The 1st dimension of Input(X) and Input(Y@Grad) must " + "be equal."); + PADDLE_ENFORCE_EQ(dy->dims()[1], 1, + "The 2nd dimension of Input(Y@Grad) must be 1."); + if (ctx.Attr("soft_label") == 1) { + PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], + "If Attr(soft_label) == 1, The 2nd dimension of " + "Input(X) and Input(Label) must be equal."); + } else { + PADDLE_ENFORCE_EQ(label->dims()[1], 1, + "If Attr(soft_label) == 0, The 2nd dimension of " + "Input(Label) must be 1."); + } + + auto dx = ctx.Output(framework::GradVarName("X")); + dx->Resize(x->dims()); + } +}; + +class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { + public: + CrossEntropyOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The first input of CrossEntropyOp"); + AddInput("Label", "The second input of CrossEntropyOp"); + AddOutput("Y", "The output of CrossEntropyOp"); + AddAttr("soft_label", "Is soft label. Default zero.").SetDefault(0); + + AddComment(R"DOC( +CrossEntropy Operator. + +It supports both standard cross-entropy and soft-label cross-entropy loss +computation. +1) One-hot cross-entropy: + soft_label = 0, Label[i, 0] indicates the class index for sample i: + + Y[i] = -log(X[i, Label[i]]) + +2) Soft-label cross-entropy: + soft_label = 1, Label[i, j] indicates the soft label of class j + for sample i: + + Y[i] = \sum_j{-Label[i, j] * log(X[i, j])} + + Please make sure that in this case the summuation of each row of Label + equals one. + +3) One-hot cross-entropy with vecterized Input(Label): + As a special case of 2), when each row of Input(Label) has only one + non-zero element (equals 1), soft-label cross-entropy degenerates to a + one-hot cross-entropy with one-hot label representation. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker, + cross_entropy_grad, ops::CrossEntropyGradientOp); +REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel); +REGISTER_OP_CPU_KERNEL(cross_entropy_grad, + ops::CrossEntropyGradientOpKernel); diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..ab6ad0e062269483948bf70e492c9431991221fb --- /dev/null +++ b/paddle/operators/cross_entropy_op.cu @@ -0,0 +1,158 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/cross_entropy_op.h" +#include "paddle/platform/assert.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +template +__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, + const int N, const int D) { + // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. + // CUDA_1D_KERNEL_LOOP(i, N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + PADDLE_ASSERT(label[i] >= 0 && label[i] < D); + Y[i] = -tolerable_value(log(X[i * D + label[i]])); + } +} + +template +__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, + const int N, const int D) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + T sum = static_cast(0); + for (int j = 0; j < D; j++) { + sum += label[i * D + j] * tolerable_value(log(X[i * D + j])); + } + Y[i] = -sum; + } +} + +// TODO(qingqing): make zero setting an common function. +template +__global__ void zero(T* X, const int N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + X[i] = 0.0; + } +} + +template +__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X, + const int* label, const int N, + const int D) { + // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. + // CUDA_1D_KERNEL_LOOP(i, N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + int idx = i * D + label[i]; + dX[idx] = -dY[i] / X[idx]; + } +} + +template +__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X, + const T* label, const int N, + const int D) { + // TOOD(qingqing): optimize for this kernel + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + for (int j = 0; j < D; ++j) { + int idx = i * D + j; + dX[idx] = -label[idx] * dY[i] / X[idx]; + } + } +} + +template +class CrossEntropyOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto x = ctx.Input("X"); + auto y = ctx.Output("Y"); + auto label = ctx.Input("Label"); + + auto* x_data = x->data(); + y->mutable_data(ctx.GetPlace()); + auto* y_data = y->data(); + + int n = x->dims()[0]; + int d = x->dims()[1]; + int block = 512; + int grid = (n + block - 1) / block; + // TODO(qingqing) launch kernel on specified stream + // base on ExecutionContext. + if (ctx.Attr("soft_label") == 1) { + auto* label_data = ctx.Input("Label")->data(); + SoftCrossEntropyKernel<<>>(y_data, x_data, label_data, n, + d); + } else { + auto* label_data = ctx.Input("Label")->data(); + CrossEntropyKernel<<>>(y_data, x_data, label_data, n, d); + } + } +}; + +template +class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto x = ctx.Input("X"); + auto dx = ctx.Output(framework::GradVarName("X")); + auto dy = ctx.Input(framework::GradVarName("Y")); + auto label = ctx.Input("Label"); + + auto* dx_data = dx->mutable_data(ctx.GetPlace()); + auto* dy_data = dy->data(); + auto* x_data = x->data(); + + int n = x->dims()[0]; + int d = x->dims()[1]; + int block = 512; + int grid = (n * d + block - 1) / block; + zero<<>>(dx_data, n * d); + grid = (n + block - 1) / block; + // TODO(qingqing): launch kernel on specified stream + // base on ExecutionContext. + if (ctx.Attr("soft_label") == 1) { + auto* label_data = label->data(); + SoftCrossEntropyGradientKernel<<>>( + dx_data, dy_data, x_data, label_data, n, d); + } else { + auto* label_data = label->data(); + CrossEntropyGradientKernel<<>>(dx_data, dy_data, x_data, + label_data, n, d); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel); +REGISTER_OP_GPU_KERNEL(cross_entropy_grad, + ops::CrossEntropyGradientOpCUDAKernel); diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h new file mode 100644 index 0000000000000000000000000000000000000000..1b4b23ac2029138afadef0168262203ac2e20430 --- /dev/null +++ b/paddle/operators/cross_entropy_op.h @@ -0,0 +1,117 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +HOSTDEVICE T tolerable_value(const T x) { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + if (x == INFINITY) { + return kApproInf; + } + if (x == -INFINITY) { + return -kApproInf; + } + return x; +} + +template +class CrossEntropyOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto x = ctx.Input("X"); + auto y = ctx.Output("Y"); + + auto* x_data = x->data(); + y->mutable_data(ctx.GetPlace()); + auto* y_data = y->data(); + + int batch_size = x->dims()[0]; + int class_num = x->dims()[1]; + + if (ctx.Attr("soft_label") == 1) { + auto* label_data = ctx.Input("Label")->data(); + int index = 0; + for (int i = 0; i < batch_size; ++i) { + T sum = static_cast(0); + for (int j = 0; j < class_num; ++j) { + sum += label_data[index] * tolerable_value(std::log(x_data[index])); + y_data[i] = -sum; + index++; + } + } + } else { + auto* label_data = ctx.Input("Label")->data(); + for (int i = 0; i < batch_size; ++i) { + int index = i * class_num + label_data[i]; + y_data[i] = -tolerable_value(std::log(x_data[index])); + } + } + } +}; + +template +class CrossEntropyGradientOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto x = ctx.Input("X"); + auto dx = ctx.Output(framework::GradVarName("X")); + auto dy = ctx.Input(framework::GradVarName("Y")); + auto label = ctx.Input("Label"); + + auto* dx_data = dx->mutable_data(ctx.GetPlace()); + auto* dy_data = dy->data(); + auto* x_data = x->data(); + + int batch_size = x->dims()[0]; + int class_num = x->dims()[1]; + + // TODO(qingqing): make zero setting an common function. + if (ctx.Attr("soft_label") == 1) { + auto* label_data = ctx.Input("Label")->data(); + int index = 0; + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < class_num; ++j) { + dx_data[index] = -label_data[index] * dy_data[i] / x_data[index]; + index++; + } + } + } else { + auto* label_data = label->data(); + memset(dx_data, 0, sizeof(T) * batch_size * class_num); + for (int i = 0; i < batch_size; ++i) { + PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num); + int index = i * class_num + label_data[i]; + dx_data[index] = -dy_data[i] / x_data[index]; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/dropout_op.cc b/paddle/operators/dropout_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b111b9fccb2310bd5fb92bda878a497c51f62ce0 --- /dev/null +++ b/paddle/operators/dropout_op.cc @@ -0,0 +1,113 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/dropout_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; +using framework::LoDTensor; + +class DropoutOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_GE(ctx.Attr("dropout_prob"), 0); + PADDLE_ENFORCE_LE(ctx.Attr("dropout_prob"), 1); + // TODO(xinghai-sun): remove this check after swtiching to bool + PADDLE_ENFORCE(ctx.Attr("is_training") == 0 || + ctx.Attr("is_training") == 1); + + auto dims = ctx.Input("X")->dims(); + ctx.Output("Out")->Resize(dims); + if (ctx.Attr("is_training") == 1) { + ctx.Output("Mask")->Resize(dims); + } + } +}; + +template +class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { + public: + DropoutOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("dropout_prob", "Probability of setting units to zero.") + .SetDefault(.5f); + // TODO(xinghai-sun): use bool for is_training after bool is supported. + AddAttr("is_training", "Whether in training phase.").SetDefault(1); + AddAttr("seed", "Dropout random seed.").SetDefault(0); + AddInput("X", "The input of dropout op."); + AddOutput("Out", "The output of dropout op."); + AddOutput("Mask", "The random sampled dropout mask.").AsIntermediate(); + + AddComment(R"DOC( +Dropout Operator. + +"Dropout" refers to randomly dropping out units in a nerual network. It is a +regularization technique for reducing overfitting by preventing neuron +co-adaption during training. The dropout operator randomly set (according to +the given dropout probability) the outputs of some units to zero, while others +being set to their inputs. +)DOC"); + } +}; + +template +class DropoutOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_EQ(ctx.Attr("is_training"), 1, + "GradOp is only callable when is_training is true"); + + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Mask"), "Mask must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) must not be null."); + + PADDLE_ENFORCE_GE(ctx.Attr("dropout_prob"), 0); + PADDLE_ENFORCE_LE(ctx.Attr("dropout_prob"), 1); + // TODO(xinghai-sun): remove this check after swtiching to bool + PADDLE_ENFORCE(ctx.Attr("is_training") == 0 || + ctx.Attr("is_training") == 1); + auto x_dims = ctx.Input("X")->dims(); + auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); + PADDLE_ENFORCE_EQ(x_dims, out_dims, + "Dimensions of Input(X) and Out@Grad must be the same."); + auto mask_dims = ctx.Input("Mask")->dims(); + PADDLE_ENFORCE_EQ(x_dims, mask_dims, + "Dimensions of Input(X) and Mask must be the same."); + + auto *x_grad = ctx.Output(framework::GradVarName("X")); + x_grad->Resize(x_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad, + ops::DropoutOpGrad); +REGISTER_OP_CPU_KERNEL( + dropout, ops::CPUDropoutKernel); +REGISTER_OP_CPU_KERNEL( + dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/operators/dropout_op.cu b/paddle/operators/dropout_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..186237fb238add37f32403309a0f7e8a9846d335 --- /dev/null +++ b/paddle/operators/dropout_op.cu @@ -0,0 +1,86 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include +#include +#include +#include +#include "paddle/operators/dropout_op.h" + +namespace paddle { +namespace operators { + +template +struct MaskGenerator { + AttrType dropout_prob; + int seed; + + __host__ __device__ MaskGenerator(AttrType dropout_prob, int seed) + : dropout_prob(dropout_prob), seed(seed) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed); + thrust::uniform_real_distribution dist(0, 1); + rng.discard(n); + if (dist(rng) < dropout_prob) { + return static_cast(0); + } else { + return static_cast(1); + } + } +}; + +// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. +// Use std::random and thrust::random(thrust is a std library in CUDA) to +// implement uniform random. +template +class GPUDropoutKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* y = context.Output("Out"); + y->mutable_data(context.GetPlace()); + AttrType dropout_prob = context.Attr("dropout_prob"); + + auto X = EigenMatrix::Reshape(*x, 1); + auto Y = EigenMatrix::Reshape(*y, 1); + + auto place = context.GetEigenDevice(); + if (context.Attr("is_training") == 1) { + auto* mask = context.Output("Mask"); + auto* mask_data = mask->mutable_data(context.GetPlace()); + int size = framework::product(mask->dims()); + int seed = context.Attr("seed"); + thrust::counting_iterator index_sequence_begin(0); + thrust::transform(index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(mask_data), + MaskGenerator(dropout_prob, seed)); + auto M = EigenMatrix::Reshape(*mask, 1); + Y.device(place) = X * M; + } else { + Y.device(place) = X * dropout_prob; + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + dropout, ops::GPUDropoutKernel); +REGISTER_OP_GPU_KERNEL( + dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/operators/dropout_op.h b/paddle/operators/dropout_op.h new file mode 100644 index 0000000000000000000000000000000000000000..82eafee0e0e7db7b4b4ae5405f37146d061aefd5 --- /dev/null +++ b/paddle/operators/dropout_op.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +template +class CPUDropoutKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* y = context.Output("Out"); + const auto* x_data = x->data(); + auto* y_data = y->mutable_data(context.GetPlace()); + AttrType dropout_prob = context.Attr("dropout_prob"); + + if (context.Attr("is_training") == 1) { + auto* mask = context.Output("Mask"); + auto* mask_data = mask->mutable_data(context.GetPlace()); + int seed = context.Attr("seed"); + std::minstd_rand engine; + engine.seed(seed); + std::uniform_real_distribution dist(0, 1); + size_t size = framework::product(mask->dims()); + for (size_t i = 0; i < size; ++i) { + if (dist(engine) < dropout_prob) { + mask_data[i] = 0; + y_data[i] = 0; + } else { + mask_data[i] = 1; + y_data[i] = x_data[i]; + } + } + } else { + auto X = EigenMatrix::Reshape(*x, 1); + auto Y = EigenMatrix::Reshape(*y, 1); + auto place = context.GetEigenDevice(); + Y.device(place) = X * dropout_prob; + } + } +}; + +template +class DropoutGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE_EQ(context.Attr("is_training"), 1, + "GradOp is only callable when is_training is true"); + + auto* grad_x = context.Output(framework::GradVarName("X")); + auto* grad_y = context.Input(framework::GradVarName("Out")); + auto* mask = context.Input("Mask"); + grad_x->mutable_data(context.GetPlace()); + + auto M = EigenMatrix::Reshape(*mask, 1); + auto dX = EigenMatrix::Reshape(*grad_x, 1); + auto dY = EigenMatrix::Reshape(*grad_y, 1); + + auto place = context.GetEigenDevice(); + dX.device(place) = dY * M; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/onehot_cross_entropy_op.cc b/paddle/operators/onehot_cross_entropy_op.cc deleted file mode 100644 index f38be3549f3c5d2443f61739fc32cdca74197649..0000000000000000000000000000000000000000 --- a/paddle/operators/onehot_cross_entropy_op.cc +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/operators/onehot_cross_entropy_op.h" - -namespace paddle { -namespace operators { - -class OnehotCrossEntropyOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL( - ctx.InputVar("X"), - "Input(X) of OnehotCrossEntropyOp should not be null."); - PADDLE_ENFORCE_NOT_NULL( - ctx.InputVar("label"), - "Input(label) of OnehotCrossEntropyOp should not be null."); - PADDLE_ENFORCE_NOT_NULL( - ctx.OutputVar("Y"), - "Output(Y) of OnehotCrossEntropyOp should not be null."); - - auto *X = ctx.Input("X"); - auto *label = ctx.Input("label"); - - PADDLE_ENFORCE_EQ(X->dims().size(), 2, "X's dimension must be 2."); - PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label's dimension must be 1."); - PADDLE_ENFORCE_EQ(X->dims()[0], label->dims()[0]); - ctx.Output("Y")->Resize({X->dims()[0], 1}); - } -}; - -class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - auto dX = ctx.Output(framework::GradVarName("X")); - auto X = ctx.Input("X"); - - dX->Resize(X->dims()); - } -}; - -class OnehotCrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { - public: - OnehotCrossEntropyOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The first input of OnehotCrossEntropyOp"); - AddInput("label", "The second input of OnehotCrossEntropyOp"); - AddOutput("Y", "The output of OnehotCrossEntropyOp"); - AddComment(R"DOC( -OnehotCrossEntropy Operator. - - Y[i] = -log(X[i][j]) - -)DOC"); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, - ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad, - ops::OnehotCrossEntropyGradientOp); -REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, - ops::OnehotCrossEntropyOpKernel); -REGISTER_OP_CPU_KERNEL(onehot_cross_entropy_grad, - ops::OnehotCrossEntropyGradientOpKernel); diff --git a/paddle/operators/onehot_cross_entropy_op.cu b/paddle/operators/onehot_cross_entropy_op.cu deleted file mode 100644 index d999bfce58c8a6db5c811aad677c07094b881841..0000000000000000000000000000000000000000 --- a/paddle/operators/onehot_cross_entropy_op.cu +++ /dev/null @@ -1,133 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/framework/op_registry.h" -#include "paddle/platform/assert.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -__host__ __device__ T clipping_log(const T x) { - PADDLE_ASSERT(std::is_floating_point::value); - const T kApproInf = 1e20; - T v = log(x); - if (v == INFINITY) { - return kApproInf; - } - if (v == -INFINITY) { - return -kApproInf; - } - return v; -} - -template -__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, - const int N, const int D) { - // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. - // CUDA_1D_KERNEL_LOOP(i, N) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - PADDLE_ASSERT(label[i] >= 0 && label[i] < D); - Y[i] = -clipping_log(X[i * D + label[i]]); - } -} - -// TODO(qingqing): make zero setting an common function. -template -__global__ void zero(T* X, const int N) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - X[i] = 0.0; - } -} - -template -__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X, - const int* label, const int N, - const int D) { - // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. - // CUDA_1D_KERNEL_LOOP(i, N) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - int idx = i * D + label[i]; - dX[idx] = -dY[i] / X[idx]; - } -} - -template -class OnehotCrossEntropyOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use GPUPlace."); - - auto X = ctx.Input("X"); - const T* Xdata = X->data(); - const int* label_data = ctx.Input("label")->data(); - auto Y = ctx.Output("Y"); - Y->mutable_data(ctx.GetPlace()); - T* Ydata = Y->data(); - - int N = X->dims()[0]; - int D = X->dims()[1]; - int block = 512; - int grid = (N + block - 1) / block; - // TODO(qingqing) launch kernel on specified stream - // base on ExecutionContext. - CrossEntropyKernel<<>>(Ydata, Xdata, label_data, N, D); - } -}; - -template -class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use GPUPlace."); - - auto X = ctx.Input("X"); - auto dX = ctx.Output(framework::GradVarName("X")); - auto dY = ctx.Input(framework::GradVarName("Y")); - auto label = ctx.Input("label"); - - auto* dXdata = dX->template mutable_data(ctx.GetPlace()); - auto* dYdata = dY->template data(); - auto* Xdata = X->template data(); - auto* label_data = label->data(); - - int N = X->dims()[0]; - int D = X->dims()[1]; - int block = 512; - int grid = (N * D + block - 1) / block; - zero<<>>(dXdata, N * D); - - grid = (N + block - 1) / block; - // TODO(qingqing): launch kernel on specified stream - // base on ExecutionContext. - CrossEntropyGradientKernel<<>>(dXdata, dYdata, Xdata, - label_data, N, D); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, - ops::OnehotCrossEntropyOpCUDAKernel); -REGISTER_OP_GPU_KERNEL(onehot_cross_entropy_grad, - ops::OnehotCrossEntropyGradientOpCUDAKernel); diff --git a/paddle/operators/onehot_cross_entropy_op.h b/paddle/operators/onehot_cross_entropy_op.h deleted file mode 100644 index eb4d1348de1d940e2648c83c8ba94b289f10c5b2..0000000000000000000000000000000000000000 --- a/paddle/operators/onehot_cross_entropy_op.h +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -inline T tolerable_value(const T x) { - static_assert(std::is_floating_point::value, - "tolerable_value works only on float, " - "double and double double."); - - const T kApproInf = 1e20; - - if (x == INFINITY) { - return kApproInf; - } - - if (x == -INFINITY) { - return -kApproInf; - } - - return x; -} - -template -class OnehotCrossEntropyOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), - "It must use CPUPlace."); - - auto X = ctx.Input("X"); - const T* Xdata = X->data(); - const int* label_data = ctx.Input("label")->data(); - auto Y = ctx.Output("Y"); - - Y->mutable_data(ctx.GetPlace()); - - T* Ydata = Y->data(); - - int batch_size = X->dims()[0]; - int class_num = X->dims()[1]; - - for (int i = 0; i < batch_size; ++i) { - int index = i * class_num + label_data[i]; - Ydata[i] = -tolerable_value(std::log(Xdata[index])); - } - } -}; - -template -class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), - "It must use CPUPlace."); - - auto X = ctx.Input("X"); - auto dX = ctx.Output(framework::GradVarName("X")); - auto dY = ctx.Input(framework::GradVarName("Y")); - auto label = ctx.Input("label"); - - auto* dXdata = dX->template mutable_data(ctx.GetPlace()); - auto* dYdata = dY->template data(); - auto* Xdata = X->template data(); - auto* label_data = label->data(); - - const int batch_size = X->dims()[0]; - const int class_num = X->dims()[1]; - - // TODO(qingqing): make zero setting an common function. - memset(dXdata, 0, sizeof(T) * batch_size * class_num); - for (int i = 0; i < batch_size; ++i) { - int index = i * class_num + label_data[i]; - dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/prelu_op.cc b/paddle/operators/prelu_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..7ae80b296850f2f433c89d904ebf32355b2a29c7 --- /dev/null +++ b/paddle/operators/prelu_op.cc @@ -0,0 +1,94 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/prelu_op.h" +#include "paddle/operators/net_op.h" + +namespace paddle { +namespace operators { + +class PReluOp : public framework::OperatorWithKernel { + public: + PReluOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + auto *in = ctx.Input("X"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Alpha"), + "Input(Alpha) should not be null"); + auto *alpha = ctx.Input("Alpha"); + PADDLE_ENFORCE(alpha->numel() == 1, "Size of weight Alpha must be one."); + + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) should not be null"); + auto *out = ctx.Output("Out"); + out->Resize(in->dims()); + } +}; + +class PReluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + PReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input tensor of prelu operator."); + AddInput("Alpha", "The alpha weight of PRelu operator."); + AddOutput("Out", "The output tensor of PRelu operator."); + AddComment(R"DOC(PRelu operator + +The equation is: + + f(x) = alpha * x , for x < 0 + f(x) = x , for x >= 0 + +)DOC"); + } +}; + +// The operator to calculate gradients of a prelu operator. +class PReluGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto *dx = ctx.Output(framework::GradVarName("X")); + auto *x = ctx.Input("X"); + + auto *dalpha = + ctx.Output(framework::GradVarName("Alpha")); + auto *alpha = ctx.Input("Alpha"); + + dx->Resize(x->dims()); + dalpha->Resize(alpha->dims()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP(prelu, ops::PReluOp, ops::PReluOpMaker, prelu_grad, + ops::PReluGradOp); +REGISTER_OP_CPU_KERNEL(prelu, + ops::PReluKernel); +REGISTER_OP_CPU_KERNEL(prelu_grad, + ops::PReluGradKernel); diff --git a/paddle/operators/prelu_op.cu b/paddle/operators/prelu_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..9e391dabae735cc8a740b46b50d31d271f99b65d --- /dev/null +++ b/paddle/operators/prelu_op.cu @@ -0,0 +1,21 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/prelu_op.h" + +REGISTER_OP_GPU_KERNEL( + prelu, paddle::operators::PReluKernel); +REGISTER_OP_GPU_KERNEL( + prelu_grad, + paddle::operators::PReluGradKernel); diff --git a/paddle/operators/prelu_op.h b/paddle/operators/prelu_op.h new file mode 100644 index 0000000000000000000000000000000000000000..63031c25cc3570cf40440726ea76976953d5417a --- /dev/null +++ b/paddle/operators/prelu_op.h @@ -0,0 +1,102 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/platform/transform.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using platform::Transform; + +template +class PReluFunctor { + public: + explicit PReluFunctor(const T* alpha) : alpha_(alpha) {} + + HOSTDEVICE T operator()(const T& x) const { + if (x > 0) + return x; + else + return x * (*alpha_); + } + + private: + const T* alpha_; +}; + +template +class PReluKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* alpha = context.Input("Alpha"); + auto* out = context.Output("Out"); + + const T* x_ptr = x->data(); + T* o_ptr = out->mutable_data(context.GetPlace()); + + auto* alpha_ptr = alpha->data(); + + int numel = x->numel(); + + Transform(context.device_context(), x_ptr, x_ptr + numel, o_ptr, + PReluFunctor(alpha_ptr)); + } +}; + +template +class PReluGradFunctor { + public: + explicit PReluGradFunctor(const T* alpha) : alpha_(alpha) {} + + HOSTDEVICE T operator()(const T& out, const T& dout) const { + if (out > 0) + return dout; + else + return dout * (*alpha_); + } + + private: + const T* alpha_; +}; + +template +class PReluGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* dx = context.Output(framework::GradVarName("X")); + auto* dout = context.Input(framework::GradVarName("Out")); + + auto* out = context.Input("Out"); + auto* alpha = context.Input("Alpha"); + auto* alpha_ptr = alpha->data(); + + T* dx_ptr = dx->mutable_data(context.GetPlace()); + const T* dout_ptr = dout->data(); + const T* out_ptr = out->data(); + int numel = dx->numel(); + + Transform(context.device_context(), out_ptr, out_ptr + numel, dout_ptr, + dx_ptr, PReluGradFunctor(alpha_ptr)); + + // TODO (Zhuoyuan): add dalpha upgrade when GPU kernels ready + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py new file mode 100644 index 0000000000000000000000000000000000000000..0206ca064be87afe204aa99021979b7ddc3c5d63 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -0,0 +1,89 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestCrossEntropyOp1(OpTest): + """Test standard cross-entropy, with index representation of labels. + """ + + def setUp(self): + self.op_type = "cross_entropy" + batch_size = 30 + class_num = 10 + X = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + label = np.random.randint(0, class_num, (batch_size, 1), dtype="int32") + cross_entropy = np.asmatrix( + [[-np.log(X[i][label[i][0]])] for i in range(X.shape[0])], + dtype="float32") + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": cross_entropy} + self.attrs = {'soft_label': 0} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Y") + + +class TestCrossEntropyOp2(OpTest): + """Test soft-label cross-entropy, with vecterized soft labels. + """ + + def setUp(self): + self.op_type = "cross_entropy" + batch_size = 10 + class_num = 5 + X = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + label = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + label /= label.sum(axis=1, keepdims=True) + cross_entropy = (-label * np.log(X)).sum( + axis=1, keepdims=True).astype("float32") + self.inputs = {'X': X, 'Label': label} + self.outputs = {'Y': cross_entropy} + self.attrs = {'soft_label': 1} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y') + + +class TestCrossEntropyOp3(OpTest): + """Test one-hot cross-entropy, with vecterized one-hot representation of + labels. + """ + + def setUp(self): + self.op_type = "cross_entropy" + batch_size = 30 + class_num = 10 + X = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + label_index = np.random.randint( + 0, class_num, (batch_size), dtype="int32") + label = np.zeros(X.shape) + label[np.arange(batch_size), label_index] = 1 + cross_entropy = np.asmatrix( + [[-np.log(X[i][label_index[i]])] for i in range(X.shape[0])], + dtype="float32") + cross_entropy2 = (-label * np.log(X)).sum( + axis=1, keepdims=True).astype("float32") + self.inputs = {'X': X, 'Label': label} + self.outputs = {'Y': cross_entropy} + self.attrs = {'soft_label': 1} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_dropout_op.py b/python/paddle/v2/framework/tests/test_dropout_op.py new file mode 100644 index 0000000000000000000000000000000000000000..3638fee1a1c26195791bc1f5a46dd749da0aee95 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_dropout_op.py @@ -0,0 +1,59 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestDropoutOp(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = {'dropout_prob': 0.0, 'is_training': 1} + self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64))} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out', max_relative_error=0.05) + + +class TestDropoutOp2(TestDropoutOp): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = {'dropout_prob': 1.0, 'is_training': 1} + self.outputs = {'Out': np.zeros((32, 64)), 'Mask': np.zeros((32, 64))} + + +class TestDropoutOp3(TestDropoutOp): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")} + self.attrs = {'dropout_prob': 0.0, 'is_training': 1} + self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64, 2))} + + +class TestDropoutOp4(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = {'dropout_prob': 0.35, 'is_training': 0} + self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']} + + def test_check_output(self): + self.check_output() + + +class TestDropoutOp5(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")} + self.attrs = {'dropout_prob': 0.75, 'is_training': 0} + self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']} + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_mnist.py b/python/paddle/v2/framework/tests/test_mnist.py index f6f8f49b797fb6e5016a5e309f12f192d5096431..66452cb3965d28fd15e814833079621410775c17 100644 --- a/python/paddle/v2/framework/tests/test_mnist.py +++ b/python/paddle/v2/framework/tests/test_mnist.py @@ -128,7 +128,7 @@ def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None): def cross_entropy_layer(net, input, label): cost_name = "cross_entropy_%d" % uniq_id() cross_entropy_op = Operator( - "onehot_cross_entropy", X=input, label=label, Y=cost_name) + "cross_entropy", X=input, Label=label, Y=cost_name) net.append_op(cross_entropy_op) scope.new_var(cost_name) net.infer_shape(scope) @@ -181,7 +181,7 @@ def error_rate(predict, label): images = data_layer(name="pixel", dims=[BATCH_SIZE, 784]) -labels = data_layer(name="label", dims=[BATCH_SIZE]) +labels = data_layer(name="label", dims=[BATCH_SIZE, 1]) fc1 = fc_layer(net=forward_net, input=images, size=100, act="sigmoid") fc2 = fc_layer(net=forward_net, input=fc1, size=100, act="sigmoid") predict = fc_layer(net=forward_net, input=fc2, size=10, act="softmax") @@ -215,6 +215,7 @@ def test(cost_name): for data in test_reader(): image_data = numpy.array(map(lambda x: x[0], data)).astype("float32") label_data = numpy.array(map(lambda x: x[1], data)).astype("int32") + label_data = numpy.expand_dims(label_data, axis=1) feed_data(images, image_data) feed_data(labels, label_data) @@ -235,6 +236,7 @@ for pass_id in range(PASS_NUM): for data in train_reader(): image_data = numpy.array(map(lambda x: x[0], data)).astype("float32") label_data = numpy.array(map(lambda x: x[1], data)).astype("int32") + label_data = numpy.expand_dims(label_data, axis=1) feed_data(images, image_data) feed_data(labels, label_data) diff --git a/python/paddle/v2/framework/tests/test_onehot_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_onehot_cross_entropy_op.py deleted file mode 100644 index fd3cbdb80374865ccf113768856096bf49dce643..0000000000000000000000000000000000000000 --- a/python/paddle/v2/framework/tests/test_onehot_cross_entropy_op.py +++ /dev/null @@ -1,30 +0,0 @@ -import unittest -import numpy -from op_test import OpTest - - -class TestOnehotCrossEntropyOp(OpTest): - def setUp(self): - self.op_type = "onehot_cross_entropy" - batch_size = 30 - class_num = 10 - - X = numpy.random.uniform(0.1, 1.0, - [batch_size, class_num]).astype("float32") - labels = numpy.random.randint(0, class_num, batch_size, dtype="int32") - - cross_entropy = numpy.asmatrix( - [[-numpy.log(X[i][labels[i]])] for i in range(X.shape[0])], - dtype="float32") - self.inputs = {"X": X, "label": labels} - self.outputs = {"Y": cross_entropy} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(["X"], "Y") - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/v2/framework/tests/test_prelu_op.py b/python/paddle/v2/framework/tests/test_prelu_op.py new file mode 100644 index 0000000000000000000000000000000000000000..76d1f1d5a418b7a2a91b36360a79317d063a72e7 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_prelu_op.py @@ -0,0 +1,28 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class PReluTest(OpTest): + def setUp(self): + self.op_type = "prelu" + x_np = np.random.normal(size=(10, 10)).astype("float32") + x_np_sign = np.sign(x_np) + x_np = x_np_sign * np.maximum(x_np, .005) + alpha_np = np.array([.1]) + self.inputs = {'X': x_np, 'Alpha': alpha_np} + out_np = np.maximum(self.inputs['X'], 0.) + out_np = out_np + np.minimum(self.inputs['X'], + 0.) * self.inputs['Alpha'] + assert out_np is not self.inputs['X'] + self.outputs = {'Out': out_np} + + def not_test_check_output(self): + self.check_output() + + def not_test_check_grad(self): + self.check_grad(['X'], 'Out') + + +if __name__ == "__main__": + unittest.main()