diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 6d14542a7218a87317c9450feb952104a69d64a1..68d05fc21547e202f0109774656b29a7078ac0b7 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -43,8 +43,6 @@ template class SoftmaxGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - std::shared_ptr scale_ = std::make_shared(); - auto Y = context.Input("Y"); auto dY = context.Input(framework::GradVarName("Y")); auto dX = context.Output(framework::GradVarName("X")); diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index b4aa9aab4b06319b56ecec42aa9614da87638ed3..fd75494ff80af8daf49f9dffe50d0c954f74c790 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -17,31 +17,16 @@ namespace paddle { namespace operators { -class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - auto logits = ctx.Input("Logits"); - PADDLE_ENFORCE( - logits->dims().size() == 2UL, - "The input of softmax_with_cross_entropy should be a 2-d tensor."); - PADDLE_ENFORCE(ctx.Input("Label")->dims().size() == 1UL, - "The label should be a 1-d tensor."); - ctx.Output("Label")->Resize({logits->dims()[0]}); - } -}; - class SoftmaxWithCrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { public: - SoftmaxWithCrossEntropyOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + SoftmaxWithCrossEntropyOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Logits", "The unscaled log probabilities which is a 2-D tensor with" - "shape [N x K]. N is the batch_size, and K is the class number."); + "shape [N x K]. N is the batch_size, and K is the class number.") + .NotInGradient(); AddInput("Label", "The ground truth. A 1-D tensor with shape N."); AddOutput("Softmax", "Store the outputs of softmax function, " @@ -70,22 +55,34 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Loss"), - "Input(Loss) should be not null."); + void InferShape(const framework::InferShapeContext& ctx) const override { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")), - "Input(Loss@GRAD) should be not null."); - PADDLE_ENFORCE_EQ( - ctx.Input("Logits")->dims(), - ctx.Input(framework::GradVarName("Logits"))->dims(), - "Input(Logits) and its gradients should have a same shape."); - PADDLE_ENFORCE_EQ( - ctx.Input("Logits")->dims(), - ctx.Input(framework::GradVarName("Logits"))->dims(), - "Input(Logits) and its gradients should have a same shape."); - + "Input(Loss@Grad) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"), + "Input(Softmax) should be not null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), "Input(Lable) should be not null."); + + ctx.Output(framework::GradVarName("Logits")) + ->Resize(ctx.Input("Softmax")->dims()); + } +}; + +class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext& ctx) const override { + const Tensor* logits = ctx.Input("Logits"); + PADDLE_ENFORCE( + logits->dims().size() == 2UL, + "The input of softmax_with_cross_entropy should be a 2-d tensor."); + PADDLE_ENFORCE(ctx.Input("Label")->dims().size() == 1UL, + "The label should be a 1-d tensor."); + + ctx.Output("Softmax")->Resize(logits->dims()); + ctx.Output("Loss")->Resize({logits->dims()[0], 1}); } }; @@ -98,9 +95,7 @@ REGISTER_OP(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp, ops::SoftmaxWithCrossEntropyOpMaker, softmax_with_cross_entropy_grad, ops::SoftmaxWithCrossEntropyOpGrad); -REGISTER_OP_CPU_KERNEL( - softmax_with_cross_entropy, - ops::SoftmaxWithCrossEntropyKernel); -REGISTER_OP_CPU_KERNEL( - softmax_with_cross_entropy_grad, - ops::SoftmaxWithCrossEntropyGradKernel); +REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy, + ops::SoftmaxWithCrossEntropyKernel); +REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad, + ops::SoftmaxWithCrossEntropyGradKernel); diff --git a/paddle/operators/softmax_with_cross_entropy_op.cu b/paddle/operators/softmax_with_cross_entropy_op.cu index c9d47cc4aae0e3cc1e9f40da33daf95336ac754f..922bb19d4de6c6b7d3e09e84b88e877b53cb0892 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/operators/softmax_with_cross_entropy_op.cu @@ -17,9 +17,4 @@ namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL( - softmax_with_cross_entropy, - ops::SoftmaxWithCrossEntropyKernel); -REGISTER_OP_GPU_KERNEL( - softmax_with_cross_entropy_grad, - ops::SoftmaxWithCrossEntropyGradKernel); +// TODO(caoying) add GPU kernel diff --git a/paddle/operators/softmax_with_cross_entropy_op.h b/paddle/operators/softmax_with_cross_entropy_op.h index 4c019a75992c316b56f5ecccc248bdc57a399e5d..e147cdb815bd7dad087a0e1054fccf3b14df17cc 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.h +++ b/paddle/operators/softmax_with_cross_entropy_op.h @@ -26,20 +26,24 @@ template using EigenMatrix = framework::EigenMatrix; -template +template class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + auto place = context.GetPlace(); + PADDLE_ENFORCE(platform::is_cpu_place(place), + "This kernel only runs on CPU."); + // Calculate ths softmax outputs. const Tensor* logits = context.Input("Logits"); Tensor* softmax = context.Output("Softmax"); - // allocate memory on device. softmax->mutable_data(context.GetPlace()); - math::SoftmaxFunctor()(logits, softmax, context); + + math::SoftmaxFunctor()(logits, softmax, context); // Calculate the cross entropy loss based on hard labels. T* softmax_out = softmax->data(); - const int* label_data = context.Input("label")->data(); + const int* label_data = context.Input("Label")->data(); Tensor* loss = context.Output("Loss"); loss->mutable_data(context.GetPlace()); @@ -55,10 +59,24 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { } }; -template +template class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override {} + void Compute(const framework::ExecutionContext& context) const override { + Tensor* logit_grad = + context.Output(framework::GradVarName("Logits")); + logit_grad->ShareDataWith(*context.Input("Softmax")); + T* logit_grad_data = logit_grad->data(); + + const int batch_size = logit_grad->dims()[0]; + const int class_num = logit_grad->dims()[1]; + + const int* label_data = context.Input("Label")->data(); + for (int i = 0; i < batch_size; ++i) { + int index = i * class_num + label_data[i]; + logit_grad_data[index] -= .1; + } + } }; } // namespace operators diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 93792c568e00048e1bfe5cd571dc4a34d3e07d5f..cb361596aea757046946bc45ed38bb0d25416947 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -39,7 +39,6 @@ USE_OP(elementwise_mul); USE_OP(mean); USE_OP(sigmoid); USE_OP(softmax); -USE_OP(softmax_with_cross_entropy); USE_OP(rowwise_add); USE_OP(fill_zeros_like); USE_NO_KERNEL_OP(recurrent); @@ -53,6 +52,7 @@ USE_OP(cos_sim); USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(scatter); USE_CPU_ONLY_OP(concat); +USE_CPU_ONLY_OP(softmax_with_cross_entropy); USE_OP(top_k); USE_OP(squared_l2_distance); USE_OP(sum); diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index 4fec4c9109bf247abb2068177583acb47a8ebd97..f5f11aa93d7b22e3bd30f4100cf4cfde977e807a 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -166,7 +166,7 @@ def get_gradient(scope, op, inputs, outputs, grad_name, place, class OpTest(unittest.TestCase): - def check_output_with_place(self, place): + def check_output_with_place(self, place, atol): self.scope = core.Scope() op_inputs = self.inputs if hasattr(self, "inputs") else dict() op_attrs = self.attrs if hasattr(self, "attrs") else dict() @@ -188,22 +188,23 @@ class OpTest(unittest.TestCase): expect = sub_out[sub_out_name] self.assertTrue( np.allclose( - actual, expect, atol=1e-05), - "output name: " + out_name + "has diff") + actual, expect, atol=atol), + "output name: " + out_name + " has diff.") else: actual = np.array(self.scope.find_var(out_name).get_tensor()) expect = self.outputs[out_name] + self.assertTrue( np.allclose( - actual, expect, atol=1e-05), - "output name: " + out_name + "has diff") + actual, expect, atol=atol), + "output name: " + out_name + " has diff.") - def check_output(self): + def check_output(self, atol=1e-5): places = [core.CPUPlace()] if core.is_compile_gpu(): places.append(core.GPUPlace(0)) for place in places: - self.check_output_with_place(place) + self.check_output_with_place(place, atol) def __assert_is_close(self, numeric_grads, analytic_grads, names, max_relative_error, msg_prefix): @@ -217,9 +218,10 @@ class OpTest(unittest.TestCase): def err_msg(): offset = np.argmax(diff_mat > max_relative_error) - return "%s Variable %s max gradient diff %f over limit %f, the first " \ - "error element is %d" % ( - msg_prefix, name, max_diff, max_relative_error, offset) + return ("%s Variable %s max gradient diff %f over limit %f, " + "the first error element is %d") % ( + msg_prefix, name, max_diff, max_relative_error, + offset) self.assertLessEqual(max_diff, max_relative_error, err_msg()) diff --git a/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py index 6116110569143fc547325d98f215c909de9622ea..4e35c063b96734845ed87a548dff13e9ce941485 100644 --- a/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py @@ -11,7 +11,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): self.op_type = "softmax_with_cross_entropy" MAX_BATCH_SIZE = 23 - MAX_CLASS_NUM = 255 + MAX_CLASS_NUM = 10 batch_size = np.random.randint(1, MAX_BATCH_SIZE, 1)[0] class_num = np.random.randint(2, MAX_CLASS_NUM, 1)[0] @@ -21,18 +21,18 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): softmax = np.apply_along_axis(stable_softmax, 1, logits) labels = np.random.randint(0, class_num, batch_size, dtype="int32") - cross_entropy = [ - -np.log(softmax[i][labels[i]]) for i in range(softmax.shape[0]) - ] + cross_entropy = np.asmatrix( + [[-np.log(softmax[i][labels[i]])] for i in range(softmax.shape[0])], + dtype="float32") self.inputs = {"Logits": logits, "Label": labels} - self.outputs = {"Loss": cross_entropy} + self.outputs = {"Softmax": softmax, "Loss": cross_entropy} def test_check_output(self): self.check_output() def test_check_grad(self): - pass + self.check_grad(["Logits"], "Loss") if __name__ == "__main__":