提交 000d7511 编写于 作者: C caoying03

fix backward op.

上级 201c2bcf
...@@ -37,13 +37,13 @@ class CrossEntropyOp : public framework::OperatorWithKernel { ...@@ -37,13 +37,13 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
"The 1st dimension of Input(X) and Input(Label) should " "The 1st dimension of Input(X) and Input(Label) should "
"be equal."); "be equal.");
if (ctx.Attr<bool>("soft_label")) { if (ctx.Attr<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
"If Attr(soft_label) == true, the 2nd dimension of " "If Attr(softLabel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal."); "Input(X) and Input(Label) should be equal.");
} else { } else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1, PADDLE_ENFORCE_EQ(label->dims()[1], 1,
"If Attr(soft_label) == false, the 2nd dimension of " "If Attr(softLabel) == false, the 2nd dimension of "
"Input(Label) should be 1."); "Input(Label) should be 1.");
} }
...@@ -63,6 +63,8 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -63,6 +63,8 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
"Input(Label) should be not null."); "Input(Label) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
"Input(Y@GRAD) shoudl be not null."); "Input(Y@GRAD) shoudl be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("X")),
"Output(X@GRAD) should be not null.");
auto x = ctx.Input<Tensor>("X"); auto x = ctx.Input<Tensor>("X");
auto label = ctx.Input<Tensor>("Label"); auto label = ctx.Input<Tensor>("Label");
...@@ -80,13 +82,13 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -80,13 +82,13 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
"be equal."); "be equal.");
PADDLE_ENFORCE_EQ(dy->dims()[1], 1, PADDLE_ENFORCE_EQ(dy->dims()[1], 1,
"The 2nd dimension of Input(Y@Grad) should be 1."); "The 2nd dimension of Input(Y@Grad) should be 1.");
if (ctx.Attr<bool>("soft_label")) { if (ctx.Attr<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
"When Attr(soft_label) == true, the 2nd dimension of " "When Attr(softLabel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal."); "Input(X) and Input(Label) should be equal.");
} else { } else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1, PADDLE_ENFORCE_EQ(label->dims()[1], 1,
"When Attr(soft_label) == false, the 2nd dimension of " "When Attr(softLabel) == false, the 2nd dimension of "
"Input(Label) should be 1."); "Input(Label) should be 1.");
} }
...@@ -105,18 +107,19 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -105,18 +107,19 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
"where N is the batch size and D is the number of classes. " "where N is the batch size and D is the number of classes. "
"This input is a probability computed by the previous operator, " "This input is a probability computed by the previous operator, "
"which is almost always the result of a softmax operator."); "which is almost always the result of a softmax operator.");
AddInput("Label", AddInput(
"Label",
"(Tensor, default Tensor<int>), the ground truth which is " "(Tensor, default Tensor<int>), the ground truth which is "
"a 1-D or 2-D tensor. " "a 2-D tensor. "
"When soft_label is set to 0, `Label` is a Tensor<int> with shape " "When softLabel is set to false, `Label` is a Tensor<int> with shape "
"[N x 1]. " "[N x 1]. "
"When soft_label is set to 1, `Label` is a Tensor<float/double> " "When softLabel is set to true, `Label` is a Tensor<float/double> "
"with shape [N x K]."); "with shape [N x K].");
AddOutput("Y", AddOutput("Y",
"(Tensor, default Tensor<float>), a 1-D tensor " "(Tensor, default Tensor<float>), a 2-D tensor "
"with shape [N x 1]. The cross entropy loss."); "with shape [N x 1]. The cross entropy loss.");
AddAttr<bool>( AddAttr<bool>(
"soft_label", "softLabel",
"(bool, default false), a flag to indicate whether to interpretate " "(bool, default false), a flag to indicate whether to interpretate "
"the given labels as soft labels.") "the given labels as soft labels.")
.SetDefault(false); .SetDefault(false);
...@@ -126,12 +129,12 @@ CrossEntropy Operator. ...@@ -126,12 +129,12 @@ CrossEntropy Operator.
It supports both standard cross-entropy and soft-label cross-entropy loss It supports both standard cross-entropy and soft-label cross-entropy loss
computation. computation.
1) One-hot cross-entropy: 1) One-hot cross-entropy:
soft_label = False, Label[i, 0] indicates the class index for sample i: softLabel = false, Label[i, 0] indicates the class index for sample i:
Y[i] = -log(X[i, Label[i]]) Y[i] = -log(X[i, Label[i]])
2) Soft-label cross-entropy: 2) Soft-label cross-entropy:
soft_label = True, Label[i, j] indicates the soft label of class j softLabel = true, Label[i, j] indicates the soft label of class j
for sample i: for sample i:
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])} Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
......
...@@ -70,7 +70,7 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, ...@@ -70,7 +70,7 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
// TODO(qingqing): make zero setting a common function. // TODO(qingqing): make zero setting a common function.
template <typename T> template <typename T>
__global__ void zero(T* X, const int N) { __global__ void Zero(T* X, const int N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) { i += blockDim.x * gridDim.x) {
X[i] = 0.0; X[i] = 0.0;
...@@ -108,18 +108,17 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { ...@@ -108,18 +108,17 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device."); "This kernel only runs on GPU device.");
auto x = ctx.Input<Tensor>("X"); const Tensor* x = ctx.Input<Tensor>("X");
auto y = ctx.Output<Tensor>("Y"); const Tensor* label = ctx.Input<Tensor>("Label");
auto label = ctx.Input<Tensor>("Label"); Tensor* y = ctx.Output<Tensor>("Y");
auto* x_data = x->data<T>(); const T* x_data = x->data<T>();
y->mutable_data<T>(ctx.GetPlace()); T* y_data = y->mutable_data<T>(ctx.GetPlace());
auto* y_data = y->data<T>();
int batch_size = x->dims()[0]; int batch_size = x->dims()[0];
int class_num = x->dims()[1]; int class_num = x->dims()[1];
if (ctx.Attr<bool>("soft_label")) { if (ctx.Attr<bool>("softLabel")) {
auto* label_data = ctx.Input<Tensor>("Label")->data<T>(); auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num))); int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));
...@@ -148,38 +147,41 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { ...@@ -148,38 +147,41 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device."); "This kernel only runs on GPU device.");
auto x = ctx.Input<Tensor>("X"); const Tensor* x = ctx.Input<Tensor>("X");
auto dx = ctx.Output<Tensor>(framework::GradVarName("X")); const Tensor* label = ctx.Input<Tensor>("Label");
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y")); Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto label = ctx.Input<Tensor>("Label");
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace()); const T* dy_data =
auto* dy_data = dy->data<T>(); ctx.Input<Tensor>(framework::GradVarName("Y"))->data<T>();
auto* x_data = x->data<T>(); T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
const T* x_data = x->data<T>();
int n = x->dims()[0]; int batch_size = x->dims()[0];
int d = x->dims()[1]; int class_num = x->dims()[1];
int block = 512; int block = 512;
int grid = (n * d + block - 1) / block; int grid = (batch_size * class_num + block - 1) / block;
zero<T><<<grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>( if (ctx.Attr<bool>("softLabel")) {
ctx.device_context())
.stream()>>>(dx_data, n * d);
if (ctx.Attr<bool>("soft_label")) {
auto* label_data = label->data<T>(); auto* label_data = label->data<T>();
SoftCrossEntropyGradientKernel<T><<< SoftCrossEntropyGradientKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>( grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context()) ctx.device_context())
.stream()>>>(dx_data, dy_data, x_data, label_data, .stream()>>>(dx_data, dy_data, x_data, label_data,
n, d); batch_size, class_num);
} else { } else {
Zero<T><<<grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(dx_data, batch_size * class_num);
auto* label_data = label->data<int>(); auto* label_data = label->data<int>();
grid = (batch_size + block - 1) / block;
CrossEntropyGradientKernel<T><<< CrossEntropyGradientKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>( grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context()) ctx.device_context())
.stream()>>>(dx_data, dy_data, x_data, label_data, .stream()>>>(dx_data, dy_data, x_data, label_data,
n, d); batch_size, class_num);
} }
} }
}; };
......
...@@ -42,14 +42,14 @@ class CrossEntropyOpKernel : public framework::OpKernel { ...@@ -42,14 +42,14 @@ class CrossEntropyOpKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace."); "This kernel only runs on CPU.");
const Tensor* x = ctx.Input<Tensor>("X"); const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* labels = ctx.Input<Tensor>("Label"); const Tensor* labels = ctx.Input<Tensor>("Label");
Tensor* y = ctx.Output<Tensor>("Y"); Tensor* y = ctx.Output<Tensor>("Y");
y->mutable_data<T>(ctx.GetPlace()); T* y_data = y->mutable_data<T>(ctx.GetPlace());
const int batch_size = x->dims()[0]; const int batch_size = x->dims()[0];
if (ctx.Attr<bool>("soft_label")) { if (ctx.Attr<bool>("softLabel")) {
auto prob = EigenMatrix<T>::From(*x); auto prob = EigenMatrix<T>::From(*x);
auto lbl_mat = EigenMatrix<T>::From(*labels); auto lbl_mat = EigenMatrix<T>::From(*labels);
auto loss = EigenMatrix<T>::From(*y); auto loss = EigenMatrix<T>::From(*y);
...@@ -60,9 +60,7 @@ class CrossEntropyOpKernel : public framework::OpKernel { ...@@ -60,9 +60,7 @@ class CrossEntropyOpKernel : public framework::OpKernel {
.reshape(Eigen::DSizes<int, 2>(batch_size, 1))); .reshape(Eigen::DSizes<int, 2>(batch_size, 1)));
} else { } else {
const int class_num = x->dims()[1]; const int class_num = x->dims()[1];
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
T* y_data = y->data<T>();
const int* label_data = labels->data<int>(); const int* label_data = labels->data<int>();
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
...@@ -78,33 +76,32 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { ...@@ -78,33 +76,32 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace."); "This kernel only runs on CPU.");
const Tensor* x = ctx.Input<Tensor>("X");
auto x = ctx.Input<Tensor>("X"); const Tensor* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto dx = ctx.Output<Tensor>(framework::GradVarName("X")); const Tensor* label = ctx.Input<Tensor>("Label");
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y")); Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto label = ctx.Input<Tensor>("Label"); T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto* dy_data = dy->data<T>();
auto* x_data = x->data<T>();
int batch_size = x->dims()[0];
int class_num = x->dims()[1]; int class_num = x->dims()[1];
if (ctx.Attr<bool>("softLabel")) {
// TODO(qingqing): make zero setting an common function. auto x_mat = EigenMatrix<T>::From(*x);
if (ctx.Attr<bool>("soft_label")) { auto dy_mat = EigenMatrix<T>::From(*dy);
auto* label_data = ctx.Input<Tensor>("Label")->data<T>(); auto lbl_mat = EigenMatrix<T>::From(*label);
int index = 0; auto dx_mat = EigenMatrix<T>::From(*dx);
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < class_num; ++j) { dx_mat.device(ctx.GetEigenDevice<platform::CPUPlace>()) =
dx_data[index] = -label_data[index] * dy_data[i] / x_data[index]; -(lbl_mat * dy_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num)) /
index++; x_mat);
}
}
} else { } else {
auto* label_data = label->data<int>(); int batch_size = x->dims()[0];
const T* dy_data = dy->data<T>();
const T* x_data = x->data<T>();
const int* label_data = label->data<int>();
// TODO(qingqing): make zero setting a common function.
memset(dx_data, 0, sizeof(T) * batch_size * class_num); memset(dx_data, 0, sizeof(T) * batch_size * class_num);
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num); PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num);
int index = i * class_num + label_data[i]; int index = i * class_num + label_data[i];
......
...@@ -21,7 +21,7 @@ class TestCrossEntropyOp1(OpTest): ...@@ -21,7 +21,7 @@ class TestCrossEntropyOp1(OpTest):
self.inputs = {"X": X, "Label": label} self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy} self.outputs = {"Y": cross_entropy}
self.attrs = {"soft_label": False} self.attrs = {"softLabel": False}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -49,7 +49,7 @@ class TestCrossEntropyOp2(OpTest): ...@@ -49,7 +49,7 @@ class TestCrossEntropyOp2(OpTest):
self.inputs = {"X": X, "Label": label} self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy} self.outputs = {"Y": cross_entropy}
self.attrs = {"soft_label": True} self.attrs = {"softLabel": True}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -73,6 +73,7 @@ class TestCrossEntropyOp3(OpTest): ...@@ -73,6 +73,7 @@ class TestCrossEntropyOp3(OpTest):
0, class_num, (batch_size), dtype="int32") 0, class_num, (batch_size), dtype="int32")
label = np.zeros(X.shape) label = np.zeros(X.shape)
label[np.arange(batch_size), label_index] = 1 label[np.arange(batch_size), label_index] = 1
cross_entropy = np.asmatrix( cross_entropy = np.asmatrix(
[[-np.log(X[i][label_index[i]])] for i in range(X.shape[0])], [[-np.log(X[i][label_index[i]])] for i in range(X.shape[0])],
dtype="float32") dtype="float32")
...@@ -81,7 +82,7 @@ class TestCrossEntropyOp3(OpTest): ...@@ -81,7 +82,7 @@ class TestCrossEntropyOp3(OpTest):
self.inputs = {"X": X, "Label": label} self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy} self.outputs = {"Y": cross_entropy}
self.attrs = {"soft_label": True} self.attrs = {"softLabel": True}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册