diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 55435103489ace11868eed61c38018d8ba357e65..0b588297169540417586d7c167a1265827b683ac 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -55,6 +55,10 @@ class OperatorBase { /// e.g. Variable "x@GRAD" is the gradient of varibale "x". static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; } + static std::string GRAD_VAR_NAME(const std::string& name) { + return name + GRAD_VAR_SUFFIX(); + } + /// Variables with this suffix are supposed to be filled up with zeros. static std::string ZERO_VAR_SUFFIX() { return "@ZERO"; } diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 5b59fad7d5f9729b0862f8cd78cb32f94f87f513..5cbb96ab754467ea6ddab9380ca25987c9376980 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -1,16 +1,17 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ #include "paddle/operators/softmax_op.h" namespace paddle { @@ -19,12 +20,13 @@ namespace operators { class SoftmaxOp : public OperatorWithKernel { protected: void InferShape(const InferShapeContext &ctx) const override { - PADDLE_ENFORCE(ctx.InputSize() == 1, "Only one input is need for softmax"); - PADDLE_ENFORCE(ctx.Input(0)->dims().size() == 2, + PADDLE_ENFORCE(ctx.InputSize() == 1UL, + "Only one input is need for softmax"); + PADDLE_ENFORCE(ctx.Input("X")->dims().size() == 2UL, "The input of softmax op must be matrix"); - PADDLE_ENFORCE(ctx.OutputSize() == 1, + PADDLE_ENFORCE(ctx.OutputSize() == 1UL, "Only one output is need for softmax"); - ctx.Output(0)->Resize(ctx.Input(0)->dims()); + ctx.Output("Y")->Resize(ctx.Input("X")->dims()); } }; @@ -40,10 +42,19 @@ public: class SoftmaxOpGrad : public OperatorWithKernel { protected: - void InferShape(const InferShapeContext &ctx) const override {} - std::string DebugString() const override { - LOG(INFO) << "SoftmaxOpGrad"; - return ""; + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 3UL, + "Input of SoftmaxOpGrad should be 3, X, Y, YG"); + PADDLE_ENFORCE(ctx.OutputSize() == 1UL, + "Output of SoftmaxOpGrad should be 1"); + PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null"); + PADDLE_ENFORCE(ctx.InputVar(GRAD_VAR_NAME("Y")) != nullptr, + "Input(Y@GRAD) should not be null"); + PADDLE_ENFORCE(ctx.Input("Y")->dims() == + ctx.Input(GRAD_VAR_NAME("Y"))->dims(), + "the shape of Input(0) and Input(1) should be the same"); + ctx.Output(GRAD_VAR_NAME("X")) + ->Resize(ctx.Input("Y")->dims()); } }; @@ -51,5 +62,7 @@ protected: } // namespace paddle REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker); -REGISTER_GRADIENT_OP(softmax, softmax_grad, ops::SoftmaxOpGrad); REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel); +REGISTER_GRADIENT_OP(softmax, softmax_grad, ops::SoftmaxOpGrad); +REGISTER_OP_CPU_KERNEL(softmax_grad, + ops::SoftmaxGradKernel); diff --git a/paddle/operators/softmax_op.cu b/paddle/operators/softmax_op.cu index ddf8f6e913ccf450185f377f531bf978f69ed1fc..8c652213f2e4c0e0ea1a31987fcb37c86374cd2a 100644 --- a/paddle/operators/softmax_op.cu +++ b/paddle/operators/softmax_op.cu @@ -3,3 +3,4 @@ #include "paddle/operators/softmax_op.h" REGISTER_OP_GPU_KERNEL(softmax, ops::SoftmaxKernel); +REGISTER_OP_GPU_KERNEL(softmax_grad, ops::SoftmaxGradKernel); diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 75c5197697dada58e09f4cda41cea13af56e79a3..13e74a79077982e9fba5d90f40986e699c1ed897 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -1,19 +1,22 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once +#include "paddle/framework/ddim.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/tensor.h" #include "paddle/operators/type_alias.h" namespace paddle { @@ -23,8 +26,8 @@ template class SoftmaxKernel : public OpKernel { public: void Compute(const ExecutionContext& context) const override { - auto input = context.Input(0); - auto output = context.Output(0); + auto input = context.Input("X"); + auto output = context.Output("Y"); output->mutable_data(context.GetPlace()); auto logits = EigenMatrix::From(*input); @@ -57,5 +60,38 @@ public: .broadcast(one_by_class)); } }; + +template +class SoftmaxGradKernel : public OpKernel { +public: + void Compute(const ExecutionContext& context) const override { + std::shared_ptr scale_ = std::make_shared(); + + auto Y = context.Input("Y"); + auto dY = context.Input(OperatorBase::GRAD_VAR_NAME("Y")); + auto dX = context.Output(OperatorBase::GRAD_VAR_NAME("X")); + dX->mutable_data(context.GetPlace()); + + const int batch_size = Y->dims()[0]; + const int class_num = Y->dims()[1]; + + Eigen::DSizes along_class(1); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, class_num); + + auto Y_eigen = EigenMatrix::From(*Y); + auto dY_eigen = EigenMatrix::From(*dY); + auto dX_eigen = EigenMatrix::From(*dX); + auto place = context.GetEigenDevice(); + + auto dot = (Y_eigen * dY_eigen) + .sum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class); + dX_eigen.device(place) = (dY_eigen - dot) * Y_eigen; + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/operators/type_alias.h b/paddle/operators/type_alias.h index 9049ffda1da5408411687474c5ed0c76c2394623..4ee08a099d877795a0292185de3099fa84473819 100644 --- a/paddle/operators/type_alias.h +++ b/paddle/operators/type_alias.h @@ -22,6 +22,7 @@ namespace paddle { namespace operators { using OpKernel = framework::OpKernel; +using OperatorBase = framework::OperatorBase; using InferShapeContext = framework::InferShapeContext; using ExecutionContext = framework::ExecutionContext; using Variable = framework::Variable; diff --git a/python/paddle/v2/framework/tests/test_softmax_op.py b/python/paddle/v2/framework/tests/test_softmax_op.py index 191b698c1cdec9b86b4ded6b1f743586867ca62f..c80888128781d98e4ed30d845a30b39121f66459 100644 --- a/python/paddle/v2/framework/tests/test_softmax_op.py +++ b/python/paddle/v2/framework/tests/test_softmax_op.py @@ -1,6 +1,10 @@ import unittest -from op_test_util import OpTestMeta + import numpy as np +import paddle.v2.framework.core as core +import paddle.v2.framework.create_op_creation_methods as creation + +from op_test_util import OpTestMeta def stable_softmax(x): @@ -19,5 +23,63 @@ class TestSoftmaxOp(unittest.TestCase): self.Y = np.apply_along_axis(stable_softmax, 1, self.X) +class TestSoftmaxGradOp(unittest.TestCase): + def test_softmax_grad(self): + op = creation.op_creations.softmax(X="X", Y="Y") + backward_op = core.Operator.backward(op, set()) + self.assertEqual(backward_op.type(), "softmax_grad") + expected = '''Op(softmax_grad), inputs:(X, Y, Y@GRAD), outputs:(X@GRAD).''' + self.assertEqual(expected, str(backward_op)) + + batch_size = 3 + class_num = 5 + # Initialize X and add 1e-2 for numerical stability + Y = np.random.rand(batch_size, class_num).astype(np.float32) + Y = Y + 1e-2 + dY = np.random.rand(batch_size, class_num).astype(np.float32) + + # Reference implementation of cross entropy with soft labels + def label_softmax_grad(Y, dY): + dX = Y * 0.0 + for i in range(batch_size): + d = np.dot(Y[i, :], dY[i, :]) + dX[i, :] = Y[i, :] * (dY[i, :] - d) + return dX + + expected = label_softmax_grad(Y, dY) + + scope = core.Scope() + places = [] + places.append(core.CPUPlace()) + if core.is_compile_gpu(): + places.append(core.GPUPlace(0)) + + for place in places: + y = scope.new_var("Y") + y_tensor = y.get_tensor() + y_tensor.set_dims([batch_size, class_num]) + y_tensor.alloc_float(place) + y_tensor.set(Y, place) + + dy = scope.new_var("Y@GRAD") + dy_tensor = dy.get_tensor() + dy_tensor.set_dims([batch_size, class_num]) + dy_tensor.alloc_float(place) + dy_tensor.set(dY, place) + + x = scope.new_var("X") + dx = scope.new_var("X@GRAD") + + tensor = scope.find_var("X@GRAD").get_tensor() + backward_op.infer_shape(scope) + self.assertEqual([batch_size, class_num], tensor.shape()) + + ctx = core.DeviceContext.create(place) + backward_op.run(scope, ctx) + actual = np.array(tensor) + + np.testing.assert_almost_equal(actual, expected, decimal=3) + + if __name__ == '__main__': unittest.main()