diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index 83649d9562a9f825600b7529ee8240e1cd0ac92b..265f06103eddab23c86aba43424b58554faf9cf9 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -31,7 +31,7 @@ limitations under the License. */ namespace py = pybind11; USE_OP(add_two); -USE_CPU_ONLY_OP(onehot_cross_entropy); +USE_OP(onehot_cross_entropy); USE_OP(sgd); USE_OP(mul); USE_OP(mean); diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index a623c551e1088365ade6f73bc6149977b6ef017e..ab1e1c101a10e09a81f7785d2f1514822e3bdf15 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -39,11 +39,10 @@ class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto X_grad = ctx.Output(framework::GradVarName("X")); + auto dX = ctx.Output(framework::GradVarName("X")); auto X = ctx.Input("X"); - // TODO(superjom) add enforce here after helper functions ready - X_grad->Resize(X->dims()); + dX->Resize(X->dims()); } }; @@ -70,9 +69,7 @@ namespace ops = paddle::operators; REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad, ops::OnehotCrossEntropyGradientOp); -REGISTER_OP_CPU_KERNEL( - onehot_cross_entropy, - ops::OnehotCrossEntropyOpKernel); -REGISTER_OP_CPU_KERNEL( - onehot_cross_entropy_grad, - ops::OnehotCrossEntropyGradientOpKernel); +REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, + ops::OnehotCrossEntropyOpKernel); +REGISTER_OP_CPU_KERNEL(onehot_cross_entropy_grad, + ops::OnehotCrossEntropyGradientOpKernel); diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 4bbc8f093a794d46737a16488684a6a0cc25e285..d999bfce58c8a6db5c811aad677c07094b881841 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -12,10 +12,122 @@ See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU -#include "paddle/operators/cross_entropy_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/platform/assert.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +__host__ __device__ T clipping_log(const T x) { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + T v = log(x); + if (v == INFINITY) { + return kApproInf; + } + if (v == -INFINITY) { + return -kApproInf; + } + return v; +} + +template +__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, + const int N, const int D) { + // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. + // CUDA_1D_KERNEL_LOOP(i, N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + PADDLE_ASSERT(label[i] >= 0 && label[i] < D); + Y[i] = -clipping_log(X[i * D + label[i]]); + } +} + +// TODO(qingqing): make zero setting an common function. +template +__global__ void zero(T* X, const int N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + X[i] = 0.0; + } +} + +template +__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X, + const int* label, const int N, + const int D) { + // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. + // CUDA_1D_KERNEL_LOOP(i, N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + int idx = i * D + label[i]; + dX[idx] = -dY[i] / X[idx]; + } +} + +template +class OnehotCrossEntropyOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto X = ctx.Input("X"); + const T* Xdata = X->data(); + const int* label_data = ctx.Input("label")->data(); + auto Y = ctx.Output("Y"); + Y->mutable_data(ctx.GetPlace()); + T* Ydata = Y->data(); + + int N = X->dims()[0]; + int D = X->dims()[1]; + int block = 512; + int grid = (N + block - 1) / block; + // TODO(qingqing) launch kernel on specified stream + // base on ExecutionContext. + CrossEntropyKernel<<>>(Ydata, Xdata, label_data, N, D); + } +}; + +template +class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto X = ctx.Input("X"); + auto dX = ctx.Output(framework::GradVarName("X")); + auto dY = ctx.Input(framework::GradVarName("Y")); + auto label = ctx.Input("label"); + + auto* dXdata = dX->template mutable_data(ctx.GetPlace()); + auto* dYdata = dY->template data(); + auto* Xdata = X->template data(); + auto* label_data = label->data(); + + int N = X->dims()[0]; + int D = X->dims()[1]; + int block = 512; + int grid = (N * D + block - 1) / block; + zero<<>>(dXdata, N * D); + + grid = (N + block - 1) / block; + // TODO(qingqing): launch kernel on specified stream + // base on ExecutionContext. + CrossEntropyGradientKernel<<>>(dXdata, dYdata, Xdata, + label_data, N, D); + } +}; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL( - onehot_cross_entropy, - ops::OnehotCrossEntropyOpKernel); +REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, + ops::OnehotCrossEntropyOpCUDAKernel); +REGISTER_OP_GPU_KERNEL(onehot_cross_entropy_grad, + ops::OnehotCrossEntropyGradientOpCUDAKernel); diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index b7df92c9a98ebf12b72a8d3d8e8e4e1a950f06c9..eb4d1348de1d940e2648c83c8ba94b289f10c5b2 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -21,7 +21,7 @@ namespace operators { using Tensor = framework::Tensor; template -T tolerable_value(T x) { +inline T tolerable_value(const T x) { static_assert(std::is_floating_point::value, "tolerable_value works only on float, " "double and double double."); @@ -39,10 +39,13 @@ T tolerable_value(T x) { return x; } -template +template class OnehotCrossEntropyOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + auto X = ctx.Input("X"); const T* Xdata = X->data(); const int* label_data = ctx.Input("label")->data(); @@ -62,10 +65,13 @@ class OnehotCrossEntropyOpKernel : public framework::OpKernel { } }; -template +template class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + auto X = ctx.Input("X"); auto dX = ctx.Output(framework::GradVarName("X")); auto dY = ctx.Input(framework::GradVarName("Y")); @@ -79,6 +85,8 @@ class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel { const int batch_size = X->dims()[0]; const int class_num = X->dims()[1]; + // TODO(qingqing): make zero setting an common function. + memset(dXdata, 0, sizeof(T) * batch_size * class_num); for (int i = 0; i < batch_size; ++i) { int index = i * class_num + label_data[i]; dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]); diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index dcd2237459d3bbabf449d30e4279b637c569555b..a85363ad81d2a23e7267026c067f74f8c94c4786 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -80,4 +80,4 @@ Use to initialize tensor with gaussian random generator. namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker); -REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel); \ No newline at end of file +REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel); diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index 876b3ef5575325daed8bebc0c964bd80064a7bb3..29491137e6d8b4bfa2d0d07d48ffed1212a6131f 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -81,4 +81,4 @@ Used to initialize tensor with uniform random generator. REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp, paddle::operators::UniformRandomOpMaker); REGISTER_OP_CPU_KERNEL(uniform_random, - paddle::operators::CPUUniformRandomKernel); \ No newline at end of file + paddle::operators::CPUUniformRandomKernel); diff --git a/paddle/operators/uniform_random_op.cu b/paddle/operators/uniform_random_op.cu index 6716b7c7f20371f2970efaebe94cb1381983783c..1d6709934cbbcf50265eabef87c857654f783ed8 100644 --- a/paddle/operators/uniform_random_op.cu +++ b/paddle/operators/uniform_random_op.cu @@ -65,4 +65,4 @@ class GPUUniformRandomKernel : public framework::OpKernel { } // namespace paddle REGISTER_OP_GPU_KERNEL(uniform_random, - paddle::operators::GPUUniformRandomKernel); \ No newline at end of file + paddle::operators::GPUUniformRandomKernel); diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index dd65e0f2dc23d3f657ff16c55fb297dae210b2d7..3bc05a0feccbbd3d5e7852d85bd3dc8edaccfd07 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -64,7 +64,8 @@ class OpTestMeta(type): actual = numpy.array(scope.find_var(out_name).get_tensor()) expect = self.outputs[out_name] self.assertTrue( - numpy.allclose(actual, expect), + numpy.allclose( + actual, expect, atol=1e-05), "output name: " + out_name + "has diff") obj.test_all = test_all diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index 4815192e255c6e0429db3f50918a76a773b30131..d4277f2a42ce2e66e37405ccd3b2ee444d403d1a 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -8,9 +8,8 @@ class TestCrossEntropy(unittest.TestCase): __metaclass__ = OpTestMeta def setUp(self): - # TODO this unit test is not passed self.type = "onehot_cross_entropy" - batch_size = 100 + batch_size = 30 class_num = 10 X = numpy.random.random((batch_size, class_num)).astype("float32") label = 5 * numpy.ones(batch_size).astype("int32") @@ -22,9 +21,9 @@ class TestCrossEntropy(unittest.TestCase): class CrossEntropyGradOpTest(GradientChecker): - def test_softmax_grad(self): + def test_check_grad(self): op = create_op("onehot_cross_entropy") - batch_size = 100 + batch_size = 30 class_num = 10 inputs = { "X": numpy.random.uniform(