提交 efa4526c 编写于 作者: C caoying03

finish implementation and fix unittest.

上级 8d88c52d
......@@ -43,8 +43,6 @@ template <typename Place, typename T>
class SoftmaxGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
std::shared_ptr<Tensor> scale_ = std::make_shared<Tensor>();
auto Y = context.Input<Tensor>("Y");
auto dY = context.Input<Tensor>(framework::GradVarName("Y"));
auto dX = context.Output<Tensor>(framework::GradVarName("X"));
......
......@@ -17,31 +17,16 @@
namespace paddle {
namespace operators {
class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto logits = ctx.Input<Tensor>("Logits");
PADDLE_ENFORCE(
logits->dims().size() == 2UL,
"The input of softmax_with_cross_entropy should be a 2-d tensor.");
PADDLE_ENFORCE(ctx.Input<Tensor>("Label")->dims().size() == 1UL,
"The label should be a 1-d tensor.");
ctx.Output<Tensor>("Label")->Resize({logits->dims()[0]});
}
};
class SoftmaxWithCrossEntropyOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
SoftmaxWithCrossEntropyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
SoftmaxWithCrossEntropyOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Logits",
"The unscaled log probabilities which is a 2-D tensor<float> with"
"shape [N x K]. N is the batch_size, and K is the class number.");
"shape [N x K]. N is the batch_size, and K is the class number.")
.NotInGradient();
AddInput("Label", "The ground truth. A 1-D tensor<int> with shape N.");
AddOutput("Softmax",
"Store the outputs of softmax function, "
......@@ -70,22 +55,34 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Loss"),
"Input(Loss) should be not null.");
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")),
"Input(Loss@GRAD) should be not null.");
PADDLE_ENFORCE_EQ(
ctx.Input<Tensor>("Logits")->dims(),
ctx.Input<Tensor>(framework::GradVarName("Logits"))->dims(),
"Input(Logits) and its gradients should have a same shape.");
PADDLE_ENFORCE_EQ(
ctx.Input<Tensor>("Logits")->dims(),
ctx.Input<Tensor>(framework::GradVarName("Logits"))->dims(),
"Input(Logits) and its gradients should have a same shape.");
"Input(Loss@Grad) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"),
"Input(Softmax) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Lable) should be not null.");
ctx.Output<Tensor>(framework::GradVarName("Logits"))
->Resize(ctx.Input<Tensor>("Softmax")->dims());
}
};
class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
const Tensor* logits = ctx.Input<Tensor>("Logits");
PADDLE_ENFORCE(
logits->dims().size() == 2UL,
"The input of softmax_with_cross_entropy should be a 2-d tensor.");
PADDLE_ENFORCE(ctx.Input<Tensor>("Label")->dims().size() == 1UL,
"The label should be a 1-d tensor.");
ctx.Output<Tensor>("Softmax")->Resize(logits->dims());
ctx.Output<Tensor>("Loss")->Resize({logits->dims()[0], 1});
}
};
......@@ -98,9 +95,7 @@ REGISTER_OP(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
ops::SoftmaxWithCrossEntropyOpMaker,
softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyOpGrad);
REGISTER_OP_CPU_KERNEL(
softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyKernel<float>);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradKernel<float>);
......@@ -17,9 +17,4 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradKernel<paddle::platform::GPUPlace, float>);
// TODO(caoying) add GPU kernel
......@@ -26,20 +26,24 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
template <typename T>
class SoftmaxWithCrossEntropyKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto place = context.GetPlace();
PADDLE_ENFORCE(platform::is_cpu_place(place),
"This kernel only runs on CPU.");
// Calculate ths softmax outputs.
const Tensor* logits = context.Input<Tensor>("Logits");
Tensor* softmax = context.Output<Tensor>("Softmax");
// allocate memory on device.
softmax->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<Place, T>()(logits, softmax, context);
math::SoftmaxFunctor<platform::CPUPlace, T>()(logits, softmax, context);
// Calculate the cross entropy loss based on hard labels.
T* softmax_out = softmax->data<T>();
const int* label_data = context.Input<Tensor>("label")->data<int>();
const int* label_data = context.Input<Tensor>("Label")->data<int>();
Tensor* loss = context.Output<Tensor>("Loss");
loss->mutable_data<T>(context.GetPlace());
......@@ -55,10 +59,24 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel {
}
};
template <typename Place, typename T>
template <typename T>
class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {}
void Compute(const framework::ExecutionContext& context) const override {
Tensor* logit_grad =
context.Output<Tensor>(framework::GradVarName("Logits"));
logit_grad->ShareDataWith<T>(*context.Input<Tensor>("Softmax"));
T* logit_grad_data = logit_grad->data<T>();
const int batch_size = logit_grad->dims()[0];
const int class_num = logit_grad->dims()[1];
const int* label_data = context.Input<Tensor>("Label")->data<int>();
for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i];
logit_grad_data[index] -= .1;
}
}
};
} // namespace operators
......
......@@ -39,7 +39,6 @@ USE_OP(elementwise_mul);
USE_OP(mean);
USE_OP(sigmoid);
USE_OP(softmax);
USE_OP(softmax_with_cross_entropy);
USE_OP(rowwise_add);
USE_OP(fill_zeros_like);
USE_NO_KERNEL_OP(recurrent);
......@@ -53,6 +52,7 @@ USE_OP(cos_sim);
USE_CPU_ONLY_OP(gather);
USE_CPU_ONLY_OP(scatter);
USE_CPU_ONLY_OP(concat);
USE_CPU_ONLY_OP(softmax_with_cross_entropy);
USE_OP(top_k);
USE_OP(squared_l2_distance);
USE_OP(sum);
......
......@@ -166,7 +166,7 @@ def get_gradient(scope, op, inputs, outputs, grad_name, place,
class OpTest(unittest.TestCase):
def check_output_with_place(self, place):
def check_output_with_place(self, place, atol):
self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_attrs = self.attrs if hasattr(self, "attrs") else dict()
......@@ -188,22 +188,23 @@ class OpTest(unittest.TestCase):
expect = sub_out[sub_out_name]
self.assertTrue(
np.allclose(
actual, expect, atol=1e-05),
"output name: " + out_name + "has diff")
actual, expect, atol=atol),
"output name: " + out_name + " has diff.")
else:
actual = np.array(self.scope.find_var(out_name).get_tensor())
expect = self.outputs[out_name]
self.assertTrue(
np.allclose(
actual, expect, atol=1e-05),
"output name: " + out_name + "has diff")
actual, expect, atol=atol),
"output name: " + out_name + " has diff.")
def check_output(self):
def check_output(self, atol=1e-5):
places = [core.CPUPlace()]
if core.is_compile_gpu():
places.append(core.GPUPlace(0))
for place in places:
self.check_output_with_place(place)
self.check_output_with_place(place, atol)
def __assert_is_close(self, numeric_grads, analytic_grads, names,
max_relative_error, msg_prefix):
......@@ -217,9 +218,10 @@ class OpTest(unittest.TestCase):
def err_msg():
offset = np.argmax(diff_mat > max_relative_error)
return "%s Variable %s max gradient diff %f over limit %f, the first " \
"error element is %d" % (
msg_prefix, name, max_diff, max_relative_error, offset)
return ("%s Variable %s max gradient diff %f over limit %f, "
"the first error element is %d") % (
msg_prefix, name, max_diff, max_relative_error,
offset)
self.assertLessEqual(max_diff, max_relative_error, err_msg())
......
......@@ -11,7 +11,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self.op_type = "softmax_with_cross_entropy"
MAX_BATCH_SIZE = 23
MAX_CLASS_NUM = 255
MAX_CLASS_NUM = 10
batch_size = np.random.randint(1, MAX_BATCH_SIZE, 1)[0]
class_num = np.random.randint(2, MAX_CLASS_NUM, 1)[0]
......@@ -21,18 +21,18 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
softmax = np.apply_along_axis(stable_softmax, 1, logits)
labels = np.random.randint(0, class_num, batch_size, dtype="int32")
cross_entropy = [
-np.log(softmax[i][labels[i]]) for i in range(softmax.shape[0])
]
cross_entropy = np.asmatrix(
[[-np.log(softmax[i][labels[i]])] for i in range(softmax.shape[0])],
dtype="float32")
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {"Loss": cross_entropy}
self.outputs = {"Softmax": softmax, "Loss": cross_entropy}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
pass
self.check_grad(["Logits"], "Loss")
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册